pytorch源码解析:Python层 pytorchmodule源码
尝试使用了pytorch,相比其他深度学习框架,pytorch显得简洁易懂。花时间读了部分源码,主要结合简单例子带着问题阅读,不涉及源码中C拓展库的实现。
一个简单例子
实现单层softmax二分类,输入特征维度为4,输出为2,经过softmax函数得出输入的类别概率。代码示意:定义网络结构;使用SGD优化;迭代一次,随机初始化三个样例,每个样例四维特征,target分别为1,0,1;前向传播,使用交叉熵计算loss;反向传播,最后由优化算法更新权重,完成一次迭代。
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.linear = nn.Linear(4, 2)
- def forward(self, input):
- out = F.softmax(self.linear(input))
- return out
- net = Net()
- sgd = torch.optim.SGD(net.parameters(), lr=0.001)
- for epoch in range(1):
- features = torch.autograd.Variable(torch.randn(3, 4), requires_grad=True)
- target = torch.autograd.Variable(torch.LongTensor([1, 0, 1]))
- sgd.zero_grad()
- out = net(features)
- loss = F.cross_entropy(out, target)
- loss.backward()
- sgd.step()
从上面的例子,带着下面的问题阅读源码:
- pytorch的主要概念:Tensor、autograd、Variable、Function、Parameter、Module(Layers)、Optimizer;
- 自定义Module如何组织网络结构和网络参数;
- 前向传播、反向传播实现流程
- 优化算法类如何实现,如何和自定义Module联系并更新参数。
pytorch的主要概念
pytorch的主要概念官网有很人性化的教程Deep Learning with PyTorch: A 60 Minute Blitz, 这里简单概括这些概念:
Tensor
类似numpy的ndarrays,强化了可进行GPU计算的特性,由C拓展模块实现。如上面的torch.randn(3, 4) 返回一个3*4的Tensor。和numpy一样,也有一系列的Operation,如
- x = torch.rand(5, 3)
- y = torch.rand(5, 3)
- print x + y
- print torch.add(x, y)
- print x.add_(y)
Varaiable与autograd
Variable封装了Tensor,包括了几乎所有的Tensor可以使用的Operation方法,主要使用在自动求导(autograd),Variable类继承_C._VariableBase,由C拓展类定义实现。
Variable是autograd的计算单元,Variable通过Function组织成函数表达式(计算图):
- data 为其封装的tensor值
- grad 为其求导后的值
- creator 为创建该Variable的Function,实现中grad_fn属性则指向该Function。
如:- import torch
- from torch.autograd import Variable
- x = Variable(torch.ones(2, 2), requires_grad=True)
- y = x + 2
- print y.grad_fn
- print "before backward: ", x.grad
- y.backward()
- print "after backward: ", x.grad
输出结果:
- <torch.autograd.function.AddConstantBackward object at 0x7faa6f3bdd68>
- before backward: None
- after backward: Variable containing:
- 1
- [torch.FloatTensor of size 1x1]
调用y的backward方法,则会对创建y的Function计算图中所有requires_grad=True的Variable求导(这里的x)。例子中显然dy/dx = 1。
Parameter
Parameter 为Variable的一个子类,后面还会涉及,大概两点区别:
- 作为Module参数会被自动加入到该Module的参数列表中;
- 不能被volatile, 默认require gradient。
Module
Module为所有神经网络模块的父类,如开始的例子,Net继承该类,____init____中指定网络结构中的模块,并重写forward方法实现前向传播得到指定输入的输出值,以此进行后面loss的计算和反向传播。
Optimizer
Optimizer是所有优化算法的父类(SGD、Adam、...),____init____中传入网络的parameters, 子类实现父类step方法,完成对parameters的更新。
自定义Module
该部分说明自定义的Module是如何组织定义在构造函数中的子Module,以及自定义的parameters的保存形式,eg:
- class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.linear = nn.Linear(4, 2)
- def forward(self, input):
- out = F.softmax(self.linear(input))
- return out
首先看构造函数,Module的构造函数初始化了Module的基本属性,这里关注_parameters和_modules,两个属性初始化为OrderedDict(),pytorch重写的有序字典类型。_parameters保存网络的所有参数,_modules保存当前Module的子Module。
module.py:
- class Module(object):
- def __init__(self):
- self._parameters = OrderedDict()
- self._modules = OrderedDict()
- ...
下面来看自定义Net类中self.linear = nn.Linear(4, 2)语句和_modules、_parameters如何产生联系,或者self.linear及其参数如何被添加到_modules、_parameters字典中。答案在Module的____setattr____方法,该Python内建方法会在类的属性被赋值时调用。
module.py:
- def __setattr__(self, name, value):
- def remove_from(*dicts):
- for d in dicts:
- if name in d:
- del d[name]
- params = self.__dict__.get('_parameters')
- if isinstance(value, Parameter): # ----------- <1>
- if params is None:
- raise AttributeError(
- "cannot assign parameters before Module.__init__() call")
- remove_from(self.__dict__, self._buffers, self._modules)
- self.register_parameter(name, value)
- elif params is not None and name in params:
- if value is not None:
- raise TypeError("cannot assign '{}' as parameter '{}' "
- "(torch.nn.Parameter or None expected)"
- .format(torch.typename(value), name))
- self.register_parameter(name, value)
- else:
- modules = self.__dict__.get('_modules')
- if isinstance(value, Module):# ----------- <2>
- if modules is None:
- raise AttributeError(
- "cannot assign module before Module.__init__() call")
- remove_from(self.__dict__, self._parameters, self._buffers)
- modules[name] = value
- elif modules is not None and name in modules:
- if value is not None:
- raise TypeError("cannot assign '{}' as child module '{}' "
- "(torch.nn.Module or None expected)"
- .format(torch.typename(value), name))
- modules[name] = value
- ......
调用self.linear = nn.Linear(4, 2)时,父类____setattr____被调用,参数name为“linear”, value为nn.Linear(4, 2),内建的Linear类同样是Module的子类。所以<2>中的判断为真,接着modules[name] = value,该linear被加入_modules字典。
同样自定义Net类的参数即为其子模块Linear的参数,下面看Linear的实现:
linear.py:
- class Linear(Module):
- def __init__(self, in_features, out_features, bias=True):
- super(Linear, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = Parameter(torch.Tensor(out_features, in_features))
- if bias:
- self.bias = Parameter(torch.Tensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight.size(1))
- self.weight.data.uniform_(-stdv, stdv)
- if self.bias is not None:
- self.bias.data.uniform_(-stdv, stdv)
- def forward(self, input):
- return F.linear(input, self.weight, self.bias)
同样继承Module类,____init____中参数为输入输出维度,是否需要bias参数。在self.weight = Parameter(torch.Tensor(out_features, in_features))的初始化时,同样会调用父类Module的____setattr____, name为“weight”,value为Parameter,此时<1>判断为真,调用self.register_parameter(name, value),该方法中对参数进行合法性校验后放入self._parameters字典中。
Linear在reset_parameters方法对权重进行了初始化。
最终可以得出结论自定义的Module以树的形式组织子Module,子Module及其参数以字典的方式保存。
前向传播、反向传播
前向传播
例子中out = net(features)实现了网络的前向传播,该语句会调用Module类的forward方法,该方法被继承父类的子类实现。net(features)使用对象作为函数调用,会调用Python内建的____call____方法,Module重写了该方法。
module.py:
- def __call__(self, *input, **kwargs):
- for hook in self._forward_pre_hooks.values():
- hook(self, input)
- result = self.forward(*input, **kwargs)
- for hook in self._forward_hooks.values():
- hook_result = hook(self, input, result)
- if hook_result is not None:
- raise RuntimeError(
- "forward hooks should never return any values, but '{}'"
- "didn't return None".format(hook))
- if len(self._backward_hooks) > 0:
- var = result
- while not isinstance(var, Variable):
- var = var[0]
- grad_fn = var.grad_fn
- if grad_fn is not None:
- for hook in self._backward_hooks.values():
- wrapper = functools.partial(hook, self)
- functools.update_wrapper(wrapper, hook)
- grad_fn.register_hook(wrapper)
- return result
____call____方法中调用result = self.forward(*input, **kwargs)前后会查看有无hook函数需要调用(预处理和后处理)。
例子中Net的forward方法中out = F.softmax(self.linear(input)),同样会调用self.linear的forward方法F.linear(input, self.weight, self.bias)进行矩阵运算(仿射变换)。
functional.py:
- def linear(input, weight, bias=None):
- if input.dim() == 2 and bias is not None:
- # fused op is marginally faster
- return torch.addmm(bias, input, weight.t())
- output = input.matmul(weight.t())
- if bias is not None:
- output += bias
- return output
最终经过F.softmax,得到前向输出结果。F.softmax和F.linear类似前面说到的Function(Parameters的表达式或计算图)。
反向传播
得到前向传播结果后,计算loss = F.cross_entropy(out, target),接下来反向传播求导数d(loss)/d(weight)和d(loss)/d(bias):
loss.backward()
backward()方法同样底层由C拓展,这里暂不深入,调用该方法后,loss计算图中的所有Variable(这里linear的weight和bias)的grad被求出。
Optimizer参数更新
在计算出参数的grad后,需要根据优化算法对参数进行更新,不同的优化算法有不同的更新策略。
optimizer.py:
- class Optimizer(object):
- def __init__(self, params, defaults):
- if isinstance(params, Variable) or torch.is_tensor(params):
- raise TypeError("params argument given to the optimizer should be "
- "an iterable of Variables or dicts, but got " +
- torch.typename(params))
- self.state = defaultdict(dict)
- self.param_groups = list(params)
- ......
- def zero_grad(self):
- """Clears the gradients of all optimized :class:`Variable` s."""
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is not None:
- if p.grad.volatile:
- p.grad.data.zero_()
- else:
- data = p.grad.data
- p.grad = Variable(data.new().resize_as_(data).zero_())
- def step(self, closure):
- """Performs a single optimization step (parameter update).
- Arguments:
- closure (callable): A closure that reevaluates the model and
- returns the loss. Optional for most optimizers.
- """
- raise NotImplementedError
Optimizer在init中将传入的params保存到self.param_groups,另外两个重要的方法zero_grad负责将参数的grad置零方便下次计算,step负责参数的更新,由子类实现。
以列子中的sgd = torch.optim.SGD(net.parameters(), lr=0.001)为例,其中net.parameters()返回Net参数的迭代器,为待优化参数;lr指定学习率。
SGD.py:
- class SGD(Optimizer):
- def __init__(self, params, lr=required, momentum=0, dampening=0,
- weight_decay=0, nesterov=False):
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
- weight_decay=weight_decay, nesterov=nesterov)
- if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
- super(SGD, self).__init__(params, defaults)
- def __setstate__(self, state):
- super(SGD, self).__setstate__(state)
- for group in self.param_groups:
- group.setdefault('nesterov', False)
- def step(self, closure=None):
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- loss = closure()
- for group in self.param_groups:
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- dampening = group['dampening']
- nesterov = group['nesterov']
- for p in group['params']:
- if p.grad is None:
- continue
- d_p = p.grad.data
- if weight_decay != 0:
- d_p.add_(weight_decay, p.data)
- if momentum != 0:
- param_state = self.state[p]
- if 'momentum_buffer' not in param_state:
- buf = param_state['momentum_buffer'] = d_p.clone()
- else:
- buf = param_state['momentum_buffer']
- buf.mul_(momentum).add_(1 - dampening, d_p)
- if nesterov:
- d_p = d_p.add(momentum, buf)
- else:
- d_p = buf
- p.data.add_(-group['lr'], d_p)
- return loss
SGD的step方法中,判断是否使用权重衰减和动量更新,如果不使用,直接更新权重param := param - lr * d(param)。例子中调用sgd.step()后完成一次epoch。这里由于传递到Optimizer的参数集是可更改(mutable)的,step中对参数的更新同样是Net中参数的更新。
小结
到此,根据一个简单例子阅读了pytorch中Python实现的部分源码,没有深入到底层Tensor、autograd等部分的C拓展实现,后面再继续读一读C拓展部分的代码。
转自链接:https://www.jianshu.com/p/f5eb8c2e671c
pytorch源码解析:Python层 pytorchmodule源码的更多相关文章
- 【vuejs深入二】vue源码解析之一,基础源码结构和htmlParse解析器
写在前面 一个好的架构需要经过血与火的历练,一个好的工程师需要经过无数项目的摧残. vuejs是一个优秀的前端mvvm框架,它的易用性和渐进式的理念可以使每一个前端开发人员感到舒服,感到easy.它内 ...
- Django生命周期 URL ----> CBV 源码解析-------------- 及rest_framework APIView 源码流程解析
一.一个请求来到Django 的生命周期 FBV 不讨论 CBV: 请求被代理转发到uwsgi: 开始Django的流程: 首先经过中间件process_request (session等) 然后 ...
- Tensorflow源码解析1 -- 内核架构和源码结构
1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...
- java架构之路-(SpringMVC篇)SpringMVC主要流程源码解析(上)源码执行流程
做过web项目的小伙伴,对于SpringMVC,Struts2都是在熟悉不过了,再就是我们比较古老的servlet,我们先来复习一下我们的servlet生命周期. servlet生命周期 1)初始化阶 ...
- Redis源码解析(1)——源码目录介绍
概念 redis是一个key-value存储系统.和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list(链表).set(集合)和zset(有序集合).这些 ...
- ubuntu源与常用python配置pip源(win)、pip常用命令
pip常用命令 ubuntu更新系统源 首先备份/etc/apt/sources.list mv /etc/apt/sources.list /etc/apt/sources.list.bak 然后下 ...
- SOFARPC源码解析-搭建环境
文档地址:https://www.sofastack.tech 简介摘要 SOFA 是蚂蚁金服自主研发的金融级分布式中间件,包含构建金融级云原生架构所需的各个组件,包括微服务研发框架,RPC 框架,服 ...
- Java并发包源码学习系列:基于CAS非阻塞并发队列ConcurrentLinkedQueue源码解析
目录 非阻塞并发队列ConcurrentLinkedQueue概述 结构组成 基本不变式 head的不变式与可变式 tail的不变式与可变式 offer操作 源码解析 图解offer操作 JDK1.6 ...
- Android 开源项目源码解析(第二期)
Android 开源项目源码解析(第二期) 阅读目录 android-Ultra-Pull-To-Refresh 源码解析 DynamicLoadApk 源码解析 NineOldAnimations ...
随机推荐
- js中定义变量之②var let const的区别
var 上一篇文章有讲过,是js定义变量的关键词. 但是在es6中,新添加了两个关键词,用于变量声明的关键词:let 和const 接下来就说一下var let 和const的区别: 首先说var 用 ...
- web前端学习(四)JavaScript学习笔记部分(10)-- JavaScript正则表达式
1.JavaScript正则表达式课程概要 方便查找字符串.数字.特殊字串等等 2.正则表达式的介绍 RegExp是正则表达式的缩写 当检索某个文本时,可以使用一种模式来描述要检索的内容.RegExp ...
- Python之路,Day2 - Python基础(转载Alex)
Day2-转自金角大王 本节内容 列表.元组操作 字符串操作 字典操作 集合操作 文件操作 字符编码与转码 1. 列表.元组操作 列表是我们最以后最常用的数据类型之一,通过列表可以对数据实现最方便的存 ...
- Django--登录功能
登录功能: 1.路由访问如果不加斜杠,内部会重定向加斜杠的路由 所有的html文件都默认卸载templates文件夹下面 所有的(css,js,前端第三方的类库)默认都放在static文件夹下 htm ...
- 初探iview
我的js功力还是TCL,太差了~ 运行iview官网例子还有它的工程文件都运行不出来.我非常感谢那些无私开源的博主,它们无私分享自己的技术,让我学到了很多东西. iview是vue的一个UI框架之一, ...
- 100个常用的原生JavaScript函数
1.原生JavaScript实现字符串长度截取 复制代码代码如下: function cutstr(str, len) { var temp; var icount = 0; var ...
- CentOS8/RHEL8--恢复root用户密码及简易加固GRUB
CentOS8/RHEL8--简易加固GRUB 今天突然想到放在数据中心的虚拟化平台下的Linux服务器,都是采用默认方式安装的,没有设置太多的安全选项,如果有恶意用户重启服务器后,通过GRUB调整启 ...
- web App libraries跟referenced libraries的一些问题
该博文内容经参看网上其他资料归纳所成,并注明出处: 问题一:myeclipse中Web App Libraries无法自动识别lib下的jar包(http://blog.csdn.net/tianca ...
- 模拟退火解TSP问题MATLAB代码
分别把前四个函数存成m文件,再运行最后一个. swap.m function [ newpath , position ] = swap( oldpath , number ) % 对 oldpath ...
- Vue 实现展开折叠效果
Vue 实现展开折叠效果 效果参见:https://segmentfault.com/q/1010000011359250/a-1020000011360185 上述链接中,大佬给除了解决方法,再次进 ...