【技术博客】Pytorch代码生成
开发组在开发过程中,都不可避免地遇到了一些困难或问题,但都最终想出办法克服了。我们认为这样的经验是有必要记录下来的,因此就有了【技术博客】。
Pytorch代码生成经验文档
关于模型代码的生成,主要思路为从根节点开始进行广度优先搜索,从而自顶向下依次生成相关层的代码。这里和搜索相关的主要有三个数据结构:
- Q:队列,记录后续继续搜索的节点,即为后续的Node。
- graph:字典,记录整颗搜索树,每个key对应一个Node,Node为自己封装的一个类,里面包含每层的一些信息。记录搜索树的目的是为了后续的正确性验证,如下为Node的定义:
class Node:
def __init__(self, id = None, name = None, in_channels = 1, out_channels = 1, kernel_size = 3,
stride = 1, padding = 0, data = None, activity = None, pool_way = None, cat_dim = None):
self.fa = np.array([], dtype = str)
self.next = np.array([], dtype = str)
self.id = id
self.name = name
self.data = data
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.pool_way = pool_way
self.activity = activity
self.data_shape = np.array([], dtype = int)
self.cat_dim = cat_dim
def add_fa(self, f):
self.fa = np.append(self.fa, f)
def add_next(self, nx):
self.next = np.append(self.next, nx)
- done:字典,记录某节点相关代码是否已经生成,每个key对应一个boolean值。
同时还有以下需要关注的地方:
广度优先搜索。BFS为代码的主要框架。从’start’节点开始搜索,直到遍历结束,做一个线性的扫描。代码框架如下(省略了主要代码):
def make_graph(nets, nets_conn, init_func, forward_func):
#code here Q = queue.Queue()
Q.put(‘start’) #code here while not Q.empty():
cur_id = Q.get()
if GL.done[cur_id]:
continue '''''''''''' Main codes here '''''''''''' GL.done[cur_id] = True return init_func, forward_func
关于全局变量的处理。由于一开始忽略了python变量的特性(不需要声明),所以在一开始第一全局变量的时候是直接定义在文件开头的,但是这样存在的问题是:如果在局部函数中引用全局变量,则此时则是重新定义了一个变量而不是引用,用global关键字代码看上去又很臃肿。所以采取的办法是重新定义了一个GLOB模块,里面存放着需要的所有全局变量。类似于这样:
class GLOB:
def __init__(self):
self.graph = {}
self.done = {}
self.layer_used_time = {'view_layer': 0, 'linear_layer': 0, 'conv1d_layer': 0, 'conv2d_layer': 0, 'element_wise_add_layer':0, 'concatenate_layer':0}
self.nn_linear = 'torch.nn.Linear'
self.nn_conv1d = 'torch.nn.Conv1d'
self.nn_conv2d = 'torch.nn.Conv2d'
self.nn_view = '.view'
self.nn_sequential = 'torch.nn.Sequential'
self.start_layer = ['start']
self.norm_layer = ['conv1d_layer', 'conv2d_layer', 'view_layer', 'linaer_layer']
self.multi_layer = ['element_wise_add_layer', 'concatenate_layer']
self.layers_except_start = self.norm_layer + self.multi_layer这样,只需要在代码里初始化一个GLOB对象GL,这样在任何地方引用全局变量都不会造成困扰。
关于变量名生成。每层的输出数据的名字格式为:层名 + “data_出现的次数”。有一个数据结构”layer_used_time”(字典)专门负责记录每个层出现的次数,同时,会在该层的代码生成结构后更新layer_used_time和done的值。
关于何时初始化和更新graph。在我们的代码中,当从队列中取出一个节点后会执行一个函数:get_next_nodes_and_update_pre_nodes()。该函数的目的是获取和初始化当前节点的儿子节点,记录前端传入该层的其他参数,更新其父子节点,同时返回当前节点的所有祖先节点代码是否已经生成完毕。另外,在该函数内部也会做模型的一部分正确性验证,主要验证搭建的模型里除了拼接层和相加层以外的层是否存在多个父节点或没有节点。该函数实现的功能较多,后期会考虑重构。
关于正确性验证。考虑到用户在搭建模型时不一定能够保证参数的正确,所以我们对参数的合理性是“宽容”的,但是也有硬性的要求,比如只能有一个开始节点,同时除了拼接层和相加层可以有多个父节点以外,其他层有且仅有一个父节点。
关于生成的模型NET中forward函数的返回值。由于搭建的模型允许出现网状结构,所以不能保证模型的出口只有一个,所以现阶段生成的模型会返回所有出度为0的层的输出值,具体顺序参见代码。
附最终生成的代码效果图(例):
【技术博客】Pytorch代码生成的更多相关文章
- 如何写出高质量的技术博客 这边文章出自http://www.jianshu.com/p/ae9ab21a5730 觉得不错直接拿过来了 好东西要大家分享嘛
如何写出高质量的技术博客?答案是:如果你想,就一定能写出高质量的技术博客.看起来很唯心,但这就是事实.有足够愿力去做一件目标明确,有良好反馈系统的事情往往很简单.就是不停地训练,慢慢地,你自己 ...
- ******IT公司面试题汇总+优秀技术博客汇总
滴滴面试题:滴滴打车数据库如何拆分 前端时间去滴滴面试,有一道题目是这样的,滴滴每天有100万的订单,如果让你去设计数据库,你会怎么去设计? 当时我的想法是根据用户id的最后一位对某个特殊的值取%操作 ...
- 转: BAT等研发团队的技术博客
BAT 技术团队博客 1. 美团技术团队博客: 地址: http://tech.meituan.com/ 2. 腾讯社交用户体验设计(ISUX) 地址:http://isux.tencent.c ...
- 解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客 http://hsj69106.blog.51cto.com/1017401/595598/
解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客 http://hsj69106.blog.51cto.com/1017401/595598/
- 欢迎访问我的最新个人技术博客http://zhangxuefei.top
博客园已停止更新,欢迎访问我的最新个人技术博客http://zhangxuefei.top
- 技术博客(初用markdown)。
技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容 输出不同的三位数 以下为代码和输出结果 *** #include<stdio.h> ...
- 技术博客(初用markdown)
技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容. 1 如果想输出多个多位数的时候,可以尝试用多个if语句.如果需要输出3为数的时候,设置三个变 ...
- 【转】【技术博客】Spark性能优化指南——高级篇
http://mp.weixin.qq.com/s?__biz=MjM5NjQ5MTI5OA==&mid=2651745207&idx=1&sn=3d70d59cede236e ...
- 作业一:创建个人技术博客、自我介绍、简单的C程序
年9月14日中午12点: 一.主要内容 建个人技术博客(博客园 www.cnblogs.com) 本学期将通过写博客的方式提交作业,实际上,最终的目的是希望同学们能通过博客的形式记录我们整个学习过程 ...
随机推荐
- Matlab桥接模式
桥接模式(Bridge)是一种结构型设计模式.它是用组合关系代替继承关系来实现,可以处理多维度变化的场景(https://blog.csdn.net/qq_31156277/article/detai ...
- SpringBoot,SSM和SSH
Springboot的概念: 是提供的全新框架,使用来简化Spring的初始搭建和开发过程,使用了特定的方式来进行配置,让开发人员不在需要定义样板化的配置.此框架不需要配置xml,依赖于想MAVEN这 ...
- 为什么会有jQuery、Dojo、Ext、Prototype、YUI、Zepto这么多JS包?
目前流行的JS框架很多Dojo .Scriptaculous .Prototype .yui-ext .Jquery .Mochikit.mootools .moo.fx 等.当然还有很多我都不熟悉的 ...
- android黑白屏的问题
你会很奇怪,为什么有些app启动时,会出现一会儿的黑屏或者白屏才进入Activity的界面显示,但是有些app却不会如QQ手机端,的确这里要做处理一下.这里先了解一下为什么会出现这样的现象,其实很简单 ...
- Filter和Listener
Filter: 1.概念: web中的过滤器:当访问服务器的资源时,过滤器可以将请求拦截下来,做一些事. 过滤器的作用:一般用于完成一些通用的操作:登录验证.统一编码处理,敏感字符处理.... 2.快 ...
- Request和Response。
复习点:1.重定向问题 2.输出字符串到浏览器.3.文件下载需求:1. 页面显示超链接2. 点击超链接后弹出下载提示框3. 完成图片文件下载 Request和Response Request: 1. ...
- H3C 什么是漫游
- 算法dfs——二叉搜索树中最接近的值 II
901. 二叉搜索树中最接近的值 II 中文 English 给定一棵非空二叉搜索树以及一个target值,找到 BST 中最接近给定值的 k 个数. 样例 样例 1: 输入: {1} 0.00000 ...
- 利用form.submit提交表单导出文件到客户端浏览器, 提示下载!
本来是想利用ajax提交json数据到服务端, 让服务端生成一个excel文件并提示客户端浏览器下载的. 但是搞了很久发现ajax方式是无法触发浏览器弹出文件下载的. 网上很多的方案都是说利用form ...
- sed 用法
sed 用法 sed的其他用法如下: 1.删除行首空格 sed 's/^[ ]*//g' filename sed 's/^ *//g' filename sed 's/^[[:space:]]*// ...