使用TensorFlow Lite进行图像分类(Android端)¶
本文将详细介绍如何使用TensorFlow Lite在Android平台上实现图像分类功能,包括模型转换、项目搭建、代码实现及运行效果。
1. 准备工作¶
1.1 安装必要工具¶
- Android Studio:用于开发Android应用
- Bazel:用于构建TensorFlow Lite模型(可选,若直接使用预编译模型可跳过)
- TensorFlow模型:推荐使用预训练的MobileNet模型
1.2 模型准备¶
- 下载预训练的TensorFlow模型(如MobileNet)
- 使用TensorFlow官方工具转换为TFLite格式:
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
2. 创建Android项目¶
2.1 配置build.gradle¶
在app/build.gradle中添加依赖:
dependencies {
implementation 'org.tensorflow:tensorflow-lite:2.0.0'
implementation 'com.bumptech.glide:glide:4.9.0'
}
2.2 配置资源文件¶
- 在
res目录下创建assets文件夹,存放TFLite模型文件 - 添加权限:
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
2.3 配置FileProvider(适配Android 7.0+)¶
在AndroidManifest.xml中添加:
<provider
android:name="android.support.v4.content.FileProvider"
android:authorities="com.yourpackage.name.fileprovider"
android:exported="false"
android:grantUriPermissions="true">
<meta-data
android:name="android.support.FILE_PROVIDER_PATHS"
android:resource="@xml/file_paths" />
</provider>
3. 核心代码实现¶
3.1 模型加载与初始化¶
private Interpreter tflite;
private MappedByteBuffer loadModelFile(String modelName) throws IOException {
AssetFileDescriptor fileDescriptor = getAssets().openFd(modelName + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
return fileChannel.map(FileChannel.MapMode.READ_ONLY,
fileDescriptor.getStartOffset(),
fileDescriptor.getDeclaredLength());
}
private void loadModel(String modelName) {
try {
tflite = new Interpreter(loadModelFile(modelName));
Toast.makeText(this, "模型加载成功", Toast.LENGTH_SHORT).show();
tflite.setNumThreads(4); // 设置线程数
} catch (IOException e) {
Toast.makeText(this, "模型加载失败", Toast.LENGTH_SHORT).show();
e.printStackTrace();
}
}
3.2 图像处理与预测¶
private Bitmap getScaledBitmap(String imagePath) {
BitmapFactory.Options options = new BitmapFactory.Options();
options.inSampleSize = 4; // 压缩图片
return BitmapFactory.decodeFile(imagePath, options);
}
private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
ByteBuffer buffer = ByteBuffer.allocateDirect(1 * 224 * 224 * 3 * 4);
buffer.order(ByteOrder.nativeOrder());
int[] pixels = new int[224 * 224];
bitmap.getPixels(pixels, 0, 224, 0, 0, 224, 224);
for (int pixel : pixels) {
// 像素归一化处理
float r = ((pixel >> 16) & 0xFF) / 255.0f;
float g = ((pixel >> 8) & 0xFF) / 255.0f;
float b = (pixel & 0xFF) / 255.0f;
// 转换为[-1,1]范围
buffer.putFloat(r - 0.5f);
buffer.putFloat(g - 0.5f);
buffer.putFloat(b - 0.5f);
}
return buffer;
}
private void classifyImage(String imagePath) {
Bitmap bitmap = getScaledBitmap(imagePath);
ByteBuffer input = convertBitmapToByteBuffer(bitmap);
float[][] output = new float[1][1001]; // 1001为ImageNet类别数
long startTime = System.currentTimeMillis();
tflite.run(input, output);
long endTime = System.currentTimeMillis();
// 找到置信度最高的类别
int maxIndex = 0;
float maxConfidence = 0;
for (int i = 0; i < output[0].length; i++) {
if (output[0][i] > maxConfidence) {
maxConfidence = output[0][i];
maxIndex = i;
}
}
// 显示结果
Log.d("Result", "类别: " + maxIndex + ", 置信度: " + maxConfidence);
}
3.3 界面实现¶
public class MainActivity extends AppCompatActivity {
private Button loadModelBtn, usePhotoBtn, takePhotoBtn;
private ImageView imageView;
private TextView resultText;
private String modelName = "mobilenet_v1_1.0_224"; // 模型文件名
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// 初始化视图
loadModelBtn = findViewById(R.id.load_model);
usePhotoBtn = findViewById(R.id.use_photo);
takePhotoBtn = findViewById(R.id.take_photo);
imageView = findViewById(R.id.image_view);
resultText = findViewById(R.id.result_text);
// 加载模型按钮点击事件
loadModelBtn.setOnClickListener(v -> loadModel(modelName));
// 相册选择按钮点击事件
usePhotoBtn.setOnClickListener(v -> {
if (tflite != null) {
Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
startActivityForResult(intent, 1);
} else {
Toast.makeText(this, "请先加载模型", Toast.LENGTH_SHORT).show();
}
});
// 拍照按钮点击事件
takePhotoBtn.setOnClickListener(v -> {
if (tflite != null) {
takePhoto();
} else {
Toast.makeText(this, "请先加载模型", Toast.LENGTH_SHORT).show();
}
});
}
private void takePhoto() {
// 实现拍照逻辑,保存图片路径
// 此处省略具体实现,参考系统相机API
}
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode == RESULT_OK && requestCode == 1) {
Uri imageUri = data.getData();
String imagePath = getRealPathFromURI(imageUri);
classifyImage(imagePath);
}
}
private String getRealPathFromURI(Uri contentUri) {
// 实现URI转路径逻辑
// 此处省略具体实现
return "";
}
}
3.4 布局文件¶
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:padding="16dp">
<Button
android:id="@+id/load_model"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="加载模型" />
<Button
android:id="@+id/use_photo"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="选择相册图片"
android:layout_marginTop="8dp"/>
<Button
android:id="@+id/take_photo"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="拍摄照片"
android:layout_marginTop="8dp"/>
<ImageView
android:id="@+id/image_view"
android:layout_width="match_parent"
android:layout_height="250dp"
android:layout_marginTop="16dp"
android:scaleType="centerCrop"/>
<TextView
android:id="@+id/result_text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="16dp"
android:text="预测结果将显示在这里"
android:scrollbars="vertical"
android:maxLines="5"/>
</LinearLayout>
4. 关键注意事项¶
4.1 模型优化¶
- 使用量化模型(Quantized Models)减少内存占用和推理时间
- 模型输入输出格式需与TensorFlow Lite兼容
4.2 性能优化¶
- 设置合适的线程数:
tflite.setNumThreads(4) - 图片压缩处理:避免OOM错误
- 使用直接字节缓冲区(DirectByteBuffer)提高效率
4.3 权限处理¶
- Android 6.0+需要动态申请相机和存储权限
- 使用
FileProvider处理文件权限
5. 效果展示¶
运行应用后,流程如下:
1. 点击”加载模型”按钮加载TFLite模型
2. 选择”选择相册图片”或”拍摄照片”
3. 系统自动处理图像并显示分类结果(类别名称和置信度)
6. 参考资料¶
通过以上步骤,你可以在Android设备上实现高效的图像分类功能,适用于移动端应用开发、计算机视觉等场景。