# 注意和前一或二篇Lenet训练并验证的文章从`y_conv = tf.nn.softmax(fc2)`起的不同
# 部分函数请参照前后2篇文章
import tensorflow as tf
import tfrecords2array
import numpy as np
from keras.utils import to_categorical
import matplotlib.pyplot as plt
import cv2
from collections import OrderedDict def lenet(char_classes): # characters_reference
recall_rate = OrderedDict().fromkeys([
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z',
'藏', '川', '鄂', '甘', '赣', '广', '桂', '贵', '黑',
'沪', '吉', '冀', '津', '晋', '京', '辽', '鲁', '蒙',
'闽', '宁', '青', '琼', '陕', '苏', '皖', '湘', '新',
'渝', '豫', '粤', '云', '浙'
])
for i in recall_rate.keys():
recall_rate[i] = 1
class_count = recall_rate.copy()
# y_train = []
# x_train = []
y_test = []
x_test = []
for char_class in char_classes:
# train_data = tfrecords2array.tfrecord2array(
# r"./data_tfrecords/" + char_class + "_tfrecords/train.tfrecords")
test_data = tfrecords2array.tfrecord2array(
r"./data_tfrecords/" + char_class + "_tfrecords/test.tfrecords")
# y_train.append(train_data[0])
# x_train.append(train_data[1])
y_test.append(test_data[0])
x_test.append(test_data[1])
for i in [y_test, x_test]: # y_train, x_train,
for j in i:
print(j.shape)
# y_train = np.vstack(y_train)
# x_train = np.vstack(x_train)
y_test = np.vstack(y_test)
x_test = np.vstack(x_test) class_num = y_test.shape[-1] # print("x_train.shape=" + str(x_train.shape))
print("x_test.shape=" + str(x_test.shape))
sess = tf.InteractiveSession() x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, class_num])
# 把x更改为4维张量,第1维代表样本数量,第2维和第3维代表图像长宽, 第4维代表图像通道数, 1表示黑白
x_image = tf.reshape(x, [-1, 28, 28, 1]) # 第一层:卷积层
conv1_weights = tf.get_variable(
"conv1_weights",
[5, 5, 1, 32],
initializer=tf.truncated_normal_initializer(stddev=0.1))
# 过滤器大小为5*5, 当前层深度为1, 过滤器的深度为32
conv1_biases = tf.get_variable("conv1_biases", [32],
initializer=tf.constant_initializer(0.0))
conv1 = tf.nn.conv2d(x_image, conv1_weights, strides=[1, 1, 1, 1],
padding='SAME') relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases)) # 激活函数Relu去线性化 pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') conv2_weights = tf.get_variable(
"conv2_weights",
[5, 5, 32, 64],
initializer=tf.truncated_normal_initializer(stddev=0.1)) conv2_biases = tf.get_variable(
"conv2_biases", [64], initializer=tf.constant_initializer(0.0))
conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1],
padding='SAME') relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases)) pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') fc1_weights = tf.get_variable("fc1_weights", [7 * 7 * 64, 1024],
initializer=tf.truncated_normal_initializer(
stddev=0.1)) fc1_biases = tf.get_variable(
"fc1_biases", [1024], initializer=tf.constant_initializer(0.1))
pool2_vector = tf.reshape(pool2, [-1, 7 * 7 * 64])
fc1 = tf.nn.relu(tf.matmul(pool2_vector, fc1_weights) + fc1_biases) # dropout
keep_prob = tf.placeholder(tf.float32)
fc1_dropout = tf.nn.dropout(fc1, keep_prob) fc2_weights = tf.get_variable("fc2_weights", [1024, class_num],
initializer=tf.truncated_normal_initializer(
stddev=0.1)) fc2_biases = tf.get_variable(
"fc2_biases", [class_num], initializer=tf.constant_initializer(0.1))
fc2 = tf.matmul(fc1_dropout, fc2_weights) + fc2_biases # softmax
y_conv = tf.nn.softmax(fc2)
pred_class_index = tf.argmax(y_conv, 1) # tf.argmax()返回的是某一维度上其数据最大所在的索引值,在这里即代表预测值和真实值
# 判断预测值y和真实值y_中最大数的索引是否一致,y的值为1-class_num概率
correct_prediction = tf.equal(pred_class_index, tf.argmax(y_, 1)) # 用平均值来统计测试准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 开始训练
saver = tf.train.Saver()
# sess.run(tf.global_variables_initializer())
saver.restore(sess, './my_model/model.ckpt')
# pred_value = sess.run([pred_class_index], feed_dict={
# x: x_test, y_: y_test, keep_prob: 1.0
# })
# print("pred_value=" + str(pred_value))
# acc_test = sess.run(accuracy, feed_dict={
# x: x_test, y_: y_test, keep_prob: 1.0
# })
#
batch_size_test = 64
epoch_test = y_test.shape[0] // batch_size_test + 1
acc_test = 0
class_sums = []
for i in range(epoch_test):
if (i*batch_size_test % x_test.shape[0]) > (((i+1)*batch_size_test) %
x_test.shape[0]):
x_data_test = np.vstack((
x_test[i*batch_size_test % x_test.shape[0]:],
x_test[:(i+1)*batch_size_test % x_test.shape[0]]))
y_data_test = np.vstack((
y_test[i*batch_size_test % y_test.shape[0]:],
y_test[:(i+1)*batch_size_test % y_test.shape[0]]))
else:
x_data_test = x_test[
i*batch_size_test % x_test.shape[0]:
(i+1)*batch_size_test % x_test.shape[0]]
y_data_test = y_test[
i*batch_size_test % y_test.shape[0]:
(i+1)*batch_size_test % y_test.shape[0]]
# plt.imshow(x_data_test[0].reshape(28, 28), cmap="gray")
# plt.show()
# Calculate batch loss and accuracy
pred_value = to_categorical(np.squeeze(
sess.run([pred_class_index], feed_dict={
x: x_data_test, y_: y_data_test, keep_prob: 1.0})), 68)
# print("{}-th pred_value={}".format(i, pred_value))
# print("{}-th y_data_test={}".format(i, y_data_test))
# print("\nCover:")
# print("pred_value:", pred_value)
# print("y_data_test:", y_data_test)
# input()
recall_sum = np.sum(cv2.bitwise_and(pred_value, y_data_test), axis=0)
class_sum = np.sum(y_data_test, axis=0)
class_sums.append(class_sum)
# print(recall_sum)
# input()
for idx in range(len(recall_sum)):
recall_rate[str(list(recall_rate.keys())[idx])] += recall_sum[idx]
class_count[str(list(class_count.keys())[idx])] += class_sum[idx]
# print(recall_rate)
c = accuracy.eval(feed_dict={
x: x_data_test, y_: y_data_test, keep_prob: 1.0})
acc_test += c / epoch_test
for i in list(recall_rate.keys()):
recall_rate[i] /= class_count[i]
print("recall_rate:\n", recall_rate)
print("class_count:\n", class_count)
print("class_sums:", np.sum(np.array(class_sums), axis=0))
print("Restored acc_test={}".format(acc_test))
return recall_rate def main():
# integers: 4679
# alphabets: 9796
# Chinese_letters: 3974
# training_set : testing_set == 4 : 1
train_lst = ['alphabets', 'integers']
recall_rate = lenet(train_lst)
recall_rate_values = recall_rate.values()
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(list(recall_rate_values), list(range(len(recall_rate_values))),
'^')
ax.hlines(list(range(len(recall_rate_values))), [0], recall_rate_values,
lw=2)
ax.set_xlabel('Recall rate')
ax.set_ylabel('Idx of elem')
ax.set_title('Statistics on Recall Rates')
plt.show() if __name__ == '__main__':
main()

TensorFlow+restore读取模型的更多相关文章

  1. TensorFlow学习笔记:保存和读取模型

    TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...

  2. tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  3. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  4. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  5. TensorFlow 训练好模型参数的保存和恢复代码

    TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...

  6. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  7. 《Entity Framework 6 Recipes》中文翻译系列 (38) ------ 第七章 使用对象服务之动态创建连接字符串和从数据库读取模型

    翻译的初衷以及为什么选择<Entity Framework 6 Recipes>来学习,请看本系列开篇 第七章 使用对象服务 本章篇幅适中,对真实应用中的常见问题提供了切实可行的解决方案. ...

  8. FaceRank-人脸打分基于 TensorFlow 的 CNN 模型

    FaceRank-人脸打分基于 TensorFlow 的 CNN 模型 隐私 因为隐私问题,训练图片集并不提供,稍微可能会放一些卡通图片. 数据集 130张 128*128 张网络图片,图片名: 1- ...

  9. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

随机推荐

  1. Java中的深浅拷贝问题,你清楚吗?

    一.前言 拷贝这个词想必大家都很熟悉,在工作中经常需要拷贝一份文件作为副本.拷贝的好处也很明显,相较于新建来说,可以节省很大的工作量.在Java中,同样存在拷贝这个概念,拷贝的意义也是可以节省创建对象 ...

  2. 微信小程序腾讯地图SDK使用方法

    一.本篇文章主要知识点有以下几种: 1.授权当前位置 2.map组件的使用 3.腾讯地图逆地址解析 4.坐标系的转化 二.效果如下: 三.WXML代码 <map id="map&quo ...

  3. 开心!再也不用担心 IntelliJ IDEA 试用过期了

    背景 前段时间 Review 团队小伙伴代码,发现当他把鼠标挪到一个方法上时,就自动显示了该方法的所有注释信息,像下图这样,他和我用的 IDE 都是 IntelliJ IDEA. 而我还按古老的方式, ...

  4. jQuery 页面滚动 吸顶 和 吸底

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...

  5. java中List元素移除元素的那些坑

    https://blog.csdn.net/javageektech/article/details/96668890  List  的迭代器类 采用倒序移除 jdk1.8的写法 public sta ...

  6. based on Greenlets (via Eventlet and Gevent) fork 孙子worker 比较 gevent不是异步 协程原理 占位符 placeholder (Future, Promise, Deferred) 循环引擎 greenlet 没有显式调度的微线程,换言之 协程

    gevent GitHub - gevent/gevent: Coroutine-based concurrency library for Python https://github.com/gev ...

  7. mysql中int型的数字怎么转换成字符串

    字段:number  是integer类型    在表test中 select cast(number as char) as number from test; 或者convert()方法.因为转换 ...

  8. HarmonyOS三方件开发指南(7)——compress组件

    目录:1. 组件compress功能介绍2. 组件compress使用方法3. 组件compress开发实现 1. 组件compress功能介绍1.1.  组件介绍:        compress是 ...

  9. cocos2d-x 调试问题

    1.昨天一个新功能,在xcode模拟器上测试没问题.后来打包安卓后,一直有问题 就又添加日志功能 #   define CCLOGFUNC(s)                             ...

  10. fedora 20安装vim Transaction check error

    Transaction check error安装时 yum remove vim-minimal 再安装vim ok