把PaddlePaddle模型移植到Android手机上

1. 编译PaddlePaddle Android库

首先需要编译PaddlePaddle的Android版本库。这里我们使用CMake来编译生成适合Android平台的动态链接库。

1.1 创建CMakeLists.txt

cmake_minimum_required(VERSION 3.4.1)

# 设置Android NDK路径
set(ANDROID_NDK /path/to/your/android-ndk)
set(CMAKE_TOOLCHAIN_FILE ${ANDROID_NDK}/build/cmake/android.toolchain.cmake)
set(ANDROID_PLATFORM android-21)
set(ANDROID_ABI armeabi-v7a)

# 项目配置
project(paddle-android)

# 添加PaddlePaddle源文件
add_subdirectory(paddle)

# 编译PaddlePaddle库
add_library(paddle_image_recognizer SHARED
            src/main/cpp/image_recognizer.cpp)

# 链接PaddlePaddle库
target_link_libraries(paddle_image_recognizer
                      paddle
                      log)

1.2 编译生成.so文件

使用ndk-build编译生成Android平台的PaddlePaddle库:

ndk-build -C path/to/paddle-prefix/src

2. 准备训练好的模型

2.1 训练模型

首先需要训练一个模型,这里以CIFAR-10分类为例,使用MobileNet网络结构:

# mobile_net.py
import paddle.v2 as paddle

def mobile_net(img_size, class_num):
    img = paddle.layer.data(name="image", type=paddle.data_type.dense_vector(img_size))
    # 定义MobileNet网络结构...
    # ...
    return out

# 训练代码
trainer = paddle.trainer.SGD(...)
# ...

2.2 模型合并

将训练好的模型参数与网络结构合并成一个可直接使用的模型文件:

# merge_model.py
from paddle.utils.merge_model import merge_v2_model

# 导入模型定义
from mobile_net import mobile_net

if __name__ == "__main__":
    img_size = 3 * 32 * 32
    class_num = 10
    net = mobile_net(img_size, class_num)
    param_file = 'models/mobile_net.tar.gz'
    output_file = 'models/mobile_net.paddle'
    merge_v2_model(net, param_file, output_file)

3. Android项目配置

3.1 项目结构

TestPaddleApp/
├── app/
│   ├── src/
│   │   └── main/
│   │       ├── assets/
│   │       │   └── models/
│   │       │       └── mobile_net.paddle
│   │       ├── cpp/
│   │       │   └── image_recognizer.cpp
│   │       └── java/
│   │           └── com/example/testpaddle/
│   │               └── ImageRecognition.java
│   ├── CMakeLists.txt
│   └── build.gradle
└── settings.gradle

3.2 配置CMakeLists.txt

cmake_minimum_required(VERSION 3.4.1)

# 设置Android NDK路径
set(ANDROID_NDK /path/to/android-ndk)
set(CMAKE_TOOLCHAIN_FILE ${ANDROID_NDK}/build/cmake/android.toolchain.cmake)
set(ANDROID_PLATFORM android-21)
set(ANDROID_ABI armeabi-v7a)

project(testpaddle)

# 添加PaddlePaddle源文件
add_subdirectory(paddle)

# 编译Paddle Android库
add_library(paddle_image_recognizer SHARED
            src/main/cpp/image_recognizer.cpp)

# 链接Paddle库
target_link_libraries(paddle_image_recognizer
                      paddle
                      log)

3.3 添加权限

AndroidManifest.xml中添加必要权限:

<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-feature android:name="android.hardware.camera" />

4. 编写C++预测代码

src/main/cpp/image_recognizer.cpp中实现预测逻辑:

#include <jni.h>
#include <android/bitmap.h>
#include <paddle/paddle_api.h>
#include <android/log.h>
#include <vector>
#include <string>

using namespace paddle;

#define LOG_TAG "PaddlePaddle"
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__)

// 全局变量存储梯度机
GradientMachine* gradient_machine_ = nullptr;

// 从Assets加载模型
bool load_model(JNIEnv* env, jobject context, const char* model_path) {
    AAssetManager* mgr = AAssetManager_fromJava(env, context);
    BinaryReader reader(mgr);
    std::string model_file = std::string(model_path);
    const void* model_data = reader.Read(model_file);
    size_t model_size = reader.Size(model_file);

    if (!model_data) {
        LOGD("无法加载模型文件: %s", model_path);
        return false;
    }

    // 创建梯度机
    gradient_machine_ = new GradientMachine();
    gradient_machine_->CreateForInferenceWithParameters(model_data, model_size);
    return true;
}

// 图像预处理
void preprocess_image(JNIEnv* env, jobject bitmap, float* input_data) {
    AndroidBitmapInfo info;
    AndroidBitmap_getInfo(env, bitmap, &info);

    void* pixels;
    AndroidBitmap_lockPixels(env, bitmap, &pixels);

    // 转换图像为BGR格式
    uint8_t* src = (uint8_t*)pixels;
    int width = info.width;
    int height = info.height;

    // 调整图像大小为32x32
    cv::Mat img(height, width, CV_8UC3, src);
    cv::resize(img, img, cv::Size(32, 32));

    // 转换为BGR格式
    std::vector<cv::Mat> channels;
    cv::split(img, channels);
    cv::Mat bgr[3];
    channels[0].copyTo(bgr[2]); // BGR
    channels[1].copyTo(bgr[1]);
    channels[2].copyTo(bgr[0]);

    // 归一化处理
    float mean[3] = {127.5, 127.5, 127.5};
    float std[3] = {1.0/127.5, 1.0/127.5, 1.0/127.5};

    for (int i = 0; i < 32; i++) {
        for (int j = 0; j < 32; j++) {
            for (int c = 0; c < 3; c++) {
                int idx = i * 32 * 3 + j * 3 + c;
                input_data[idx] = (img.at<cv::Vec3b>(i, j)[c] - mean[c]) * std[c];
            }
        }
    }

    AndroidBitmap_unlockPixels(env, bitmap);
}

// JNI调用方法
extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_example_testpaddle_MainActivity_predict(
        JNIEnv* env,
        jobject thiz,
        jobject bitmap) {

    if (!gradient_machine_) {
        load_model(env, thiz, "model/mobile_net.paddle");
    }

    float input_data[1*3*32*32]; // 输入数据: 1×3×32×32
    preprocess_image(env, bitmap, input_data);

    // 创建输入参数
    std::vector<Argument> inputs;
    Argument input(1, 3, 32, 32);
    input.set_data(input_data);
    inputs.push_back(input);

    // 执行预测
    std::vector<Argument> outputs;
    gradient_machine_->Predict(inputs, &outputs);

    // 提取结果
    jfloatArray result = env->NewFloatArray(10);
    for (int i = 0; i < 10; i++) {
        float value = outputs[0].value()[i];
        env->SetFloatArrayRegion(result, i, 1, &value);
    }

    return result;
}

5. Java端代码实现

MainActivity.java中实现图像选择和预测:

public class MainActivity extends AppCompatActivity {
    private static final int REQUEST_IMAGE_CAPTURE = 1;
    private ImageView imageView;
    private Button predictButton;
    private ImageRecognizer imageRecognizer;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        imageView = findViewById(R.id.imageView);
        predictButton = findViewById(R.id.predictButton);
        imageRecognizer = new ImageRecognizer();

        // 加载Paddle模型
        imageRecognizer.loadModel(getAssets(), "model/mobile_net.paddle");

        // 按钮点击事件
        predictButton.setOnClickListener(v -> {
            // 打开相机或相册选择图像
            Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
            startActivityForResult(intent, REQUEST_IMAGE_CAPTURE);
        });
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (requestCode == REQUEST_IMAGE_CAPTURE && resultCode == RESULT_OK) {
            Bundle extras = data.getExtras();
            Bitmap imageBitmap = (Bitmap) extras.get("data");
            imageView.setImageBitmap(imageBitmap);

            // 进行预测
            float[] result = imageRecognizer.predict(imageBitmap);
            // 显示预测结果
            StringBuilder sb = new StringBuilder();
            sb.append("预测结果: ");
            for (int i = 0; i < result.length; i++) {
                sb.append("类别 ").append(i).append(": ").append(result[i]).append("\n");
            }
            Toast.makeText(this, sb.toString(), Toast.LENGTH_LONG).show();
        }
    }
}

6. ImageRecognizer类实现

public class ImageRecognizer {
    static {
        System.loadLibrary("paddle_image_recognizer");
    }

    private Context context;

    public ImageRecognizer() {
        context = MyApplication.getContext();
    }

    // 加载模型
    public boolean loadModel(AssetManager assets, String modelPath) {
        return nativeLoadModel(assets, modelPath);
    }

    // 预测图像
    public float[] predict(Bitmap bitmap) {
        float[] result = new float[10];
        if (bitmap != null) {
            result = nativePredict(bitmap);
        }
        return result;
    }

    // JNI方法声明
    private native boolean nativeLoadModel(AssetManager assets, String modelPath);
    private native float[] nativePredict(Bitmap bitmap);
}

7. 配置Android项目

7.1 添加依赖

build.gradle中添加:

android {
    defaultConfig {
        // ...
        externalNativeBuild {
            cmake {
                abiFilters 'armeabi-v7a'
            }
        }
    }
    externalNativeBuild {
        cmake {
            path "CMakeLists.txt"
        }
    }
}

dependencies {
    // ...
}

7.2 配置权限

AndroidManifest.xml中添加:

<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />

8. 优化与性能调优

  1. 模型压缩:使用PaddlePaddle提供的模型压缩工具减小模型体积
  2. 量化优化:将模型参数从float32转换为int8以减小内存占用
  3. 图像预处理加速:使用OpenCV进行图像加速处理
  4. 线程优化:将预测过程放入后台线程执行,避免阻塞UI

9. 常见问题解决

  1. 模型加载失败:检查模型文件路径是否正确,确保Assets文件夹中模型文件正确放置
  2. 图像格式不匹配:确保图像格式为RGB,且通道顺序与模型要求一致
  3. 权限问题:检查应用是否有相机和存储权限
  4. NDK版本兼容性:确保NDK版本与PaddlePaddle编译版本兼容

通过以上步骤,你可以成功将训练好的PaddlePaddle模型部署到Android设备上,并实现图像分类功能。PaddlePaddle提供了丰富的API和工具,方便开发者进行模型优化和移植。

完整代码已上传至GitHub: https://github.com/yeyupiaoling/LearnPaddle

Xiaoye