LeNet-5 pytorch+torchvision+visdom
# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
# -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()
先开启visdom 进行可视化
python -m visdom.server
运行程序
python LeNet-5_main.py
打开浏览器查看live graph
LeNet-5 pytorch+torchvision+visdom的更多相关文章
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- Linux服务器配置GPU版本的pytorch Torchvision TensorFlow
最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...
- pytorch的visdom启动不了、蓝屏
pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...
- 云服务器搭建anaconda pytorch torchvision
(因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...
- Pytorch Torchvision Transform
Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...
- pytorch torchvision.ImageFolder的使用
参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...
- pytorch torchvision对图像进行变换
class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
- PyTorch常用代码段整理合集
PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...
随机推荐
- PHP 之Html标签转义与反转义
1.htmlentities()函数转义html 2.html_entity_decode()函数反转义html 我这里是用来反转义富文本编辑器的内容
- 更加方便的使用git上传自己的代码
经过以上的培训,同学们肯定对git的基本使用没有什么问题了.但是每次代码有更改后,依旧需要 git add * git commit * git 打开vim编辑器,编辑提交信息 或者 git ...
- Manjaro Linux无备份迁移home目录
前几天安装了最新的manjaro kde 18.10,速度刚开始非常快,后来几乎每次重启都会出现无法挂在home分区的情况,刚开始以为是分区对齐的问题,但是后来发现根本不是.算了,干脆迁移下home分 ...
- C#MD5方法
不同形式,一样结果 /// <summary> /// 获取大写的MD5签名结果 /// </summary> /// <param name="encypSt ...
- 使用Adivisor配置增强处理,来实现数据库读写分离
一.先写一个demo来概述Adivisor的简单使用步骤 实现步骤: 1.通过MethodBeforeAdivice接口实现前置增强处理 public class ServiceBeforeAdvis ...
- Kafka 最新版配置
当前基于kafaka最新版 kafka_2.12-2.2.1.tgz 进行配置 . 官网地址:http://kafka.apache.org/intro kafka的一些基础知识 参考:http:// ...
- SpringBoot-自动装配对象及源码ImportSelector分析
SpringBoot框架已经很流行了,笔者做项目也一直在用,使用久了,越来越觉得有必要理解SpringBoot框架中的一些原理了,目前的面试几乎都会用问到底层原理.我们在使用过程中基本上是搭建有一个框 ...
- Tosca Connection Validation error:40 - Could not open a connection to SQL Server (不知道怎么解决)
谁知道下面这个错怎么解决,请给我留言,谢谢. 数据库能正常链接,服务也是 normal running
- JSP页面中如何注入Spring容器中的bean
第一步在JSP页面中导入下面的包: <%@page import="org.springframework.web.context.support.WebApplicationCont ...
- Spring Bootz之热部署
在项目的pom.xml文件添加如下两段 <dependency> <groupId>org.springframework.boot</groupId> <a ...