import torch
from torch import nn
from torch.nn import functional as F
from torch import optim import torchvision
from matplotlib import pyplot as plt # 小工具 def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)),data,color='blue')
plt.legend(['value'],loc='upper right')
plt.xlabel('step')
plt.tlabel('value')
plt.show() def plot_image(img,label,name):
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt,tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
plt.title("{}:{}".format(name,label[i].item()))
plt.xticks([])
plt.xticks([]) plt.show() def one_hot(label,depth = 10):
out = torch.zeros(label.size(0),depth)
idx = torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out # 一次加载多少图片
batch_size = 512
# step1. load dataset 数据加载
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MINST('mnist_data',train=True,download=True,
transform=torchvision.transforms.Compose([
torchvision.transfroms.ToTensor(), torchvision.transfroms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MINST('mnist_data/',train=False,download=True,
transform=torchvision.transforms.Compose([
torchvision.transfroms.ToTensor(),
torchvision.transfroms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False) # 网络创建
class Net(nn.Module): def __init__(self):
super(Net,self).__init__() #xw+b
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,64)
self.fc3 = nn.Linear(64,10) def forward(self,x):
# x:[batch_size,1,28,28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h1 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = self.fc3(x) return x net = Net()
# [w1,b1,w2,b1,w3,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9) train_loss = [] # 训练
for epoch in range(3): for batch_idx,(x,y) in enumerate(train_loader): # x: [b,1,28,28], y:[512]
# [b,1,28,28]-->[b,feature]
x = x.view(x.size(0),28*28)
# --> [b,10]
out = net(x)
# --> [b,10]
y_onehot = one_hot(y)
# loss = mse(out,y_onehot)
loss = F.mse_loss(out,y_onehot)
# 清零梯度
optimizer.zero_grad()
# 计算梯度
loss.backward()
#w' = w - lr*grad 更新梯度
optimizer.step() train_loss.append(loss.item()) if batch_idx % 10 == 0:
print(epoch,batch_idx,loss.item()) plot_curve(train_loss) # 得到一个比较好的 [w1,b1,w2,b1,w3,b3] # 验证准确率
total_correct = 0
for x,y in test_loader"
x = x.view(x.size(0),28*28)
out = net(x)
# out: [b,10] --> pred: [b]
pred = out.argmax(dim = 1)
correct = pred.eq(y).sum().float().item()
total_correct += correct total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:',acc) # 直观显示验证
x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim = 1)
plot_image(x,pred,'test')

龙良曲pytorch学习笔记_03的更多相关文章

  1. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

  2. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  3. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  4. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  5. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  6. [PyTorch 学习笔记] 1.3 张量操作与线性回归

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/linear_regression.py 张量的操作 拼 ...

  7. [PyTorch 学习笔记] 1.1 PyTorch 简介与安装

    PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...

  8. [PyTorch 学习笔记] 1.4 计算图与动态图机制

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/computational_graph.py 计算图 深 ...

  9. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

随机推荐

  1. linux-free、lscpu、

    1.free -h 以人类可读的形式显示 -m 以MB为单位显示 -w 将buffers和cache分开单独显示(针对centos7系统) centos6上: centos7上: -s 动态查看内存信 ...

  2. 小白进阶之路-python数据类型

    1.数据类型:变量值是我们存储的数据,所以数据类型值得就是变量的不同种类 2.数据分类型的原因:变量值是用来保存现实世界的中的状态的,呢么针对不同的状态就应该用不同类型上午数据去表示 (1)整型int ...

  3. $Noip2018/Luogu5021$ 赛道修建 二分+树形

    $Luogu$ $Sol$ 一直以为是每个点只能经过一次没想到居然是每条边只能经过一次$....$ 首先其实这题$55$分的部分分真的很好写啊,分别是链,数的直径和菊花图,这里就不详细说了. 使得修建 ...

  4. Theia架构

    上一篇:Theia——云端和桌面版的IDE 架构概述 本节描述了Theia的整体架构. Theia被设计为一个可以在本地运行的桌面应用程序,也可以在浏览器和远程服务器之间工作.为了支持这两种工作方式, ...

  5. JDK1.8中的HashMap实现

    1.HashMap概述 在JDK1.8之前,HashMap采用数组+链表实现,即使用链表处理冲突,同一hash值的节点都存储在一个链表里.但是当位于一个桶中的元素较多,即hash值相等的元素较多时,通 ...

  6. VueRouter爬坑第四篇-命名路由、编程式导航

    VueRouter系列的文章示例编写时,项目是使用vue-cli脚手架搭建. 项目搭建的步骤和项目目录专门写了一篇文章:点击这里进行传送 后续VueRouter系列的文章的示例编写均基于该项目环境. ...

  7. React Hooks 实现和由来以及解决的问题

    与React类组件相比,React函数式组件究竟有何不同? 一般的回答都是: 类组件比函数式组件多了更多的特性,比如 state,那如果有 Hooks 之后呢? 函数组件性能比类组件好,但是在现代浏览 ...

  8. 配置IDEA默认作者@author

    IDEA安装目录下,使用文本编辑器打开~/bin/idea64.exe.vmoptions文件 在最后添加:-Duser.name=Your name 保存重启IDEA,Done

  9. Django-视图&网址

    前言 Django第一篇简单的介绍了环境搭建与创建Django项目的两种方式,以及如何启动服务,在前端访问HelloWorld地址,这篇内容首先学习一下Django项目中的各个模块的用途及Django ...

  10. 《C# 爬虫 破境之道》:第二境 爬虫应用 — 第一节:HTTP协议数据采集

    首先欢迎您来到本书的第二境,本境,我们将全力打造一个实际生产环境可用的爬虫应用了.虽然只是刚开始,虽然路漫漫其修远,不过还是有点小鸡冻:P 本境打算针对几大派生类做进一步深耕,包括与应用的结合.对比它 ...