使用NCNN在Android上实现图像分类

前言

在之前的文章中,我们介绍了如何使用百度的PaddleMobile框架在Android设备上实现图像分类。本章将介绍使用腾讯开源的手机深度学习框架ncnn来实现同样的功能。ncnn具有较长的开源历史,相对更为稳定。

ncnn的GitHub地址:https://github.com/Tencent/ncnn

使用Ubuntu编译ncnn库

1. 下载和解压NDK

wget https://dl.google.com/android/repository/android-ndk-r17b-linux-x86_64.zip
unzip android-ndk-r17b-linux-x86_64.zip

2. 设置NDK环境变量

export ANDROID_NDK="/home/test/paddlepaddle/android-ndk-r17b"

验证配置:

echo $NDK_ROOT
/home/test/paddlepaddle/android-ndk-r17b

3. 安装CMake

下载CMake源码:

wget https://cmake.org/files/v3.11/cmake-3.11.2.tar.gz
tar -zxvf cmake-3.11.2.tar.gz

编译安装:

cd cmake-3.11.2
./bootstrap
make
make install

验证安装:

cmake --version
cmake version 3.11.2

4. 克隆ncnn源码

git clone https://github.com/Tencent/ncnn.git

5. 编译ncnn源码

# 进入ncnn源码目录
cd ncnn
# 创建编译目录
mkdir -p build-android-armv7
cd build-android-armv7
# 配置编译参数
cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
    -DANDROID_ABI="armeabi-v7a" -DANDROID_ARM_NEON=ON \
    -DANDROID_PLATFORM=android-14 ..
# 并行编译(使用4个线程)
make -j4
make install

6. 编译结果

编译完成后,在build-android-armv7目录下会生成install目录,包含:
- include:ncnn头文件,用于Android项目的src/main/cpp目录
- lib:编译得到的libncnn.a,用于Android项目的src/main/jniLibs/armeabi-v7a/

转换预测模型

1. 克隆Caffe源码

git clone https://github.com/BVLC/caffe.git

2. 编译Caffe源码

cd caffe
cmake .
make -j4
make install

3. 升级Caffe模型

cd tools
# 升级模型配置文件
./upgrade_net_proto_text mobilenet_v2_deploy.prototxt mobilenet_v2_deploy_new.prototxt
./upgrade_net_proto_binary mobilenet_v2.caffemodel mobilenet_v2_new.caffemodel

4. 检查模型配置文件

确保输入层的dim设置为1(单张图片预测):

name: "MOBILENET_V2"
layer {
  name: "input"
  type: "Input"
  top: "data"
  input_param {
    shape {
      dim: 1
      dim: 3
      dim: 224
      dim: 224
    }
  }
}

5. 编译ncnn工具

cd ncnn/
mkdir -p build
cd build
cmake ..
make -j4
make install

6. 转换Caffe模型到NCNN模型

cd tools/caffe/
# 转换模型
./caffe2ncnn mobilenet_v2_deploy_new.prototxt mobilenet_v2_new.caffemodel mobilenet_v2.param mobilenet_v2.bin

7. 模型加密

cd ..
# 加密模型参数
./ncnn2mem mobilenet_v2.param mobilenet_v2.bin mobilenet_v2.id.h mobilenet_v2.mem.h

8. 转换后的文件

  • mobilenet_v2.param.bin:网络模型参数
  • mobilenet_v2.bin:网络权重
  • mobilenet_v2.id.h:预测时使用

开发Android项目

项目结构设置

  1. 创建Android Studio项目:确保启用C++支持和C++11标准

  2. 添加模型文件到assets
    - mobilenet_v2.param.bin
    - mobilenet_v2.bin
    - synset.txt(标签文件)

  3. 添加ncnn库文件
    - 将编译得到的libncnn.a放入jniLibs/armeabi-v7a/

  4. 添加头文件
    - 将ncnn/include目录下的头文件放入cpp/include
    - 将mobilenet_v2.id.h放入cpp目录

JNI实现(ncnn_jni.cpp)

#include <android/bitmap.h>
#include <android/log.h>
#include <jni.h>
#include <string>
#include <vector>
#include "include/net.h"
#include "mobilenet_v2.id.h"
#include <sys/time.h>

static ncnn::UnlockedPoolAllocator g_blob_pool_allocator;
static ncnn::PoolAllocator g_workspace_pool_allocator;

static ncnn::Mat ncnn_param;
static ncnn::Mat ncnn_bin;
static ncnn::Net ncnn_net;

extern "C" {

// 初始化模型
JNIEXPORT jboolean JNICALL
Java_com_example_ncnn1_NcnnJni_Init(JNIEnv *env, jobject thiz, jbyteArray param, jbyteArray bin) {
    // 加载参数
    {
        int len = env->GetArrayLength(param);
        ncnn_param.create(len, 1u);
        env->GetByteArrayRegion(param, 0, len, reinterpret_cast<jbyte*>(ncnn_param));
        int ret = ncnn_net.load_param(reinterpret_cast<const unsigned char*>(ncnn_param));
        __android_log_print(ANDROID_LOG_DEBUG, "NcnnJni", "load_param %d", ret);
    }

    // 加载模型
    {
        int len = env->GetArrayLength(bin);
        ncnn_bin.create(len, 1u);
        env->GetByteArrayRegion(bin, 0, len, reinterpret_cast<jbyte*>(ncnn_bin));
        int ret = ncnn_net.load_model(reinterpret_cast<const unsigned char*>(ncnn_bin));
        __android_log_print(ANDROID_LOG_DEBUG, "NcnnJni", "load_model %d", ret);
    }

    ncnn::Option opt;
    opt.lightmode = true;
    opt.num_threads = 4;
    opt.blob_allocator = &g_blob_pool_allocator;
    opt.workspace_allocator = &g_workspace_pool_allocator;
    ncnn::set_default_option(opt);

    return JNI_TRUE;
}

// 图像分类预测
JNIEXPORT jfloatArray JNICALL Java_com_example_ncnn1_NcnnJni_Detect(JNIEnv* env, jobject thiz, jobject bitmap) {
    ncnn::Mat in;
    {
        AndroidBitmapInfo info;
        AndroidBitmap_getInfo(env, bitmap, &info);
        int width = info.width;
        int height = info.height;
        if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) return NULL;

        void* indata;
        AndroidBitmap_lockPixels(env, bitmap, &indata);
        in = ncnn::Mat::from_pixels((const unsigned char*)indata, ncnn::Mat::PIXEL_RGBA2BGR, width, height);
        AndroidBitmap_unlockPixels(env, bitmap);
    }

    // 预处理
    const float mean_vals[3] = {103.94f, 116.78f, 123.68f};
    const float scale[3] = {0.017f, 0.017f, 0.017f};
    in.substract_mean_normalize(mean_vals, scale);

    // 推理
    ncnn::Extractor ex = ncnn_net.create_extractor();
    ex.input(mobilenet_v2_param_id::BLOB_data, in);
    ncnn::Mat out;
    ex.extract(mobilenet_v2_param_id::BLOB_prob, out);

    // 转换结果
    jfloatArray jOutputData = env->NewFloatArray(out.w);
    if (jOutputData) {
        env->SetFloatArrayRegion(jOutputData, 0, out.w, reinterpret_cast<const jfloat*>(out));
    }
    return jOutputData;
}
}

Java接口(NcnnJni.java)

package com.example.ncnn1;

import android.graphics.Bitmap;

public class NcnnJni {
    // 加载模型
    public native boolean Init(byte[] param, byte[] bin);

    // 图像分类预测
    public native float[] Detect(Bitmap bitmap);

    static {
        System.loadLibrary("ncnn_jni");
    }
}

主Activity(MainActivity.java)

package com.example.ncnn1;

import android.Manifest;
import android.app.Activity;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class MainActivity extends Activity {
    private static final String TAG = "MainActivity";
    private static final int USE_PHOTO = 1001;
    private NcnnJni ncnnJni = new NcnnJni();
    private List<String> labels = new ArrayList<>();

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        try {
            initModel();
        } catch (IOException e) {
            Log.e(TAG, "模型初始化失败: " + e.getMessage());
        }

        initView();
        readLabels();
    }

    // 初始化模型
    private void initModel() throws IOException {
        AssetManager assets = getAssets();
        // 读取参数和模型文件
        byte[] param = readAssetFile(assets, "mobilenet_v2.param.bin");
        byte[] bin = readAssetFile(assets, "mobilenet_v2.bin");
        // 加载模型
        boolean initSuccess = ncnnJni.Init(param, bin);
        Log.d(TAG, "模型初始化结果: " + initSuccess);
    }

    // 读取标签文件
    private void readLabels() {
        try {
            AssetManager assets = getAssets();
            BufferedReader reader = new BufferedReader(new InputStreamReader(assets.open("synset.txt")));
            String line;
            while ((line = reader.readLine()) != null) {
                labels.add(line);
            }
            reader.close();
        } catch (IOException e) {
            Log.e(TAG, "读取标签失败: " + e.getMessage());
        }
    }

    // 读取Asset文件
    private byte[] readAssetFile(AssetManager assets, String filename) throws IOException {
        InputStream is = assets.open(filename);
        byte[] buffer = new byte[is.available()];
        is.read(buffer);
        is.close();
        return buffer;
    }

    // 初始化UI
    private void initView() {
        Button usePhotoBtn = findViewById(R.id.use_photo);
        usePhotoBtn.setOnClickListener(v -> {
            if (checkPermission()) {
                Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
                startActivityForResult(intent, USE_PHOTO);
            } else {
                requestPermissions();
            }
        });
    }

    // 检查权限
    private boolean checkPermission() {
        return ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE)
                == PackageManager.PERMISSION_GRANTED;
    }

    // 请求权限
    private void requestPermissions() {
        ActivityCompat.requestPermissions(this,
                new String[]{Manifest.permission.READ_EXTERNAL_STORAGE},
                1);
    }

    // 处理权限请求结果
    @Override
    public void onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults) {
        if (requestCode == 1 && grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
            // 权限已授予
        } else {
            Toast.makeText(this, "需要存储权限才能使用", Toast.LENGTH_SHORT).show();
        }
    }

    // 处理图片选择结果
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (requestCode == USE_PHOTO && resultCode == RESULT_OK && data != null) {
            Uri imageUri = data.getData();
            Bitmap bitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), imageUri);
            // 预测图片
            float[] result = ncnnJni.Detect(bitmap);
            showResult(result);
        }
    }

    // 显示预测结果
    private void showResult(float[] result) {
        TextView textView = findViewById(R.id.result_text);
        if (result == null || result.length == 0) return;

        // 找到最高概率的类别
        int maxIndex = 0;
        float maxValue = result[0];
        for (int i = 1; i < result.length; i++) {
            if (result[i] > maxValue) {
                maxValue = result[i];
                maxIndex = i;
            }
        }

        String resultText = String.format("类别: %d\n名称: %s\n概率: %.4f\n时间: %dms",
                maxIndex, labels.get(maxIndex), maxValue, System.currentTimeMillis());
        textView.setText(resultText);
    }
}

布局文件(activity_main.xml)

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <ImageView
        android:id="@+id/show_image"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:layout_above="@id/result_text"
        android:scaleType="centerCrop" />

    <TextView
        android:id="@+id/result_text"
        android:layout_width="match_parent"
        android:layout_height="100dp"
        android:layout_alignParentBottom="true"
        android:scrollbars="vertical"
        android:padding="8dp" />

    <Button
        android:id="@+id/use_photo"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_above="@id/result_text"
        android:layout_centerHorizontal="true"
        android:layout_marginBottom="16dp"
        android:text="选择图片" />

</RelativeLayout>

CMake配置(CMakeLists.txt)

cmake_minimum_required(VERSION 3.4.1)

# 设置ncnn库路径
set(NCNN_LIB_PATH ${CMAKE_SOURCE_DIR}/src/main/jniLibs/armeabi-v7a/libncnn.a)

# 添加ncnn库
add_library(ncnn STATIC IMPORTED)
set_target_properties(ncnn PROPERTIES IMPORTED_LOCATION ${NCNN_LIB_PATH})

# 创建JNI库
add_library(ncnn_jni SHARED src/main/cpp/ncnn_jni.cpp)

# 链接系统库和ncnn库
find_library(log-lib log)
target_link_libraries(ncnn_jni ncnn ${log-lib} jnigraphics)

Gradle配置(build.gradle)

```gradle
android {
compileSdkVersion 28
defaultConfig {
applicationId “com.example.ncnn1”
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName “1.0”
externalNativeBuild {
cmake {
cppFlags “-

Xiaoye