技术背景

在MindSpore深度学习框架中,我们可以向construct函数传输必备参数或者关键字参数,这跟普通的Python函数没有什么区别。但是对于MindSpore中的自定义反向传播bprop函数,因为标准化格式决定了最后的两位函数输入必须是必备参数outdout用于接收函数值和导数值。那么对于一个自定义的反向传播函数而言,我们有可能要传入多个参数。例如这样的一个案例:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, y=1, out, dout):
return msnp.cos(x) + y
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

但是因为在Python的函数传参规则下,必备参数必须放在关键字参数之前,也就是out和dout这两个参数要放在前面,否则就会出现这样的报错:

  File "test_rand.py", line 53
def bprop(self, x, y=1, out, dout):
^
SyntaxError: non-default argument follows default argument

按照普通Python函数的传参规则,我们可以把y这个关键字参数的放到最后面去:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, out, dout, y=1):
return msnp.cos(x) + y
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

经过这一番调整之后,我们发现没有报错了,可以正常输出结果,但是这个结果似乎不太正常:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 1.25169754e-06]))

因为这里x传入了一个近似的\(\pi\),所以在construct函数计算函数值时,得到的结果应该是\(\sin(\pi)+y\),那么这里面\(y\)取\(0\)和\(1\)所得到的结果都是对的。但是关键问题在反向传播函数的计算,原本应该是\(\cos(\pi)+y=y-1\),但是在这里输入的\(y=0\),而导数的计算结果却是\(0\)而不是正确结果\(-1\)。这就说明,在MindSpore的自定义反向传播函数中,并不支持传入关键字参数。

解决方案

刚好前面写了一篇关于PyTorch的文章,这篇文章中提到的两个Issue就针对此类问题。受到这两个Issue的启发,我们在MindSpore中如果需要自定义反向传播函数,可以这么写:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, y, out, dout):
return msnp.cos(x) + y if y is not None else msnp.cos(x)
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

简单来说就是,把原本要传给bprop的关键字参数,转换成必备参数的方式进行传入,然后做一个条件判断:当给定了该输入的时候,执行计算一,如果不给定参数值,或者给一个None,执行计算二。上述代码的执行结果如下所示:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [-9.99998748e-01]))

这里输出的结果都是正确的。

当然,这里因为我们其实是强行把关键字参数按照顺序变成了必备参数进行输入,所以在顺序上一定要严格遵守bprop所定义的必备参数的顺序,否则计算结果也会出错:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, w, y, out, dout):
return w*msnp.cos(x) + y if y is not None else msnp.cos(x)
def construct(self, x, w=1, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0, w=2))

输出的结果为:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))

那么很显然,这个结果就是因为在执行函数时给定的关键字参数跟必备参数顺序不一致,所以才出错的。

总结概要

继上一篇文章从Torch的两个Issue中找到一些类似的问题之后,可以发现深度学习框架对于自定义反向传播函数中的传参还是比较依赖于必备参数,而不是关键字参数,MindSpore深度学习框架也是如此。但是我们可以使用一些临时的解决方案,对此问题进行一定程度上的规避,只要能够自定义的传参顺序传入关键字参数即可。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/bprop-kwargs.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://www.cnblogs.com/dechinphy/p/18179248/torch

MindSpore反向传播配置关键字参数的更多相关文章

  1. 深度学习原理与框架-神经网络结构与原理 1.得分函数 2.SVM损失函数 3.正则化惩罚项 4.softmax交叉熵损失函数 5. 最优化问题(前向传播) 6.batch_size(批量更新权重参数) 7.反向传播

    神经网络由各个部分组成 1.得分函数:在进行输出时,对于每一个类别都会输入一个得分值,使用这些得分值可以用来构造出每一个类别的概率值,也可以使用softmax构造类别的概率值,从而构造出loss值, ...

  2. 一个batch的数据如何做反向传播

    一个batch的数据如何做反向传播 对于一个batch内部的数据,更新权重我们是这样做的: 假如我们有三个数据,第一个数据我们更新一次参数,不过这个更新只是在我们脑子里,实际的参数没有变化,然后使用原 ...

  3. 神经网络(NN)+反向传播算法(Backpropagation/BP)+交叉熵+softmax原理分析

    神经网络如何利用反向传播算法进行参数更新,加入交叉熵和softmax又会如何变化? 其中的数学原理分析:请点击这里.

  4. Django---路由系统,URLconf的配置,正则表达式的说明(位置参数),分组命名(捕获关键字参数),传递额外的参数给视图,命名url和url的反向解析,url名称空间

    Django---路由系统,URLconf的配置,正则表达式的说明(位置参数),分组命名(捕获关键字参数),传递额外的参数给视图,命名url和url的反向解析,url名称空间 一丶URLconf配置 ...

  5. 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播

    卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别       分类                        相似图像搜索                        ...

  6. 神经网络(9)--如何求参数: backpropagation algorithm(反向传播算法)

    Backpropagation algorithm(反向传播算法) Θij(l) is a real number. Forward propagation 上图是给出一个training examp ...

  7. <反向传播(backprop)>梯度下降法gradient descent的发展历史与各版本

    梯度下降法作为一种反向传播算法最早在上世纪由geoffrey hinton等人提出并被广泛接受.最早GD由很多研究团队各自发表,可他们大多无人问津,而hinton做的研究完整表述了GD方法,同时hin ...

  8. [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

    [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播 目录 [源码解析] PyTorch 分布式(13) ----- Distribu ...

  9. 一文弄懂神经网络中的反向传播法——BackPropagation

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  10. Backpropagation反向传播算法(BP算法)

    1.Summary: Apply the chain rule to compute the gradient of the loss function with respect to the inp ...

随机推荐

  1. #min-max容斥,FWT#洛谷 3175 [HAOI2015]按位或

    题目 分析 按位去看,最终的答案要求 \(E(\max S)\) 就是 \(S\) 出现的期望时间. 根据min-max容斥,\(E(\max S)=\sum_{T\subset S}(-1)^{|T ...

  2. 什么是慢SQL且如何查看慢SQL

    什么是慢 SQL 且如何查看慢 SQL? 介绍 某个 SQL 执行时间超过指定时间时称为慢 SQL.我们可以查看慢 SQL,包括历史慢 SQL 以及当前慢 SQL. 查看历史慢 SQL 首先要设置 l ...

  3. VSCode如何通过Ctrl+P快速打开node_modules中的文件

    背景 咱们新建一个NodeJS项目,必然会安装许多依赖包,因此经常需要查阅某些依赖包的源码文件.但是,由于node_modules目录包含的文件太多,出于性能考虑,在VSCode中默认情况下是禁止搜索 ...

  4. AI数字人克隆人直播源码独立部署的应用!

    AI虚拟数字人正在从概念性试验品逐步落地到实际应用场景,特别是在电商直播领域,AI数字人虚拟主播应用可以说是大放异彩,目前,以真人形象为基础的数字人主播,不受场地.真人.布景.灯光.直播设备的限制,相 ...

  5. How Python Handles Big Files

     The Python programming language has become more and more popular in handling data analysis and proc ...

  6. mysql 必知必会整理—全球化与本地化[十六]

    前言 简单介绍一下字符集. 数据库表被用来存储和检索数据.不同的语言和字符集需要以不同的方式存储和检索. 因此,MySQL需要适应不同的字符集(不同的字母和字符),适应不同的排序和检索数据的方法. 字 ...

  7. 重新整理数据结构与算法(c#)——算法套路k克鲁斯算法[三十]

    前言 这个和前面一节有关系,是这样子的,前面是用顶点作为参照条件,这个是用边作为参照条件. 正文 图解如下: 每次选择最小的边. 但是会遇到一个小问题,就是会构成回路. 比如说第四步中,最小边是CE, ...

  8. C#的窗体假关闭操作例子 - 开源研究系列文章

    晚上编码的时候,想到了以前编写的窗体关闭的事情,就是带托盘图标的应用,有一个主显示操作窗体,但是主窗体点击关闭按钮的时候,实际上是窗体隐藏而非真正关闭,这个在其它的一些应用程序里有这个效果.于是就想到 ...

  9. bs4、selenium的使用

    爬取新闻 # 1 爬取网页---requests # 2 解析 ---xml格式,用了re匹配的 ---html,bs4,lxml... ---json: -python :内置的 -java : f ...

  10. 力扣645(java)-错误的集合(简单)

    题目: 集合 s 包含从 1 到 n 的整数.不幸的是,因为数据错误,导致集合里面某一个数字复制了成了集合里面的另外一个数字的值,导致集合 丢失了一个数字 并且 有一个数字重复 . 给定一个数组 nu ...