pytorch移植到安卓,pytorch部署到手机
本文演示了如何将经过训练的pytorch模型部署到Android设备上。通过示例代码详细介绍,具有一定的参考价值。感兴趣的朋友可以参考一下。
00-1010模型转换Android部署新项目导入包页面文件模型推理这篇文章演示了如何将训练好的pytorch模型部署到Android设备上。我也是刚开始学安卓,代码简单。
环境:
Pytorch版本:1.10.0
目录
pytorch_android支持的型号是。pt模型,而我们训练的模型是。pth。所以需要转换后才能使用。先看官方在线给出的换算方法:
进口火炬
进口火炬视觉
从torch.utils.mobile_optimizer导入优化_for_mobile
model=torch vision . models . mobilenet _ v3 _ small(pre trained=True)
model.eval()
example=torch.rand(1,3,224,224)
traced _ script _ module=torch . JIT . trace(模型,示例)
optimized _ traced _ model=optimize _ for _ mobile(traced _ script _ module)
优化跟踪模型。_ save _ for _ lite _ interpreter( app/src/main/assets/model . ptl )
这个型号在安卓对应的包里:
存储库{
jcenter()
}
依赖关系{
实现 org . py torch : py torch _ Android _ lite :1 . 9 . 0
实现 org . py torch : py torch _ Android _ torch vision :1 . 9 . 0
}
注意:pytorch_android_lite版本要和转换模型用的版本一致,否则会报错各种错误。
目前这种方法是有问题的。我在用另一种方法。
转换代码如下:
进口火炬
导入torch.utils.data.distributed
# pytorch环境
Model_pth=model_31_0.96.pth #模型的参数文件
Mobile_pt=model.pt #将模型保存为Android可以调用的文件
model=torch.load(型号_pth)
model.eval() #模型设置为评估模式
device=torch.device(cpu )
型号至(设备)
# 1画面,3个通道,224*224
Input _ tensor=torch.rand (1,3,224,224) #设置输入数据格式
Mobile=torch.jit.trace (model,input _张量)#模型转换
Mobile.save(mobile_pt) #保存文件
相应的包:
//pytorch
实现 org . py torch : py torch _ Android :1 . 10 . 0
实现 org . py torch : py torch _ Android _ torch vision :1 . 10 . 0
定义模型文件和转换文件的路径。
负载模型。请注意,如果您保存模型
torch.save(model, models.pth )
装载模型是
model=torch.load(models.pth )
如果保存的模型是
torch.save(model.state_dict(), models.pth )
装载模型是
model . load _ state _ dict(torch . load( models . PTH ))
re>
定义输入数据格式。
模型转化,然后再保存模型。
安卓部署
新建项目
新建安卓项目,选择Empy Activity,然后选择Next
然后,填写项目信息,选择安卓版本,我用的4.4,点击完成
导入包
导入pytorch_android的包
//pytorchimplementation org.pytorch:pytorch_android:1.10.0
implementation org.pytorch:pytorch_android_torchvision:1.10.0
如果有参数报错请参照我的完整的配置,代码如下:
plugins {id com.android.application
}
android {
compileSdk 32
defaultConfig {
applicationId "com.example.myapplication"
minSdk 21
targetSdk 32
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile(proguard-android-optimize.txt), proguard-rules.pro
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation androidx.appcompat:appcompat:1.3.0
implementation com.google.android.material:material:1.4.0
implementation androidx.constraintlayout:constraintlayout:2.0.4
testImplementation junit:junit:4.13.2
androidTestImplementation androidx.test.ext:junit:1.1.3
androidTestImplementation androidx.test.espresso:espresso-core:3.4.0
//pytorch
implementation org.pytorch:pytorch_android:1.10.0
implementation org.pytorch:pytorch_android_torchvision:1.10.0
}
页面文件
页面的配置如下:
<?xml version="1.0" encoding="utf-8"?><FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:scaleType="fitCenter" />
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="top"
android:textSize="24sp"
android:background="#80000000"
android:textColor="@android:color/holo_red_light" />
</FrameLayout>
这个页面只有两个空间,一个展示图片,一个显示文字。
模型推理
新增assets文件夹,然后将转化的模型和待测试的图片放进去。
新增ImageNetClasses类,这个类存放类别名字。
代码如下:
package com.example.myapplication;public class ImageNetClasses {
public static String[] IMAGENET_CLASSES = new String[]{
"Black-grass",
"Charlock",
"Cleavers",
"Common Chickweed",
"Common wheat",
"Fat Hen",
"Loose Silky-bent",
"Maize",
"Scentless Mayweed",
"Shepherds Purse",
"Small-flowered Cranesbill",
"Sugar beet",
};
}
在MainActivity类中,增加模型推理的逻辑。完成代码如下:
package com.example.myapplication;import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset image.jpg,
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "models.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
System.out.println(maxScoreIdx);
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
然后运行。
到此这篇关于如何将pytorch模型部署到安卓上的方法示例的文章就介绍到这了,更多相关pytorch模型部署到安卓内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!
郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。