1.首先官网上下载libtorch,放到当前项目下

2.将pytorch训练好的模型使用torch.jit.trace导出为.pt格式

  1. import torch
  2. from skimage import io, transform, color
  3. import numpy as np
  4. import os
  5. import torch.nn.functional as F
  6. import warnings
  7. warnings.filterwarnings("ignore")
  8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  9.  
  10. labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy']
  11. path = "test/n_1.jpg"
  12. im = io.imread(path)
  13. if im.shape[2] == 4:
  14. im = color.rgba2rgb(im)
  15.  
  16. im = transform.resize(im, (224, 224))
  17. im = np.transpose(im, (2, 0, 1))
  18. dummy_input = np.expand_dims(im, 0)
  19. inp = torch.from_numpy(dummy_input)
  20. inp = inp.float()
  21. model = torch.load(
  22. "models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu')
  23. traced_script_module = torch.jit.trace(model, inp)
  24. output = model(inp)
  25. probs = F.softmax(output).detach().numpy()[0]
  26. pred = np.argmax(probs)
  27.  
  28. traced_script_module.save("models/traced_resnet_model.pt")

torchscript加载.pt模型

  1. // One-stop header.
  2. #include <torch/script.h>
  3.  
  4. // headers for opencv
  5. #include <opencv2/highgui/highgui.hpp>
  6. #include <opencv2/imgproc/imgproc.hpp>
  7. #include <opencv2/opencv.hpp>
  8.  
  9. #include <cmath>
  10. #include <iostream>
  11. #include <memory>
  12. #include <string>
  13. #include <vector>
  14.  
  15. #define kIMAGE_SIZE 224
  16. #define kCHANNELS 3
  17. #define kTOP_K 1 //print top k predicted results
  18.  
  19. bool LoadImage(std::string file_name, cv::Mat &image)
  20. {
  21. image = cv::imread(file_name); // CV_8UC3
  22. if (image.empty() || !image.data)
  23. {
  24. return false;
  25. }
  26. cv::cvtColor(image, image, CV_BGR2RGB);
  27. // scale image to fit
  28. cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);
  29. cv::resize(image, image, scale);
  30.  
  31. // convert [unsigned int] to [float]
  32. image.convertTo(image, CV_32FC3,1.0/255);
  33.  
  34. return true;
  35. }
  36.  
  37. bool LoadImageNetLabel(std::string file_name,
  38. std::vector<std::string> &labels)
  39. {
  40. std::ifstream ifs(file_name);
  41. if (!ifs)
  42. {
  43. return false;
  44. }
  45. std::string line;
  46. while (std::getline(ifs, line))
  47. {
  48. labels.push_back(line);
  49. }
  50. return true;
  51. }
  52.  
  53. int main(int argc, const char *argv[])
  54. {
  55. if (argc != 3)
  56. {
  57. std::cerr << "Usage:classifier <path-to-exported-script-module> <path-to-lable-file> " << std::endl;
  58. return -1;
  59. }
  60.  
  61. //load model
  62. torch::jit::script::Module module = torch::jit::load(argv[1]);
  63. // to GPU
  64. // module->to(at::kCUDA);
  65. std::cout << "== ResNet50 loaded!\n";
  66.  
  67. //load labels(classes names)
  68. std::vector<std::string> labels;
  69. if (LoadImageNetLabel(argv[2], labels))
  70. {
  71. std::cout << "== Label loaded! Let's try it\n";
  72. }
  73. else
  74. {
  75. std::cerr << "Please check your label file path." << std::endl;
  76. return -1;
  77. }
  78.  
  79. std::string file_name = "";
  80. cv::Mat image;
  81. while (true)
  82. {
  83. std::cout << "== Input image path: [enter q to exit]" << std::endl;
  84. std::cin >> file_name;
  85. if (file_name == "Q" || file_name == "q")
  86. {
  87. break;
  88. }
  89. if (LoadImage(file_name, image))
  90. {
  91. //read image tensor
  92. auto input_tensor = torch::from_blob(
  93. image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});
  94. input_tensor = input_tensor.permute({0, 3, 1, 2});
  95. input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
  96. input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
  97. input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
  98. // to GPU
  99. // input_tensor = input_tensor.to(at::kCUDA);
  100.  
  101. torch::Tensor out_tensor = module.forward({input_tensor}).toTensor();
  102.  
  103. auto results = out_tensor.sort(-1, true);
  104. auto softmaxs = std::get<0>(results)[0].softmax(0);
  105. auto indexs = std::get<1>(results)[0];
  106.  
  107. for (int i = 0; i < kTOP_K; ++i)
  108. {
  109. auto idx = indexs[i].item<int>();
  110. std::cout << " ============= Top-" << i + 1 << " =============" << std::endl;
  111. std::cout << " Label: " << labels[idx] << std::endl;
  112. std::cout << " With Probability: "
  113. << softmaxs[i].item<float>() * 100.0f << "%" << std::endl;
  114. }
  115. }
  116. else
  117. {
  118. std::cout << "Can't load the image, please check your path." << std::endl;
  119. }
  120. }
  121. }

CMakeLists.txt编译

  1. cmake_minimum_required(VERSION 2.8)
  2. project(predict_demo)
  3. SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3")
  4.  
  5. set(OpenCV_DIR /home/buyizhiyou/opencv-3.4./build)
  6. find_package(OpenCV REQUIRED)
  7. find_package(Torch REQUIRED)
  8.  
  9. # 添加头文件
  10. include_directories( ${OpenCV_INCLUDE_DIRS} )
  11.  
  12. add_executable(resnet_demo resnet_demo.cpp)
  13. target_link_libraries(resnet_demo ${TORCH_LIBRARIES} ${OpenCV_LIBS})
  14. set_property(TARGET resnet_demo PROPERTY CXX_STANDARD )

运行

  1. ./resnet_demo models/traced_resnet_model.pt labels.txt

c++ 使用torchscript 加载训练好的pytorch模型的更多相关文章

  1. vue中加载three.js的gltf模型

    vue中加载three.js的gltf模型 一.开始引入three.js相关插件.首先利用淘宝镜像,操作命令为: cnpm install three //npm install three也行 二. ...

  2. pytorch 加载训练好的模型做inference

    前提: 模型参数和结构是分别保存的 1. 构建模型(# load model graph) model = MODEL() 2.加载模型参数(# load model state_dict) mode ...

  3. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  4. Tensorflow加载预训练模型和保存模型

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  5. 关于Tensorflow 加载和使用多个模型的方式

    在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Sessi ...

  6. [原][osgearth]earth文件加载道路一初步看见模型道路

    时间是2017年2月5日17:16:32 由于OE2.9还没有发布,但是我又急于使用OE的道路. 所以,我先编译了正在github上调试中的OE2.9 github网址是:https://github ...

  7. Three.js中加载外部fbx格式的模型素材

    index.html部分: index.js部分: Scene.js部分:

  8. 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试

    AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...

  9. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

随机推荐

  1. CentOS下启动和停止Tomcat

    启动Tomcat: 进入tomcat目录/bin,然后./startup.sh 停止Tomcat: 进入tomcat目录/bin,然后./shutdown.sh

  2. oracle update from多表性能优化一例

    这几天测试java内存数据库,和oracle比较时发下一个update from语句很慢,如下: update business_new set fare1_balance_ratio = (sele ...

  3. Android 加密之文件级加密(CE/DE)

    https://blog.csdn.net/myfriend0/article/details/77094890/   Android加密之文件级加密

  4. 关于H5项目开发中TS(或JS)文件按照顺序编译成一个文件的记录

    由于js的执行特性,多个js文件合成一个文件或者进行多个js文件加载时,时需要按照指定的顺序进行的,否则会出现报错的情况. 我们看一下目前几个主流H5引擎的做法. 白鹭的做法 当前版本的做法 在tsc ...

  5. Java的方法类型

    1.无参数无返回值的方法 package com.imooc.method; public class MethodDemo { public static void printStar() { Sy ...

  6. python语言使用yaml 管理selenium元素

    1.所有元素都在PageElement下的.yaml,如图 login_page.yaml文件: username: dec: 登录页 type: xpath value: //input[@clas ...

  7. Archer和ArcherUI配置说明

    如果Bladex的网关端口是80,则需要修改Archer服务端口,并修改ArcherUI的vue.config.js的端口

  8. ChrW函数

    ChrW 函数返回包含 Unicode 的 String,若在不支持 Unicode 的平台上,则其功能与 Chr 函数相同.相反的函数是 ASCW() 在access当中用到了

  9. LODOP打印维护适应不同的客户端

    之前的博文:Lodop打印设计.维护.预览.直接打印简单介绍,介绍了打印设计.打印维护.打印预览,直接打印等的区别和使用. 如上面以前博文描述的,打印维护是针对客户端进行调整的,开放打印维护给客户端, ...

  10. python:单例模式--使用__new__(cls)实现

    单例模式:即一个类有且仅有一个实例. 那么通过python怎么实现一个类只能有一个实例呢. class Earth: """ 假如你是神,你可以创造地球 "&q ...