基于TNN在Android手机上实现图像分类
前言
TNN:由腾讯优图实验室打造,移动端高性能、轻量级推理框架,同时拥有跨平台、高性能、模型压缩、代码裁剪等众多突出优势。TNN框架在原有Rapidnet、ncnn框架的基础上进一步加强了移动端设备的支持以及性能优化,同时也借鉴了业界主流开源框架高性能和良好拓展性的优点。
教程源码地址:https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TNNClassification
编译Android库
- 安装cmake 3.12
# 卸载旧的cmake
sudo apt-get autoremove cmake
# 下载cmake3.12
wget https://cmake.org/files/v3.12/cmake-3.12.2-Linux-x86_64.tar.gz
tar zxvf cmake-3.12.2-Linux-x86_64.tar.gz
# 移动目录并添加软连接
sudo mv cmake-3.12.2-Linux-x86_64 /opt/cmake-3.12.2
sudo ln -sf /opt/cmake-3.12.2/bin/* /usr/bin/
- 添加Android NDK
wget https://dl.google.com/android/repository/android-ndk-r21b-linux-x86_64.zip
unzip android-ndk-r21b-linux-x86_64.zip
# 添加环境变量,留意你实际下载地址
export ANDROID_NDK=/mnt/d/android-ndk-r21b
- 安装编译环境
sudo apt-get install attr
- 开始编译
git clone https://github.com/Tencent/TNN.git
cd TNN/scripts
vim build_android.sh
ABIA32="armeabi-v7a"
ABIA64="arm64-v8a"
STL="c++_static"
SHARED_LIB="ON" # ON表示编译动态库,OFF表示编译静态库
ARM="ON" # ON表示编译带有Arm CPU版本的库
OPENMP="ON" # ON表示打开OpenMP
OPENCL="ON" # ON表示编译带有Arm GPU版本的库
SHARING_MEM_WITH_OPENGL=0 # 1表示OpenGL的Texture可以与OpenCL共享
执行编译
./build_android.sh
编译完成后,会在当前目录的release
目录下生成对应的armeabi-v7a
库,arm64-v8a
库和include
头文件,这些文件在下一步的Android开发都需要使用到。
模型转换
接下来我们需要把Tensorflow,onnx等其他的模型转换为TNN的模型。目前 TNN 支持业界主流的模型文件格式,包括ONNX、PyTorch、TensorFlow 以及 Caffe 等。TNN 将 ONNX 作为中间层,借助于ONNX 开源社区的力量,来支持多种模型文件格式。如果要将PyTorch、TensorFlow 以及 Caffe 等模型文件格式转换为 TNN,首先需要使用对应的模型转换工具,统一将各种模型格式转换成为 ONNX 模型格式,然后将 ONNX 模型转换成 TNN 模型。
sudo docker pull turandotkay/tnn-convert
sudo docker tag turandotkay/tnn-convert:latest tnn-convert:latest
sudo docker rmi turandotkay/tnn-convert:latest
针对不同的模型转换,有不同的命令,如onnx2tnn,caffe2tnn,tf2tnn。
docker run --volume=$(pwd):/workspace -it tnn-convert:latest python3 ./converter.py tf2tnn \
-tp /workspace/mobilenet_v1.pb \
-in "input[1,224,224,3]" \
-on MobilenetV1/Predictions/Reshape_1 \
-v v1.0 \
-optimize
通过上面的输出,可以发现针对 TF 模型的转换,convert2tnn 工具提供了很多参数,我们一次对下面的参数进行解释:
- tp 参数(必须)
通过 “-tp” 参数指定需要转换的模型的路径。目前只支持单个 TF模型的转换,不支持多个 TF 模型的一起转换。 - in 参数(必须)
通过 “-in” 参数指定模型输入的名称,输入的名称需要放到“”中,例如,-in "name"。如果模型有多个输入,请使用 “;”进行分割。有的 TensorFlow 模型没有指定 batch 导致无法成功转换为 ONNX 模型,进而无法成功转换为 TNN 模型。你可以通过在名称后添加输入 shape 进行指定。shape 信息需要放在 [] 中。例如:-in "name[1,28,28,3]"。 - on 参数(必须)
通过 “-on” 参数指定模型输入的名称,如果模型有多个输出,请使用 “;”进行分割 - output_dir 参数:
可以通过 “-o” 参数指定输出路径,但是在 docker 中我们一般不使用这个参数,默认会将生成的 TNN 模型放在当前和 TF 模型相同的路径下。 - optimize 参数(可选)
可以通过 “-optimize” 参数来对模型进行优化,我们强烈建议你开启这个选项,只有在开启这个选项模型转换失败时,我们才建议你去掉 “-optimize” 参数进行重新尝试。 - v 参数(可选)
可以通过 -v 来指定模型的版本号,以便于后期对模型进行追踪和区分。 - half 参数(可选)
可以通过 -half 参数指定,模型数据通过 FP16 进行存储,减少模型的大小,默认是通过 FP32 的方式进行存储模型数据的。 - align 参数(可选)
可以通过 -align 参数指定,将 转换得到的 TNN 模型和原模型进行对齐,确定 TNN 模型是否转换成功。当前仅支持单输入单输出模型和单输入多输出模型。 align 只支持 FP32 模型的校验,所以使用 align 的时候不能使用 half - input_file 参数(可选)
可以通过 -input_file 参数指定模型对齐所需要的输入文件的名称,输入需要遵循如下格式。 - ref_file 参数(可选)
可以通过 -ref_file 参数指定待对齐的输出文件的名称,输出需遵循如下格式。生成输出的代码可以参考。
成功转换会输出以下的日志。
---------- convert model, please wait a moment ----------
Converter Tensorflow to TNN model
Convert TensorFlow to ONNX model succeed!
Converter ONNX to TNN Model
Converter ONNX to TNN model succeed!
最终会得到这两个模型文件,mobilenet_v1.opt.tnnmodel
mobilenet_v1.opt.tnnproto
。
开发Android项目
- 将转换的模型放在
assets
目录下。 - 把上一步编译得到的
include
目录复制到Android项目的app
目录下。 - 把上一步编译得到的
armeabi-v7a
和arm64-v8a
目录复制到main/jniLibs
下。 - 在
app/src/main/cpp/
目录下编写JNI的C++代码。
TNN工具
编写一个ImageClassifyUtil.java工具类,关于TNN的操作都在这里完成,如加载模型、预测。
下面三个就是TNN的JNI接口,通过这个接口完成模型加载,预测,当不使用的时候和可以调用deinit()
清空对象。
public native int init(String modelPath, String protoPath, int computeUnitType);
public native float[] predict(Bitmap image, int width, int height);
public native int deinit();
通过上面的JNI接口,下面就可以实现图像识别了,WIDTH
和HEIGHT
是模型输入图片的大小。为了兼容图片路径和Bitmap格式的图片预测,这里创建了两个重载方法。
private static final int WIDTH = 224;
private static final int HEIGHT = 224;
public ImageClassifyUtil() {
System.loadLibrary("TNN");
System.loadLibrary("tnn_wrapper");
}
// 重载方法,根据图片路径转Bitmap预测
public float[] predictImage(String image_path) throws Exception {
if (!new File(image_path).exists()) {
throw new Exception("image file is not exists!");
}
FileInputStream fis = new FileInputStream(image_path);
Bitmap bitmap = BitmapFactory.decodeStream(fis);
Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
float[] result = predictImage(scaleBitmap);
if (bitmap.isRecycled()) {
bitmap.recycle();
}
return result;
}
// 重载方法,直接使用Bitmap预测
public float[] predictImage(Bitmap bitmap) {
Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
float[] results = predict(scaleBitmap, WIDTH, HEIGHT);
int l = getMaxResult(results);
return new float[]{l, results[l] * 0.01f};
}
这里创建一个获取最大概率值,并把下标返回的方法,其实就是获取概率最大的预测标签。
public static int getMaxResult(float[] result) {
float probability = 0;
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}
不同的模型,训练的预处理方式可能不一样,TNN 的图像预处理在C++中完成,代码片段。
TNN_NS::MatConvertParam input_cvt_param;
input_cvt_param.scale = {1.0 / (255 * 0.229), 1.0 / (255 * 0.224), 1.0 / (255 * 0.225), 0.0};
input_cvt_param.bias = {-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225, 0.0};
auto status = instance_->SetInputMat(input_mat, input_cvt_param);
选择图片预测
本教程会有两个页面,一个是选择图片进行预测的页面,另一个是使用相机实时预测并显示预测结果。以下为activity_main.xml
的代码,通过按钮选择图片,并在该页面显示图片和预测结果。
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image_view"
android:layout_width="match_parent"
android:layout_height="400dp" />
<TextView
android:id="@+id/result_text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_below="@id/image_view"
android:text="识别结果"
android:textSize="16sp" />
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:orientation="horizontal">
<Button
android:id="@+id/select_img_btn"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="选择照片" />
<Button
android:id="@+id/open_camera"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="实时预测" />
</LinearLayout>
</RelativeLayout>
在MainActivity.java
中,进入到页面我们就要先加载模型,我们是把模型放在Android项目的assets目录的,我们需要把模型复制到一个缓存目录,然后再从缓存目录加载模型,同时还有读取标签名,标签名称按照训练的label顺序存放在assets的label_list.txt
,以下为实现代码。
classNames = Utils.ReadListFromFile(getAssets(), "label_list.txt");
String protoContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnproto";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnproto", protoContent);
String modelContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnmodel";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnmodel", modelContent);
imageClassifyUtil = new ImageClassifyUtil();
int status = imageClassifyUtil.init(modelContent, protoContent, USE_GPU ? 1 : 0);
if (status == 0){
Toast.makeText(MainActivity.this, "模型加载成功!", Toast.LENGTH_SHORT).show();
}else {
Toast.makeText(MainActivity.this, "模型加载失败!", Toast.LENGTH_SHORT).show();
finish();
}
添加两个按钮点击事件,可以选择打开相册读取图片进行预测,或者打开另一个Activity进行调用摄像头实时识别。
Button selectImgBtn = findViewById(R.id.select_img_btn);
Button openCamera = findViewById(R.id.open_camera);
imageView = findViewById(R.id.image_view);
textView = findViewById(R.id.result_text);
selectImgBtn.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// 打开相册
Intent intent = new Intent(Intent.ACTION_PICK);
intent.setType("image/*");
startActivityForResult(intent, 1);
}
});
openCamera.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// 打开实时拍摄识别页面
Intent intent = new Intent(MainActivity.this, CameraActivity.class);
startActivity(intent);
}
});
当打开相册选择照片之后,回到原来的页面,在下面这个回调方法中获取选择图片的Uri,通过Uri可以获取到图片的绝对路径。如果Android8以上的设备获取不到图片,需要在AndroidManifest.xml
配置文件中的application
添加android:requestLegacyExternalStorage="true"
。拿到图片路径之后,调用TFLiteClassificationUtil
类中的predictImage()
方法预测并获取预测值,在页面上显示预测的标签、对应标签的名称、概率值和预测时间。
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
super.onActivityResult(requestCode, resultCode, data);
String image_path;
if (resultCode == Activity.RESULT_OK) {
if (requestCode == 1) {
if (data == null) {
Log.w("onActivityResult", "user photo data is null");
return;
}
Uri image_uri = data.getData();
image_path = getPathFromURI(MainActivity.this, image_uri);
try {
// 预测图像
FileInputStream fis = new FileInputStream(image_path);
imageView.setImageBitmap(BitmapFactory.decodeStream(fis));
long start = System.currentTimeMillis();
float[] result = imageClassifyUtil.predictImage(image_path);
long end = System.currentTimeMillis();
String show_text = "预测结果标签:" + (int) result[0] +
"\n名称:" + classNames[(int) result[0]] +
"\n概率:" + result[1] +
"\n时间:" + (end - start) + "ms";
textView.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
上面获取的Uri可以通过下面这个方法把Url转换成绝对路径。
// get photo from Uri
public static String getPathFromURI(Context context, Uri uri) {
String result;
Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
if (cursor == null) {
result = uri.getPath();
} else {
cursor.moveToFirst();
int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
result = cursor.getString(idx);
cursor.close();
}
return result;
}
摄像头实时预测
在调用相机实时预测我就不再介绍了,原理都差不多,具体可以查看https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification中的源代码。核心代码如下,创建一个子线程,子线程中不断从摄像头预览的AutoFitTextureView
上获取图像,并执行预测,并在页面上显示预测的标签、对应标签的名称、概率值和预测时间。每一次预测完成之后都立即获取图片继续预测,只要预测速度够快,就可以看成实时预测。
private Runnable periodicClassify =
new Runnable() {
@Override
public void run() {
synchronized (lock) {
if (runClassifier) {
// 开始预测前要判断相机是否已经准备好
if (getApplicationContext() != null && mCameraDevice != null && mnnClassification != null) {
predict();
}
}
}
if (mInferThread != null && mInferHandler != null && mCaptureHandler != null && mCaptureThread != null) {
mInferHandler.post(periodicClassify);
}
}
};
// 预测相机捕获的图像
private void predict() {
// 获取相机捕获的图像
Bitmap bitmap = mTextureView.getBitmap();
try {
// 预测图像
long start = System.currentTimeMillis();
float[] result = imageClassifyUtil.predictImage(bitmap);
long end = System.currentTimeMillis();
String show_text = "预测结果标签:" + (int) result[0] +
"\n名称:" + classNames[(int) result[0]] +
"\n概率:" + result[1] +
"\n时间:" + (end - start) + "ms";
textView.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}
本项目中使用的了读取图片的权限和打开相机的权限,所以不要忘记在AndroidManifest.xml
添加以下权限申请。
<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
如果是Android 6 以上的设备还要动态申请权限。
// check had permission
private boolean hasPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return checkSelfPermission(Manifest.permission.CAMERA) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;
} else {
return true;
}
}
// request permission
private void requestPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(new String[]{Manifest.permission.CAMERA,
Manifest.permission.READ_EXTERNAL_STORAGE,
Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);
}
}
效果图: