一、背景描述

最近python的tensorflow项目要弄到线上去。网络用的Tensorflow现成的包。数据用kaggle中的数据为例子。

数据地址:

https://www.kaggle.com/johnfarrell/gpu-example-from-prepared-data-try-deepfm

二、Python代码

1、Python Code

  1. # author: adrian.wu
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5.  
  6. import tensorflow as tf
  7.  
  8. tf.logging.set_verbosity(tf.logging.INFO)
  9. # Set to INFO for tracking training, default is WARN
  10.  
  11. print("Using TensorFlow version %s" % (tf.__version__))
  12.  
  13. CATEGORICAL_COLUMNS = ["workclass", "education",
  14. "marital.status", "occupation",
  15. "relationship", "race",
  16. "sex", "native.country"]
  17.  
  18. # Columns of the input csv file
  19. COLUMNS = ["age", "workclass", "fnlwgt", "education",
  20. "education.num", "marital.status",
  21. "occupation", "relationship", "race",
  22. "sex", "capital.gain", "capital.loss",
  23. "hours.per.week", "native.country", "income"]
  24.  
  25. FEATURE_COLUMNS = ["age", "workclass", "education",
  26. "education.num", "marital.status",
  27. "occupation", "relationship", "race",
  28. "sex", "capital.gain", "capital.loss",
  29. "hours.per.week", "native.country"]
  30.  
  31. import pandas as pd
  32.  
  33. df = pd.read_csv("/Users/adrian.wu/Desktop/learn/kaggle/adult-census-income/adult.csv")
  34.  
  35. from sklearn.model_selection import train_test_split
  36.  
  37. BATCH_SIZE = 40
  38.  
  39. num_epochs = 1
  40. shuffle = True
  41.  
  42. y = df["income"].apply(lambda x: ">50K" in x).astype(int)
  43. del df["fnlwgt"] # Unused column
  44. del df["income"] # Labels column, already saved to labels variable
  45. X = df
  46.  
  47. print(X.describe())
  48.  
  49. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)
  50.  
  51. train_input_fn = tf.estimator.inputs.pandas_input_fn(
  52. x=X_train,
  53. y=y_train,
  54. batch_size=BATCH_SIZE,
  55. num_epochs=num_epochs,
  56. shuffle=shuffle)
  57.  
  58. eval_input_fn = tf.estimator.inputs.pandas_input_fn(
  59. x=X_test,
  60. y=y_test,
  61. batch_size=BATCH_SIZE,
  62. num_epochs=num_epochs,
  63. shuffle=shuffle)
  64.  
  65. def generate_input_fn(filename, num_epochs=None, shuffle=True, batch_size=BATCH_SIZE):
  66. df = pd.read_csv(filename) # , header=None, names=COLUMNS)
  67. labels = df["income"].apply(lambda x: ">50K" in x).astype(int)
  68. del df["fnlwgt"] # Unused column
  69. del df["income"] # Labels column, already saved to labels variable
  70.  
  71. type(df['age'].iloc[3])
  72.  
  73. return tf.estimator.inputs.pandas_input_fn(
  74. x=df,
  75. y=labels,
  76. batch_size=batch_size,
  77. num_epochs=num_epochs,
  78. shuffle=shuffle)
  79.  
  80. sex = tf.feature_column.categorical_column_with_vocabulary_list(
  81. key="sex",
  82. vocabulary_list=["female", "male"])
  83. race = tf.feature_column.categorical_column_with_vocabulary_list(
  84. key="race",
  85. vocabulary_list=["Amer-Indian-Eskimo",
  86. "Asian-Pac-Islander",
  87. "Black", "Other", "White"])
  88.  
  89. # 先对categorical的列做hash
  90. education = tf.feature_column.categorical_column_with_hash_bucket(
  91. "education", hash_bucket_size=1000)
  92. marital_status = tf.feature_column.categorical_column_with_hash_bucket(
  93. "marital.status", hash_bucket_size=100)
  94. relationship = tf.feature_column.categorical_column_with_hash_bucket(
  95. "relationship", hash_bucket_size=100)
  96. workclass = tf.feature_column.categorical_column_with_hash_bucket(
  97. "workclass", hash_bucket_size=100)
  98. occupation = tf.feature_column.categorical_column_with_hash_bucket(
  99. "occupation", hash_bucket_size=1000)
  100. native_country = tf.feature_column.categorical_column_with_hash_bucket(
  101. "native.country", hash_bucket_size=1000)
  102.  
  103. print('Categorical columns configured')
  104.  
  105. age = tf.feature_column.numeric_column("age")
  106. deep_columns = [
  107. # Multi-hot indicator columns for columns with fewer possibilities
  108. tf.feature_column.indicator_column(workclass),
  109. tf.feature_column.indicator_column(marital_status),
  110. tf.feature_column.indicator_column(sex),
  111. tf.feature_column.indicator_column(relationship),
  112. tf.feature_column.indicator_column(race),
  113. # Embeddings for categories with more possibilities. Should have at least (possibilties)**(0.25) dims
  114. tf.feature_column.embedding_column(education, dimension=8),
  115. tf.feature_column.embedding_column(native_country, dimension=8),
  116. tf.feature_column.embedding_column(occupation, dimension=8),
  117. age
  118. ]
  119.  
  120. m2 = tf.estimator.DNNClassifier(
  121. model_dir="model/dir",
  122. feature_columns=deep_columns,
  123. hidden_units=[100, 50])
  124.  
  125. m2.train(input_fn=train_input_fn)
  126.  
  127. start, end = 0, 5
  128. data_predict = df.iloc[start:end]
  129. predict_labels = y.iloc[start:end]
  130. print(predict_labels)
  131. print(data_predict.head(12)) # show this before deleting, so we know what the labels
  132. predict_input_fn = tf.estimator.inputs.pandas_input_fn(
  133. x=data_predict,
  134. batch_size=1,
  135. num_epochs=1,
  136. shuffle=False)
  137.  
  138. predictions = m2.predict(input_fn=predict_input_fn)
  139.  
  140. for prediction in predictions:
  141. print("Predictions: {} with probabilities {}\n".format(prediction["classes"], prediction["probabilities"]))
  142.  
  143. def column_to_dtype(column):
  144. if column in CATEGORICAL_COLUMNS:
  145. return tf.string
  146. else:
  147. return tf.float32
  148.  
  149. # 什么数据要喂给输入
  150. FEATURE_COLUMNS_FOR_SERVE = ["workclass", "education",
  151. "marital.status", "occupation",
  152. "relationship", "race",
  153. "sex", "native.country", "age"]
  154.  
  155. serving_features = {column: tf.placeholder(shape=[1], dtype=column_to_dtype(column), name=column) for column in
  156. FEATURE_COLUMNS_FOR_SERVE}
  157. # serving_input_receiver_fn有很多种方式
  158. export_dir = m2.export_savedmodel(export_dir_base="models/export",
  159. serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn(
  160. serving_features), as_text=True)
  161. export_dir = export_dir.decode("utf8")

2、通过 export_savedmodel这个函数生成了variables变量和pbtxt文件。如图所示:

3、先打开saved_model.pbtxt文件浏览一下,会发现这是对tensorflow 的一个个描述。包含了node name, operation name,dtype等信息。在套用java时需要明确node的name。

  1. node {
  2. name: "dnn/head/predictions/probabilities"
  3. op: "Softmax"
  4. input: "dnn/head/predictions/two_class_logits"
  5. attr {
  6. key: "T"
  7. value {
  8. type: DT_FLOAT
  9. }
  10. }
  11. attr {
  12. key: "_output_shapes"
  13. value {
  14. list {
  15. shape {
  16. dim {
  17. size: -1
  18. }
  19. dim {
  20. size: 2
  21. }
  22. }
  23. }
  24. }

三、Java代码

1、先将variable和pbtxt文件放到resources下面。

2、Java代码

  1. import org.tensorflow.SavedModelBundle;
  2. import org.tensorflow.Session;
  3. import org.tensorflow.Tensor;
  4.  
  5. /**
  6. * Created by adrian.wu on 2019/3/14.
  7. */
  8. public class TestAdultIncome {
  9.  
  10. public static void main(String[] args) throws Exception {
  11.  
  12. SavedModelBundle model = SavedModelBundle.load("/Users/adrian.wu/Desktop/sc/adrian_test/src/main/resources/adultincomemodel", "serve");
  13. Session sess = model.session();
  14.  
  15. String sex = "Female";
  16. String workclass = "?";
  17. String education = "HS-grad";
  18. String ms = "Widowed";
  19. String occupation = "?";
  20. String relationship = "Not-in-family";
  21. String race = "White";
  22. String nc = "United-States";
  23.  
  24. //不能将string直接喂给create()接口
  25. Tensor sexTensor = Tensor.create(new byte[][]{sex.getBytes()});
  26. Tensor workclassTensor = Tensor.create(new byte[][]{workclass.getBytes()});
  27. Tensor eduTensor = Tensor.create(new byte[][]{education.getBytes()});
  28. Tensor msTensor = Tensor.create(new byte[][]{ms.getBytes()});
  29. Tensor occuTensor = Tensor.create(new byte[][]{occupation.getBytes()});
  30. Tensor ralaTensor = Tensor.create(new byte[][]{relationship.getBytes()});
  31. Tensor raceTensor = Tensor.create(new byte[][]{race.getBytes()});
  32. Tensor ncTesnsor = Tensor.create(new byte[][]{nc.getBytes()});
  33.  
  34. float[][] age = {{90f}};
  35.  
  36. Tensor ageTensor = Tensor.create(age);
  37.  
  38. //根据pbtxt文件,查看operation name。
  39. Tensor result = sess.runner()
  40. .feed("workclass", workclassTensor)
  41. .feed("education", eduTensor)
  42. .feed("marital.status", msTensor)
  43. .feed("relationship", ralaTensor)
  44. .feed("race", raceTensor)
  45. .feed("sex", sexTensor)
  46. .feed("native.country", ncTesnsor)
  47. .feed("occupation",occuTensor)
  48. .feed("age", ageTensor)
  49. .fetch("dnn/head/predictions/probabilities")
  50. .run()
  51. .get(0);
  52.  
  53. float[][] buffer = new float[1][2];
  54. result.copyTo(buffer);
  55. System.out.println("" + String.valueOf(buffer[0][0]));
  56. }
  57.  
  58. }

四、结果对比

python和java结果:

  1. java: 0.9432887
  2. python: 0.9432887

  

Tensorflow Python 转 Java(一)的更多相关文章

  1. 谈谈Python、Java与AI

    Python好像天生是为AI而生的,随着AI的火热,特别是用Python写的TensorFlow越来越火,Python的热度越来越高,就像当年Java就是随着互联网火起来的感觉.在我的工作中,Pyth ...

  2. 将来会是Python、Java、Golang三足鼎立的局面吗?

    甲:听说最近java跌落神坛,python称霸武林了,你知道吗? 乙:不是吧,我前几天看python怎么还是第三? 丙:你们都在扯蛋,python在2018年就已经是最好的语言了! 乙:不可能吧? 甲 ...

  3. Golang、Php、Python、Java基于Thrift0.9.1实现跨语言调用

    目录: 一.什么是Thrift? 1) Thrift内部框架一瞥 2) 支持的数据传输格式.数据传输方式和服务模型 3) Thrift IDL 二.Thrift的官方网站在哪里? 三.在哪里下载?需要 ...

  4. paip.判断文件是否存在uapi python php java c#

    paip.判断文件是否存在uapi python php java c# ==========uapi file_exists exists() 面向对象风格:  File.Exists 作者: 老哇 ...

  5. paip.web数据绑定 下拉框的api设计 选择框 uapi python .net java swing jsf总结

    paip.web数据绑定 下拉框的api设计 选择框 uapi  python .net java swing jsf总结 ====总结: 数据绑定下拉框,Uapi 1.最好的是默认绑定..Map(k ...

  6. MongoDB的账户与权限管理及在Python与Java中的登陆

    本文主要介绍了MongoDB的账户新建,权限管理(简单的),以及在Python,Java和默认客户端中的登陆. 默认的MongoDB是没有账户权限管理的,也就是说,不需要密码即可登陆,即可拥有读写的权 ...

  7. tensorflow.python.framework.errors_impl.OutOfRangeError: FIFOQueue

    tensorflow.python.framework.errors_impl.OutOfRangeError: FIFOQueue 原创文章,请勿转载哦~!! 觉得有用的话,欢迎一起讨论相互学习~F ...

  8. [翻译] 比较 Node.js,Python,Java,C# 和 Go 的 AWS Lambda 性能

    [翻译] 比较 Node.js,Python,Java,C# 和 Go 的 AWS Lambda 性能 原文: Comparing AWS Lambda performance of Node.js, ...

  9. Python和Java的硬盘夜话

    这是一个程序员的电脑硬盘,在一个叫做"学习"的目录下曾经生活着两个小程序,一个叫做Hello.java,即Java小子:另外一个叫做hello.c ,也就是C老头儿. C老头儿的命 ...

随机推荐

  1. kafka集群图形界面管理工具kafka-manager

    应用说明: 图形web相对于命令行很多时候显得更直观,kafka-manager是yahoo开源出来的项目,web界面还挺好用,安装更是很便捷. 安装环境: 具体安装: 1. 下载已经编译好的zip包 ...

  2. Arcgis for qml - 鼠标拖拽移动

    以实现鼠标拖拽文本图层为例 GitHub:ArcGIS拖拽文本 作者:狐狸家的鱼 目的是利用鼠标进行拖拽. 实现两种模式,一种是屏幕上的拖拽,第二种是地图上图层的挪动. 屏幕上的拖拽其实跟ArcGIS ...

  3. 洛谷P1712 区间

    题意:给你n个区间,从中选择m个,使得它们有交,且最长与最短区间的差值最小. 解:这道题我想了好多的,nlog²n错的,nlogn错的,最后终于想出nlogn的了...... 把区间按照长度排序,然后 ...

  4. A1006. Sign In and Sign Out

    At the beginning of every day, the first person who signs in the computer room will unlock the door, ...

  5. 【洛谷P2257】YY的GCD

    题目大意:有 \(T\) 个询问,每个询问给定 \(N, M\),求 \(1\le x\le N, 1\le y\le M\) 且 \(gcd(x, y)\) 为质数的 \((x, y)\) 有多少对 ...

  6. 【P2303】Longge的问题

    题目大意:求\[\sum\limits_{i=1}^ngcd(n,i)\] 题解:发现 gcd 中有很多是重复的,因此考虑枚举 gcd. \[\sum\limits_{i=1}^ngcd(n,i)=\ ...

  7. 【POJ2676】sudoku 搜索

    按照每一行每一列去填数,当填到每一行的第9列时,开始填下一行. 代码如下: #include <cstdio> #include <algorithm> #include &l ...

  8. 修改 iis 的端口号: 80 与 443

    来自:https://support.microsoft.com/en-us/help/149605/how-to-change-the-tcp-port-for-iis-services Micro ...

  9. testng学习-before&after,parameters,并行,factory,beanshell,监听器,依赖注入

    一系列的before after的操作测试 [TestNG] Running: C:\Users\user\AppData\Local\Temp\testng-eclipse-1538841744\t ...

  10. mysql 5.7 启动脚本

    最近这段时间,在看mysql,安装了,也应用过,对于生产环境中,一般都选择使用source code安装,在安装的时候可以自定义相关路径和内容,对于生产环境来说更有效.相对于mysql 5.5的安装, ...