【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化
TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了。tf.keras 从 1.x 版本迁移到 2.0 版本,需要注意几个地方。
1. 设置随机种子
import tensorflow as tf
# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)
2. 设置并行线程数和动态分配显存
import tensorflow as tf
from tensorflow.python.keras import backend as K
import os
# 将程序限定在一块GPU上
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配
K.set_session(tf.Session(config=config))
# TF 2.0,由于之前限定了GPU可见范围,这里只能看到0号GPU
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, enable=True)
3. model.compile() 中设置 metrics=['acc'] 或者 ['accuracy'],会影响 model.fit() 生成的 log,callbacks.ModelCheckpoint 需要对应填 val_acc 或者 val_accuracy:
from tensorflow.python.keras import callbacks
# TF 2.0, acc and val_acc
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['acc'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
verbose=1, save_best_only=True, save_weights_only=True)
# TF 2.0, accuracy and val_accuracy
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
verbose=1, save_best_only=True, save_weights_only=True)
4. 舍弃 model.fit_generator() 函数
model.fit_generator() 函数在 TF 2.x 中合并到 model.fit() 函数中,并且在 TF 2.0 版本,该函数有问题,不能很好利用 GPU,训练速度很慢:
Performance: Training is much slower in TF v2.0.0 VS v1.14.0 when using Tf.Keras and model.fit_generator #33024
TF 2.0 版本的 model.fit() 在传入 generator 时需要手动设置 model.fit(shuffle=False)。
解决办法:直接使用 model.fit() 函数,并且升级到 TF 2.1。
【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化的更多相关文章
- 【tf.keras】使用手册
目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...
- tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑
自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...
- 【tf.keras】实现 F1 score、precision、recall 等 metric
tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...
- 基于tensorflow2.0 使用tf.keras实现Fashion MNIST
本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...
- TensorFlow2.0(11):tf.keras建模三部曲
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- 【tf.keras】tf.keras使用tensorflow中定义的optimizer
Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...
- 【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
运行以下类似代码: while True: inputs, outputs = get_AlexNet() model = tf.keras.Model(inputs=inputs, outputs= ...
- 一文上手Tensorflow2.0之tf.keras(三)
系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...
- 【tf.keras】tensorflow datasets,tfds
一些最常用的数据集如 MNIST.Fashion MNIST.cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也常用的数据集如 SVHN.Caltech101,t ...
随机推荐
- CodeForces1006E- Military Problem
E. Military Problem time limit per test 3 seconds memory limit per test 256 megabytes input standard ...
- Python之Flask项目开发【入门必学】
前言 本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理.作者:藤藤菜丶 Flask 安装Flask模块 创建一个Flask项目 运行 ...
- django上传并显示图片
环境 python 3.5 django 1.10.6 步骤 创建名为 testupload的项目 django-admin startproject testupload 在项目testupload ...
- TypeScript躬行记(2)——接口
在传统的面向对象语言中,接口(Interface)好比协议,它会列出一系列的规则(即对行为进行抽象),再由类来实现这些规则.而TypeScript中的接口更加灵活,除了包含常规的作用之外,它还能扩展其 ...
- Python3 类与对象
目录 面向对象基础 面向过程编程 面向对象编程 类 什么是类 如何定义类 类的基本操作 对象 实例化对象 对象添加特有属性 对象与类的查找顺序 对象的绑定方法 面向对象基础 面向过程编程 面向过程的核 ...
- DARLING in the FRANXX
那是一种叫比翼的鸟,这种鸟只有半边翅膀,如果雌鸟和雄鸟不紧靠在一起的话,它们就无法飞上天空,是不完整的生物. 但是,为什么,我总认为这种生活方式是如此美丽,感觉如此美丽. 那是一种叫比翼的鸟,我以前从 ...
- 函数计算: 让小程序开发进入 Serverless 时代
点击下载<不一样的 双11 技术:阿里巴巴经济体云原生实践> 本文节选自<不一样的 双11 技术:阿里巴巴经济体云原生实践>一书,点击上方图片即可下载! 作者 | 吴天龙(木吴 ...
- 前端表单提交,提交有图片出现的问题,及解决方案 兼容ie9
更新一下我的小园子,主要说的是jq文件上传的过程中,如果出现上传的文件里有图片问题 其实文件上传有图片的情况下,不是什么大问题,对于前端来说,但是,如果需要兼容ie9的时候,就需要处理一下 文件上传如 ...
- 输出错误long类型
Microsoft Visual C++ 输出不了long 类型的数字怎么办? 在C/C++中,64为整型一直是一种没有确定规范的数据类型.现今主流的编译器中,对64为整型的支持也是标准不一,形态各异 ...
- Centos7 Openresty 开发环境
首先要安装编译环境 yum group install "Development Tools" yum install pcre-devel openssl-devel gcc c ...