使用VGG模型进行猫狗大战

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import torch
  5. import torch.nn as nn
  6. import torchvision
  7. from torchvision import models,transforms,datasets
  8. import time
  9. import json

1、下载数据

! wget https://static.leiphone.com/cat_dog.rar
! unrar x cat_dog.rar

2、数据处理

datasets 是 torchvision 中的一个包,可以用做加载图像数据。它可以以多线程(multi-thread)的形式从硬盘中读取数据,使用 mini-batch 的形式,在网络训练中向 GPU 输送。在使用CNN处理图像时,需要进行预处理。图片将被整理成 224×224×3 的大小,同时还将进行归一化处理。

  1. normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  2.  
  3. vgg_format = transforms.Compose([
  4. transforms.CenterCrop(224),
  5. transforms.ToTensor(),
  6. normalize,
  7. ])
  8.  
  9. #这里进行了修改,包括训练数据、验证数据、以及测试数据,分别在三个目录train/val/test
  10. import shutil
  11. data_dir = './cat_dog'
  12. os.mkdir("./cat_dog/train/cat")
  13. os.mkdir("./cat_dog/train/dog")
  14. os.mkdir("./cat_dog/val/cat")
  15. os.mkdir("./cat_dog/val/dog")
  16. for i in range(10000):
  17. cat_name = './cat_dog/train/cat_'+str(i)+'.jpg';
  18. dog_name = './cat_dog/train/dog_'+str(i)+'.jpg';
  19. shutil.move(cat_name,"./cat_dog/train/cat")
  20. shutil.move(dog_name,"./cat_dog/train/dog")
  21.  
  22. for i in range(1000):
  23. cat_name = './cat_dog/val/cat_'+str(i)+'.jpg';
  24. dog_name = './cat_dog/val/dog_'+str(i)+'.jpg';
  25. shutil.move(cat_name,"./cat_dog/val/cat")
  26. shutil.move(dog_name,"./cat_dog/val/dog")
  27. #读取测试问题的数据集
  28.  
  29. test_path = "./cat_dog/test/dogs_cats"
  30. os.mkdir(test_path)
  31. #移动到test_path
  32. for i in range(2000):
  33. name = './cat_dog/test/'+str(i)+'.jpg'
  34. shutil.move(name,"./cat_dog/test/dogs_cats")
  35.  
  36. file_list=os.listdir("./cat_dog/test/dogs_cats")
  37. #将图片名补全,防止读取顺序不对
  38. for file in file_list:
  39. #填充0后名字总共10位,包括扩展名
  40. filename = file.zfill(10)
  41. new_name =''.join(filename)
  42. os.rename(test_path+'/'+file,test_path+'/'+new_name)
  43. #将所有图片数据放到dsets内
  44. dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
  45. for x in ['train','val','test']}
  46. dset_sizes = {x: len(dsets[x]) for x in ['train','val','test']}
  47. dset_classes = dsets['train'].classes
  1. loader_train = torch.utils.data.DataLoader(dsets['train'], batch_size=64, shuffle=True, num_workers=6)
  2. loader_valid = torch.utils.data.DataLoader(dsets['val'], batch_size=5, shuffle=False, num_workers=6)
  3. #加入测试集
  4. loader_test = torch.utils.data.DataLoader(dsets['test'], batch_size=5,shuffle=False, num_workers=6)
  5.  
  6. '''
  7. valid 数据一共有2000张图,每个batch是5张,因此,下面进行遍历一共会输出到 400
  8. 同时,把第一个 batch 保存到 inputs_try, labels_try,分别查看
  9. '''
  10. count = 1
  11. for data in loader_test:
  12. print(count, end=',')
  13. if count%50==0:
  14. print()
  15. if count == 1:
  16. inputs_try,labels_try = data
  17. count +=1
  18.  
  19. print(labels_try)
  20. print(inputs_try.shape)
  1. # 显示图片的小程序
  2.  
  3. def imshow(inp, title=None):
  4. # Imshow for Tensor.
  5. inp = inp.numpy().transpose((1, 2, 0))
  6. mean = np.array([0.485, 0.456, 0.406])
  7. std = np.array([0.229, 0.224, 0.225])
  8. inp = np.clip(std * inp + mean, 0,1)
  9. plt.imshow(inp)
  10. if title is not None:
  11. plt.title(title)
  12. plt.pause(0.001) # pause a bit so that plots are updated
  1. # 显示 labels_try 的5张图片,即valid里第一个batch的5张图片
  2. out = torchvision.utils.make_grid(inputs_try)
  3. imshow(out, title=[dset_classes[x] for x in labels_try])

3. 创建 VGG Model

  1. !wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
  1. model_vgg = models.vgg16(pretrained=True)
  2.  
  3. with open('./imagenet_class_index.json') as f:
  4. class_dict = json.load(f)
  5. dic_imagenet = [class_dict[str(i)][1] for i in range(len(class_dict))]
  6.  
  7. inputs_try , labels_try = inputs_try.to(device), labels_try.to(device)
  8. model_vgg = model_vgg.to(device)
  9.  
  10. outputs_try = model_vgg(inputs_try)
  11.  
  12. print(outputs_try)
  13. print(outputs_try.shape)
  14.  
  15. '''
  16. 可以看到结果为5行,1000列的数据,每一列代表对每一种目标识别的结果。
  17. 但是我也可以观察到,结果非常奇葩,有负数,有正数,
  18. 为了将VGG网络输出的结果转化为对每一类的预测概率,我们把结果输入到 Softmax 函数
  19. '''
  20. m_softm = nn.Softmax(dim=1)
  21. probs = m_softm(outputs_try)
  22. vals_try,pred_try = torch.max(probs,dim=1)
  23.  
  24. print( 'prob sum: ', torch.sum(probs,1))
  25. print( 'vals_try: ', vals_try)
  26. print( 'pred_try: ', pred_try)
  27.  
  28. print([dic_imagenet[i] for i in pred_try.data])
  29. imshow(torchvision.utils.make_grid(inputs_try.data.cpu()),
  30. title=[dset_classes[x] for x in labels_try.data.cpu()])

4. 修改最后一层,冻结前面层的参数

  1. print(model_vgg)
  2.  
  3. model_vgg_new = model_vgg;
  4.  
  5. for param in model_vgg_new.parameters():
  6. param.requires_grad = False
  7. model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2)
  8. model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1)
  9.  
  10. model_vgg_new = model_vgg_new.to(device)
  11.  
  12. print(model_vgg_new.classifier)

5. 训练并测试全连接层

包括三个步骤:第1步,创建损失函数和优化器;第2步,训练模型;第3步,测试模型。

  1. '''
  2. 第一步:创建损失函数和优化器
  3.  
  4. 损失函数 NLLLoss() 的 输入 是一个对数概率向量和一个目标标签.
  5. 它不会为我们计算对数概率,适合最后一层是log_softmax()的网络.
  6. '''
  7. criterion = nn.NLLLoss()
  8.  
  9. # 学习率
  10. lr = 0.001
  11.  
  12. # 随机梯度下降
  13. optimizer_vgg = torch.optim.SGD(model_vgg_new.classifier[6].parameters(),lr = lr)
  14.  
  15. '''
  16. 第二步:训练模型
  17. '''
  18.  
  19. def train_model(model,dataloader,size,epochs=1,optimizer=None):
  20. model.train()
  21.  
  22. for epoch in range(epochs):
  23. running_loss = 0.0
  24. running_corrects = 0
  25. count = 0
  26. for inputs,classes in dataloader:
  27. inputs = inputs.to(device)
  28. classes = classes.to(device)
  29. outputs = model(inputs)
  30. loss = criterion(outputs,classes)
  31. optimizer = optimizer
  32. optimizer.zero_grad()
  33. loss.backward()
  34. optimizer.step()
  35. _,preds = torch.max(outputs.data,1)
  36. # statistics
  37. running_loss += loss.data.item()
  38. running_corrects += torch.sum(preds == classes.data)
  39. count += len(inputs)
  40. print('Training: No. ', count, ' process ... total: ', size)
  41. epoch_loss = running_loss / size
  42. epoch_acc = running_corrects.data.item() / size
  43. print('Loss: {:.4f} Acc: {:.4f}'.format(
  44. epoch_loss, epoch_acc))
  45.  
  46. # 模型训练
  47. train_model(model_vgg_new,loader_train,size=dset_sizes['train'], epochs=1,
  48. optimizer=optimizer_vgg)
  1. #验证模型正确率的代码
  2. def test_model(model,dataloader,size):
  3. model.eval()
  4. predictions = np.zeros(size)
  5. all_classes = np.zeros(size)
  6. all_proba = np.zeros((size,2))
  7. i = 0
  8. running_loss = 0.0
  9. running_corrects = 0
  10. for inputs,classes in dataloader:
  11. inputs = inputs.to(device)
  12. classes = classes.to(device)
  13. outputs = model(inputs)
  14. loss = criterion(outputs,classes)
  15. _,preds = torch.max(outputs.data,1)
  16. # statistics
  17. running_loss += loss.data.item()
  18. running_corrects += torch.sum(preds == classes.data)
  19. predictions[i:i+len(classes)] = preds.to('cpu').numpy()
  20. all_classes[i:i+len(classes)] = classes.to('cpu').numpy()
  21. all_proba[i:i+len(classes),:] = outputs.data.to('cpu').numpy()
  22. i += len(classes)
  23. print('validing: No. ', i, ' process ... total: ', size)
  24. epoch_loss = running_loss / size
  25. epoch_acc = running_corrects.data.item() / size
  26. print('Loss: {:.4f} Acc: {:.4f}'.format(
  27. epoch_loss, epoch_acc))
  28. return predictions, all_proba, all_classes
  29.  
  30. #predictions, all_proba, all_classes = test_model(model_vgg_new,loader_valid,size=dset_sizes['val'])
  31. #如果使用的是已有的模型,应该跑下面这行代码
  32. predictions, all_proba, all_classes = test_model(model_new,loader_valid,size=dset_sizes['val'])
  1. #这个是对测试集进行预测的代码
  2. def result_model(model,dataloader,size):
  3. model.eval()
  4. predictions=np.zeros((size,2),dtype='int')
  5. i = 0
  6. for inputs,classes in dataloader:
  7. inputs = inputs.to(device)
  8. outputs = model(inputs)
  9. #_表示的就是具体的value,preds表示下标,1表示在行上操作取最大值,返回类别
  10. _,preds = torch.max(outputs.data,1)
  11. predictions[i:i+len(classes),1] = preds.to('cpu').numpy();
  12. predictions[i:i+len(classes),0] = np.linspace(i,i+len(classes)-1,len(classes))
  13. #可在过程中看到部分结果
  14. print(predictions[i:i+len(classes),:])
  15. i += len(classes)
  16. print('creating: No. ', i, ' process ... total: ', size)
  17. return predictions
  18.  
  19. result = result_model(model_vgg_new,loader_test,size=dset_sizes['test'])
  20. #如果使用的是已有的模型,应该跑下面这行代码
  21. result = result_model(model_new,loader_test,size=dset_sizes['test'])
  22.  
  23. #这里是生成结果的文件,上传到AI研习社可以看到正确率
  24. np.savetxt("./cat_dog/result.csv",result,fmt="%d",delimiter=",")

6. 可视化模型预测结果(主观分析)

主观分析就是把预测的结果和相对应的测试图像输出出来看看,一般有四种方式:

随机查看一些预测正确的图片
随机查看一些预测错误的图片
预测正确,同时具有较大的probability的图片
预测错误,同时具有较大的probability的图片
最不确定的图片,比如说预测概率接近0.5的图片

  1. # 单次可视化显示的图片个数
  2. n_view = 8
  3. correct = np.where(predictions==all_classes)[0]
  4. from numpy.random import random, permutation
  5. idx = permutation(correct)[:n_view]
  6. print('random correct idx: ', idx)
  7. loader_correct = torch.utils.data.DataLoader([dsets['valid'][x] for x in idx],
  8. batch_size = n_view,shuffle=True)
  9. for data in loader_correct:
  10. inputs_cor,labels_cor = data
  11. # Make a grid from batch
  12. out = torchvision.utils.make_grid(inputs_cor)
  13. imshow(out, title=[l.item() for l in labels_cor])
  14.  
  15. print(all_classes)
  16. # 类似的思路,可以显示错误分类的图片,这里不再重复代码

【第4次作业】CNN实战的更多相关文章

  1. CNN实战篇-手把手教你利用开源数据进行图像识别(基于keras搭建)

    我一直强调做深度学习,最好是结合实际的数据上手,参照理论,对知识的掌握才会更加全面.先了解原理,然后找一匹数据来验证,这样会不断加深对理论的理解. 欢迎留言与交流! 数据来源: cifar10  (其 ...

  2. python作业/练习/实战:生成双色球小程序

    作业要求: 每注投注号码由6个红色球号码和1个蓝色球号码组成.红色球号码从1--33中选择:蓝色球号码从1--16中选择 代码范例 import random all_red_ball = [str( ...

  3. python作业/练习/实战:生成随机密码

    作业要求1.写一个函数,函数的功能是生成一批密码,存到文件里面 def gen_password(num): #num代表生成多少条密码2.密码复杂度要求 1)长度在,8-16位之间 2)密码必须包括 ...

  4. python作业/练习/实战:3、实现商品管理的一个程序

    作业要求 实现一个商品管理的一个程序,运行程序有三个选项,输入1添加商品:输入2删除商品:输入3 查看商品信息1.添加商品: 商品名称:xx 商品如果已经存在,提示商品已存在 商品价格:xx数量只能为 ...

  5. python作业/练习/实战:2、注册、登录(文件读写操作)

    作业要求 1.实现注册功能输入:username.passowrd,cpassowrd最多可以输错3次3个都不能为空用户名长度最少6位, 最长20位,用户名不能重复密码长度最少8位,最长15位两次输入 ...

  6. python作业/练习/实战:1、简单登录脚本

    作业要求 写一个登陆的小程序 username = xiaoming passwd = 123456 1.输入账号密码,输入正确就登陆成功, 提示:欢迎xxxx登陆,今天的日期是xxx. 2.输入错误 ...

  7. python作业/练习/实战:下载QQ群所有人的头像

    步骤与提示:1.在腾讯群网页中进入任意一个群,获取相关信息,可以用postman是试一下,可以看到我们要的是mems里面的数据,需要获取到QQ号和群名片,如果没有群名片的话取昵称2.根据QQ号下载头像 ...

  8. 《大数据实时计算引擎 Flink 实战与性能优化》新专栏

    基于 Flink 1.9 讲解的专栏,涉及入门.概念.原理.实战.性能调优.系统案例的讲解. 专栏介绍 扫码下面专栏二维码可以订阅该专栏 首发地址:http://www.54tianzhisheng. ...

  9. 深度学习之tensorflow2实战:多输出模型

    欢迎来到CNN实战,尽管我们刚刚开始,但还是要往前看!让我们开始吧! 数据集 链接:https://pan.baidu.com/s/1zztS32iuNynepLq7jiF6RA 提取码:ilxh,请 ...

  10. Selenium自动化测试,接口自动化测试开发,性能测试从入门到精通

    Selenium自动化测试,接口自动化测试开发,性能测试从入门到精通Selenium接口性能自动化测试基础部分:分层自动化思想Slenium介绍Selenium1.0/2.0/3.0Slenium R ...

随机推荐

  1. COOP/COHP(上)-PROOUT

    晶体轨道重叠布居 COOP(crystal orbital overlap population)的一个更为直观的名称是 重叠布居权重的态密度 (overlap population-weighted ...

  2. Navicate破解安装

    1.安装Navicate客户端     2. 注意安装完毕不要打开navicate,打开后后面可能出现rsa public key not found之类的错误,直接点击注册机,选择版本,点击patc ...

  3. git将本地文件上传到远程仓库

    要记住! 上传代码之前,一定要先下拉代码,如果有冲突(你和别人同时修改了某一个文件的某一行代码),那么就要先解决冲突,才能提交! 这里以将自己的本地文件上传至git仓库为例 1.首先进入需要上传的文件 ...

  4. 蓝桥杯训练赛二-问题 B

    字符串的输入输出处理. 输入 第一行是一个正整数N,最大为100.之后是多行字符串(行数大于N), 每一行字符串可能含有空格,字符数不超过1000. 输出 先将输入中的前N行字符串(可能含有空格)原样 ...

  5. FileStream与StreamReader区别

    FileStream操作字节,更适合大文件. StreamReader操作字符,更适合小文件

  6. pandas通过sqlalchemy写入pgsql报错can't adapt type 'numpy.int64'

    其实以前也遇到过,后来不了了之,但今天又出现了,还是大概记录下. 我个人习惯按我自己的理解搞事情,只要结果对就行,但不一定对. 分析下原因,突然想到dataframe中有一列全是列表,列表中全是整数, ...

  7. mysql 修改字符集相关操作

    修改某个表字段的字符集 ALTER TABLE apply_info MODIFY member_name varchar(128) CHARACTER SET utf8mb4; 查看某个库的字符集类 ...

  8. nginx配置文件过大导致起不来

    更改src/core/ngx_conf_file.c,默认只有4k,将下面值改大重新编译

  9. nginx文件上传模块 nginx_upload_module

    1.编译安装nginx wget https://github.com/fdintino/nginx-upload-module/archive/refs/heads/master.zip PS:原先 ...

  10. 2170. 使数组变成交替数组的最少操作数 (Medium)

    问题描述 2170. 使数组变成交替数组的最少操作数 (Medium) 给你一个下标从 0 开始的数组 nums ,该数组由 n 个正整数组成. 如果满足下述条件,则数组 nums 是一个 交替数组 ...