tf.keras 的预训练模型都放在了'tensorflow.python.keras.applications' 目录下,在 tensorflow 1.10 版本中,预训练好的模型有:

DenseNet121, DenseNet169, DenseNet201, InceptionResNetV2, InceptionV3, MobileNet, NASNetLarge, NASNetMobile, ResNet50, VGG16, VGG19, Xception.

找了半天,发现 keras 没有预训练好的 AlexNet。。。

所以本文提供一种从其它框架(如 PyTorch)导入预训练模型的方法,下面以 AlexNet 为例。代码参见 wuliytTaotao · Github

从 PyTorch 中导出模型参数

首先明白一点,当模型的结构一样时,我们只需要导入模型的参数即可复现模型,所以我们要做的就是从 PyTorch 中导出预训练好的模型参数,并用 keras 加载。

这里要介绍一个微软的项目:MMdnn。MMdnn 使我们可以在不同深度学习框架之间转换模型,这里我也使用 MMdnn 来转换 AlexNet(PyTorch to Keras)。

第 0 步:配置环境

必须一致配置:
- PyTorch: 0.4.0 (如果其它版本出现了问题,请退回到 0.4.0 版) 非必须一致配置:
- numpy: 1.14.5
- Keras: 2.1.3 (非 tensorflow 中的 keras)

第 1 步:安装 MMdnn

$ pip3 install mmdnn

我安装的 mmdnn 版本为 0.2.5。

其它安装方式请参考 github

第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件)

PyTorch 保存模型时,可以保存整个模型,也可以仅保存模型的参数,都是存放到 pth 文件中。

mmdnn 操作的 pth 文件是要求含有模型结构的,具体参见 FAQ,而在 PyTorch 中预训练 AlexNet 仅保存了参数。

通过以下程序得到包含有模型结构和权重的 AlexNet 预训练模型(pth 文件):

import torch
import torchvision m = torchvision.models.alexnet(pretrained=True)
torch.save(m, './alexnet.pth')

对于其它模型,如 resnet101,可以通过以下指令直接得到含有结构和权重的预训练模型:

$ mmdownload -f pytorch -n resnet101 -o ./

(不要通过上述指令得到 alexnet.pth,因为其仅仅包含权重,而不含结构,故后面一步会出现错误 "AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'"。)

第 3 步:导出 PyTorch 模型的参数,保存至 hdf5 文件

依次执行以下三条指令,最后会得到一个 'keras_alexnet.h5' 文件,这就是我们想要的 keras 能加载的预训练权重文件。

$ mmtoir -f pytorch -d alexnet --inputShape 3,227,227 -n alexnet.pth
IR network structure is saved as [alexnet.json].
IR network structure is saved as [alexnet.pb].
IR weights are saved as [alexnet.npy].
$ mmtocode -f keras --IRModelPath alexnet.pb --IRWeightPath alexnet.npy --dstModelPath keras_alexnet.py
Using TensorFlow backend.
Parse file [alexnet.pb] with binary format successfully.
Target network code snippet is saved as [keras_alexnet.py].
$ python3 -m mmdnn.conversion.examples.keras.imagenet_test -n keras_alexnet.py -w alexnet.npy --dump keras_alexnet.h5
Using TensorFlow backend.
Keras model file is saved as [keras_alexnet.h5], generated by [keras_alexnet.py.py] and [alexnet.npy].

可能遇到的问题

  • AttributeError: 'Conv2d' object has no attribute 'padding_mode'

Solution:PyTorch 版本问题,1.1.0 版会出现这个问题,回退到 0.4.0 版本即可。

$ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade torch==0.4.0 torchvision==0.2.0

Solution:请更改 numpy 版本。

Solution:pth 文件仅含模型参数而不含模型结构,在 PyTorch 中加载一下然后保存含有模型结构和参数的 pth 文件。

验证从 PyTorch 导出的 AlexNet 预训练模型

测试用的几张图片、代码以及生成的 keras_alexnet.h5 文件都存放到了 wuliytTaotao · Github

import torch
import torchvision
import cv2
import numpy as np from torch.autograd import Variable import tensorflow as tf
from tensorflow.keras import layers,regularizers filename_test = 'data/dog2.png' img = cv2.imread(filename_test)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 数据预处理
img = cv2.resize(img, (227, 227))
img = img / 255.0
img = np.reshape(img, (1, 227, 227, 3))
# 标准化,这是 PyTorch 预训练 AlexNet 模型的预处理方式,详情请见 https://pytorch.org/docs/stable/torchvision/models.html
mean = np.array([0.485, 0.456, 0.406]).reshape([1, 1, 1, 3])
std = np.array([0.229, 0.224, 0.225]).reshape([1, 1, 1, 3])
img = (img - mean) / std # PyTorch
# PyTorch 数据输入 channel 排列和 Keras 不一致
img_tmp = np.transpose(img, (0, 3, 1, 2)) model = torchvision.models.alexnet(pretrained=True) # torch.save(model, './model/alexnet.pth')
model = model.double()
model.eval() y = model(Variable(torch.tensor(img_tmp)))
# 预测的类别
print(np.argmax(y.detach().numpy())) # Keras
def get_AlexNet(num_classes=1000, drop_rate=0.5):
"""
PyTorch 中实现的 AlexNet 预训练模型结构,filter 的深度分别为:(64,192,384,256,256)。
返回 AlexNet 的 inputs 和 outputs
"""
inputs = layers.Input(shape=[227, 227, 3]) conv1 = layers.Conv2D(64, (11, 11), strides=(4, 4), padding='valid', activation='relu')(inputs) pool1 = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(conv1) conv2 = layers.Conv2D(192, (5, 5), strides=(1, 1), padding='same', activation='relu')(pool1) pool2 = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(conv2) conv3 = layers.Conv2D(384, (3, 3), strides=(1, 1), padding='same', activation='relu')(pool2) conv4 = layers.Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')(conv3) conv5 = layers.Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')(conv4) pool3 = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(conv5) flat = layers.Flatten()(pool3) dense1 = layers.Dense(4096, activation='relu')(flat)
dense1 = layers.Dropout(drop_rate)(dense1)
dense2 = layers.Dense(4096, activation='relu')(dense1)
dense2 = layers.Dropout(drop_rate)(dense2)
outputs = layers.Dense(num_classes, activation='softmax')(dense2) return inputs, outputs inputs, outputs = get_AlexNet()
model2 = tf.keras.Model(inputs, outputs)
model2.load_weights('./keras_alexnet.h5')
# 预测的类别
print(np.argmax(model2.predict(img)))

预测结果代表的类别请看博客 ImageNet图像库1000个类别名称(中文注释不断更新)

Attentions

PyTorch 中的预训练 AlexNet 模型卷积层 filter 的个数和原论文不一致,filter 的个数分别 \(64,192,384,256,256\)。具体参见 GitHub - pytorch: vision/torchvision/models/alexnet.py

PyTorch 给出的解释是,它的预训练 AlexNet 模型用的是论文 Krizhevsky, A. (2014). One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv:1404.5997. 给出的架构,但 PyTorch 的模型架构和这篇论文还是有区别,这篇论文中第四个卷积层 filter 个数为 384,而 PyTorch 为 256。

而 caffe 中实现的 AlexNet 含有原始的 LRN 层,去掉 LRN 层后,个人感觉预训练的权重就不能直接拿来用了。

References

PyTorch

GitHub - microsoft/MMdnn

GitHub - pytorch: vision/torchvision/models/alexnet.py

ImageNet图像库1000个类别名称(中文注释不断更新)-- 徐小妹

【tf.keras】tf.keras加载AlexNet预训练模型的更多相关文章

  1. javascript图片懒加载与预加载的分析

    javascript图片懒加载与预加载的分析 懒加载与预加载的基本概念.  懒加载也叫延迟加载:前一篇文章有介绍:JS图片延迟加载 延迟加载图片或符合某些条件时才加载某些图片. 预加载:提前加载图片, ...

  2. 基于jQuery的图片异步加载和预加载实例

    如今的网页中有很多图片,比如相册列表,那么如果一次性读取图片将会瞬间加重服务器的负担,所以我们用jQuery来实现图片的异步加载和预加载功能,这样在页面的可视范围内才会加载图片,当拖动页面至可视界面时 ...

  3. 带你认识网站图片img懒加载和预加载的区别

    懒加载 什么是懒加载? 懒加载也就是延迟加载.当访问一个页面的时候,先把img元素或是其他元素的背景图片路径替换成一张大小为1*1px图片的路径(这样就只需请求一次,俗称占位图),只有当图片出现在浏览 ...

  4. [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model

    在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...

  5. keras 从txt加载预测数据

    ImageDataGenerator.flow_from_directory()的用法已经非常多了,优点是简单方便,但数据量很大时,需要组织目录结构和copy数据,很浪费资源和时间 1. 训练时从tx ...

  6. 加载执行预编译的Sql :prepareStatement

    1.获得连接:Connection con = null; con = DBUtil.getConnection(); 2.写sql语句:String sql=""; 3.用连接加 ...

  7. django模型层优化(关联对象) 懒加载和预加载 +长链接

    懒加载 存在于外键和多对多关系不检索关联对象的数据调用关联对象会再次查询数据库 问题根源 查看django orm的数据加载,两次. 查询user,查询menu 预加载的方法 预加载单个关联对象--s ...

  8. 转 Keras 保存与加载网络模型

    https://blog.csdn.net/qq_28413479/article/details/77367665

  9. 微信小程序 - 分包加载(预下载)

    开发者可以通过配置,在进入小程序某个页面时,由框架自动预下载可能需要的分包,提升进入后续分包页面时的启动速度.对于独立分包,也可以预下载主包. 配置方法 预下载分包行为在进入某个页面时触发,通过在 a ...

随机推荐

  1. JWT(JSON WEB TOKEN) / oauth2 / SSL

    1: JWT: 为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准((RFC 7519).该token被设计为紧凑且安全的,特别适用于分布式站点的单点登录(SSO)场景.JWT的声明一般被 ...

  2. arm交叉编译 扫盲贴

    ARM交叉编译工具链 为什么要用交叉编译器? 交叉编译通俗地讲就是在一种平台上编译出能运行在体系结构不同的另一种平台上的程序, 比如在PC平台(X86 CPU)上编译出能运行在以ARM为内核的CPU平 ...

  3. mysql查询语句例题

    1.一条SQL语句查询两表中两个字段 首先描述问题,student表中有字段startID,endID.garde表中的ID需要对应student表中的startID或者student表中的endID ...

  4. Windchill 配置LOG文件,使开发中的代码能显示打印的信息

    如开发代码的类HomeLogic.java, 包路径在pnt.report.home 需求:需监控此类的打印数据 方法:配置D:\ptc\Windchill_10.1\Windchill\codeba ...

  5. PullToRefresh------ListView的使用

    第一步 :写出布局文件的设置 <com.handmark.pulltorefresh.library.PullToRefreshListView android:id="@+id/pu ...

  6. Java 常见注解

    @Retention 1.RetentionPolicy.SOURCE —— 这种类型的Annotations只在源代码级别保留,编译时就会被忽略2.RetentionPolicy.CLASS —— ...

  7. Koa1 框架

    安装创建项目: 1.一定要全局安装(koa1.2和koa2都己经支持) npm install koa-generator -g 2.koa1 生成一个test项目,切到test目录并下载依赖 koa ...

  8. 由hibernate配置inverse="true"而导致的软件错误,并分析解决此问题的过程

    题目背景软件是用来做安装部署的工具,在部署一套系统时会有很多安装包,通过此工具,可以生成一个xml文件用以保存每个安装包的文件位置.顺序.参数.所需脚本.依赖条件验证(OS..net.IIS.数据版本 ...

  9. Python通过调用windows命令行处理sam文件

    Python通过调用windows命令行处理sam文件 以samtools软件为例 一.下载或者索取得到windows版本的samtools软件,解压后如下: 进入文件内部,有如下几个文件: 二.将s ...

  10. 6.2 卸载原来的Ubuntu,重新安装Ubuntu

    6.1日其实已经成功安装了Ubuntu,6.2日打开电脑,进入Ubuntu系统,发现自己6.1日保存的工作,比如下载的文档和做的笔记,都不在Ubuntu系统中了.当时觉得特别奇怪,第一反应就是,我的U ...