模型训练的三要素:数据处理、损失函数、优化算法 

   数据处理(模块torch.utils.data)

从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始

from torch.nn import init   # pytorch的init模块提供了多中参数初始化方法

init.normal_(net[0].weight, mean=0, std=0.01)    #初始化net[0].weight的期望为0,标准差为0.01的正态分布tensor
init.constant_(net[0].bias, val=0) #初始化net[0].bias,值为0的常数tensor
# 此外还封装了好多
# init.ones_(w) 初始化一个形状如w的全1分布的tensor,如w是3行5列,则初始化为3行5列的全1tensor
# init.zeros_(w) 初始化一个形状如w的全0分布的tensor
# init.eye_(w) 初始化一个形状如w的对角线为1,其余为0的tensor
# init.sparse_(w,sparsity=0.1) 初始化一个形状如w稀疏性为0.1的稀疏矩阵

 损失函数(模块torch.nn含有大量的神经网络层)

 pytorch的nn模块中定义了各种损失函数,这些损失函数可以看成一种特殊的网络层 

loss = nn.MSELoss()  # 均方误差损失函数
# torch.nn.MSELoss(reduce=True, size_average=True)
# reduce=True,返回标量形式的loss,reduce=False,返回向量形式的loss
# size_average=True,返回loss.mean(),size_average=False,返回loss.sum()
# 默认两者皆为True

 优化算法(模块torch.optim)

torch.optim模块定义了很多的优化算法,如SGD、Adam、RMSProp等

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.03)
print(optimizer) # 对不同的子网络设置不同的学习率
optimizer = optim.SGD([
          # 如果对某个参数不指定学习率,就使用最外层的默认学习率
          {'params':net.subnet1.parameters()}, # lr=0.03
          {'params':net.subnet2.parameters(),'lr':0.01}
],lr=0.03)

  设置动态学习率,不是固定一个常数

  方法1、修改optimizer.param_groups中的学习率

#调整学习率
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1 # 学习率是之前的0.1倍

  方法2、新建优化器,即构建新的optimizer。使用动量的优化器(如Adam),可能会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。

optimizer = optim.SGD([
{'param':net.subnet1.parameters()},
{'param':net.subnet2.parameters(),'lr':old_lr*0.1}],lr=0.03)

  上述代码若不理解net.subnet1.parameters(),可参考博客 https://www.cnblogs.com/hellcat/p/8496727.html   万分感谢博主

view(-1,1)   # -1是不确定几行的意思,在这就是我不确定要取几行,但是肯定是一列,故view(-1,1);

  torch.view()和numpy.reshape()效果一样,view操作的是tensor,且view后的tensor和原tensor共享内存,修改其中一个,另一个也会改变,reshape()操作的是nparray。

  线性回归  

  torch.nn.Linear(in_features,out_features,bias)

  参数解析:

    in_features:输入特征的数量(或称为特征数特征向量X的维度),即在房价预测中仅和房龄与面积有关,则in_features=2

    out_features:输出特征的数量(同in_features)

    bias:偏置,默认为True

  例子请参考 https://www.cnblogs.com/Archer-Fang/p/10645473.html  感谢博主

小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()的更多相关文章

  1. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  2. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  3. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  4. 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...

  5. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  6. 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())

    先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...

  7. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  8. Pytorch修改ResNet模型全连接层进行直接训练

    之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把 最后一层的输出改一下,不需要加载前面层 ...

  9. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

随机推荐

  1. CSU-ACM2020寒假集训比赛2

    A - Messenger Simulator CodeForces - 1288E 两种解法,我选择了第二种 mn很好求,联系过就是1,没联系过就是初始位置 第一种:统计同一个人两次联系之间的出现的 ...

  2. office(CVE-2012-0158)漏洞分析报告

    2019/9/12 1.漏洞复现 ①发现崩溃 ②找到漏洞所在的函数,下断点,重新跑起来,单步调试,找到栈被改写的地方 ③分析该函数 把MSCOMCTL拖入IDA,查看该函数代码 ④查看调用栈,回溯. ...

  3. TD信息通(无课表)使用体验

    首先,在注册账户的时候,TD信息通还是比较严谨的.用户名字符数.密码字符数.邮箱格式等都有要求,我认为,这对App的长远发展来说,是很重要的一个细节.而且,在登陆之前,会有一项关于是否自动登陆的选择, ...

  4. 远程控制使用kill软件映射内网进行远程控制(9.28 第十四天)

    1.能ping通IP情况下远程控制 设置kill软件中的端口.密码.上线列表 2.在软件的Bin\Plugins目录下找到Consys21.dll复制到/phpstudy/www目录下留作生成软件 3 ...

  5. POJ 3984:迷宫问题 bfs+递归输出路径

    迷宫问题 Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 11844   Accepted: 7094 Description ...

  6. vs使用opencv总提示igdrclneo64.dll异常.exe: 0xC0000005:的解决方法

    最近项目中要使用opencv库,搭建好环境,使用接口的时候,总提示 igdrclneo64.dll报错崩溃,一直怀疑是自己程序的问题,后面经过一系列的查资料才解决 解决办法: 本地环境:vs2015+ ...

  7. Python风格规范分享

    今天给大家分享Python 风格规范,以下代码中 Yes 表示推荐,No 表示不推荐. 分号 不要在行尾加分号, 也不要用分号将两条命令放在同一行. 行长度 每行不超过80个字符 以下情况除外: 长的 ...

  8. python可移植支持代码;用format.节省打印输出参数代码;math模块;

    1.多平台移植代码: #!/usr/bin/env python3 这一行比较特殊,称为 shebang 行,在 Python 脚本中,你应该一直将它作为第一行. 请注意行中的第一个字符是井号(#). ...

  9. python刷LeetCode:20. 有效的括号

    难度等级:简单 题目描述: 给定一个只包括 '(',')','{','}','[',']' 的字符串,判断字符串是否有效. 有效字符串需满足: 左括号必须用相同类型的右括号闭合.左括号必须以正确的顺序 ...

  10. Java算法练习——两数相加

    题目链接 题目描述 给出两个 非空 的链表用来表示两个非负的整数.其中,它们各自的位数是按照 逆序 的方式存储的,并且它们的每个节点只能存储 一位 数字. 如果,我们将这两个数相加起来,则会返回一个新 ...