1 import tensorflow as tf
2 import onnx
3 import onnxsim
4 import numpy as np
5 import torch
6 from model.facedetector_model import mobilenetv2_yolov3
7
8 #提取pb模型中的参数
9 def extract_params_from_pb():
10 constant_values = {}
11 with tf.compat.v1.Session() as sess:
12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
13 graph_def = tf.compat.v1.GraphDef()
14 graph_def.ParseFromString(f.read())
15 sess.graph.as_default()
16 tf.import_graph_def(graph_def, name='')
17 # # input
18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
19 # # output
20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
21 # sess.run(output, feed_dict={'input/input_data:0': inputimage})
22
23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
24 for constant_op in constant_ops:
25 if constant_op.op_def.name == "Const":
26 if "Shape" in constant_op.name or "pred" in constant_op.name:
27 continue
28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
29 return constant_values
30
31 #过滤提取出来的params
32 def filter_params(constant_values):
33 total = 0
34 prompt = []
35 res = {}
36 forbidden = ['shape','stack']
37
38 for k,v in constant_values.items():
39 # filtering some by checking ndim and name
40 if v.ndim<1: continue
41 if v.ndim==1:
42 token = k.split(r'/')[-1]
43 flag = False
44 for word in forbidden:
45 if token.find(word)!=-1:
46 flag = True
47 break
48 if flag:
49 continue
50
51 shape = v.shape
52 cnt = 1
53 for dim in shape:
54 cnt *= dim
55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
56 res[k] = v
57 print(prompt[-1])
58 total += cnt
59 prompt.append('totaling {}'.format(total))
60 # print(prompt[-1])
61 return res
62
63 #将Tensorflow的张量转换成PyTorch的张量
64 def trans_tensor_pb2pth(k,a):
65
66 v = tf.convert_to_tensor(a).numpy()
67 # tensorflow weights to pytorch weights
68 if len(v.shape) == 4:
69 if "depthwise_weights" in k:#防止深度可分离卷积
70 return np.ascontiguousarray(v.transpose(2,3,0,1))
71 return np.ascontiguousarray(v.transpose(3,2,0,1))
72 elif len(v.shape) == 2:
73 return np.ascontiguousarray(v.transpose())
74 return v
75
76 #将pb的对应params名字转换为pth对应参数名
77 def trans_name_pb2pth(trans_weights):
78 model_dict = {}
79 for name,para in trans_weights.items():
80 name = name.replace('/',".")
81
82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
83 name = name.replace('weights',"0.weight")
84 name = name.replace('BatchNorm',"1")
85 name = name.replace('gamma',"weight")
86 name = name.replace('beta',"bias")
87 name = name.replace('moving_mean',"running_mean")
88 name = name.replace('moving_variance',"running_var")
89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
90 name = name.replace('depthwise.',"0.")
91 name = name.replace('project',"1")
92 name = name.replace('depthwise_weights',"0.weight")
93 name = name.replace('weights',"0.weight")
94 name = name.replace('BatchNorm',"1")
95 name = name.replace('gamma',"weight")
96 name = name.replace('beta',"bias")
97 name = name.replace('moving_mean',"running_mean")
98 name = name.replace('moving_variance',"running_var")
99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100 name = name.replace('expand.',"0.")
101 name = name.replace('depthwise.',"1.")
102 name = name.replace('project',"2")
103 name = name.replace('depthwise_weights',"0.weight")
104 name = name.replace('weights',"0.weight")
105 name = name.replace('BatchNorm',"1")
106 name = name.replace('gamma',"weight")
107 name = name.replace('beta',"bias")
108 name = name.replace('moving_mean',"running_mean")
109 name = name.replace('moving_variance',"running_var")
110 elif "yolo-v3" in name:
111 if "bbox" in name:
112 continue
113 name = name.replace('yolo-v3',"yolo_v3")
114 name = name.replace('weight',"0.weight")
115 name = name.replace('kernel',"weight")
116 name = name.replace('batch_normalization',"1")
117 name = name.replace('gamma',"weight")
118 name = name.replace('beta',"bias")
119 name = name.replace('moving_mean',"running_mean")
120 name = name.replace('moving_variance',"running_var")
121 print(name)
122 model_dict[name] = torch.Tensor(para)
123 return model_dict
124
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127 constant_values = extract_params_from_pb()
128 TF_weights = filter_params(constant_values)
129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130
131 #创建pytorch模型
132 PyTorchModel = mobilenetv2_yolov3()
133 model_dict = trans_name_pb2pth(trans_weights)
134 # model_dict = PyTorchModel.state_dict()
135 # for name in model_dict.keys():
136 # print(name)
137 PyTorchModel.load_state_dict(model_dict)
138 PyTorchModel.cuda().eval()
139 dummy_input = torch.rand(1,1,224,224,device="cuda").float()
140 # out = PyTorchModel(dummy_input)
141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142 print("====> Simplifying...")
143 model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144 onnx.save(model_opt, 'P3mNet_sim.onnx')
145 print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()

将pb模型参数提取转成torch模型的更多相关文章

  1. 利用反射将Datatable、SqlDataReader转换成List模型

    1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...

  2. (原)torch模型转pytorch模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...

  3. 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型

    在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...

  4. 【tensorflow-v2.0】如何将模型转换成tflite模型

    前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...

  5. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  6. 使用C#语言,将DataTable 转换成域模型

    DataTable dt = SqlHelper.Query(strQuery); ) * size).Take(pagesize); List<Model> listData = new ...

  7. PB之取下来列修改后的值(AcceptText)

    AcceptText()功能 将“漂浮”在数据窗口控件上编辑框的内容放入到数据窗口控件的当前项中(主缓区中).在将数据放入到当前项之前,编辑框中的数据必须通过有效性规则检查语法  dwcontrol. ...

  8. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

  9. MxNet 模型转Tensorflow pb模型

    用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...

  10. iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型

    在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...

随机推荐

  1. 慧销平台ThreadPoolExecutor内存泄漏分析

    作者:京东零售 冯晓涛 问题背景 京东生旅平台慧销系统,作为平台系统对接了多条业务线,主要进行各个业务线广告,召回等活动相关内容与能力管理. 最近根据告警发现内存持续升高,每隔2-3天会收到内存超过阈 ...

  2. MTU设置不当导致ssh运行命令卡死

    MTU:最大网络传输单元,计算机网络课会介绍. 场景: 本地通过VPN连接某个机房内网的linux服务器,连接上之后,运行top命令.vi命令.yum update等需要刷新大量内容时导致ssh卡死, ...

  3. python之定时任务APScheduler

    一.APScheduler APScheduler全称Advanced Python Scheduler 作用为在指定的时间规则执行指定的作业. 指定时间规则的方式可以是间隔多久执行,可以是指定日期时 ...

  4. oracle to mogdb 迁移---mtk工具

    ## 一.MTK工具介绍--------- MTK–异构数据迁移工具 MTK全称为 Database Migration Toolkit,是一个可以将Oracle/DB2/MySQL/openGaus ...

  5. react 03 组件传值

    一 基础 props: 父传子  单向 import React from 'react'; import ReactDOM from 'react-dom'; import './index.css ...

  6. redis底层数据结构之跳表(skiplist)

    跳表(跳跃表, skiplist) 跳跃表(skiplist)是用于有序元素序列快速搜索查找的数据结构,跳表是一个随机化的数据结构,实质是一种可以进行二分查找的.具有层次结构的有序链表 跳表在原有的有 ...

  7. 创建异步倒计时触发Task

    https://www.cnblogs.com/shanfeng1000/p/13402152.html //Task关闭 CancellationTokenSource cancel = new C ...

  8. 【SSO单点系列】(7):CAS4.0 二级域名

    CAS4.0 二级域名 一.描述 当cas成功登录后如果访问同一域名下的资源是 被当作同一应用下资源不需要再次请求登录,但是如果二级域名不同会 被当作不同应用在访问 需要请求CAS 在请求时会把TGC ...

  9. react lodash节流this找不到正确用法

    if (!this.throttleLoadDicom) { this.throttleLoadDicom = throttle(this.loadDicomFun, 800, { leading: ...

  10. Jvm 相关记录

    ## 内存分析工具- JConsole.JVisualVM- gperftools Linux 安装- MAT ## JVM Tools• jps: java process status jps - ...