1 import tensorflow as tf
2 import onnx
3 import onnxsim
4 import numpy as np
5 import torch
6 from model.facedetector_model import mobilenetv2_yolov3
7
8 #提取pb模型中的参数
9 def extract_params_from_pb():
10 constant_values = {}
11 with tf.compat.v1.Session() as sess:
12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
13 graph_def = tf.compat.v1.GraphDef()
14 graph_def.ParseFromString(f.read())
15 sess.graph.as_default()
16 tf.import_graph_def(graph_def, name='')
17 # # input
18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
19 # # output
20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
21 # sess.run(output, feed_dict={'input/input_data:0': inputimage})
22
23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
24 for constant_op in constant_ops:
25 if constant_op.op_def.name == "Const":
26 if "Shape" in constant_op.name or "pred" in constant_op.name:
27 continue
28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
29 return constant_values
30
31 #过滤提取出来的params
32 def filter_params(constant_values):
33 total = 0
34 prompt = []
35 res = {}
36 forbidden = ['shape','stack']
37
38 for k,v in constant_values.items():
39 # filtering some by checking ndim and name
40 if v.ndim<1: continue
41 if v.ndim==1:
42 token = k.split(r'/')[-1]
43 flag = False
44 for word in forbidden:
45 if token.find(word)!=-1:
46 flag = True
47 break
48 if flag:
49 continue
50
51 shape = v.shape
52 cnt = 1
53 for dim in shape:
54 cnt *= dim
55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
56 res[k] = v
57 print(prompt[-1])
58 total += cnt
59 prompt.append('totaling {}'.format(total))
60 # print(prompt[-1])
61 return res
62
63 #将Tensorflow的张量转换成PyTorch的张量
64 def trans_tensor_pb2pth(k,a):
65
66 v = tf.convert_to_tensor(a).numpy()
67 # tensorflow weights to pytorch weights
68 if len(v.shape) == 4:
69 if "depthwise_weights" in k:#防止深度可分离卷积
70 return np.ascontiguousarray(v.transpose(2,3,0,1))
71 return np.ascontiguousarray(v.transpose(3,2,0,1))
72 elif len(v.shape) == 2:
73 return np.ascontiguousarray(v.transpose())
74 return v
75
76 #将pb的对应params名字转换为pth对应参数名
77 def trans_name_pb2pth(trans_weights):
78 model_dict = {}
79 for name,para in trans_weights.items():
80 name = name.replace('/',".")
81
82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
83 name = name.replace('weights',"0.weight")
84 name = name.replace('BatchNorm',"1")
85 name = name.replace('gamma',"weight")
86 name = name.replace('beta',"bias")
87 name = name.replace('moving_mean',"running_mean")
88 name = name.replace('moving_variance',"running_var")
89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
90 name = name.replace('depthwise.',"0.")
91 name = name.replace('project',"1")
92 name = name.replace('depthwise_weights',"0.weight")
93 name = name.replace('weights',"0.weight")
94 name = name.replace('BatchNorm',"1")
95 name = name.replace('gamma',"weight")
96 name = name.replace('beta',"bias")
97 name = name.replace('moving_mean',"running_mean")
98 name = name.replace('moving_variance',"running_var")
99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100 name = name.replace('expand.',"0.")
101 name = name.replace('depthwise.',"1.")
102 name = name.replace('project',"2")
103 name = name.replace('depthwise_weights',"0.weight")
104 name = name.replace('weights',"0.weight")
105 name = name.replace('BatchNorm',"1")
106 name = name.replace('gamma',"weight")
107 name = name.replace('beta',"bias")
108 name = name.replace('moving_mean',"running_mean")
109 name = name.replace('moving_variance',"running_var")
110 elif "yolo-v3" in name:
111 if "bbox" in name:
112 continue
113 name = name.replace('yolo-v3',"yolo_v3")
114 name = name.replace('weight',"0.weight")
115 name = name.replace('kernel',"weight")
116 name = name.replace('batch_normalization',"1")
117 name = name.replace('gamma',"weight")
118 name = name.replace('beta',"bias")
119 name = name.replace('moving_mean',"running_mean")
120 name = name.replace('moving_variance',"running_var")
121 print(name)
122 model_dict[name] = torch.Tensor(para)
123 return model_dict
124
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127 constant_values = extract_params_from_pb()
128 TF_weights = filter_params(constant_values)
129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130
131 #创建pytorch模型
132 PyTorchModel = mobilenetv2_yolov3()
133 model_dict = trans_name_pb2pth(trans_weights)
134 # model_dict = PyTorchModel.state_dict()
135 # for name in model_dict.keys():
136 # print(name)
137 PyTorchModel.load_state_dict(model_dict)
138 PyTorchModel.cuda().eval()
139 dummy_input = torch.rand(1,1,224,224,device="cuda").float()
140 # out = PyTorchModel(dummy_input)
141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142 print("====> Simplifying...")
143 model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144 onnx.save(model_opt, 'P3mNet_sim.onnx')
145 print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()

将pb模型参数提取转成torch模型的更多相关文章

  1. 利用反射将Datatable、SqlDataReader转换成List模型

    1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...

  2. (原)torch模型转pytorch模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...

  3. 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型

    在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...

  4. 【tensorflow-v2.0】如何将模型转换成tflite模型

    前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...

  5. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  6. 使用C#语言,将DataTable 转换成域模型

    DataTable dt = SqlHelper.Query(strQuery); ) * size).Take(pagesize); List<Model> listData = new ...

  7. PB之取下来列修改后的值(AcceptText)

    AcceptText()功能 将“漂浮”在数据窗口控件上编辑框的内容放入到数据窗口控件的当前项中(主缓区中).在将数据放入到当前项之前,编辑框中的数据必须通过有效性规则检查语法  dwcontrol. ...

  8. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

  9. MxNet 模型转Tensorflow pb模型

    用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...

  10. iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型

    在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...

随机推荐

  1. 使用Shapefile-js读取shp文件并使用WebGL绘制

    1. 引言 坐标数据是空间数据文件的核心,空间数据的数据量往往是很大的.数据可视化是GIS的一个核心应用,绘制海量的坐标数据始终是一个考验设备性能的难题,使用GPU进行绘制可有效减少CPU的负载,提升 ...

  2. 微信小程序分享百度网盘文件的实现思路

    需求: 在小程序中点击按钮,获取百度网盘文件的下载地址. 实现思路: 1.网盘文件的下载地址,使用官方API只能自己下载,别人通过dlink无法下载,所以采用网页端生成接口. 好处是可以自定义提取码, ...

  3. Solidity8.0-02

    对应崔棉大师 26-40课程https://www.bilibili.com/video/BV1yS4y1N7yu/?spm_id_from=333.788&vd_source=c81b130 ...

  4. perlist

    1 <!DOCTYPE html> 2 <html> 3 <head> 4 <meta charset="utf-8"> 5 < ...

  5. 数据类型之字符串(string)(一)

    1.引号括起的都是字符串(可以时空格),可以是''(单引号).""(双引号).''''''(三引号).""""""(我还 ...

  6. 四大组件之广播接收者BroadcastReceiver

    参考:Android开发基础之广播接收者BroadcastReceiver 什么是广播接收者? 我们小时候都知道,听广播,收听广播!什么是收听广播呢?打开收音机,调频就可以收到对应的广播节目了.其实我 ...

  7. docker自动化启动停止脚本

    docker一键启动命令 sh auto.sh [start|restart|stop] [keywords...] keywords可选(包含编号,镜像名,容器名称,端口) 其中defaultLis ...

  8. golang json 字符串 用 json.Number 解析字段

    不定义结构体,用map 解析json 字符串字段 func main() { jsonString := `{"age": 20, "height": 180 ...

  9. vue2和vue3配置全局自定义参数及vue3动态绑定ref

    在 Vue2.x 中我们可以通过 Vue.prototype 添加全局属性 property.但是在 Vue3.x 中需要将 Vue.prototype 替换为 config.globalProper ...

  10. 图片视频二进制流base64加密

    一:读取图片或者视频,转换二进制流,进行Base64加密 @PostMapping("/base64Encoder") public StringBuilder changeIma ...