Run Your Tensorflow Deep Learning Models on Google AI
People commonly tend to put much effort on hyperparameter tuning and training while using Tensoflow&Deep Learning. A realistic problem for TF is how to integrate models into industry: saving pre-trained models, restoring them when necessary, and doing predictions regarding to request input. Fortunately, Google AI helps!
Actually, while a model is trained, tensorflow has two different modes to save it. Most people and blog posts adopt Checkpoint, which refers to 'Training Mode'. The training work continues if someone load the checkpoint. But a drawback is you have to define the architecture once and once again before restore the checkpoint. Another mode called 'SavedModel' is more suitable for serving (release version product). Applications can send prediction requests to a server where the 'SavedModel' is deployed, and then responses will be sent back.
Before that, we only need to follow three steps: save the model properly, deploy it onto Google AI, transform data to required format then request. I am going to illustrate them one by one.
1. Save the model into SavedModel:
In a typical tensorflow training work, architecture is defined first, then it is trained, finally comes to saving part. We just jump to the saving code: the function used here is 'simple_save', and four parameters are session, saving folder, input variable&name, output variable&name.
tf.saved_model.simple_save(sess, 'simple_save/model', \
inputs={"x": tf_x},outputs={"pred": pred})
After that, we got the saved model on the target directory:
saved_model.pb
/variables/variables.index
/variables/variables.data-00000-of-00001
2. Deploy SavedModel onto Google AI:
On Google Cloud, files are stored on Google Bucket, so first a Tensorflow model (.pb file and variables folder) need to be uploaded. Then create a Google AI model and a version. Actually there can be multiple versions under a model, which is quite like solving one task by different ways. You can even use several deep learning architectures as
different version, and then switch solutions when request predictions. Versions and Google Bucket location that stores the SavedModel are bound.
3. Doing online predictions:
Because we request prediction inside the application, and require immediate response, so we choose online prediction. The official code to request is shown below, which is HTTP Request and HTTP Response. The input and output data are all in Json Format. We can transform our input data into List, and call this function.
def predict_json(project, model, instances, version=None):
"""Send json data to a deployed model for prediction. Args:
project (str): project where the AI Platform Model is deployed.
model (str): model name.
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
your deployed model expects as inputs. Values should be datatypes
convertible to Tensors, or (potentially nested) lists of datatypes
convertible to tensors.
version: str, version of the model to target.
Returns:
Mapping[str: any]: dictionary of prediction results defined by the
model.
"""
# Create the AI Platform service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}'.format(project, model) if version is not None:
name += '/versions/{}'.format(version) response = service.projects().predict(
name=name,
body={'instances': instances}
).execute() if 'error' in response:
raise RuntimeError(response['error']) return response['predictions']
The response is also in Json format, I wrote a piece of code to transform it into Numpy Array:
def from_json_to_array(dict_list):
value_list = []
for dict_instance in dict_list:
instance = dict_instance.get('pred')
value_list.append(instance)
value_array = np.asarray(value_list)
return value_array
Yeah, that's it! Let's get your hands dirty!
Reference:
https://www.tensorflow.org/guide/saved_model
https://cloud.google.com/blog/products/ai-machine-learning/simplifying-ml-predictions-with-google-cloud-functions
https://cloud.google.com/ml-engine/docs/tensorflow/online-predict
Run Your Tensorflow Deep Learning Models on Google AI的更多相关文章
- How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras
Hyperparameter optimization is a big part of deep learning. The reason is that neural networks are n ...
- a Javascript library for training Deep Learning models
w强化算法和数学,来迎接机器学习.神经网络. http://cs.stanford.edu/people/karpathy/convnetjs/ ConvNetJS is a Javascript l ...
- Towards Deep Learning Models Resistant to Adversarial Attacks
目录 概 主要内容 Note Madry A, Makelov A, Schmidt L, et al. Towards Deep Learning Models Resistant to Adver ...
- (转) Awesome Deep Learning
Awesome Deep Learning Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...
- What are some good books/papers for learning deep learning?
What's the most effective way to get started with deep learning? 29 Answers Yoshua Bengio, ...
- (转) Deep Learning Resources
转自:http://www.jeremydjacksonphd.com/category/deep-learning/ Deep Learning Resources Posted on May 13 ...
- Machine and Deep Learning with Python
Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...
- The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near
The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...
- Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Regularization)
声明:所有内容来自coursera,作为个人学习笔记记录在这里. Regularization Welcome to the second assignment of this week. Deep ...
随机推荐
- mysql远程命令连接
#mysql -h 服务器地址 -P 端口 -u账号 -p密码 mysql -uroot -proot
- WOJ#3836 Sightseeing Trip
描述 给定一张无向图,求图中一个至少包含 3 个点的环,环上的节点不重复,并且环上的边的长度之和最小.该问题称为无向图的最小环问题.在本题中,你需要输出最小环的方案,若最小环不唯一,输出任意一个均可. ...
- 小白学Python(15)——pyecharts 绘制树形图表 Tree
Tree-基本示例 import json import os from pyecharts import options as opts from pyecharts.charts import P ...
- P1973 [NOI2011]Noi嘉年华
传送门 首先可以把时间区间离散化 然后求出 $cnt[l][r]$ 表示完全在时间 $[l,r]$ 之内的活动数量 设 $f[i][j]$ 表示当前考虑到时间 $i$,第一个会场活动数量为 $j$ 时 ...
- ORM数据库的增删改查
数据库可视化工具: https://sqlitestudio.pl/index.rvt from app01 import models def orm(request): #增加数据 # 方法1: ...
- java 进销存管理 商户管理 库存管理 springmvc SSM 项目源码
统介绍: 1.系统采用主流的 SSM 框架 jsp JSTL bootstrap html5 (PC浏览器使用) 2.springmvc +spring4.3.7+ mybaits3.3 SSM 普 ...
- 解析安装mysql
大多数人在结束咱们前面学习的基础知识的时候,其实一脸懵逼,不过我们已经开始步入了另一个新的高度,针对基础知识还是必须巩固针对性的进行补充,可以分模块总结:比如基础知识的数据结构---->函数-- ...
- unity2017 光照与渲染(二)FAQs
FAQ: 场景里的物体没有影子? 1)灯光是否开了影子 2)QualitySettings 中 shadows 的设置 3) 模型MeshRenderer 的 ReciveShadows 和 Cast ...
- 03.LinuxCentOS系统root目录LVM磁盘扩容
根目录LVM扩容操作步骤: [root@centos7 ~]# df -lh文件系统 容量 已用 可用 已用% 挂载点/dev/mapper/centos-root 50G 7.7G 43G 6% / ...
- python基础--2 字符串
整型 int python3里,不管数字多大都是int类型 python2里面有长整型long 将整型字符串转换为数字 # a='123' # print(type(a),a) # b=int(a) ...