使用TensorFlow Lite进行图像分类(Android端)

本文将详细介绍如何使用TensorFlow Lite在Android平台上实现图像分类功能,包括模型转换、项目搭建、代码实现及运行效果。

1. 准备工作

1.1 安装必要工具

  • Android Studio:用于开发Android应用
  • Bazel:用于构建TensorFlow Lite模型(可选,若直接使用预编译模型可跳过)
  • TensorFlow模型:推荐使用预训练的MobileNet模型

1.2 模型准备

  1. 下载预训练的TensorFlow模型(如MobileNet)
  2. 使用TensorFlow官方工具转换为TFLite格式:
   bazel build tensorflow/python/tools:freeze_graph
   bazel build tensorflow/lite/toco:toco

2. 创建Android项目

2.1 配置build.gradle

app/build.gradle中添加依赖:

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:2.0.0'
    implementation 'com.bumptech.glide:glide:4.9.0'
}

2.2 配置资源文件

  1. res目录下创建assets文件夹,存放TFLite模型文件
  2. 添加权限:
   <uses-permission android:name="android.permission.CAMERA" />
   <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
   <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />

2.3 配置FileProvider(适配Android 7.0+)

AndroidManifest.xml中添加:

<provider
    android:name="android.support.v4.content.FileProvider"
    android:authorities="com.yourpackage.name.fileprovider"
    android:exported="false"
    android:grantUriPermissions="true">
    <meta-data
        android:name="android.support.FILE_PROVIDER_PATHS"
        android:resource="@xml/file_paths" />
</provider>

3. 核心代码实现

3.1 模型加载与初始化

private Interpreter tflite;
private MappedByteBuffer loadModelFile(String modelName) throws IOException {
    AssetFileDescriptor fileDescriptor = getAssets().openFd(modelName + ".tflite");
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, 
                          fileDescriptor.getStartOffset(), 
                          fileDescriptor.getDeclaredLength());
}

private void loadModel(String modelName) {
    try {
        tflite = new Interpreter(loadModelFile(modelName));
        Toast.makeText(this, "模型加载成功", Toast.LENGTH_SHORT).show();
        tflite.setNumThreads(4); // 设置线程数
    } catch (IOException e) {
        Toast.makeText(this, "模型加载失败", Toast.LENGTH_SHORT).show();
        e.printStackTrace();
    }
}

3.2 图像处理与预测

private Bitmap getScaledBitmap(String imagePath) {
    BitmapFactory.Options options = new BitmapFactory.Options();
    options.inSampleSize = 4; // 压缩图片
    return BitmapFactory.decodeFile(imagePath, options);
}

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
    ByteBuffer buffer = ByteBuffer.allocateDirect(1 * 224 * 224 * 3 * 4);
    buffer.order(ByteOrder.nativeOrder());

    int[] pixels = new int[224 * 224];
    bitmap.getPixels(pixels, 0, 224, 0, 0, 224, 224);

    for (int pixel : pixels) {
        // 像素归一化处理
        float r = ((pixel >> 16) & 0xFF) / 255.0f;
        float g = ((pixel >> 8) & 0xFF) / 255.0f;
        float b = (pixel & 0xFF) / 255.0f;

        // 转换为[-1,1]范围
        buffer.putFloat(r - 0.5f);
        buffer.putFloat(g - 0.5f);
        buffer.putFloat(b - 0.5f);
    }
    return buffer;
}

private void classifyImage(String imagePath) {
    Bitmap bitmap = getScaledBitmap(imagePath);
    ByteBuffer input = convertBitmapToByteBuffer(bitmap);

    float[][] output = new float[1][1001]; // 1001为ImageNet类别数
    long startTime = System.currentTimeMillis();
    tflite.run(input, output);
    long endTime = System.currentTimeMillis();

    // 找到置信度最高的类别
    int maxIndex = 0;
    float maxConfidence = 0;
    for (int i = 0; i < output[0].length; i++) {
        if (output[0][i] > maxConfidence) {
            maxConfidence = output[0][i];
            maxIndex = i;
        }
    }

    // 显示结果
    Log.d("Result", "类别: " + maxIndex + ", 置信度: " + maxConfidence);
}

3.3 界面实现

public class MainActivity extends AppCompatActivity {
    private Button loadModelBtn, usePhotoBtn, takePhotoBtn;
    private ImageView imageView;
    private TextView resultText;
    private String modelName = "mobilenet_v1_1.0_224"; // 模型文件名

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 初始化视图
        loadModelBtn = findViewById(R.id.load_model);
        usePhotoBtn = findViewById(R.id.use_photo);
        takePhotoBtn = findViewById(R.id.take_photo);
        imageView = findViewById(R.id.image_view);
        resultText = findViewById(R.id.result_text);

        // 加载模型按钮点击事件
        loadModelBtn.setOnClickListener(v -> loadModel(modelName));

        // 相册选择按钮点击事件
        usePhotoBtn.setOnClickListener(v -> {
            if (tflite != null) {
                Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
                startActivityForResult(intent, 1);
            } else {
                Toast.makeText(this, "请先加载模型", Toast.LENGTH_SHORT).show();
            }
        });

        // 拍照按钮点击事件
        takePhotoBtn.setOnClickListener(v -> {
            if (tflite != null) {
                takePhoto();
            } else {
                Toast.makeText(this, "请先加载模型", Toast.LENGTH_SHORT).show();
            }
        });
    }

    private void takePhoto() {
        // 实现拍照逻辑,保存图片路径
        // 此处省略具体实现,参考系统相机API
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && requestCode == 1) {
            Uri imageUri = data.getData();
            String imagePath = getRealPathFromURI(imageUri);
            classifyImage(imagePath);
        }
    }

    private String getRealPathFromURI(Uri contentUri) {
        // 实现URI转路径逻辑
        // 此处省略具体实现
        return "";
    }
}

3.4 布局文件

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    android:padding="16dp">

    <Button
        android:id="@+id/load_model"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="加载模型" />

    <Button
        android:id="@+id/use_photo"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="选择相册图片"
        android:layout_marginTop="8dp"/>

    <Button
        android:id="@+id/take_photo"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="拍摄照片"
        android:layout_marginTop="8dp"/>

    <ImageView
        android:id="@+id/image_view"
        android:layout_width="match_parent"
        android:layout_height="250dp"
        android:layout_marginTop="16dp"
        android:scaleType="centerCrop"/>

    <TextView
        android:id="@+id/result_text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_marginTop="16dp"
        android:text="预测结果将显示在这里"
        android:scrollbars="vertical"
        android:maxLines="5"/>
</LinearLayout>

4. 关键注意事项

4.1 模型优化

  • 使用量化模型(Quantized Models)减少内存占用和推理时间
  • 模型输入输出格式需与TensorFlow Lite兼容

4.2 性能优化

  • 设置合适的线程数:tflite.setNumThreads(4)
  • 图片压缩处理:避免OOM错误
  • 使用直接字节缓冲区(DirectByteBuffer)提高效率

4.3 权限处理

  • Android 6.0+需要动态申请相机和存储权限
  • 使用FileProvider处理文件权限

5. 效果展示

运行应用后,流程如下:
1. 点击”加载模型”按钮加载TFLite模型
2. 选择”选择相册图片”或”拍摄照片”
3. 系统自动处理图像并显示分类结果(类别名称和置信度)

6. 参考资料

  1. TensorFlow Lite官方文档
  2. Android Camera API文档
  3. ImageNet分类标签

通过以上步骤,你可以在Android设备上实现高效的图像分类功能,适用于移动端应用开发、计算机视觉等场景。

Xiaoye