将pb模型参数提取转成torch模型
1 import tensorflow as tf
2 import onnx
3 import onnxsim
4 import numpy as np
5 import torch
6 from model.facedetector_model import mobilenetv2_yolov3
7
8 #提取pb模型中的参数
9 def extract_params_from_pb():
10 constant_values = {}
11 with tf.compat.v1.Session() as sess:
12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
13 graph_def = tf.compat.v1.GraphDef()
14 graph_def.ParseFromString(f.read())
15 sess.graph.as_default()
16 tf.import_graph_def(graph_def, name='')
17 # # input
18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
19 # # output
20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
21 # sess.run(output, feed_dict={'input/input_data:0': inputimage})
22
23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
24 for constant_op in constant_ops:
25 if constant_op.op_def.name == "Const":
26 if "Shape" in constant_op.name or "pred" in constant_op.name:
27 continue
28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
29 return constant_values
30
31 #过滤提取出来的params
32 def filter_params(constant_values):
33 total = 0
34 prompt = []
35 res = {}
36 forbidden = ['shape','stack']
37
38 for k,v in constant_values.items():
39 # filtering some by checking ndim and name
40 if v.ndim<1: continue
41 if v.ndim==1:
42 token = k.split(r'/')[-1]
43 flag = False
44 for word in forbidden:
45 if token.find(word)!=-1:
46 flag = True
47 break
48 if flag:
49 continue
50
51 shape = v.shape
52 cnt = 1
53 for dim in shape:
54 cnt *= dim
55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
56 res[k] = v
57 print(prompt[-1])
58 total += cnt
59 prompt.append('totaling {}'.format(total))
60 # print(prompt[-1])
61 return res
62
63 #将Tensorflow的张量转换成PyTorch的张量
64 def trans_tensor_pb2pth(k,a):
65
66 v = tf.convert_to_tensor(a).numpy()
67 # tensorflow weights to pytorch weights
68 if len(v.shape) == 4:
69 if "depthwise_weights" in k:#防止深度可分离卷积
70 return np.ascontiguousarray(v.transpose(2,3,0,1))
71 return np.ascontiguousarray(v.transpose(3,2,0,1))
72 elif len(v.shape) == 2:
73 return np.ascontiguousarray(v.transpose())
74 return v
75
76 #将pb的对应params名字转换为pth对应参数名
77 def trans_name_pb2pth(trans_weights):
78 model_dict = {}
79 for name,para in trans_weights.items():
80 name = name.replace('/',".")
81
82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
83 name = name.replace('weights',"0.weight")
84 name = name.replace('BatchNorm',"1")
85 name = name.replace('gamma',"weight")
86 name = name.replace('beta',"bias")
87 name = name.replace('moving_mean',"running_mean")
88 name = name.replace('moving_variance',"running_var")
89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
90 name = name.replace('depthwise.',"0.")
91 name = name.replace('project',"1")
92 name = name.replace('depthwise_weights',"0.weight")
93 name = name.replace('weights',"0.weight")
94 name = name.replace('BatchNorm',"1")
95 name = name.replace('gamma',"weight")
96 name = name.replace('beta',"bias")
97 name = name.replace('moving_mean',"running_mean")
98 name = name.replace('moving_variance',"running_var")
99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100 name = name.replace('expand.',"0.")
101 name = name.replace('depthwise.',"1.")
102 name = name.replace('project',"2")
103 name = name.replace('depthwise_weights',"0.weight")
104 name = name.replace('weights',"0.weight")
105 name = name.replace('BatchNorm',"1")
106 name = name.replace('gamma',"weight")
107 name = name.replace('beta',"bias")
108 name = name.replace('moving_mean',"running_mean")
109 name = name.replace('moving_variance',"running_var")
110 elif "yolo-v3" in name:
111 if "bbox" in name:
112 continue
113 name = name.replace('yolo-v3',"yolo_v3")
114 name = name.replace('weight',"0.weight")
115 name = name.replace('kernel',"weight")
116 name = name.replace('batch_normalization',"1")
117 name = name.replace('gamma',"weight")
118 name = name.replace('beta',"bias")
119 name = name.replace('moving_mean',"running_mean")
120 name = name.replace('moving_variance',"running_var")
121 print(name)
122 model_dict[name] = torch.Tensor(para)
123 return model_dict
124
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127 constant_values = extract_params_from_pb()
128 TF_weights = filter_params(constant_values)
129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130
131 #创建pytorch模型
132 PyTorchModel = mobilenetv2_yolov3()
133 model_dict = trans_name_pb2pth(trans_weights)
134 # model_dict = PyTorchModel.state_dict()
135 # for name in model_dict.keys():
136 # print(name)
137 PyTorchModel.load_state_dict(model_dict)
138 PyTorchModel.cuda().eval()
139 dummy_input = torch.rand(1,1,224,224,device="cuda").float()
140 # out = PyTorchModel(dummy_input)
141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142 print("====> Simplifying...")
143 model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144 onnx.save(model_opt, 'P3mNet_sim.onnx')
145 print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()
将pb模型参数提取转成torch模型的更多相关文章
- 利用反射将Datatable、SqlDataReader转换成List模型
1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...
- (原)torch模型转pytorch模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...
- 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型
在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...
- 【tensorflow-v2.0】如何将模型转换成tflite模型
前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...
- DEX-6-caffe模型转成pytorch模型办法
在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...
- 使用C#语言,将DataTable 转换成域模型
DataTable dt = SqlHelper.Query(strQuery); ) * size).Take(pagesize); List<Model> listData = new ...
- PB之取下来列修改后的值(AcceptText)
AcceptText()功能 将“漂浮”在数据窗口控件上编辑框的内容放入到数据窗口控件的当前项中(主缓区中).在将数据放入到当前项之前,编辑框中的数据必须通过有效性规则检查语法 dwcontrol. ...
- pytorch1.0 用torch script导出模型
python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...
- MxNet 模型转Tensorflow pb模型
用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...
- iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型
在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...
随机推荐
- .NET AsyncLocal 避坑指南
目录 AsyncLocal 用法简介 AsyncLocal 实现原理 AsyncLocal 的坑 AsyncLocal 的避坑指南 HttpContextAccessor 的实现原理 AsyncLoc ...
- RISC-V核及工具链整理
RISC-V开源核分为开源核(无外设).SOC.FPGA.多核等多种框架. 开源核 SOC框架 平头哥无剑100 包括EDA仿真框架及FPGA框架 https://github.com/T-head- ...
- 08. AssetBundle.LoadFromFile
参数 path 文件在磁盘上的路径. crc 未压缩内容的 CRC-32 校验和(可选).如果该参数不为零,则加载前将内容与校验和进行比较,如果不匹配则给出错误. offset 字节偏移(可选).该值 ...
- c# 游戏设计:地图移动
想实现一个小游戏,先做地图移动.步骤记录如下: 1.百度到一张大的迷宫地图,放在项目的debug目录下,备用. 2.创建一个winform项目,不添加任何界面元素. 3.添加数据成员如下: Pictu ...
- sql(上)例题
一.数据库概述 数据库(DataBase,DB):指长期保存在计算机的存储设备上,按照一定规则组织起来,可以被各种用户或应用共享的数据集合. 数据库管理系统(DataBase Management S ...
- Linux - tar 命令详解 (压缩,解压,加密压缩,解密压缩)
压缩tar -czvf /path/to/file.tar.gz file (第一个参数:文件压缩的位置和名字 第二个参数:需要压缩的文件) 解压 tar -xzvf /path/to/file. ...
- RxJava2.x的理解与总结
RxJava2.x的理解与总结 RxJava是一个基于观察者设计模式将链式编程和异步结合在一起的开源库. 链式编程 通过查看GitHub开源项目的简介开源知道,RxJava有几个基类. 他们分别适用于 ...
- js - 数字转中文
js - 数字转中文 JavaScript 中将阿拉伯数字转换为中文 转换代码 var _change = { ary0: ['零', '一', '二', '三', '四', '五', '六', '七 ...
- 前端下载的方式总结(url,文件流,压缩包)
1.比较常见的是通过a标签的href属性直接访问文件url地址. (1)const downloadUrl = (url: string, file_name?: string) => { if ...
- ES5 绑定 this 的方法
this的动态切换,固然为 JavaScript 创造了巨大的灵活性,但也使得编程变得困难和模糊.有时,需要把this固定下来,避免出现意想不到的情况.JavaScript 提供了call.apply ...