pytorch官网上两个例程
caffe用起来太笨重了,最近转到pytorch,用起来实在不要太方便,上手也非常快,这里贴一下pytorch官网上的两个小例程,掌握一下它的用法:
例程一:利用nn 这个module构建网络,实现一个图像分类的小功能;
链接:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
- # -*- coding:utf-8 -*-
- import torch
- from torch.autograd import Variable
- import torchvision
- import torchvision.transforms as transforms
- #数据预处理:转换为Tensor,归一化,设置训练集和验证集以及加载子进程数目
- transform = transforms.Compose([transforms.ToTensor() , transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 , 0.5 , 0.5))]) #前面参数是均值,后面是标准差
- trainset = torchvision.datasets.CIFAR10(root = './data' , train = True , download = True , transform = transform)
- trainloader = torch.utils.data.DataLoader(trainset , batch_size = 4 , shuffle = True , num_workers =2) #num_works = 2表示使用两个子进程加载数据
- testset = torchvision.datasets.CIFAR10(root = './data' , train = False , download = True , transform = transform)
- testloader = torch.utils.data.DataLoader(testset , batch_size = 4 , shuffle = True , num_workers = 2)
- classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck')
- import matplotlib.pyplot as plt
- import numpy as np
- import pylab
- def imshow(img):
- img = img / 2 + 0.5
- npimg = img.numpy()
- plt.imshow(np.transpose(npimg , (1 , 2 , 0)))
- pylab.show()
- dataiter = iter(trainloader)
- images , labels = dataiter.next()
- for i in range(4):
- p = plt.subplot()
- p.set_title("label: %5s" % classes[labels[i]])
- imshow(images[i])
- #构建网络
- from torch.autograd import Variable
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- class Net(nn.Module):
- def __init__(self):
- super(Net , self).__init__()
- self.conv1 = nn.Conv2d(3 , 6 , 5)
- self.pool = nn.MaxPool2d(2 , 2)
- self.conv2 = nn.Conv2d(6 , 16 , 5)
- self.fc1 = nn.Linear(16 * 5 * 5 , 120)
- self.fc2 = nn.Linear(120 , 84)
- self.fc3 = nn.Linear(84 , 10)
- def forward(self , x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1 , 16 * 5 * 5) #利用view函数使得conv2层输出的16*5*5维的特征图尺寸变为400大小从而方便后面的全连接层的连接
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
- net = Net()
- net.cuda()
- #define loss function
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.SGD(net.parameters() , lr = 0.001 , momentum = 0.9)
- #train the Network
- for epoch in range(2):
- running_loss = 0.0
- for i , data in enumerate(trainloader , 0):
- inputs , labels = data
- inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
- optimizer.zero_grad()
- #forward + backward + optimizer
- outputs = net(inputs)
- loss = criterion(outputs , labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.data[0]
- if i % 2000 == 1999:
- print('[%d , %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000))
- running_loss = 0.0
- print('Finished Training')
- dataiter = iter(testloader)
- images , labels = dataiter.next()
- imshow(torchvision.utils.make_grid(images))
- print('GroundTruth:' , ' '.join(classes[labels[j]] for j in range(4)))
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- print('Predicted: ' , ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
- correct = 0
- total = 0
- for data in testloader:
- images , labels = data
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- correct += (predicted == labels.cuda()).sum()
- total += labels.size(0)
- print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
- class_correct = torch.ones(10).cuda()
- class_total = torch.ones(10).cuda()
- for data in testloader:
- images , labels = data
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- c = (predicted == labels.cuda()).squeeze()
- #print(predicted.data[0])
- for i in range(4):
- label = labels[i]
- class_correct[label] += c[i]
- class_total[label] += 1
- for i in range(10):
- print('Accuracy of %5s : %2d %%' % (classes[i] , 100 * class_correct[i] / class_total[i]))
例程二:在resnet18的预训练模型上进行finetune,然后实现一个ants和bees的二分类功能:
链接:http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
- # -*- coding:utf-8 -*-
- from __future__ import print_function , division
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.optim import lr_scheduler
- from torch.autograd import Variable
- import numpy as np
- import torchvision
- from torchvision import datasets , models , transforms
- import matplotlib.pyplot as plt
- import time
- import os
- import pylab
- #data process
- data_transforms = {
- 'train' : transforms.Compose([
- transforms.RandomSizedCrop(224) ,
- transforms.RandomHorizontalFlip() ,
- transforms.ToTensor() ,
- transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
- ]) ,
- 'val' : transforms.Compose([
- transforms.Scale(256) ,
- transforms.CenterCrop(224) ,
- transforms.ToTensor() ,
- transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
- ]) ,
- }
- data_dir = 'hymenoptera_data'
- image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir , x) , data_transforms[x]) for x in ['train' , 'val']}
- dataloders = {x : torch.utils.data.DataLoader(image_datasets[x] , batch_size = 4 , shuffle = True , num_workers = 4) for x in ['train' , 'val']}
- dataset_sizes = {x : len(image_datasets[x]) for x in ['train' , 'val']}
- class_names = image_datasets['train'].classes
- print(class_names)
- use_gpu = torch.cuda.is_available()
- #show several images
- def imshow(inp , title = None):
- inp = inp.numpy().transpose((1 , 2 , 0))
- mean = np.array([0.485 , 0.456 , 0.406])
- std = np.array([0.229 , 0.224 , 0.225])
- inp = std * inp + mean
- inp = np.clip(inp , 0 , 1)
- plt.imshow(inp)
- if title is not None:
- plt.title(title)
- pylab.show()
- plt.pause(0.001)
- inputs , classes = next(iter(dataloders['train']))
- out = torchvision.utils.make_grid(inputs)
- imshow(out , title = [class_names[x] for x in classes])
- #train the model
- def train_model(model , criterion , optimizer , scheduler , num_epochs = 25):
- since = time.time()
- best_model_wts = model.state_dict() #Returns a dictionary containing a whole state of the module.
- best_acc = 0.0
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch , num_epochs - 1))
- print('-' * 10)
- #set the mode of model
- for phase in ['train' , 'val']:
- if phase == 'train':
- scheduler.step() #about lr and gamma
- model.train(True) #set model to training mode
- else:
- model.train(False) #set model to evaluate mode
- running_loss = 0.0
- running_corrects = 0
- #Iterate over data
- for data in dataloders[phase]:
- inputs , labels = data
- if use_gpu:
- inputs = Variable(inputs.cuda())
- labels = Variable(labels.cuda())
- else:
- inputs = Variable(inputs)
- lables = Variable(labels)
- optimizer.zero_grad()
- #forward
- outputs = model(inputs)
- _ , preds = torch.max(outputs , 1)
- loss = criterion(outputs , labels)
- #backward
- if phase == 'train':
- loss.backward() #backward of gradient
- optimizer.step() #strategy to drop
- running_loss += loss.data[0]
- running_corrects += torch.sum(preds.data == labels.data)
- epoch_loss = running_loss / dataset_sizes[phase]
- epoch_acc = running_corrects / dataset_sizes[phase]
- print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase , epoch_loss , epoch_acc))
- if phase == 'val' and epoch_acc > best_acc:
- best_acc = epoch_acc
- best_model_wts = model.state_dict()
- print()
- time_elapsed = time.time() - since
- print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60 , time_elapsed % 60))
- print('Best val Acc: {:4f}'.format(best_acc))
- model.load_state_dict(best_model_wts)
- return model
- #visualizing the model predictions
- def visualize_model(model , num_images = 6):
- images_so_far = 0
- fig = plt.figure()
- for i , data in enumerate(dataloders['val']):
- inputs , labels = data
- if use_gpu:
- inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
- else:
- inputs , labels = Variable(inputs) , Variable(labels)
- outputs = model(inputs)
- _ , preds = torch.max(outputs.data , 1)
- for j in range(inputs.size()[0]):
- images_so_far += 1
- ax = plt.subplot(num_images // 2 , 2 , images_so_far)
- ax.axis('off')
- ax.set_title('predicted: {}'.format(class_names[preds[j]]))
- imshow(inputs.cpu().data[j])
- if images_so_far == num_images:
- return
- #Finetuning the convnet
- from torchvision.models.resnet import model_urls
- model_urls['resnet18'] = model_urls['resnet18'].replace('https://' , 'http://')
- model_ft = models.resnet18(pretrained = True)
- num_ftrs = model_ft.fc.in_features
- model_ft.fc = nn.Linear(num_ftrs , 2)
- if use_gpu:
- model_ft = model_ft.cuda()
- criterion = nn.CrossEntropyLoss()
- optimizer_ft = optim.SGD(model_ft.parameters() , lr = 0.001 , momentum = 0.9)
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft , step_size = 7 , gamma = 0.1)
- #start finetuning
- model_ft = train_model(model_ft , criterion , optimizer_ft , exp_lr_scheduler , num_epochs = 25)
- torch.save(model_ft.state_dict() , '/home/zf/resnet18.pth')
- visualize_model(model_ft)
当然finetune的话有两种方式:在这个例子里
(1)只修改最后一层全连接层,输出类数改为2,然后在预训练模型上进行finetune;
(2)固定全连接层前面的卷积层参数,也就是它们不反向传播,只对最后一层进行反向传播;实现的时候前面这些层的requires_grad就设为False就OK了;
代码见下:
- model_conv = torchvision.models.resnet18(pretrained=True)
- for param in model_conv.parameters():
- param.requires_grad = False
- # Parameters of newly constructed modules have requires_grad=True by default
- num_ftrs = model_conv.fc.in_features
- model_conv.fc = nn.Linear(num_ftrs, 2)
- if use_gpu:
- model_conv = model_conv.cuda()
- criterion = nn.CrossEntropyLoss()
- # Observe that only parameters of final layer are being optimized as
- # opoosed to before.
- optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
- # Decay LR by a factor of 0.1 every 7 epochs
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
- model_conv = train_model(model_conv, criterion, optimizer_conv,
- exp_lr_scheduler, num_epochs=25)
可以说,从构建网络,到训练网络,再到测试,由于完全是python风格,实在是太方便了~
pytorch官网上两个例程的更多相关文章
- spring官网上下载历史版本的spring插件,springsource-tool-suite
spring官网下载地址(https://spring.io/tools/sts/all),历史版本地址(https://spring.io/tools/sts/legacy). 注:历史版本下载的都 ...
- jquery ui中的dialog,官网上经典的例子
jquery ui中的dialog,官网上经典的例子 jquery ui中dialog和easy ui中的dialog很像,但是最近用到的时候全然没有印象,一段时间不用就忘记了,这篇随笔介绍一下这 ...
- [pytorch] 官网教程+注释
pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...
- iOS开发:创建推送开发证书和生产证书,以及往极光推送官网上传证书的步骤方法
在极光官网上面上传应用的极光推送证书的实质其实就是上传导出的p12文件,在极光推送应用管理里面,需要上传两个p12文件,一个是生产证书,一个是开发证书 ,缺一不可,具体如下所示: 在开发者账号里面创建 ...
- 自己封装的Windows7 64位旗舰版,微软官网上下载的Windows7原版镜像制作,绝对纯净版
MSDN官网上下载的Windows7 64位 旗舰版原版镜像制作,绝对纯净版,无任何精简,不捆绑任何第三方软件.浏览器插件,不含任何木马.病毒等. 集成: 1.Office2010 2.DirectX ...
- 关于在官网上查看和下载特定版本的webrtc代码
注:这个方法已经不适用了,帖子没删只是留个纪念而已 gclient:如果不知道gclient是什么东西 ... 就别再往下看了. 下载特定版本的代码: #gclient sync --revision ...
- echarts官网上的动态加载数据bug被我解决。咳咳/。
又是昨天,为什么昨天发生了这么多事.没办法,谁让我今天没事可做呢. 昨天需求是动态加载数据,画一个实时监控的折线图.大概长这样. 我屁颠屁颠的把代码copy过来,一运行,caocaocao~bug出现 ...
- 训练DCGAN(pytorch官网版本)
将pytorch官网的python代码当下来,然后下载好celeba数据集(百度网盘),在代码旁新建celeba文件夹,将解压后的img_align_celeba文件夹放进去,就可以运行代码了. 输出 ...
- Jenkins利用官网上的rpm源安装
官网网址:https://pkg.jenkins.io/redhat/ (官网上有安装的命令,参考网址) 安装jdk yum install -y java-1.8.0- ...
随机推荐
- Android Service服务的生命周期
与activity类似,服务也存在生命周期回调方法,你可以实现这些方法来监控服务的状态变化,并在适当的时机执行一些操作. 以下代码提纲展示了服务的每个生命周期回调方法: public class Ex ...
- 设置outlook 2013 默认的ost路径
How To Change Default Data File (.OST) Location in Office 2013 To set the default location of an out ...
- 平衡树Splay
维护区间添加,删除,前驱,后继,排名,逆排名 普通平衡树 #include <cstdio> #define ls t[now].ch[0] #define rs t[now].ch[1] ...
- 【洛谷P4054】计数问题
题目大意:维护 N*M 个点,每个点有三个权值,支持单点修改,查询矩形区间内权值等于某个值的点的个数. 题解:矩阵可以看成两个维度,权值为第三个维度,为一个三维偏序维护问题.发现第三维仅仅为单点修改和 ...
- javascript面向对象精要第五章继承整理精要
javascript中使用原型链支持继承,当一个对象的[prototype]设置为另一个对象时, 就在这两个对象之间创建了一条原型对象链.如果要创建一个继承自其它对象的对象, 使用Object.cre ...
- c/c++ 整形转字符串
int findex;char instr[10]; sprintf(instr,"%d",findex); 好像ltoa用不了...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- poj 1330(RMQ&LCA入门题)
传送门:Problem 1330 https://www.cnblogs.com/violet-acmer/p/9686774.html 参考资料: http://dongxicheng.org/st ...
- vs widows服务的发布
1.在service1.cs里空白处点击右键,弹出菜单选择 添加安装程序 2.自动生成ProjectInstaller.cs文件后 可在InitializeComponent()方法里自定义服务名称 ...
- Yosimite10.10(Mac os)安装c/c++内存检测工具valgrind
1.下载支持包m4-1.4.13.tar.gz $ curl -O http://mirrors.kernel.org/gnu/m4/m4-1.4.13.tar.gz 2. 解压m4-1.4.13.t ...