如何在Android上使用Paddle Lite部署图像分类模型

Paddle Lite是百度飞桨推出的轻量级端侧推理引擎,专为移动设备和嵌入式系统优化。本教程将指导你如何使用Paddle Lite在Android上部署图像分类模型,包括模型转换、Android预测库准备、项目开发等完整流程。

1. 环境准备

1.1 安装依赖工具

  • Python 3.7+:用于模型转换和训练
  • Android Studio:用于开发Android应用
  • PaddlePaddle 2.0+:用于模型训练和导出

1.2 安装Paddle Lite

pip install paddlelite

2. 模型准备

2.1 导出PaddlePaddle模型

使用PaddlePaddle训练图像分类模型后,通过以下代码导出预测模型:

import paddle.fluid as fluid

# 定义网络结构(以MobileNetV2为例)
image = fluid.layers.data(name='img', shape=[3, 224, 224], dtype='float32')
predict = fluid.layers.fc(input=image, size=1000, act='softmax')  # 1000类分类

# 加载训练好的模型参数
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, "./models/mobilenet_v2", main_program=fluid.default_main_program())

# 保存预测模型
fluid.io.save_inference_model(
    dirname="./models/mobilenet_v2",
    feeded_var_names=["img"],
    target_vars=[predict],
    executor=exe
)

2.2 使用Opt工具转换模型

Paddle Lite模型需要通过opt工具转换为移动端可执行格式:

# 下载opt工具(或源码编译)
wget https://paddle-lite-bin.bj.bcebos.com/2.0.0-beta/opt.linux
chmod +x opt.linux

# 转换模型
./opt.linux \
    --model_file=./models/mobilenet_v2/model \
    --param_file=./models/mobilenet_v2/params \
    --optimize_out_type=naive_buffer \
    --optimize_out=mobilenet_v2.nb \
    --valid_targets=arm  # 仅针对Android ARM架构

转换后得到mobilenet_v2.nb模型文件。

3. Android预测库配置

3.1 下载预编译库

  • Paddle Lite官网下载Android预测库
  • 选择with_extra=ONarm_stl=c++_staticwith_cv=ON的ARMv7/ARMv8版本

3.2 导入库文件

  • PaddlePredictor.jar放入app/libs目录
  • libpaddle_lite_jni.so分别放入app/src/main/jniLibs/armeabi-v7aarm64-v8a

4. Android项目开发

4.1 配置build.gradle

android {
    defaultConfig {
        externalNativeBuild {
            cmake {
                abiFilters 'armeabi-v7a', 'arm64-v8a'
            }
        }
    }
}

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
}

4.2 模型加载与预处理

public class PaddleClassifier {
    private PaddlePredictor predictor;
    private Tensor inputTensor;
    private Tensor outputTensor;

    public PaddleClassifier(AssetManager assetManager, String modelPath) throws Exception {
        // 1. 配置模型路径
        MobileConfig config = new MobileConfig();
        config.setModelFromFile(modelPath);
        config.setThreads(4);  // 4线程加速

        // 2. 初始化预测器
        predictor = PaddlePredictor.createPaddlePredictor(config);
        inputTensor = predictor.getInput(0);
        outputTensor = predictor.getOutput(0);

        // 3. 配置输入参数
        inputTensor.resize(new long[]{1, 3, 224, 224});  // [batch, channel, height, width]
    }

    // 图像预处理
    private float[] preprocess(Bitmap bitmap) {
        // 1. 缩放图像至224x224
        Bitmap scaled = Bitmap.createScaledBitmap(bitmap, 224, 224, false);

        // 2. 转换为RGB浮点数组
        int[] pixels = new int[224 * 224];
        scaled.getPixels(pixels, 0, 224, 0, 0, 224, 224);

        // 3. 归一化处理 (减均值/除以标准差)
        float[] inputData = new float[3 * 224 * 224];
        for (int i = 0; i < 224 * 224; i++) {
            int pixel = pixels[i];
            inputData[i * 3 + 0] = ( (pixel >> 16) & 0xFF ) / 255.0f - 0.485f / 0.229f;
            inputData[i * 3 + 1] = ( (pixel >> 8) & 0xFF ) / 255.0f - 0.456f / 0.224f;
            inputData[i * 3 + 2] = ( pixel & 0xFF ) / 255.0f - 0.406f / 0.225f;
        }
        return inputData;
    }

    // 执行预测
    public float[] predict(Bitmap bitmap) throws Exception {
        float[] inputData = preprocess(bitmap);
        inputTensor.setData(inputData);
        predictor.run();
        return outputTensor.getFloatData();
    }
}

4.3 图像分类功能实现

public class MainActivity extends AppCompatActivity {
    private PaddleClassifier classifier;
    private TextView resultView;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 1. 复制模型到应用目录
        copyModelFromAssets("mobilenet_v2.nb");

        try {
            // 2. 初始化分类器
            classifier = new PaddleClassifier(getAssets(), getFilesDir().getAbsolutePath() + "/mobilenet_v2.nb");
            resultView = findViewById(R.id.result);
        } catch (Exception e) {
            e.printStackTrace();
        }

        // 3. 绑定按钮事件(选择图片)
        Button selectBtn = findViewById(R.id.select_image);
        selectBtn.setOnClickListener(v -> pickImage());
    }

    // 选择图片并预测
    private void pickImage() {
        Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
        startActivityForResult(intent, 1);
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && requestCode == 1) {
            try {
                Uri uri = data.getData();
                Bitmap bitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), uri);

                // 执行预测
                long start = System.currentTimeMillis();
                float[] result = classifier.predict(bitmap);
                long end = System.currentTimeMillis();

                // 解析结果
                int label = getMaxIndex(result);
                String className = getResources().getStringArray(R.array.class_names)[label];
                resultView.setText("预测: " + className + "\n概率: " + result[label] + "\n耗时: " + (end - start) + "ms");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    // 获取最大概率索引
    private int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxValue = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxValue) {
                maxValue = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }
}

4.4 权限配置

AndroidManifest.xml中添加权限:

<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />

注意:Android 6.0+需动态申请权限。

5. 优化与扩展

5.1 性能优化

  • 量化模型:使用--quantize参数开启模型量化,减小模型体积
  • 线程优化:根据设备核心数调整config.setThreads(4)
  • 图像压缩:使用Bitmap.compress(Bitmap.CompressFormat.JPEG, 80, outputStream)压缩图片

5.2 功能扩展

  • 相机实时识别:通过TextureView获取相机预览帧进行预测
  • 图像增强:添加paddle_image_preprocess.h中的预处理函数(如resize、归一化)

6. 效果展示

选择图片识别

选择图片识别

相机实时识别

通过TextureView实现实时摄像头预览,每300ms更新一次预测结果。

参考资料

通过以上步骤,你可以快速在Android设备上部署图像分类模型,实现高效的端侧AI推理。

Xiaoye