Convolutional neural network (CNN) - Pytorch版
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # 配置GPU或CPU设置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 超参数设置
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001 # 下载 MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255
download=True) test_dataset = torchvision.datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor())# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255 # 训练数据加载,按照batch_size大小加载,并随机打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 测试数据加载,按照batch_size大小加载
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Convolutional neural network (two convolutional layers) 2层卷积
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7 * 7 * 32, num_classes) def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out model = ConvNet(num_classes).to(device)
print(model) # ConvNet(
# (layer1): Sequential(
# (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
# (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
# (layer2): Sequential(
# (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
# (fc): Linear(in_features=1568, out_features=10, bias=True)) # 损失函数与优化器设置
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器设置 ,并传入CNN模型参数和相应的学习率
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 训练CNN模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = model(images)
# 计算损失 loss
loss = criterion(outputs, labels) # 反向传播与优化
# 清空上一步的残余更新参数值
optimizer.zero_grad()
# 反向传播
loss.backward()
# 将参数更新值施加到RNN model的parameters上
optimizer.step()
# 每迭代一定步骤,打印结果值
if (i + 1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item())) # 测试模型
# model.train model.eval 在测试模型时在前面使用:model.eval() ; 在训练模型时会在前面加上:model.train()
# 让model变成测试模式,是针对model 在训练时和评价时不同的 Batch Normalization 和 Dropout 方法模式
# eval()时,让model变成测试模式, pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,
# 不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # 保存已经训练好的模型
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
Convolutional neural network (CNN) - Pytorch版的更多相关文章
- 卷积神经网络(Convolutional Neural Network, CNN)简析
目录 1 神经网络 2 卷积神经网络 2.1 局部感知 2.2 参数共享 2.3 多卷积核 2.4 Down-pooling 2.5 多层卷积 3 ImageNet-2010网络结构 4 DeepID ...
- Recurrent neural network (RNN) - Pytorch版
import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms # ...
- 斯坦福大学卷积神经网络教程UFLDL Tutorial - Convolutional Neural Network
Convolutional Neural Network Overview A Convolutional Neural Network (CNN) is comprised of one or mo ...
- 卷积神经网络(Convolutional Neural Network,CNN)
全连接神经网络(Fully connected neural network)处理图像最大的问题在于全连接层的参数太多.参数增多除了导致计算速度减慢,还很容易导致过拟合问题.所以需要一个更合理的神经网 ...
- 深度学习FPGA实现基础知识10(Deep Learning(深度学习)卷积神经网络(Convolutional Neural Network,CNN))
需求说明:深度学习FPGA实现知识储备 来自:http://blog.csdn.net/stdcoutzyx/article/details/41596663 说明:图文并茂,言简意赅. 自今年七月份 ...
- 【转载】 卷积神经网络(Convolutional Neural Network,CNN)
作者:wuliytTaotao 出处:https://www.cnblogs.com/wuliytTaotao/ 本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,欢迎 ...
- CNN(Convolutional Neural Network)
CNN(Convolutional Neural Network) 卷积神经网络(简称CNN)最早可以追溯到20世纪60年代,Hubel等人通过对猫视觉皮层细胞的研究表明,大脑对外界获取的信息由多层的 ...
- 论文阅读(Weilin Huang——【TIP2016】Text-Attentional Convolutional Neural Network for Scene Text Detection)
Weilin Huang--[TIP2015]Text-Attentional Convolutional Neural Network for Scene Text Detection) 目录 作者 ...
- 论文笔记:(CVPR2019)Relation-Shape Convolutional Neural Network for Point Cloud Analysis
目录 摘要 一.引言 二.相关工作 基于视图和体素的方法 点云上的深度学习 相关性学习 三.形状意识表示学习 3.1关系-形状卷积 建模 经典CNN的局限性 变换:从关系中学习 通道提升映射 3.2性 ...
随机推荐
- jmeter非GUI界面常用参数详解
压力测试或者接口自动化测试常常用到的jmeter非GUI参数,以下记录作为以后的参考 讲解:非GUI界面,压测参数讲解(欢迎加入QQ群一起讨论性能测试:537188253) -h 帮助 -n 非GUI ...
- PHP chmod() 函数
chmod() 函数改变文件模式. 如果成功则返回 TRUE,否则返回 FALSE. 例子 <?php // 所有者可读写,其他人没有任何权限 chmod(); // 所有者可读写,其他人可读 ...
- python 链表的实现
code #!/usr/bin/python # -*- coding: utf- -*- class Node(object): def __init__(self,val,p=): self.da ...
- jquery中语法初学必备
$(this).hide() - 隐藏当前元素 $("p").hide() - 隐藏所有段落 $(".test").hide() - 隐藏所有 class=&q ...
- arts 打卡12周
一 算法: 字符串转换整数 (atoi) 请你来实现一个 atoi 函数,使其能将字符串转换成整数. 首先,该函数会根据需要丢弃无用的开头空格字符,直到寻找到第一个非空格的字符为止. 当我们寻找 ...
- Thingsboard 重新启动docker-compose容器基础数据存在的问题
在重启了thingsboard的容器后,想再次重新启动容器,发现已经出现了错误 查看posttres中,持久化的地址是tb-node/postgres中 再查看相应的文件夹 删除以上log和postg ...
- [spring-boot] 健康状况监控
pom文件 <dependency> <groupId>org.springframework.boot</groupId> <artifactId>s ...
- Python-matplotlib画图(莫烦笔记)
https://www.zhihu.com/collection/260736383 https://blog.csdn.net/gaotihong/article/details/80983937 ...
- Centos7修改为固定IP后 yum 出现could not retrieve mirrorlist
Centos7修改为固定IP后 yum 出现could not retrieve mirrorlist,发现yum源的域名无法解析 按照6,修改/etc/resovle.conf,新增域名解析服务器1 ...
- 【idea】idea远程调试代码
一.前置条件 1.idea的代码和远程服务器代码保持一致 二.远程服务器配置 服务启动时,需要给jvm添加指定参数,进行启动 -agentlib:jdwp=transport=dt_socket,se ...