数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

chinese_write_detection.py

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image
from log_utils import get_logger logger = get_logger("HandWritten Practice")
root_path = 'D:/eclipse-workspace/sxzsb'
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', 'D:/eclipse-workspace/sxzsb/checkpoint', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', 'D:/eclipse-workspace/sxzsb/data/train', 'the train dataset dir(containing png files)')
tf.app.flags.DEFINE_string('test_data_dir', 'D:/eclipse-workspace/sxzsb/data/test', 'the test dataset dir(containing png files)')
tf.app.flags.DEFINE_string('log_dir', 'D:/eclipse-workspace/sxzsb/log', 'the logging path)') tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('epoch', 1, 'Number of epoches')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
FLAGS = tf.app.flags.FLAGS class DataIterator: def __init__(self, data_dir):
# Set FLAGS.charset_size to a small value if available computation power is limited.
truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
print(truncate_path)
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
if root < truncate_path: # some problem here ,because the first root is contain inside ,and there is no file_list
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] # int("00020") output:20 @property
def size(self): # @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
return len(self.labels) @staticmethod
def data_augmentation(images):
if FLAGS.random_flip_up_down:
images = tf.image.random_flip_up_down(images)
if FLAGS.random_brightness:
images = tf.image.random_brightness(images, max_delta=0.3)
if FLAGS.random_contrast:
images = tf.image.random_contrast(images, 0.8, 1.2)
return images def input_pipeline(self, batch_size, num_epochs=None, aug=False):
# 1、convert images to a tensor 构造数据queue
images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
# 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
# 2、 ## queue输出数据
labels = input_queue[1]
images_content = tf.read_file(input_queue[0]) # read images from the queue,refer to input_queue
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
if aug:
images = self.data_augmentation(images)
new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
images = tf.image.resize_images(images, new_size)
# collect batches of images before processing
# 3、shuffle_batch批量从queu批量读取数据
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
min_after_dequeue=10000) # produce shunffled batch
return image_batch, label_batch def build_graph(top_k):
# with tf.device('/cpu:0'):
keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') flatten = slim.flatten(max_pool_3)
fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) # global_step interesting sharing varialbes
rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) # train_op 包含了训练数据
probabilities = tf.nn.softmax(logits) # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出 tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) # 这个思路真是清奇!!!看来我回答对了 # return the operator
return {'images': images,
'labels': labels,
'keep_prob': keep_prob,
'top_k': top_k,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'accuracy': accuracy,
'accuracy_top_k': accuracy_in_top_k,
'merged_summary_op': merged_summary_op,
'predicted_distribution': probabilities,
'predicted_index_top_k': predicted_index_top_k,
'predicted_val_top_k': predicted_val_top_k} def train():
print('Begin training')
train_feeder = DataIterator(data_dir='../data/train/')
test_feeder = DataIterator(data_dir='../data/test/')
with tf.Session() as sess:
# session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True) # num_epochs what's refer to ?
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
graph = build_graph(top_k=1) # very important
sess.run(tf.global_variables_initializer())
# 4、 ## 启动queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
start_step = 0
if FLAGS.restore: # 这里是加载保存好的模型,的到step继续训练
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')
try:
while not coord.should_stop(): ###----
start_time = time.time()
print(start_time)
train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
print(len(train_images_batch))
feed_dict = {graph['images']: train_images_batch,
graph['labels']: train_labels_batch,
graph['keep_prob']: 0.8} # keep 80% connection
_, loss_val, train_summary, step = sess.run(
[graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
feed_dict=feed_dict)
train_writer.add_summary(train_summary, step)
end_time = time.time()
logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > FLAGS.max_steps:
break
if step % FLAGS.eval_steps == 1:
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
accuracy_test, test_summary = sess.run(
[graph['accuracy'], graph['merged_summary_op']],
feed_dict=feed_dict) # 这里的多层括号问题
test_writer.add_summary(test_summary, step)
logger.info('===============Eval a batch=======================')
logger.info('the step {0} test accuracy: {1}'
.format(step, accuracy_test))
logger.info('===============Eval a batch=======================')
if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
global_step=graph['global_step'])
except tf.errors.OutOfRangeError:
logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
finally:
coord.request_stop() # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
coord.join(threads) def validation():
print('validation')
test_feeder = DataIterator(data_dir='../data/test/') final_predict_val = []
final_predict_index = []
groundtruth = [] with tf.Session() as sess:
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
graph = build_graph(3) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt is not None:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt)) logger.info(':::Start validation:::')
try:
i = 0
acc_top_1, acc_top_k = 0.0, 0.0
while not coord.should_stop():
i += 1
start_time = time.time()
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
graph['predicted_val_top_k'],
graph['predicted_index_top_k'],
graph['accuracy'],
graph['accuracy_top_k']], feed_dict=feed_dict)
final_predict_val += probs.tolist()
final_predict_index += indices.tolist()
groundtruth += batch_labels.tolist()
acc_top_1 += acc_1
acc_top_k += acc_k
end_time = time.time()
logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
.format(i, end_time - start_time, acc_1, acc_k)) except tf.errors.OutOfRangeError:
logger.info('==================Validation Finished================')
acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
finally:
coord.request_stop()
coord.join(threads)
return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} def inference(image):
print('inference')
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
temp_image = temp_image.reshape([-1, 64, 64, 1])
with tf.Session() as sess:
logger.info('========start inference============')
# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
# Pass a shadow label 0. This label will not affect the computation graph.
graph = build_graph(top_k=3)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
return predict_val, predict_index def main(_):
print(FLAGS.mode)
if FLAGS.mode == "train":
train()
elif FLAGS.mode == 'validation':
dct = validation() # thinking what is "dct"
result_file = 'result.dict'
logger.info('Write result into {0}'.format(result_file))
with open(result_file, 'wb') as f:
pickle.dump(dct, f)
logger.info('Write file ends')
elif FLAGS.mode == 'inference':
image_path = '../data/test/00159/75700.png'
final_predict_val, final_predict_index = inference(image_path) # figure out what is inference
logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
final_predict_val)) if __name__ == "__main__":
tf.app.run() # It's just a very quick wrapper that handles flag parsing and then dispatches to your own main.

log_utils.py

# -*- coding:utf-8 -*-
import os, os.path as osp
import time def strftime(t=None):
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time())) #################
# Logging
#################
import logging
from logging.handlers import TimedRotatingFileHandler
logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s") DEFAULT_LEVEL = logging.INFO
DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
fh = None def init_fh():
global fh
if fh is not None:
return
if DEFAULT_LOGGING_DIR is None:
return
if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
fh = logging.FileHandler(logging_path)
fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")) def update_default_level(defalut_level):
global DEFAULT_LEVEL
DEFAULT_LEVEL = defalut_level def update_default_logging_dir(default_logging_dir):
global DEFAULT_LOGGING_DIR
DEFAULT_LOGGING_DIR = default_logging_dir def get_logger(name="HandWrittenPractice", level=None):
level = level or DEFAULT_LEVEL
logger = logging.getLogger(name)
logger.setLevel(level)
init_fh()
if fh is not None:
logger.addHandler(fh)
return logger

Train

python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

Validation

python chinese_write_detection.py --mode=validation

Inference

python chinese_write_detection.py --mode=inference

tensorflow创建cnn网络进行中文手写文字识别的更多相关文章

  1. Atitit s2018.2 s2 doc list on home ntpc.docx  \Atiitt uke制度体系 法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别 讯飞科大 语音云.docx \Atitit 代码托管与虚拟主机.docx \Atitit 企业文化 每日心灵 鸡汤 值班 发布.docx \Atitit 几大研发体系对比 Stage-Gat

    Atitit s2018.2 s2 doc list on home ntpc.docx \Atiitt uke制度体系  法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别   ...

  2. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  3. Tensorflow项目实战一:MNIST手写数字识别

    此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...

  4. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  5. TensorFlow(十):卷积神经网络实现手写数字识别以及可视化

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  6. Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. 安装phpredis扩展以及phpRedisAdmin工具

    先从phpredis的git拿到最新的源码包:wget https://github.com/nicolasff/phpredis/archive/master.tar.gz 然后解压到进入目录:ta ...

  2. RMQ Fanout

    原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/11795256.html RMQ Fanout Project Directory Maven Depe ...

  3. LeetCode--055--跳跃游戏(java)

    给定一个非负整数数组,你最初位于数组的第一个位置. 数组中的每个元素代表你在该位置可以跳跃的最大长度. 判断你是否能够到达最后一个位置. 示例 1: 输入: [2,3,1,1,4] 输出: true ...

  4. 3 Base64编码主要应用在那些场合?

    ,电子邮件数据也好,经常要用到Base64编码,那么为什么要作一下这样的编码呢? 我们知道在计算机中任何数据都是按ascii码存储的,而ascii码的128-255之间的值是不可见字符.而在网络上交换 ...

  5. __getattr__属性查找

    from datetime import date """ __getattr__ : 在查找不到对象的属性时调用 __getattribute__ : 在查找属性之前调 ...

  6. Python3解leetcode Min Cost Climbing Stairs

    问题描述: On a staircase, the i-th step has some non-negative cost cost[i]assigned (0 indexed). Once you ...

  7. 怎么测试php代码

    没有任何一名程序员可以一气呵成.完美无缺的在不用调试的情况下完成一个功能或模块.调试实际分很多种情况. 暴力调试 这种方式简单粗暴,一般PHP程序员都会用,那就是浏览器调试,在编辑器内写完代码后随后打 ...

  8. 基于mpvue搭建小程序项目框架

    简介: mpvue框架对于从没有接触过小程序又要尝试小程序开发的人员来说,无疑是目前最好的选择.mpvue从底层支持 Vue.js 语法和构建工具体系,同时再结合相关UI组件库,便可以高效的实现小程序 ...

  9. 实现word在线预览 有php的写法 也有插件似

    <?php //header("Content-type:text/html;charset=utf-8"); //word转html 展示 $lj=$_GET['file' ...

  10. Gym 100917F Find the Length

    题目链接:http://codeforces.com/gym/100917/problem/F ---------------------------------------------------- ...