Caffe2 载入预训练模型(Loading Pre-Trained Models)[7]
这一节我们主要讲述如何使用预训练模型。Ipython notebook链接在这里。
模型下载
你可以去Model Zoo下载预训练好的模型,或者使用Caffe2的models.download
模块获取预训练的模型。caffe2.python.models.download
需要模型的名字所谓参数。你可以去看看有什么模型可用,然后替换下面代码中的squeezenet
。
python -m caffe2.python.models.download -i squeezenet
译者注:如果不明白为什么用python -m 执行,可以看看这个帖子。
如果上面下载成功,那么你应该下载了 squeezenet到你的文件夹中。如果你使用i那么模型文件将下载到/caffe2/python/models
文件夹中。当然,你也可以下载所有模型文件:git clone https://github.com/caffe2/models
。
Overview
在这个教程中,我们将会使用squeezenet
模型进行图片的目标识别。如果,你读了前面的预处理章节,那么你会看到我们使用rescale和crop对图像进行处理。同时做了CHW和BGR的转换,最后的图像数据是NCHW。我们也统计了图像均值,而不是简单地将图像减去128.
你会发现载入预处理模型是相当简单的,仅仅需要几行代码就可以了。
- 读取protobuf文件
with open("init_net.pb") as f:
init_net = f.read()
with open("predict_net.pb") as f:
predict_net = f.read()
- 使用
Predictor
函数从protobuf中载入blobs数据
p = workspace.Predictor(init_net, predict_net)
- 跑网络并获取结果
results = p.run([img])
返回的结果是一个多维概率的矩阵,每一行是一个百分比,表示网络识别出图像属于某一个物体的概率。当你使用前面那张花图来测试时,网络的返回应该告诉你超过95的概率是雏菊。
Configuration
网络设置如下:
# 你安装caffe2的路径
CAFFE2_ROOT = "~/caffe2"
# 假设是caffe2的子目录
CAFFE_MODELS = "~/caffe2/caffe2/python/models"
#如果你有mean file,把它放在模型文件那个目录里面
%matplotlib inline
from caffe2.proto import caffe2_pb2
import numpy as np
import skimage.io
import skimage.transform
from matplotlib import pyplot
import os
from caffe2.python import core, workspace
import urllib2
print("Required modules imported.")
传递图像的路径,或者网络图像的URL。物体编码参照Alex Net,比如“985”代表是“雏菊”。其他编码参照这里。
IMAGE_LOCATION = "https://cdn.pixabay.com/photo/2015/02/10/21/28/flower-631765_1280.jpg"
# 参数格式: folder, INIT_NET, predict_net, mean , input image size
MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.npy', 227
# AlexNet的物体编码
codes = "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes"
print "Config set!"
处理图像
def crop_center(img,cropx,cropy):
y,x,c = img.shape
startx = x//2-(cropx//2)
starty = y//2-(cropy//2)
return img[starty:starty+cropy,startx:startx+cropx]
def rescale(img, input_height, input_width):
print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!")
print("Model's input shape is %dx%d") % (input_height, input_width)
aspect = img.shape[1]/float(img.shape[0])
print("Orginal aspect ratio: " + str(aspect))
if(aspect>1):
# landscape orientation - wide image
res = int(aspect * input_height)
imgScaled = skimage.transform.resize(img, (input_width, res))
if(aspect<1):
# portrait orientation - tall image
res = int(input_width/aspect)
imgScaled = skimage.transform.resize(img, (res, input_height))
if(aspect == 1):
imgScaled = skimage.transform.resize(img, (input_width, input_height))
pyplot.figure()
pyplot.imshow(imgScaled)
pyplot.axis('on')
pyplot.title('Rescaled image')
print("New image shape:" + str(imgScaled.shape) + " in HWC")
return imgScaled
print "Functions set."
# set paths and variables from model choice and prep image
CAFFE2_ROOT = os.path.expanduser(CAFFE2_ROOT)
CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS)
# 均值最好从训练集中计算得到
MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[3])
if not os.path.exists(MEAN_FILE):
mean = 128
else:
mean = np.load(MEAN_FILE).mean(1).mean(1)
mean = mean[:, np.newaxis, np.newaxis]
print "mean was set to: ", mean
# 输入大小
INPUT_IMAGE_SIZE = MODEL[4]
# 确保所有文件存在
if not os.path.exists(CAFFE2_ROOT):
print("Houston, you may have a problem.")
INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1])
print 'INIT_NET = ', INIT_NET
PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2])
print 'PREDICT_NET = ', PREDICT_NET
if not os.path.exists(INIT_NET):
print(INIT_NET + " not found!")
else:
print "Found ", INIT_NET, "...Now looking for", PREDICT_NET
if not os.path.exists(PREDICT_NET):
print "Caffe model file, " + PREDICT_NET + " was not found!"
else:
print "All needed files found! Loading the model in the next block."
#载入一张图像
img = skimage.img_as_float(skimage.io.imread(IMAGE_LOCATION)).astype(np.float32)
img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
print "After crop: " , img.shape
pyplot.figure()
pyplot.imshow(img)
pyplot.axis('on')
pyplot.title('Cropped')
# 转换为CHW
img = img.swapaxes(1, 2).swapaxes(0, 1)
pyplot.figure()
for i in range(3):
pyplot.subplot(1, 3, i+1)
pyplot.imshow(img[i])
pyplot.axis('off')
pyplot.title('RGB channel %d' % (i+1))
#转换为BGR
img = img[(2, 1, 0), :, :]
# 减均值
img = img * 255 - mean
# 增加batch size
img = img[np.newaxis, :, :, :].astype(np.float32)
print "NCHW: ", img.shape
状态输出:
Functions set.
mean was set to: 128
INIT_NET = /home/aaron/models/squeezenet/init_net.pb
PREDICT_NET = /home/aaron/models/squeezenet/predict_net.pb
Found /home/aaron/models/squeezenet/init_net.pb ...Now looking for /home/aaron/models/squeezenet/predict_net.pb
All needed files found! Loading the model in the next block.
Original image shape:(751, 1280, 3) and remember it should be in H, W, C!
Model's input shape is 227x227
Orginal aspect ratio: 1.70439414115
New image shape:(227, 386, 3) in HWC
After crop: (227, 227, 3)
NCHW: (1, 3, 227, 227)
既然图像准备好了,那么放进CNN里面吧。打开protobuf,载入到workspace中,并跑起网络。
#初始化网络
with open(INIT_NET) as f:
init_net = f.read()
with open(PREDICT_NET) as f:
predict_net = f.read()
p = workspace.Predictor(init_net, predict_net)
# 进行预测
results = p.run([img])
# 把结果转换为np矩阵
results = np.asarray(results)
print "results shape: ", results.shape
results shape: (1, 1, 1000, 1, 1)
看到1000没。如果我们batch很大,那么这个矩阵将会很大,但是中间的维度仍然是1000。它记录着模型预测的每一个类别的概率。现在,让我们继续下一步。
results = np.delete(results, 1)#这句话不是很明白
index = 0
highest = 0
arr = np.empty((0,2), dtype=object)#创建一个0x2的矩阵?
arr[:,0] = int(10)#这是什么个意思?
arr[:,1:] = float(10)
for i, r in enumerate(results):
# imagenet的索引从1开始
i=i+1
arr = np.append(arr, np.array([[i,r]]), axis=0)
if (r > highest):
highest = r
index = i
print index, " :: ", highest
# top 3 结果
# sorted(arr, key=lambda x: x[1], reverse=True)[:3]
# 获取 code list
response = urllib2.urlopen(codes)
for line in response:
code, result = line.partition(":")[::2]
if (code.strip() == str(index)):
print result.strip()[1:-2]
最后输出:
985 :: 0.979059
daisy
译者注:上面最后一段处理结果的代码,译者也不是很明白,有木有明白的同学在下面回复下?
转载请注明出处:http://www.jianshu.com/c/cf07b31bb5f2
Caffe2 载入预训练模型(Loading Pre-Trained Models)[7]的更多相关文章
- 预训练模型与Keras.applications.models权重资源地址
什么是预训练模型 简单来说,预训练模型(pre-trained model)是前人为了解决类似问题所创造出来的模型.你在解决问题的时候,不用从零开始训练一个新模型,可以从在类似问题中训练过的模型入手. ...
- 我的Keras使用总结(3)——利用bottleneck features进行微调预训练模型VGG16
Keras的预训练模型地址:https://github.com/fchollet/deep-learning-models/releases 一个稍微讲究一点的办法是,利用在大规模数据集上预训练好的 ...
- 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用
本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...
- tensorflow利用预训练模型进行目标检测(二):预训练模型的使用
一.运行样例 官网链接:https://github.com/tensorflow/models/blob/master/research/object_detection/object_detect ...
- tensorflow 预训练模型列表
tensorflow 预训练模型列表 https://github.com/tensorflow/models/tree/master/research/slim Pre-trained Models ...
- 最强 NLP 预训练模型库 PyTorch-Transformers 正式开源:支持 6 个预训练框架,27 个预训练模型
先上开源地址: https://github.com/huggingface/pytorch-transformers#quick-tour 官网: https://huggingface.co/py ...
- Paddle预训练模型应用工具PaddleHub
Paddle预训练模型应用工具PaddleHub 本文主要介绍如何使用飞桨预训练模型管理工具PaddleHub,快速体验模型以及实现迁移学习.建议使用GPU环境运行相关程序,可以在启动环境时,如下图所 ...
- 【AI】Pytorch_预训练模型
1. 模型下载 import re import os import glob import torch from torch.hub import download_url_to_file from ...
- NLP与深度学习(五)BERT预训练模型
1. BERT简介 Transformer架构的出现,是NLP界的一个重要的里程碑.它激发了很多基于此架构的模型,其中一个非常重要的模型就是BERT. BERT的全称是Bidirectional En ...
随机推荐
- samba对外开放的端口
前言搭建samba的时候,如果是在内网\测试环境中,可以直接关闭防火墙,但是如果是在外网情况下,需要对防火墙开放某些端口.开放的具体步骤,下面我们来看. 操作步骤1.添加端口 firewall-cmd ...
- CSS学习(4)常见样式声明
1.文本 color 文字颜色 预设值:定义好的单词,如red blue 光学的三原色(红,绿,蓝),如 rgb(32,45,255) HEX十六进制,如#008CFF(#112233可以简写为#12 ...
- 每天进步一点点------Sobel算子(2)
转载 http://blog.csdn.net/tianhai110 索贝尔算子(Sobel operator)主要用作边缘检测,在技术上,它是一离散性差分算子,用来运算图像亮度函数的灰度之近似值. ...
- hadoop学习笔记(四):HDFS文件权限,安全模式,以及整体注意点总结
本文原创,转载注明作者和原文链接! 一:总结注意点: 到现在为止学习到的角色:三个NameNode.SecondaryNameNode.DataNode 1.存储的是每一个文件分割存储之后的元数据信息 ...
- JS-禁用浏览器前进后退
使用jQuery: <script type="text/javascript" language="javascript"> $(document ...
- 红帽RHCE培训-课程3笔记内容2
9 NFS 9.1 NFS基础 目标 .使用NFS将文件系统连接到客户端,并使用IP 地址控制访问 .使用NFS将文件系统连接到客户端,并使用kerberos 来控制访问 .配置用户名和密码控制访问的 ...
- C#面向对象三大特性:封装
什么是封装 定义:把一个或多个项目封闭在一个物理的或者逻辑的包中.在面向对象程序设计方法论中,封装是为了防止对实现细节的访问. 封装的优点 1. 隔离性,安全性.被封装后的对象(这里的对象是泛指代码的 ...
- ES6-三点运算符
首先理解一下函数总的arguments变量,这个变量是函数内部自动生成的,他用来保存传入函数的实参,是一个伪数组. 例: function fun(a,b){ console.log(argument ...
- 关注Ionic底部导航按钮tabs在android情况下浮在上面的处理
Ionic是一款流行的移动端开发框架,但是刚入门的同学会发现,Ionic在IOS和android的底部tabs显示不一样.在安卓情况下底部tabs会浮上去. 如下图展示: 网上也有很多此类的解决方案 ...
- vim 一些操作
在ESC下 gg # 光标跳到开头 dG # 删除光标后的数据 dd # 删除光标所在行 gg dG # 删除全部 (光标跳到开头&删除光标后的数据) x # 删除当前光标下的字符 i # 编 ...