(第三章)TF框架之实现验证码识别
这里实现一个用神经网络(卷积神经网络也可以)实现验证码识别的小案例,主要记录本人做这个案例的流程,不会像之前那么详细,主要用作个人记录用。。。
- 这里是验证码的四个字母,被one-hot编码后形成的四个一维数组,[1, 26] * 4 ----> 可以转变成[4, 26] ----> [1, 104]
第一个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0]
第二个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]
第三个位置:[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
第四个位置:[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
字母验证码识别设计:
这两个(真实值和预测值)104的一阶张量进行交叉熵损失计算,得出损失大小。会提高四个位置的概率,使得4组中每组26个目标值中为1的位置对应的预测概率值越来越大,在预测的四组当中概率值最大。这样得出预测中每组的字母位置。所有104个概率相加为1
流程设计
1、把图片的特征值和目标值,-----> 转换成tfrecords格式,方便数据特征值、目标值统一读取
[b'NZPP' b'WKHK' b'WPSJ' ..., b'FVQJ' b'BQYA' b'BCHR'] -----> [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10]]
"ABCD……Z" —>"0, 1, …, 25"
2、训练验证码、准确率的计算
将原来的图片数据(特征)和csv数据(标签)------> 转变为tfrecords格式的数据,注意example协议(序列化后)
代码如下:
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件")
tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径")
tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类") def dealwithlabel(label_str): # 构建字符索引 {0:'A', 1:'B'......}
num_letter = dict(enumerate(list(FLAGS.letter))) # 键值对反转 {'A':0, 'B':1......}
letter_num = dict(zip(num_letter.values(), num_letter.keys())) print(letter_num) # 构建标签的列表
array = [] # 给标签数据进行处理[[b"NZPP"], ......]
for string in label_str: letter_list = []# [1,2,3,4] # 修改编码,b'FVQJ'到字符串,并且循环找到每张验证码的字符对应的数字标记
for letter in string.decode('utf-8'):
letter_list.append(letter_num[letter]) array.append(letter_list) # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10], [1, 0, 8, 17], [0, 9, 24, 14].....]
print(array) # 将array转换成tensor类型
label = tf.constant(array) return label def get_captcha_image():
"""
获取验证码图片数据
:param file_list: 路径+文件名列表
:return: image
"""
# 构造文件名
filename = [] for i in range(6000):
string = str(i) + ".jpg"
filename.append(string) # 构造路径+文件
file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename] # 构造文件队列
file_queue = tf.train.string_input_producer(file_list, shuffle=False) # 构造阅读器
reader = tf.WholeFileReader() # 读取图片数据内容
key, value = reader.read(file_queue) # 解码图片数据
image = tf.image.decode_jpeg(value) image.set_shape([20, 80, 3]) # 批处理数据 [6000, 20, 80, 3]
image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000) return image_batch def get_captcha_label():
"""
读取验证码图片标签数据
:return: label
"""
file_queue = tf.train.string_input_producer(["../data/Genpics/labels.csv"], shuffle=False) reader = tf.TextLineReader() key, value = reader.read(file_queue) records = [[1], ["None"]] number, label = tf.decode_csv(value, record_defaults=records) # [["NZPP"], ["WKHK"], ["ASDY"]]
label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000) return label_batch def write_to_tfrecords(image_batch, label_batch):
"""
将图片内容和标签写入到tfrecords文件当中
:param image_batch: 特征值
:param label_batch: 标签值
:return: None
"""
# 转换类型
label_batch = tf.cast(label_batch, tf.uint8) print(label_batch) # 建立TFRecords 存储器
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir) # 循环将每一个图片上的数据构造example协议块,序列化后写入
for i in range(6000):
# 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式
image_string = image_batch[i].eval().tostring() # 标签值,转换成整型
label_string = label_batch[i].eval().tostring() # 构造协议块
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
})) writer.write(example.SerializeToString()) # 关闭文件
writer.close() return None if __name__ == "__main__": # 获取验证码文件当中的图片
image_batch = get_captcha_image() # 获取验证码文件当中的标签数据
label = get_captcha_label() print(image_batch, label) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # [b'NZPP' b'WKHK' b'WPSJ' ..., b'FVQJ' b'BQYA' b'BCHR']
label_str = sess.run(label) print(label_str) # 处理字符串标签到数字张量
label_batch = dealwithlabel(label_str) print(label_batch) # 将图片数据和内容写入到tfrecords文件当中
write_to_tfrecords(image_batch, label_batch) coord.request_stop() coord.join(threads)
- 训练验证码,得到准确率的代码
import tensorflow as tf class CaptchaIdentification(object):
"""
验证码的读取数据、网络训练
"""
def __init__(self): # 验证码图片的属性
self.height = 20
self.width = 80
self.channel = 3
# 每个验证码的目标值个数(4个字符)
self.label_num = 4
self.feature_num = 26 # 每批次训练样本个数
self.train_batch = 100 @staticmethod
def weight_variables(shape):
w = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=0.1))
return w @staticmethod
def bias_variables(shape):
b = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=0.1))
return b def read_captcha_tfrecords(self):
"""
读取验证码特征值和目标值数据
:return:
"""
# 1、构造文件的队列
file_queue = tf.train.string_input_producer(["./tfrecords/captcha.tfrecords"]) # 2、tf.TFRecordReader 读取TFRecords数据
reader = tf.TFRecordReader() # 单个样本数据
key, value = reader.read(file_queue) # 3、解析example协议
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.string)
}) # 4、解码操作、数据类型、形状
image = tf.decode_raw(feature["image"], tf.uint8)
label = tf.decode_raw(feature["label"], tf.uint8) # 确定类型和形状
# 图片形状 [20, 80, 3]
# 目标值 [4]
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
label_reshape = tf.reshape(label, [self.label_num]) # 类型
image_type = tf.cast(image_reshape, tf.float32)
label_type = tf.cast(label_reshape, tf.int32) # 5、 批处理
# print(image_type, label_type)
# 提供每批次多少样本去进行训练
image_batch, label_batch = tf.train.batch([image_type, label_type],
batch_size=self.train_batch,
num_threads=1,
capacity=self.train_batch)
print(image_batch, label_batch)
return image_batch, label_batch def captcha_model(self, image_batch):
"""
建立全连接层网络
:param image_batch: 验证码图片特征值
:return: 预测结果
"""
# 全连接层
# [100, 20, 80, 3] --->[100, 20 * 80 * 3]
# [100, 20 * 80 * 3] * [20 * 80 * 3, 104] + [104] = [None, 104] 104 = 4*26
with tf.variable_scope("captcha_fc_model"):
# 初始化权重和偏置参数
self.weight = self.weight_variables([20 * 80 * 3, 104]) self.bias = self.bias_variables([104]) # 4维---->2维做矩阵运算
x_reshape = tf.reshape(image_batch, [self.train_batch, 20 * 80 * 3]) # [self.train_batch, 104]
y_predict = tf.matmul(x_reshape, self.weight) + self.bias return y_predict def loss(self, y_true, y_predict):
"""
建立验证码4个目标值的损失
:param y_true: 真实值
:param y_predict: 预测值
:return: loss
"""
with tf.variable_scope("loss"):
# 先进行网络输出的值的概率计算softmax,在进行交叉熵损失计算
# y_true:[100, 4, 26]------>[None, 104]
# y_predict:[100, 104]
y_reshape = tf.reshape(y_true,
[self.train_batch, self.label_num * self.feature_num]) all_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_reshape,
logits=y_predict,
name="compute_loss")
# 求出平均损失
loss = tf.reduce_mean(all_loss) return loss def turn_to_onehot(self, label_batch):
"""
目标值转换成one_hot编码
:param label_batch: 目标值 [None, 4]
:return:
"""
with tf.variable_scope("one_hot"): # [None, 4]--->[None, 4, 26]
y_true = tf.one_hot(label_batch,
depth=self.feature_num,
on_value=1.0)
return y_true def sgd(self, loss):
"""
梯度下降优化损失
:param loss:
:return: train_op
"""
with tf.variable_scope("sgd"): train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) return train_op def acc(self, y_true, y_predict):
"""
计算准确率
:param y_true: 真实值
:param y_predict: 预测值
:return: accuracy
"""
with tf.variable_scope("acc"): # y_true:[None, 4, 26]
# y_predict:[None, 104]
y_predict_reshape = tf.reshape(y_predict, [self.train_batch, self.label_num, self.feature_num]) # 先对最大值的位置去求解 这里的2指的是维度
euqal_list = tf.equal(tf.argmax(y_true, 2), tf.argmax(y_predict_reshape, 2)) # 需要对每个样本进行判断 这里的1指的是维度
# x = tf.constant([[True, True], [False, False]])
# tf.reduce_all(x, 1) # [True, False]
accuracy = tf.reduce_mean(tf.cast(tf.reduce_all(euqal_list, 1), tf.float32)) return accuracy def train(self):
"""
模型训练逻辑
:return:
"""
# 1、通过接口获取特征值和目标值
# image_batch:[100, 20, 80, 3]
# label_batch: [100, 4]
# [[13, 25, 15, 15], [22, 10, 7, 10]]
image_batch, label_batch = self.read_captcha_tfrecords() # 2、建立验证码识别的模型
# 全连接层神经网络
# y_predict [100, 104]
y_predict = self.captcha_model(image_batch) # 转换label_batch 到one_hot编码
# y_true:[None, 4, 26]
y_true = self.turn_to_onehot(label_batch) # 3、利用真实值和目标值建立损失
loss = self.loss(y_true, y_predict) # 4、对损失进行梯度下降优化
train_op = self.sgd(loss) # 5、计算准确率
accuracy = self.acc(y_true, y_predict) # 会话训练
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 生成线程的管理
coord = tf.train.Coordinator() # 指定开启子线程去读取数据
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 循环训练打印结果
for i in range(1000): _, acc_run = sess.run([train_op, accuracy]) print("第 %d 次训练的准确率为:%f " % (i, acc_run)) # 回收线程
coord.request_stop() coord.join(threads) return None if __name__ == '__main__':
ci = CaptchaIdentification()
ci.train()
(第三章)TF框架之实现验证码识别的更多相关文章
- jQuery系列 第三章 jQuery框架操作CSS
第三章 jQuery框架操作CSS 3.1 jQuery框架的CSS方法 jQuery框架提供了css方法,我们通过调用该方法传递对应的参数,可以方便的来批量设置标签的CSS样式. 使用JavaScr ...
- NancyFX 第三章 Web框架
如果使用Nancy作为一个WEB框架而言,会有什么不同?实际上很多. 在使用Nancy框架为网页添加Rest节点和路由和之前的Rest框架中是相同的,这方面没有什么需要学习的了.Nancy采用一贯的处 ...
- 第三章、drf框架 - 序列化组件 | Serializer
目录 第三章.drf框架 - 序列化组件 | Serializer 序列化组件 知识点:Serializer(偏底层).ModelSerializer(重点).ListModelSerializer( ...
- ASP.NET Core 中文文档 第三章 原理(17)为你的服务器选择合适版本的.NET框架
原文:Choosing the Right .NET For You on the Server 作者:Daniel Roth 翻译:王健 校对:谢炀(Kiler).何镇汐.许登洋(Seay).孟帅洋 ...
- 《Django By Example》第三章 中文 翻译 (个人学习,渣翻)
书籍出处:https://www.packtpub.com/web-development/django-example 原作者:Antonio Melé (译者注:第三章滚烫出炉,大家请不要吐槽文中 ...
- 《Entity Framework 6 Recipes》中文翻译系列 (11) -----第三章 查询之异步查询
翻译的初衷以及为什么选择<Entity Framework 6 Recipes>来学习,请看本系列开篇 第三章 查询 前一章,我们展示了常见数据库场景的建模方式,本章将向你展示如何查询实体 ...
- 《Entity Framework 6 Recipes》中文翻译系列 (19) -----第三章 查询之使用位操作和多属性连接(join)
翻译的初衷以及为什么选择<Entity Framework 6 Recipes>来学习,请看本系列开篇 3-16 过滤中使用位操作 问题 你想在查询的过滤条件中使用位操作. 解决方案 假 ...
- 第19章 集合框架(3)-Map接口
第19章 集合框架(3)-Map接口 1.Map接口概述 Map是一种映射关系,那么什么是映射关系呢? 映射的数学解释 设A,B是两个非空集合,如果存在一个法则,使得对A中的每一个元素a,按法则f,在 ...
- 第18章 集合框架(2)-Set接口
第18章 集合框架(2)-Set接口 Set是Collection子接口,模拟了数学上的集的概念 Set集合存储特点 1.不允许元素重复 2.不会记录元素的先后添加顺序 Set只包含从Collecti ...
随机推荐
- NextCloud + python API
NextCloud库地址:https://github.com/matejak/nextcloud-API 安装库依赖: 安装库: 建议在虚拟环境下使用 使用示例: # -*- coding: utf ...
- React Transition css动画案例解析
实现React Transition Css动画效果 首先在项目工程中引入react-transition-group: npm install react-transition-group --sa ...
- js修改css
转载请注明来源:https://www.cnblogs.com/hookjc/ <style type="text/css"> .style{font-size:9pt ...
- 生成静态库.a文件和动态库.so文件
转载来源:https://www.cnblogs.com/hookjc/ 静态库 在linux环境中, 使用ar命令创建静态库文件.如下是命令的选项: d -----从指定的静态库文件中删除文件 m ...
- webpack热更新 同时导出文件到本地
webpack 配置热更新后,文件配置导出到本地 安装 npm i webpack-dev-server-output --save-dev 引入 const WebpackDevServerOutp ...
- zookeeper集群+kafka集群 部署
zookeeper集群 +kafka 集群部署 1.Zookeeper 概述: Zookeeper 定义 zookeeper是一个开源的分布式的,为分布式框架提供协调服务的Apache项目 Zooke ...
- 让我一时不知所措 Linux 常用命令 爱情三部曲 下部
Linux目录与文件管理 我试着把你忘记,可总在夜里想你~ 1.linux目录结构 2.查看及检索文件 3.压缩及解压缩文件 4.vi文本编辑器 1.Linux目录结构:树形目录结构根目录:所有分区, ...
- Content-Type: multipart/form-data;文件上传利用
当我们找到一个文件上传接口时,发现他的MIME类型检测为Content-Type: multipart/form-data;时,我们就可以尝试下面几种方法来绕过限制. ---------------- ...
- Elementui【tooltip】 在弹框关闭之后再次‘出现’的问题
如图,第一次弹窗进来的时候,符合条件之后,这个提示文字的位置是对的,而且正常显示: 现在点击取消按钮,第二次进入弹窗,如图,提示文字就跑到了左上角,而且输入符合条件的数值之后,会另外显示一个toolt ...
- 聊一聊DTM子事务屏障功能之SQL Server版
背景 前面写了两篇如何用 C# 基于 DTM 轻松实现 SAGA 和 TCC 的分布式事务,其中有一个子事务屏障的功能,很好的处理了空补偿.悬挂.重复请求等异常问题. https://dtm.pub/ ...