开发组在开发过程中,都不可避免地遇到了一些困难或问题,但都最终想出办法克服了。我们认为这样的经验是有必要记录下来的,因此就有了【技术博客】。

Pytorch代码生成经验文档

关于模型代码的生成,主要思路为从根节点开始进行广度优先搜索,从而自顶向下依次生成相关层的代码。这里和搜索相关的主要有三个数据结构:

  • Q:队列,记录后续继续搜索的节点,即为后续的Node。
  • graph:字典,记录整颗搜索树,每个key对应一个Node,Node为自己封装的一个类,里面包含每层的一些信息。记录搜索树的目的是为了后续的正确性验证,如下为Node的定义:
class Node:
def __init__(self, id = None, name = None, in_channels = 1, out_channels = 1, kernel_size = 3,
stride = 1, padding = 0, data = None, activity = None, pool_way = None, cat_dim = None):
self.fa = np.array([], dtype = str)
self.next = np.array([], dtype = str)
self.id = id
self.name = name
self.data = data
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.pool_way = pool_way
self.activity = activity
self.data_shape = np.array([], dtype = int)
self.cat_dim = cat_dim def add_fa(self, f):
self.fa = np.append(self.fa, f) def add_next(self, nx):
self.next = np.append(self.next, nx)
  • done:字典,记录某节点相关代码是否已经生成,每个key对应一个boolean值。

同时还有以下需要关注的地方:

  • 广度优先搜索。BFS为代码的主要框架。从’start’节点开始搜索,直到遍历结束,做一个线性的扫描。代码框架如下(省略了主要代码):

    def make_graph(nets, nets_conn, init_func, forward_func):
    #code here Q = queue.Queue()
    Q.put(‘start’) #code here while not Q.empty():
    cur_id = Q.get()
    if GL.done[cur_id]:
    continue '''''''''''' Main codes here '''''''''''' GL.done[cur_id] = True return init_func, forward_func
  • 关于全局变量的处理。由于一开始忽略了python变量的特性(不需要声明),所以在一开始第一全局变量的时候是直接定义在文件开头的,但是这样存在的问题是:如果在局部函数中引用全局变量,则此时则是重新定义了一个变量而不是引用,用global关键字代码看上去又很臃肿。所以采取的办法是重新定义了一个GLOB模块,里面存放着需要的所有全局变量。类似于这样:

    class GLOB:
    def __init__(self):
    self.graph = {}
    self.done = {}
    self.layer_used_time = {'view_layer': 0, 'linear_layer': 0, 'conv1d_layer': 0, 'conv2d_layer': 0, 'element_wise_add_layer':0, 'concatenate_layer':0}
    self.nn_linear = 'torch.nn.Linear'
    self.nn_conv1d = 'torch.nn.Conv1d'
    self.nn_conv2d = 'torch.nn.Conv2d'
    self.nn_view = '.view'
    self.nn_sequential = 'torch.nn.Sequential'
    self.start_layer = ['start']
    self.norm_layer = ['conv1d_layer', 'conv2d_layer', 'view_layer', 'linaer_layer']
    self.multi_layer = ['element_wise_add_layer', 'concatenate_layer']
    self.layers_except_start = self.norm_layer + self.multi_layer

    这样,只需要在代码里初始化一个GLOB对象GL,这样在任何地方引用全局变量都不会造成困扰。

  • 关于变量名生成。每层的输出数据的名字格式为:层名 + “data_出现的次数”。有一个数据结构”layer_used_time”(字典)专门负责记录每个层出现的次数,同时,会在该层的代码生成结构后更新layer_used_time和done的值。

  • 关于何时初始化和更新graph。在我们的代码中,当从队列中取出一个节点后会执行一个函数:get_next_nodes_and_update_pre_nodes()。该函数的目的是获取和初始化当前节点的儿子节点,记录前端传入该层的其他参数,更新其父子节点,同时返回当前节点的所有祖先节点代码是否已经生成完毕。另外,在该函数内部也会做模型的一部分正确性验证,主要验证搭建的模型里除了拼接层和相加层以外的层是否存在多个父节点或没有节点。该函数实现的功能较多,后期会考虑重构。

  • 关于正确性验证。考虑到用户在搭建模型时不一定能够保证参数的正确,所以我们对参数的合理性是“宽容”的,但是也有硬性的要求,比如只能有一个开始节点,同时除了拼接层和相加层可以有多个父节点以外,其他层有且仅有一个父节点。

  • 关于生成的模型NET中forward函数的返回值。由于搭建的模型允许出现网状结构,所以不能保证模型的出口只有一个,所以现阶段生成的模型会返回所有出度为0的层的输出值,具体顺序参见代码。

附最终生成的代码效果图(例):

【技术博客】Pytorch代码生成的更多相关文章

  1. 如何写出高质量的技术博客 这边文章出自http://www.jianshu.com/p/ae9ab21a5730 觉得不错直接拿过来了 好东西要大家分享嘛

        如何写出高质量的技术博客?答案是:如果你想,就一定能写出高质量的技术博客.看起来很唯心,但这就是事实.有足够愿力去做一件目标明确,有良好反馈系统的事情往往很简单.就是不停地训练,慢慢地,你自己 ...

  2. ******IT公司面试题汇总+优秀技术博客汇总

    滴滴面试题:滴滴打车数据库如何拆分 前端时间去滴滴面试,有一道题目是这样的,滴滴每天有100万的订单,如果让你去设计数据库,你会怎么去设计? 当时我的想法是根据用户id的最后一位对某个特殊的值取%操作 ...

  3. 转: BAT等研发团队的技术博客

    BAT 技术团队博客   1. 美团技术团队博客:  地址: http://tech.meituan.com/ 2. 腾讯社交用户体验设计(ISUX) 地址:http://isux.tencent.c ...

  4. 解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客 http://hsj69106.blog.51cto.com/1017401/595598/

    解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客  http://hsj69106.blog.51cto.com/1017401/595598/

  5. 欢迎访问我的最新个人技术博客http://zhangxuefei.top

    博客园已停止更新,欢迎访问我的最新个人技术博客http://zhangxuefei.top

  6. 技术博客(初用markdown)。

    技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容 输出不同的三位数 以下为代码和输出结果 *** #include<stdio.h> ...

  7. 技术博客(初用markdown)

    技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容. 1 如果想输出多个多位数的时候,可以尝试用多个if语句.如果需要输出3为数的时候,设置三个变 ...

  8. 【转】【技术博客】Spark性能优化指南——高级篇

    http://mp.weixin.qq.com/s?__biz=MjM5NjQ5MTI5OA==&mid=2651745207&idx=1&sn=3d70d59cede236e ...

  9. 作业一:创建个人技术博客、自我介绍、简单的C程序

    年9月14日中午12点: 一.主要内容  建个人技术博客(博客园 www.cnblogs.com) 本学期将通过写博客的方式提交作业,实际上,最终的目的是希望同学们能通过博客的形式记录我们整个学习过程 ...

随机推荐

  1. JS基础 —— call、apply 和 bind

    函数的三个原型方法 作用:改变this指向 call MDN:https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Glo ...

  2. jQuery 前端复选框 全选 反选 下拉菜单联动

    jQuery 页面中复选框全选.反选.下拉联动(级联) <!DOCTYPE html> <html lang="en"> <head> < ...

  3. jQuery选择器与过滤器(二)

    一.jQuery选择器1.基本选择器:所有选择器    *标签选择器    标签名ID选择器    #ID类选择器    .className组合选择器    selector1,selector2 ...

  4. 15、vue项目封装axios并访问接口

    1.在src下新建util文件夹,在util下新建request.js文件: 封装axios: import axios from 'axios' import QS from 'qs'; // im ...

  5. Solr基础理论【倒排索引,模糊查询】

    一.简介 现有的许多不同类型 的技术系统,如关系型数据库.键值存储.操作磁盘文件的map-reduce[映射-规约]引擎.图数据库等,都是为了帮助用户解决颇具挑战性的数据存储与检索问题而设计的.而搜索 ...

  6. mac php Swoole入门

    一. swoole 扩展安装 安装前必须保证系统已经安装了下列软件 php-7.0 或更高版本 gcc-4.8 或更高版本 make autoconf pcre (CentOS系统可以执行命令:yum ...

  7. PHP在无限分类时注意的一些问题(不保证代码完全正确哦)

    转自:PHP在无限分类时注意的一些问题(http://lxiaoke.cn) (注意:代码使用的是原生PHP,旨在提供解决思路)1 无限分类的查找(获取所有节点) 代码: /** * 无限分类查询,默 ...

  8. 记vue nextTick用到的地方

    nextTick是vue提供的全局函数,在下次 DOM 更新循环结束之后执行延迟回调.在修改数据之后立即使用这个方法,获取更新后的 DOM. // 修改数据 vm.msg = 'Hello' // D ...

  9. spring cloud微服务实战教程/pdf/视频/百度云资源

    资源站:http://www.supan.vip 点击进入直接查找资源: http://www.supan.vip/spring%20cloud微服务实战 <Spring Cloud微服务实战& ...

  10. shell脚本返回值问题

    如果学习过高级语言比如java和c语言等,此时你要是获取一个函数的返回值,直接在函数里面写上return即可,然后在函数执行时将返回结果赋值给某个变量即可.但是在shell脚本中限制较多,因此如果我们 ...