训练的代码,以cifar为例

# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
dataiter = iter(trainloader)
images, labels = dataiter.next() class Net(torch.jit.ScriptModule):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) @torch.jit.script_method
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for epoch in range(1):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
if i == 2:
break
net.save("/Users/zhouyang3/CLionProjects/hello_world/a.pt")
print('Finished Training')

c++推理代码

#include <torch/script.h>
#include <typeinfo>
#include <iostream>
#include <memory> int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
} // Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]); assert(module != nullptr);
std::cout << "ok\n";
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 32, 32}));
at::Tensor output = module->forward(inputs).toTensor();
std::cout << "AAAAA" << '\n';
std::cout << output.argmax() << '\n';
std::cout << "BBBBB" << '\n';
}

ARTS-S pytorch用c++实现推理的更多相关文章

  1. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

  2. ONNXRuntime学习笔记(二)

    继上一篇计划的实践项目,这篇记录我训练模型相关的工作. 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%:部署后推理效率达到50ms/张, 部署平台为wi ...

  3. 从零教你使用MindStudio进行Pytorch离线推理全流程

    摘要:MindStudio的是一套基于华为自研昇腾AI处理器开发的AI全栈开发工具平台,该IDE上功能很多,涵盖面广,可以进行包括网络模型训练.移植.应用开发.推理运行及自定义算子开发等多种任务. 本 ...

  4. 使用TensorRT对caffe和pytorch onnx版本的mnist模型进行fp32和fp16 推理 | tensorrt fp32 fp16 tutorial with caffe pytorch minist model

    本文首发于个人博客https://kezunlin.me/post/bcdfb73c/,欢迎阅读最新内容! tensorrt fp32 fp16 tutorial with caffe pytorch ...

  5. ArXiv最受欢迎开源深度学习框架榜单:TensorFlow第一,PyTorch第四

    [导读]Kears作者François Chollet刚刚在Twitter贴出最近三个月在arXiv提到的深度学习框架,TensorFlow不出意外排名第一,Keras排名第二.随后是Caffe.Py ...

  6. 深度学习框架PyTorch一书的学习-第三章-Tensor和autograd-2-autograd

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 torch.autograd就是为了方 ...

  7. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  8. 【转载】 Pytorch(1) pytorch中的BN层的注意事项

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...

  9. 吐血整理:PyTorch项目代码与资源列表 | 资源下载

    http://www.sohu.com/a/164171974_741733   本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文 ...

随机推荐

  1. CSS复合选择器是什么?复合选择器是如何工作

    复合选择器介绍 复合选择器其实很好理解,说白了就跟我们生活中的有血缘关系家庭成员一样,通过标签或者class属性或id属性,去找对应的有血缘关系的某个选择器,具体的大家往下看哦. 如果是初学者对基本的 ...

  2. spark基于yarn的两种提交模式

    一.spark的三种提交模式 1.第一种,Spark内核架构,即standalone模式,基于Spark自己的Master-Worker集群. 2.第二种,基于YARN的yarn-cluster模式. ...

  3. Python 常用模块系列(2)--time module and datatime module

    import time print (help(time)) #time帮助文档 1. time模块--三种时间表现形式: 1° 时间戳--如:time.time()  #从python创立以来,到当 ...

  4. SpringSecurity退出功能实现的正确方式

    本文将介绍在Spring Security框架下如何实现用户的"退出"logout的功能.其实这是一个非常简单的功能,我见过很多的程序员在使用了Spring Security之后, ...

  5. GitHub远程库的搭建以及使用

    GitHub远程库的搭建 一).配置SSH 步骤: 1).注册GitHub账号 2).本地git仓库与远程的GitHub仓库的传输要通过SSH进行加密 3).创建SSH key ​ 1.检查在用户主目 ...

  6. 前端的构建化工具Webpack

    经常看到如jquery-3.0.0.js和jquery-3.0.0-min.js等两相似的文件名. 其实以上两个文件名的内容是一样的,不过带min代表的是占用最小的空间,为项目提高性能.压缩的部分如换 ...

  7. 基于Galera Cluster多主结构的Mysql高可用集群

    Galera Cluster特点 1.多主架构:真正的多点读写的集群,在任何时候读写数据,都是最新的 2.同步复制:集群不同节点之间数据同步,没有延迟,在数据库挂掉之后,数据不会丢失 3.并发复制:从 ...

  8. day 07 复习总结

    今日主要内容 1. 补充基础数据类型的相关知识点 1. str. join() 把列表变成字符串 对应的是split () 表示把字符串变成列表.  ()里面为分隔符,不写默认为空格分隔 1.吧 2. ...

  9. C语言基础 -- 变量

    常用变量类型 ​​ 地址 小端 低地址保存低位,高地址保存高位 常用于 PC(复杂指令集) 大端 低地址保存高位,高地址保存低位 常用于 ARM/手机/网络(精简指令集)

  10. NIO-概览

    目录 NIO-概览 目录 前言 什么是NIO 通道 缓冲区 选择器 其他 管道 FileLock 参考文档 NIO-概览 目录 NIO-概览 前言 本来是想学习Netty的,但是Netty是一个NIO ...