(原)torch模型转pytorch模型
转载请注明出处:
http://www.cnblogs.com/darkknightzh/p/7839263.html
目前使用的torch模型转pytorch模型的程序为:
https://github.com/clcarwin/convert_torch_to_pytorch
该程序中,常见的模型都可以转换,但是对于torch中为BatchNormalization的则会提示出错:
Not Implement BatchNormalization
torch中的SpatialBatchNormalization对应于输入为4d的特征(batchsize*featdim*featHeight*featWidth),对应于pytorch中的nn.BatchNorm2d。
而torch中的BatchNormalization对应于输入为2d的特征(batchsize*featdim),对应于pytorch中的nn.BatchNorm1d。
因而修改方法很简单:
1. 在convert_torch.py的行(elif name == 'ReLU':)之前添加:
elif name == 'BatchNormalization':
n = nn.BatchNorm1d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
copy_param(m,n)
add_submodule(seq,n)
2. 在convert_torch.py的(未修改前的)行(elif name == 'ReLU':)之前添加:
elif name == 'BatchNormalization':
s += ['nn.BatchNorm1d({},{},{},{}),#BatchNorm1d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
3. 在convert_torch.py的(未修改前的)行(s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s))之前添加:
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm1d',')'),s)
s = map(lambda x: x.replace('),#BatchNorm1d',')'),s)
经过上述修改后,torch模型中含有BatchNormalization,转换到pytorch后的模型性能和转换前的模型性能一致。
顺便说一下,2天前更新的该程序,添加了BatchNorm3d的支持,但是在243、244行之后,并没有增加BatchNorm3d的相关代码,不清楚是否会有问题。我这边没有用到BatchNorm3d,因而没有测试。
另一方面,上面的3步中,我是根据BatchNorm2d去修改,没有测试如果不修改某一步(如第3步),程序是否会有问题。反正都改了,模型没有问题。。。
(原)torch模型转pytorch模型的更多相关文章
- 生产与学术之Pytorch模型导出为安卓Apk尝试记录
生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...
- 将Pytorch模型从CPU转换成GPU
1. 如何进行迁移 对模型和相应的数据进行.cuda()处理.通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去.从而可以通过GPU来进行运算了. 1.1 判定使用GPU 下载了对应的GPU ...
- 使用C++调用pytorch模型(Linux)
前言 模型转换思路通常为: Pytorch -> ONNX -> TensorRT Pytorch -> ONNX -> TVM Pytorch -> 转换工具 -> ...
- 使用C++调用并部署pytorch模型
1.背景(Background) 上图显示了目前深度学习模型在生产环境中的方法,本文仅探讨如何部署pytorch模型! 至于为什么要用C++调用pytorch模型,其目的在于:使用C++及多线程可以加 ...
- DEX-6-caffe模型转成pytorch模型办法
在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...
- PyTorch模型加载与保存的最佳实践
一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...
- 从零搭建Pytorch模型教程(三)搭建Transformer网络
前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...
- Pytorch模型量化
在深度学习中,量化指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算.这么做的好处主要有如下几点: 更少的模型体积,接近4倍的减少: 可以更快 ...
- 计算机网络原理和OSI模型与TCP模型
计算机网络原理和OSI模型与TCP模型 一.计算机网络的概述 1.计算机网络的定义 计算机网络是一组自治计算机的互连的集合 2.计算机网络的基本功能 a.资源共享 b.分布式处理与负载均衡 c.综合信 ...
随机推荐
- c#利用SWIG调用c++dll学习总结【转】
开发环境: 操作系统:windows 7 IDE:Microsoft Visual Studio Professional 2015 SWIG: 3.0.12 swig的介绍 详细介绍可看官网,一下贴 ...
- NodeJS错误-throw er; // Unhandled 'error' event
第一眼看以为Express版本出现问题,因为本地已经存在另外一个运行的Node项目,端口重复,修改一下端口号即可,错误提示如下: events.js:85 throw er; // Unhandled ...
- html与css架构的一点体验
css本身,可以说是一门非常简单而容易入门的语言.制作一个页面,或者制作一个小企业站,对于css的要求都是非常低的.只要熟悉语法,通过英文单词的含义猜,都基本可以拼出一套样式.更何况市面上还有各种各样 ...
- Matplotlib.pyplot 常用方法
1.介绍 Matplotlib 是一个 Python 的 2D绘图库,它以各种硬拷贝格式和跨平台的交互式环境生成出版质量级别的图形.通过 Matplotlib,开发者可以仅需要几行代码,便可以生成绘图 ...
- Double-Array Trie分词词典简述
http://www.xuebuyuan.com/1991441.html 一.TRIE树简介(以下简称T树) TRIE树用于确定词条的快速检索,对于给定的一个字符串a1,a2,a3,…an,则采用T ...
- Netflix推荐系统:从评分预测到消费者法则
http://in.sdo.com/?p=11 原文链接:Netflix recommendations: beyond the 5 stars (Part 1), (Part 2) 原文作者:Xav ...
- ASP入门(十一)-Session小案例
一般来说,在实际开发中,对于 Session 对象使用最多的就是用户登录部分了,这个案例将简单模拟一个用户登录表单.用户是否登录的判断以及用户退出的一系列功能,它一共分了以下几个页面. Login.a ...
- [Canvas]人物型英雄出现(前作仅为箭头)
源码点此下载,用浏览器打开index.html观看. 代码: <!DOCTYPE html> <html lang="utf-8"> <meta ht ...
- Oracle——数据库启动与关闭
本文内容 服务器环境 客户端环境 概述 启动数据库 关闭数据库 补充 参考资料 本文说明 Oracle 数据库的启动和关闭,内容虽然基础,但是在数据库很多操作中都需要,因此,基础而重要,必须深入理解. ...
- crm创建启用停用用户
public static readonly string entityName = "systemuser"; public Guid userId = Guid ...