dataset.py

'''
准备数据集
'''
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor,Compose,Normalize
import torchvision
import config def mnist_dataset(train):
func = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=(0.1307,),
std = (0.3081,)
)
]) #准备Mnist数据集
return MNIST(root="../mnist",train=train,download=False,transform=func) def get_dataloader(train = True):
mnist = mnist_dataset(train)
batch_size = config.train_batch_size if train else config.test_batch_size
return DataLoader(mnist,batch_size=batch_size,shuffle=True) if __name__ == '__main__':
for (images,labels) in get_dataloader():
print(images.size())
print(labels)
break

  model.py

'''定义模型'''

import torch.nn as nn
import torch.nn.functional as F class MnistModel(nn.Module):
def __init__(self):
super(MnistModel,self).__init__()
self.fc1 = nn.Linear(28*28,100)
self.fc2 = nn.Linear(100,10) def forward(self,image):
image_viwed = image.view(-1,28*28)
fc1_out = self.fc1(image_viwed)
fc1_out_relu = F.relu(fc1_out)
out = self.fc2(fc1_out_relu) return F.log_softmax(out,dim=-1)

  config.py

'''
项目配置
'''
import torch train_batch_size = 128
test_batch_size = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  train.py

'''
进行模型的训练
'''
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
from tqdm import tqdm
import numpy as np
import torch
import os
from eval import eval #实例化模型、优化器、损失函数
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(),lr=0.001) if os.path.exists("./model/mnist_net.pt"):
model.load_state_dict(torch.load("./model/mnist_net.pt"))
optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt")) #迭代训练 def train(epoch):
train_dataloader = get_dataloader(train=True)
bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
total_loss = []
for idx,(input,target) in bar:
input = input.to(config.device)
target = target.to(config.device)
#梯度置为0
optimizer.zero_grad()
#计算得到预测值
output = model(input)
#得到损失
loss = F.nll_loss(output,target)
total_loss.append(loss.item())
#反向传播,计算损失
loss.backward()
#参数更新
optimizer.step() if idx%10 ==0:
bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss)))
torch.save(model.state_dict(),"model/mnist_net.pt")
torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt") if __name__ == '__main__':
for i in range(10):
train(i)
eval()

  eval.py

'''
进行模型的训练
'''
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
import numpy as np
import torch
import os #迭代训练 def eval():
# 实例化模型、优化器、损失函数
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(), lr=0.01) if os.path.exists("./model/mnist_net.pt"):
model.load_state_dict(torch.load("./model/mnist_net.pt"))
optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
test_dataloader = get_dataloader(train=False)
total_loss = []
total_acc = []
with torch.no_grad():
for input,target in test_dataloader:
input = input.to(config.device)
target = target.to(config.device)
#计算得到预测值
output = model(input)
#计算损失
loss = F.nll_loss(output,target)
#反向传播,计算损失
total_loss.append(loss.item())
#计算准确率
pred = output.max(dim=-1)[-1]
total_acc.append(pred.eq(target).float().mean().item())
print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc))) if __name__ == '__main__':
eval()

  

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py
epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s]
test loss:0.17968503131142147,test acc:0.9453125
epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
test loss:0.12370304338916947,test acc:0.9624208860759493
epoch:2 idx:460,loss:0.10398845713577534: 99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s]
test loss:0.10385569722592077,test acc:0.9697389240506329
epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]
test loss:0.08691684670652015,test acc:0.9754746835443038
epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
test loss:0.0803159438309413,test acc:0.9760680379746836
epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s]
test loss:0.08102699166423158,test acc:0.9759691455696202
epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s]
test loss:0.07991968260347089,test acc:0.9769580696202531
epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s]
test loss:0.07767781678917288,test acc:0.9774525316455697
epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s]
test loss:0.07755146227494071,test acc:0.9773536392405063
epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s]
test loss:0.07112929566845863,test acc:0.9802215189873418

  接口interface.py

'''
进行模型的训练
'''
from models import MnistModel
from torch import optim
import config
import torch
import os
import cv2
import torchvision.transforms as transforms tranform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307,),
std = (0.3081,)
)]) # 实例化模型、优化器、损失函数
model = MnistModel()
optimizer = optim.Adam(model.parameters(), lr=0.01) if os.path.exists("./model/mnist_net.pt"):
model.load_state_dict(torch.load("./model/mnist_net.pt",map_location=lambda storage, loc: storage))
optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt",map_location=lambda storage, loc: storage)) #预测接口
def interface(pic_path):
img = cv2.imread(pic_path)
img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = tranform(img_gray)
# img = np.transpose(img, (2,0,1))
img = img.unsqueeze(0)
with torch.no_grad():
input = img
#计算得到预测值
output = model(input)
pred = output.max(dim=-1)[1]
print("识别结果为:",pred[0].to("cpu").numpy()) if __name__ == '__main__':
while True:
path = input("请输入图片地址:")
path = "./pic_test/"+path+".png"
print(path)
interface(path)

  

pytorch 手写数字识别项目 增量式训练的更多相关文章

  1. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  2. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  3. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  4. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  5. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

  6. 使用AI算法进行手写数字识别

    人工智能   人工智能(Artificial Intelligence,简称AI)一词最初是在1956年Dartmouth学会上提出的,从那以后,研究者们发展了众多理论和原理,人工智能的概念也随之扩展 ...

  7. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  8. 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)

    @ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...

  9. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

随机推荐

  1. 使用tensorflow的softmax进行mnist识别

    tensorflow真是方便,看来深度学习需要怎么使用框架.如何建模- ''' softmax classifier for mnist created on 2019.9.28 author: vi ...

  2. JS 剑指Offer(四) 从尾到头打印链表

    题目:输入一个链表的头节点,从尾到头反过来返回每个节点的值(用数组返回). 首先定义一下链表中的节点,关于链表这个数据结构在另外一篇文章中会详细讲 function ListNode(val) { t ...

  3. jmeter发送Query String Parameters格式参数报错

    当发起一次GET请求时,参数会以url string的形式进行传递.即?后的字符串则为其请求参数,并以&作为分隔符 当参数为json格式时,这时需要勾选编码,否则会报错

  4. Nginx 是如何处理 HTTP 头部的?

    Nginx 处理 HTTP 头部的过程 Nginx 在处理 HTTP 请求之前,首先需要 Nginx 的框架先和客户端建立好连接,然后接收用户发来的 HTTP 的请求行,比如方法.URL 等,然后接收 ...

  5. Maven多模块项目+MVC框架+AJAX技术+layui分页对数据库增删改查实例

    昨天刚入门Maven多模块项目,所以简单写了一个小测试,就是对数据库单表的增删改查,例子比较综合,写得哪里不妥还望大神赐教,感谢! 首先看一下项目结构: 可以看到,一个项目MavenEmployee里 ...

  6. Zabbix监控平台

                                                                     Zabbix监控平台 案例1:常用系统监控命令 案例2:部署Zabbi ...

  7. 大O表示法是什么?

    1.什么是大O表示法: 1.在算法描述中,我们用这种方式来描述计算机算法的效率. 2.在计算机中,这种粗略的量度叫做 "大O" 表示法. 3.在具体的情境中,利用大O表示法来描述具 ...

  8. Python+Tornado开发微信公众号

    本文已同步到专业技术网站 www.sufaith.com, 该网站专注于前后端开发技术与经验分享, 包含Web开发.Nodejs.Python.Linux.IT资讯等板块. 本教程针对的是已掌握Pyt ...

  9. Python zipfile模块学习

    转载自https://www.j4ml.com/t/15270 import zipfile import os from zipfile import ZipFile class ZipManage ...

  10. iphone se2的优缺点分析:

    4月15日晚间消息,在毫无征兆的情况下苹果公司刚刚正式发布iPhone SE二代手机,这款传闻多年的产品终于出现,国内定价人民币3299元起.本周五开始预定,4月24日开始送货. Phone SE‭‮ ...