莫烦 - Pytorch学习笔记 [ 一 ]
1. Numpy VS Torch
#相互转换
np_data = torch_data.numpy()
torch_data = torch.from_numpy(np_data)
#abs
data = [1, 2, -2, -1] #array
tensor = torch.FloatTensor(data) #32bit 传入普通数组
np.abs(data); torch.abs(tensor);
#矩阵相乘
data.dot(data) #但是要先转换为numpy的data data=np.array(data)
torch.mm(tensor, tensor)
2. Variable
#引入
from torch.autograd import Variable
#声明
variable = Varible(tensor, requires_grad=True)
variable.data #type是tensor
3. Activation Function 激励函数
y = AF(Wx)
画图
#引入
import torch.nn.function as F
import matplotlib.pyplot as plt
#fake data
x = torch.linspace(-5, 5, 200)
x = Variable(x)
x_np = x.data.numpy() ***
#activation
y_relu = F.relu(x).data.numpy() ***
plt.plot(x_np, y_relu, c='red', label='relu')
4. Regression 回归
# 动态更新画图
plt.ion()
plt.show()
#for循环中的if条件内部
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data, fontdict={'size':20, 'color': 'red'})
plt.pause(0.1)
#for外部
plt.ioff()
plt.show()
#net层的定义看regression代码!
5. Classification 分类
#二元分类 模拟数据 及 画图
n_data = torch.ones(100, 2) # shape(100, 2)
x0 = torch.normal(2*n_data, 1)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data, 1)
y1 = torch.ones(100)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
y = torch.cat((y0, y1)).type(torch.LongTensor) #label 只能是integer类型
x, y = Variable(x), Variable(y)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()
#输入二维 hiddenlayer10个神经元 输入也是二维
net = Net(2, 10, 2)
#优化使用
loss_func = torch.nn.CrossEntropyLoss()
#for循环内部 区分out 和 prediction
out = net(x) #此时的out格式是很乱的
loss = loss_func(out, y) #两者的误差
prediction = torch.max(F.softmax(out), 1)[1] # 过了一道 softmax 的激励函数后的最大概率才是预测值
accuracy = sum(pre_y == target_y) / 200 #预测有多少和真实值一样
6. 快速搭建法
net = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2)
)
7. 保存提取
使用两种方式提取整个神经网络:提取整个网络或只提取参数。
两段式声明,在save中保存,在restore中提取,最后显示。
def save():
#建网络#
#训练#
#保存
torch.save(net1, 'net.pkl') #保存整个网络
torch.save(net1.state_dict(), 'net_params.pkl') #只保存网络中的参数
#提取整个网络
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
#只提取网络参数
def restore_params():
net3 = ... #net3 = net1
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
#显示结果
save()
restore_net()
restore_params()
8. 批数据训练
#数据引入
import torch.utils.data as Data
# 先定义batchsize
BATCH_SIZE = 5
# 转换torch为Dataset
torch_dataset = Data.TensorDataset(x, y) #(1)
loader = Data.DataLoader(...)
#for循环内的读取
for step, (batch_x, batch_y) in enumerate(loader):
#如果在loader中开了多线程
if __name__ == '__main__': #加上双线程的入口
#(1)
9. Optimizer 优化器
#给每个优化器优化一个神经网络
net_SGD = Net()
net_Momentum = Net()
net_RMSprop = Net()
net_Adam = Net()
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
#创建不同的优化器来训练不同的网络,并创建loss_func来计算误差
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]
loss_func = torch.nn.MSELoss()
losses_his = [[], [], [], []] # 记录 training 时不同神经网络的 loss
#训练每个优化器,优化属于自己的神经网络
for epoch in range(EPOCH):
print('Epoch: ', epoch)
for step, (b_x, b_y) in enumerate(loader):
for net, opt, l_his in zip(nets, optimizers, losses_his): #都是列表
output = net(b_x) # get output for every net
loss = loss_func(output, b_y) # compute loss for every net
opt.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
opt.step() # apply gradients
l_his.append(loss.data.numpy()) # loss recoder
#画图
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
for i, l_his in enumerate(losses_his):
plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.show()
莫烦 - Pytorch学习笔记 [ 一 ]的更多相关文章
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- 莫烦PyTorch学习笔记(五)——模型的存取
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...
- 莫烦PyTorch学习笔记(六)——批处理
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
- 莫烦pytorch学习笔记(二)——variable
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
- 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )
CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...
- 莫烦PyTorch学习笔记(五)——分类
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- 莫烦PyTorch学习笔记(三)——激励函数
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
- 莫烦pytorch学习笔记(一)——torch or numpy
Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...
随机推荐
- [vue学习] 卡片展示分行功能简单实现
如图所示,实现简单的卡片展示分行功能. 分行功能较多地用于展示商品.相册等,本人在学习的过程中也是常常需要用到这个功能:虽然说现在有很多插件都能实现这个功能,但是自己写出来,能够理解原理,相信能够进步 ...
- oracle中以dba_、user_、v$_、all_、session_、index_开头
原 oracle中以dba_.user_.v$_.all_.session_.index_开头 2011年07月05日 11:26:06 clbxp 阅读数:3279 oracle中以dba_.u ...
- k8s搭建
K8s官方文档地址:https://kubernetes.io/docs/reference/command-line-tools-reference/kube-apiserver/ 如果用云主机部 ...
- element-ui表头render-header 传自定义参数
最近用到 element 的表格的 render-header 这个属性查了文档 发现: 发现它会返回部分参数 但是因为考虑要工程化,需要自定义传入参数,后来找度娘 ,发现是可以自定义传参的 :re ...
- Java 枚举(enum)的学习
Java 枚举(enum)的学习 本文转自:https://blog.csdn.net/javazejian/article/details/71333103 枚举的定义 在定义枚举类型时我们使用的关 ...
- B站上线互动视频背后,是一场谁都输不起的未来之战
毋庸置疑的是,视频网站的竞争已愈发激烈.而它们的竞争体现在多个维度,比如买视频会员赠送购物网站会员.依靠各自的社交体系不断尝试打破圈层瓶颈等.当然,最直接的竞争还是体现在内容层面.购买独家版权.制作原 ...
- 【MySQL】安装及配置
" 目录 #. 概述 1. 什么是数据(Data) 2. 什么是数据库(DataBase, 简称DB) 3. 什么是数据库管理系统(DataBase Management System) 4 ...
- 实现简单HttpServer案例
<html> <head> <title>第一个表单</title> </head> <body> <pre> me ...
- 使用13行Python代码实现四则运算计算器函数
原创的刷新行数记录的代码!!! 支持带小括号,支持多个连续+-号,如-7.9/(-1.2-++--99.3/-4.44)*---(2998.654+-+-+-(+1.3-7.654/(-1.36-99 ...
- 【SSM sql.xml】日志查询mapper.xml
LogInfoMapper.xml <?xml version="1.0" encoding="UTF-8"?> <!DOCTYPE mapp ...