深度学习之 cnn 进行 CIFAR10 分类
深度学习之 cnn 进行 CIFAR10 分类
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
import torch as t
import torch.nn as nn
import torch.nn.functional as F
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5)),
])
# 下载数据
trainset = tv.datasets.CIFAR10(root=".",train=True, download=True, transform=transform)
trainloader = t.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = tv.datasets.CIFAR10('.', train=False, download=True, transform=transform)
testloader = t.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)
from torch.autograd import Variable
for epoch in range(2):
running_loss = 0.0
for i,data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
# 测试
correct = 0
total = 0
for data in testloader:
images, labels = data
outputs = net(Variable(images))
# print(outputs.data)
_, predicted = t.max(outputs.data, 1)
print(outputs.data,_, predicted)
total += labels.size(0)
correct += (predicted == labels).sum()
print('10000张测式中: %d %%' % (100 * correct / total) )
深度学习之 cnn 进行 CIFAR10 分类的更多相关文章
- [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...
- 【深度学习】CNN 中 1x1 卷积核的作用
[深度学习]CNN 中 1x1 卷积核的作用 最近研究 GoogLeNet 和 VGG 神经网络结构的时候,都看见了它们在某些层有采取 1x1 作为卷积核,起初的时候,对这个做法很是迷惑,这是因为之前 ...
- 深度学习入门: CNN与LSTM(RNN)
1. 理解深度学习与CNN: 台湾李宏毅教授的入门视频<一天搞懂深度学习>:https://www.bilibili.com/video/av16543434/ 其中对CNN算法的矩阵卷积 ...
- 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...
- 深度学习笔记(一):logistic分类【转】
本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...
- PyTorch中使用深度学习(CNN和LSTM)的自动图像标题
介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...
- keras框架下的深度学习(二)二分类和多分类问题
本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...
- 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型
代码仓库: https://github.com/brandonlyg/cute-dl 目标 上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...
- Python深度学习案例1--电影评论分类(二分类问题)
我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...
随机推荐
- Linux tar包安装Nginx-1.7.6 (yum方式安装依赖)
1.首先安装依赖包(依赖包有点多,我们采用yum的方式来安装) yum -y install zlib zlib-devel openssl openssl-devel pcre pcre-devel ...
- gulp配置
/* gulp配置 */ /* gulp配置 */ var gulp = require('gulp'), concat = require('gulp-concat'), rename = requ ...
- c#开发wps插件
wps 2016版比旧版感觉大气多了,加载速度快,操作方便,一直是wps的优点.随着wps的稳定性提高(当然比office还是差了很多),政府等一些部门采用几乎免费的wps来办公.我们公司决定把业务扩 ...
- 通过返回动态改变textview和imageview
//获取并显示优惠券ID Intent intent = getIntent(); awardID=(TextView)findViewById(R.id.awardID); String id = ...
- kubernetes1.9中部署dashboard
在1.9k8s中 dashboard可以有两种访问方式 kubeconfig(HTTPS)和token(http) 2018-03-18 一.基于token的访问1.下载官方的dashboardwge ...
- PHP基础入门(一)
php现在很火的后台开发语言,它融合了许多其他的语言,所以它的灵活性不用多说.话不多说,我们开始php的学习吧! 整数类型:$变量名=132;浮点类型:$变量名=1.32;字符串类型:$变量名=&qu ...
- guava cache使用和源码分析
guava cache的优点和使用场景,用来判断业务中是否适合使用此缓存 介绍常用的方法,并给出示例,作为使用的参考 深入解读源码. guava简介 guava cache是一个本地缓存.有以下优点: ...
- Mysql使用规范文档 20180223版
强制:不允许在跳板机上/生产服务器上手工连接,查询或更改线上数据 强制:所有上线脚本必须先在测试环境执行,验证通过以后方可在生产环境执行. 强制:上线脚本的编码格式统一为UTF-8 强制:访问数据库需 ...
- 在oracle中,group by后将字符拼接,以及自定义排序
1.在oracle中,group by后将字符拼接.任务:在学生表中,有studentid和subject两个字段.要求对studentid进行group by分组,并将所选科目拼接在一起.oracl ...
- Redis学习笔记01--主从数据库配置
1.创建公共配置文件 所有配置文件添加到以下目录: /xxxx/redis-slave-master 创建公共的redis配置文件,直接使用redis的默认配置文件,修改以下配置项: bind 127 ...