前言¶
Tensorflow2之後,訓練保存的模型也有所變化,基於Keras接口搭建的網絡模型默認保存的模型是h5格式的,而之前的模型格式是pb。Tensorflow2的h5格式的模型轉換成tflite格式模型非常方便。本教程就是介紹如何使用Tensorflow2的Keras接口訓練分類模型並使用Tensorflow Lite部署到Android設備上。
本教程源碼:https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification
訓練和轉換模型¶
以下是使用Tensorflow2的keras搭建的一個MobileNetV2模型並訓練自定義數據集,本教程主要是介紹如何在Android設備上使用Tensorflow Lite部署分類模型,所以關於訓練模型只是簡單介紹,代碼並不完整。通過下面的訓練模型,我們最終會得到一個mobilenet_v2.h5模型。
import os
import tensorflow as tf
import reader
import config as cfg
# 獲取模型
input_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, cfg.IMAGE_CHANNEL)
model = tf.keras.Sequential(
[tf.keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, pooling='max'),
tf.keras.layers.Dense(units=cfg.CLASS_DIM, activation='softmax')])
model.summary()
# 獲取訓練數據
train_data = reader.train_reader(data_list_path=cfg.TRAIN_LIST_PATH, batch_size=cfg.BATCH_SIZE)
# 定義訓練參數
model.compile(optimizer=tf.keras.optimizers.RMSprop(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# 開始訓練
model.fit(train_data, epochs=cfg.EPOCH_SUM, workers=4)
# 保存h5模型
if not os.path.exists(os.path.dirname(cfg.H5_MODEL_PATH)):
os.makedirs(os.path.dirname(cfg.H5_MODEL_PATH))
model.save(filepath=cfg.H5_MODEL_PATH)
print('saved h5 model!')
通過上面得到的mobilenet_v2.h5模型,我們需要轉換爲tflite格式的模型,在Tensorflow2之後,這個轉換就變動很簡單了,通過下面的幾行代碼即可完成轉換,最終我們會得到一個mobilenet_v2.tflite模型。
import tensorflow as tf
import config as cfg
# 加載模型
model = tf.keras.models.load_model(cfg.H5_MODEL_PATH)
# 生成非量化的tflite模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(cfg.TFLITE_MODEL_FILE, 'wb').write(tflite_model)
print('saved tflite model!')
如果保存的模型格式不是h5,而是tf格式的,如下代碼,保存的模型是tf格式的。
import tensorflow as tf
model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3))
model.save(filepath='mobilenet_v2', save_format='tf')
如果是tf格式的模型,那需要使用以下轉換模型的方式。
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model('mobilenet_v2')
tflite_model = converter.convert()
open("mobilenet_v2.tflite", "wb").write(tflite_model)
在部署到Android中可能需要到輸入輸出層的名稱,通過下面代碼可以獲取到輸入輸出層的名稱和shape。
import tensorflow as tf
model_path = 'models/mobilenet_v2.tflite'
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# 獲取輸入和輸出張量。
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
部署到Android設備¶
首先要在build.gradle導入這三個庫,如果不使用GPU可以只導入兩個庫。
implementation 'org.tensorflow:tensorflow-lite:2.3.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0-rc1'
在以前還需要在android下添加以下代碼,避免在打包apk的是對模型有壓縮操作,損壞模型。現在好像不加也沒有關係,但是爲了安全起見,還是添加上去。
aaptOptions {
noCompress "tflite"
}
複製轉換的預測模型到app/src/main/assets目錄下,還有類別的標籤,每一行對應一個標籤名稱。
Tensorflow Lite工具¶
編寫一個TFLiteClassificationUtil工具類,關於Tensorflow Lite的操作都在這裏完成,如加載模型、預測。在構造方法中,通過參數傳遞的模型路徑加載模型,在加載模型的時候配置預測信息,例如是否使用Android底層神經網絡APINnApiDelegate或者是否使用GPUGpuDelegate,同時獲取網絡的輸入輸出層。有了tensorflow-lite-support庫,數據預處理就變得非常簡單,通過ImageProcessor創建一個數據預處理的工具,之後在預測之前使用這個工具對圖像進行預處理,處理速度還是挺快的,要注意的是圖像的均值IMAGE_MEAN和標準差IMAGE_STD,因爲在訓練的時候圖像預處理可能不一樣的,有些讀者出現在電腦上準確率很高,但在手機上準確率很低,多數情況下就是這個圖像預處理做得不對。
private static final float[] IMAGE_MEAN = new float[]{128.0f, 128.0f, 128.0f};
private static final float[] IMAGE_STD = new float[]{128.0f, 128.0f, 128.0f};
public TFLiteClassificationUtil(String modelPath) throws Exception {
File file = new File(modelPath);
if (!file.exists()) {
throw new Exception("model file is not exists!");
}
try {
Interpreter.Options options = new Interpreter.Options();
// 使用多線程預測
options.setNumThreads(NUM_THREADS);
// 使用Android自帶的API或者GPU加速
NnApiDelegate delegate = new NnApiDelegate();
// GpuDelegate delegate = new GpuDelegate();
options.addDelegate(delegate);
tflite = new Interpreter(file, options);
// 獲取輸入,shape爲{1, height, width, 3}
int[] imageShape = tflite.getInputTensor(tflite.getInputIndex("input_1")).shape();
DataType imageDataType = tflite.getInputTensor(tflite.getInputIndex("input_1")).dataType();
inputImageBuffer = new TensorImage(imageDataType);
// 獲取輸入,shape爲{1, NUM_CLASSES}
int[] probabilityShape = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).shape();
DataType probabilityDataType = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).dataType();
outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
// 添加圖像預處理方式
imageProcessor = new ImageProcessor.Builder()
.add(new ResizeOp(imageShape[1], imageShape[2], ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
.build();
} catch (Exception e) {
e.printStackTrace();
throw new Exception("load model fail!");
}
}
爲了兼容圖片路徑和Bitmap格式的圖片預測,這裏創建了兩個重載方法,它們都是通過調用predict()
public int predictImage(String image_path) throws Exception {
if (!new File(image_path).exists()) {
throw new Exception("image file is not exists!");
}
FileInputStream fis = new FileInputStream(image_path);
Bitmap bitmap = BitmapFactory.decodeStream(fis);
int result = predictImage(bitmap);
if (bitmap.isRecycled()) {
bitmap.recycle();
}
return result;
}
public int predictImage(Bitmap bitmap) throws Exception {
return predict(bitmap);
}
這裏創建一個獲取最大概率值,並把下標返回的方法,其實就是獲取概率最大的預測標籤。
public static int getMaxResult(float[] result) {
float probability = 0;
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}
這個方法就是Tensorflow Lite執行預測的最後一步,通過執行tflite.run()對輸入的數據進行預測並得到預測結果,通過解析獲取到最大的概率的預測標籤,並返回。到這裏Tensorflow Lite的工具就完成了。
private int predict(Bitmap bmp) throws Exception {
inputImageBuffer = loadImage(bmp);
try {
tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
} catch (Exception e) {
throw new Exception("predict image fail! log:" + e);
}
float[] results = outputProbabilityBuffer.getFloatArray();
Log.d(TAG, Arrays.toString(results));
return getMaxResult(results);
}
選擇圖片預測¶
本教程會有兩個頁面,一個是選擇圖片進行預測的頁面,另一個是使用相機即時預測並顯示預測結果。以下爲activity_main.xml的代碼,通過按鈕選擇圖片,並在該頁面顯示圖片和預測結果。
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image_view"
android:layout_width="match_parent"
android:layout_height="400dp" />
<TextView
android:id="@+id/result_text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_below="@id/image_view"
android:text="識別結果"
android:textSize="16sp" />
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:orientation="horizontal">
<Button
android:id="@+id/select_img_btn"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="選擇照片" />
<Button
android:id="@+id/open_camera"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="即時預測" />
</LinearLayout>
</RelativeLayout>
在MainActivity.java中,進入到頁面我們就要先加載模型,我們是把模型放在Android項目的assets目錄的,但是Tensorflow Lite並不建議直接在assets讀取模型,所以我們需要把模型複製到一個緩存目錄,然後再從緩存目錄加載模型,同時還有讀取標籤名,標籤名稱按照訓練的label順序存放在assets的label_list.txt,以下爲實現代碼。
classNames = Utils.ReadListFromFile(getAssets(), "label_list.txt");
String classificationModelPath = getCacheDir().getAbsolutePath() + File.separator + "mobilenet_v2.tflite";
Utils.copyFileFromAsset(MainActivity.this, "mobilenet_v2.tflite", classificationModelPath);
try {
tfLiteClassificationUtil = new TFLiteClassificationUtil(classificationModelPath);
Toast.makeText(MainActivity.this, "模型加載成功!", Toast.LENGTH_SHORT).show();
} catch (Exception e) {
Toast.makeText(MainActivity.this, "模型加載失敗!", Toast.LENGTH_SHORT).show();
e.printStackTrace();
finish();
}
添加兩個按鈕點擊事件,可以選擇打開相冊讀取圖片進行預測,或者打開另一個Activity進行調用攝像頭即時識別。
Button selectImgBtn = findViewById(R.id.select_img_btn);
Button openCamera = findViewById(R.id.open_camera);
imageView = findViewById(R.id.image_view);
textView = findViewById(R.id.result_text);
selectImgBtn.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// 打開相冊
Intent intent = new Intent(Intent.ACTION_PICK);
intent.setType("image/*");
startActivityForResult(intent, 1);
}
});
openCamera.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
// 打開即時拍攝識別頁面
Intent intent = new Intent(MainActivity.this, CameraActivity.class);
startActivity(intent);
}
});
當打開相冊選擇照片之後,回到原來的頁面,在下面這個回調方法中獲取選擇圖片的Uri,通過Uri可以獲取到圖片的絕對路徑。如果Android8以上的設備獲取不到圖片,需要在AndroidManifest.xml配置文件中的application添加android:requestLegacyExternalStorage="true"。拿到圖片路徑之後,調用TFLiteClassificationUtil類中的predictImage()方法預測並獲取預測值,在頁面上顯示預測的標籤、對應標籤的名稱、概率值和預測時間。
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
super.onActivityResult(requestCode, resultCode, data);
String image_path;
if (resultCode == Activity.RESULT_OK) {
if (requestCode == 1) {
if (data == null) {
Log.w("onActivityResult", "user photo data is null");
return;
}
Uri image_uri = data.getData();
image_path = getPathFromURI(MainActivity.this, image_uri);
try {
// 預測圖像
FileInputStream fis = new FileInputStream(image_path);
imageView.setImageBitmap(BitmapFactory.decodeStream(fis));
long start = System.currentTimeMillis();
float[] result = tfLiteClassificationUtil.predictImage(image_path);
long end = System.currentTimeMillis();
String show_text = "預測結果標籤:" + (int) result[0] +
"\n名稱:" + classNames[(int) result[0]] +
"\n概率:" + result[1] +
"\n時間:" + (end - start) + "ms";
textView.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
上面獲取的Uri可以通過下面這個方法把Url轉換成絕對路徑。
// get photo from Uri
public static String getPathFromURI(Context context, Uri uri) {
String result;
Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
if (cursor == null) {
result = uri.getPath();
} else {
cursor.moveToFirst();
int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
result = cursor.getString(idx);
cursor.close();
}
return result;
}
攝像頭即時預測¶
在調用相機即時預測我就不再介紹了,原理都差不多,具體可以查看https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification中的源代碼。核心代碼如下,創建一個子線程,子線程中不斷從攝像頭預覽的AutoFitTextureView上獲取圖像,並執行預測,並在頁面上顯示預測的標籤、對應標籤的名稱、概率值和預測時間。每一次預測完成之後都立即獲取圖片繼續預測,只要預測速度夠快,就可以看成即時預測。
private Runnable periodicClassify =
new Runnable() {
@Override
public void run() {
synchronized (lock) {
if (runClassifier) {
// 開始預測前要判斷相機是否已經準備好
if (getApplicationContext() != null && mCameraDevice != null && tfLiteClassificationUtil != null) {
predict();
}
}
}
if (mInferThread != null && mInferHandler != null && mCaptureHandler != null && mCaptureThread != null) {
mInferHandler.post(periodicClassify);
}
}
};
// 預測相機捕獲的圖像
private void predict() {
// 獲取相機捕獲的圖像
Bitmap bitmap = mTextureView.getBitmap();
try {
// 預測圖像
long start = System.currentTimeMillis();
float[] result = tfLiteClassificationUtil.predictImage(bitmap);
long end = System.currentTimeMillis();
String show_text = "預測結果標籤:" + (int) result[0] +
"\n名稱:" + classNames[(int) result[0]] +
"\n概率:" + result[1] +
"\n時間:" + (end - start) + "ms";
textView.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}
本項目中使用的了讀取圖片的權限和打開相機的權限,所以不要忘記在AndroidManifest.xml添加以下權限申請。
<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
如果是Android 6 以上的設備還要動態申請權限。
// check had permission
private boolean hasPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return checkSelfPermission(Manifest.permission.CAMERA) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;
} else {
return true;
}
}
// request permission
private void requestPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(new String[]{Manifest.permission.CAMERA,
Manifest.permission.READ_EXTERNAL_STORAGE,
Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);
}
}
選擇圖片識別效果圖:

相機即時識別效果圖:
