Pytorch入门

简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记。

直接从例子开始学,基础知识咱已经看了很多论文了。。。

import torch
import torch.nn as nn
import torch.nn.functional as F
# Linear 层 就是全连接层
class Net(nn.Module): # 继承nn.Module,只用定义forward,反向传播会自动生成
def __init__(self): # 初始化方法,这里的初始化是为了forward函数可以直接调过来
super(Net,self).__init__() # 调用父类初始化方法
# (input_channel,output_channel,kernel_size)
self.conv1 = nn.Conv2d(1,6,5) # 第一层卷积
self.conv2 = nn.Conv2d(6,16,5)# 第二层卷积
self.fc1 = nn.Linear(16*5*5,120) # 这里16*5*5是前向算的
self.fc2 = nn.Linear(120,84) # 第二层全连接
self.fc3 = nn.Linear(84,10) # 第三层全连接->分类
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2)) # 卷积一次激活一次然后2*2池化一次
x = F.max_pool2d(F.relu(self.conv2(x)),2) # (2,2)与直接写 2 等价
x = x.view(-1,self.num_flatten_features(x)) # 将x展开成向量
x = F.relu(self.fc1(x)) # 全连接 + 激活
x = F.relu(self.fc2(x)) # 全连接+ 激活
x = self.fc3(x) # 最后再全连接
return x
def num_flatten_features(self,x):
size = x.size()[1:] # 除了batch_size以外的维度,(batch_size,channel,h,w)
num_features = 1
for s in size:
num_features*=s
return num_features
# ok,模型定义完毕。
net = Net()
print(net)
'''
Net(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
'''
params = list(net.parameters())
print(len(params))
print(params[0].size())
'''
10
torch.Size([6, 1, 5, 5])
'''
inpt = torch.randn(1,1,32,32)
out = net(inpt)
print(out)
'''
tensor([[-0.0265, -0.1246, -0.0796, 0.1028, -0.0595, 0.0383, 0.0038, -0.0019,
0.1181, 0.1373]], grad_fn=<AddmmBackward>)
'''
target = torch.randn(10)
criterion = nn.MSELoss()
loss = criterion(out,target)
print(loss)
'''
tensor(0.5742, grad_fn=<MseLossBackward>)
'''
net.zero_grad()# 梯度归零
print(net.conv1.bias.grad)
loss.backward()
print(net.conv1.bias.grad)
'''
None
tensor([-0.0039, 0.0052, 0.0034, -0.0002, 0.0018, 0.0096])
'''
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr = 0.01)
optimizer.zero_grad()
output = net(inpt)
loss = criterion(output,target)
loss.backward()
optimizer.step()
# 一个step完成,多个step就写在循环里

pytorch简直太好理解了。。继续蓄力!!

[pytorch] Pytorch入门的更多相关文章

  1. 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...

  2. 《深度学习框架PyTorch:入门与实践》读书笔记

    https://github.com/chenyuntc/pytorch-book Chapter2 :PyTorch快速入门 + Chapter3: Tensor和Autograd + Chapte ...

  3. pytorch怎么入门学习

    pytorch怎么入门学习 https://www.zhihu.com/question/55720139

  4. pytorch从入门到放弃(目录)

    目录 前置基础 Pytorch从入门到放弃 推荐阅读 前置基础 Python从入门到放弃(目录) 人工智能(目录) Pytorch从入门到放弃 01_pytorch和tensorflow的区别 02_ ...

  5. 【笔记】PyTorch快速入门:基础部分合集

    PyTorch快速入门 Tensors Tensors贯穿PyTorch始终 和多维数组很相似,一个特点是可以硬件加速 Tensors的初始化 有很多方式 直接给值 data = [[1,2],[3, ...

  6. 图神经网络 PyTorch Geometric 入门教程

    简介 Graph Neural Networks 简称 GNN,称为图神经网络,是深度学习中近年来一个比较受关注的领域.近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN ...

  7. Pytorch快速入门及在线体验

    本文搭配了Pytorch在线环境,可以直接在线体验. Pytorch是Facebook 的 AI 研究团队发布了一个基于 Python的科学计算包,旨在服务两类场合: 1.替代numpy发挥GPU潜能 ...

  8. PyTorch快速入门教程七(RNN做自然语言处理)

    以下内容均来自: https://ptorch.com/news/11.html word embedding也叫做word2vec简单来说就是语料中每一个单词对应的其相应的词向量,目前训练词向量的方 ...

  9. pytorch 从入门到实战

    一.安装 按照 http://pytorch.org 官网上的说明来做,遇到了几个坑.记录如下: 1.用 conda 安装 pytorch 时,下载安装包非常慢,无法忍受. 解决办法:用蓝灯FQ,将蓝 ...

随机推荐

  1. es6数组新方法

    (1)Array.from(aarr,fn,obj) function fn(dr, sd, d) { /*Array.from 类数组转化为数组*/ console.log(arguments) v ...

  2. SQL SERVER linked server Login failed for user 'NT AUTHORITY\ANONYMOUS LOGON'

    昨天创建了一个View, 这个view是一系列的表达式(CTE)组成,封装了好多的业务逻辑,简化下语句如下 ;with CTE AS( ...) SELECT a.company_id ,b.comp ...

  3. 转 使用SwingBench 对Oracle RAC DB性能 压力测试

    ###########说明1: 1 Swingbench 简述 1.1 概述 这是Oracle UK的一个员工在一个被抛弃的项目的基础上开发的.目前稳定版本2.2,最新版本2.3,基于JDK1.5.该 ...

  4. Dev Express Report 学习总结(六)Dev Express Reports自定义Summary

    在我们使用DevExpress开发报表的过程中,对于页面中复杂的数据合计,我们可能会使用到自定义Summary.下面通过一个例子来进行说明: 首先,我建立了如上图所示的报表页面,其中的数据源来自cla ...

  5. JavaSE---位运算符

    1.Java支持的位运算符有7个: &:按位与 [2个相同取相同.2个不同取0] |:按位或 [2个相同取相同.2个不同取1] ~:按位非 ^:按位异或 [2个相同取0.2个不同取1] < ...

  6. STM32中管脚利用

    如果利用4线SWD则剩余的调试引脚可以作为IO使用: void JTAG_Set(unsigned char Mode){ u32 temp; temp=Mode; temp<<=25; ...

  7. git使用笔记-日志

    1.查看函数的历史修改 git log -L :git_deflate_bound:zlib.c2.查看HEAD的所有记录 git reflog $ git reflog 1a410ef HEAD@{ ...

  8. tinkphp3.2.3 关于事务处理。

    自己做一个测试,关于事务处理的. 在对多表进行操作的时候 基本上都离不开事务. 有的操作,是要由上一操作后,产的值(如主表里插入后,要获取插入的主键ID值,返回给下面处理表用.)带到后面的表处理当中去 ...

  9. Spark生态系统

    在大数据非常流行的今天,每个行业都在谈论大数据,每个公司(互联网公司,传统企业,金融行业等)都在讨论大数据.高层管理者利用大数据来进行决策:数据科学家利用大数据来进行业务创新:程序员利用大数据来完成项 ...

  10. hduoj 2955Robberies

    Robberies Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Total ...