数据集为玻森命名实体数据。

目前代码流程跑通了,后续再进行优化。

项目地址:https://github.com/cyandn/practice/tree/master/NER

步骤:

数据预处理:

def data_process():
zh_punctuation = [',', '。', '?', ';', '!', '……'] with open('data/BosonNLP_NER_6C_process.txt', 'w', encoding='utf8') as fw:
with open('data/BosonNLP_NER_6C.txt', encoding='utf8') as fr:
for line in fr.readlines():
line = ''.join(line.split()).replace('\\n', '') # 去除文本中的空字符 i = 0
while i < len(line):
word = line[i] if word in zh_punctuation:
fw.write(word + '/O')
fw.write('\n')
i += 1
continue if word == '{':
i += 2
temp = ''
while line[i] != '}':
temp += line[i]
i += 1
i += 2 type_ne = temp.split(':')
etype = type_ne[0]
entity = type_ne[1]
fw.write(entity[0] + '/B_' + etype + ' ')
for item in entity[1:]:
fw.write(item + '/I_' + etype + ' ')
else:
fw.write(word + '/O ')
i += 1

加载数据:

def load_data(self):
maxlen = 0 with open('data/BosonNLP_NER_6C_process.txt', encoding='utf8') as f:
for line in f.readlines():
word_list = line.strip().split()
one_sample, one_label = zip(
*[word.rsplit('/', 1) for word in word_list])
one_sample_len = len(one_sample)
if one_sample_len > maxlen:
maxlen = one_sample_len
one_sample = ' '.join(one_sample)
one_label = [config.classes[label] for label in one_label]
self.total_sample.append(one_sample)
self.total_label.append(one_label) tok = Tokenizer()
tok.fit_on_texts(self.total_sample)
self.vocabulary = len(tok.word_index) + 1
self.total_sample = tok.texts_to_sequences(self.total_sample) self.total_sample = np.array(pad_sequences(
self.total_sample, maxlen=maxlen, padding='post', truncating='post'))
self.total_label = np.array(pad_sequences(
self.total_label, maxlen=maxlen, padding='post', truncating='post'))[:, :, None] print('total_sample shape:', self.total_sample.shape)
print('total_label shape:', self.total_label.shape) X_train, self.X_test, y_train, self.y_test = train_test_split(
self.total_sample, self.total_label, test_size=config.proportion['test'], random_state=666)
self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
X_train, y_train, test_size=config.proportion['val'], random_state=666) print('X_train shape:', self.X_train.shape)
print('y_train shape:', self.y_train.shape)
print('X_val shape:', self.X_val.shape)
print('y_val shape:', self.y_val.shape)
print('X_test shape:', self.X_test.shape)
print('y_test shape:', self.y_test.shape) del self.total_sample
del self.total_label

构建模型:

def build_model(self):
model = Sequential() model.add(Embedding(self.vocabulary, 100, mask_zero=True))
model.add(Bidirectional(LSTM(64, return_sequences=True)))
model.add(CRF(len(config.classes), sparse_target=True))
model.summary() opt = Adam(lr=config.hyperparameter['learning_rate'])
model.compile(opt, loss=crf_loss, metrics=[crf_viterbi_accuracy]) self.model = model

训练:

def train(self):
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = '{epoch:03d}_{val_crf_viterbi_accuracy:.4f}.h5'
if not os.path.isdir(save_dir):
os.makedirs(save_dir) tensorboard = TensorBoard()
checkpoint = ModelCheckpoint(os.path.join(save_dir, model_name),
monitor='val_crf_viterbi_accuracy',
save_best_only=True)
lr_reduce = ReduceLROnPlateau(
monitor='val_crf_viterbi_accuracy', factor=0.2, patience=10) self.model.fit(self.X_train, self.y_train,
batch_size=config.hyperparameter['batch_size'],
epochs=config.hyperparameter['epochs'],
callbacks=[tensorboard, checkpoint, lr_reduce],
validation_data=[self.X_val, self.y_val])

预测:

def evaluate(self):
best_model_name = sorted(os.listdir('saved_models')).pop()
self.best_model = load_model(os.path.join('saved_models', best_model_name),
custom_objects={'CRF': CRF,
'crf_loss': crf_loss,
'crf_viterbi_accuracy': crf_viterbi_accuracy})
scores = self.best_model.evaluate(self.X_test, self.y_test)
print('test loss:', scores[0])
print('test accuracy:', scores[1])

参考:

https://zhuanlan.zhihu.com/p/44042528

https://blog.csdn.net/buppt/article/details/81180361

https://github.com/stephen-v/zh-NER-keras

http://www.voidcn.com/article/p-pykfinyn-bro.html

NER(BiLSTM+CRF,Keras)的更多相关文章

  1. 百度坐标(BD09)、国测局坐标(火星坐标,GCJ02)、和WGS84坐标系之间的转换(JS版代码)

    /** * Created by Wandergis on 2015/7/8. * 提供了百度坐标(BD09).国测局坐标(火星坐标,GCJ02).和WGS84坐标系之间的转换 */ //定义一些常量 ...

  2. Slider插件(滑动条,拉链)

    Slider插件(滑动条,拉链) 下载地址:http://files.cnblogs.com/elves/Slider.rar 提示:微软AJAX插件中也带此效果!

  3. NGUI系列教程四(自定义Atlas,Font)

    今天我们来看一下怎么自定义NGUIAtlas,制作属于自己风格的UI.第一部分:自定义 Atlas1 . 首先我们要准备一些图标素材,也就是我们的UI素材,将其导入到unity工程中.2. 全选我们需 ...

  4. Java基础知识强化之集合框架笔记60:Map集合之TreeMap(TreeMap<Student,String>)的案例

    1. TreeMap(TreeMap<Student,String>)的案例 2. 案例代码: (1)Student.java: package cn.itcast_04; public ...

  5. Java基础知识强化之集合框架笔记57:Map集合之HashMap集合(HashMap<Student,String>)的案例

    1. HashMap集合(HashMap<Student,String>)的案例 HashMap<Student,String>键:Student      要求:如果两个对象 ...

  6. Java基础知识强化之集合框架笔记56:Map集合之HashMap集合(HashMap<String,Student>)的案例

    1. HashMap集合(HashMap<String,Student>)的案例 HashMap是最常用的Map集合,它的键值对在存储时要根据键的哈希码来确定值放在哪里. HashMap的 ...

  7. Java基础知识强化之集合框架笔记54:Map集合之HashMap集合(HashMap<String,String>)的案例

    1. HashMap集合 HashMap集合(HashMap<String,String>)的案例 2. 代码示例: package cn.itcast_02; import java.u ...

  8. pearl(二分查找,stl)

    最近大概把有关二分的题目都看了一遍... 嗯..这题是二分查找...二分查找的代码都类似,所以打起来会水很多 但是刚开始打二分还是很容易写挂..所以依旧需要注意 题2 天堂的珍珠 [题目描述] 我有很 ...

  9. KMP算法(研究总结,字符串)

    KMP算法(研究总结,字符串) 前段时间学习KMP算法,感觉有些复杂,不过好歹是弄懂啦,简单地记录一下,方便以后自己回忆. 引入 首先我们来看一个例子,现在有两个字符串A和B,问你在A中是否有B,有几 ...

随机推荐

  1. Bloom’S Taxonomy

    引用:https://www.learning-theories.com/blooms-taxonomy-bloom.html Bloom's Taxonomy is a model that is ...

  2. javascript中常用函数

    1.js 获取文件后缀名 <script type="text/javascript"> var filename="www/data/index.php&q ...

  3. CRM产品主数据在行业解决方案industry solution中的应用

    AG3, choose this role: Create a new Acquisition Contracts: Here our product advances search will be ...

  4. Spring面试专题之aop

    1.背景 aop是编程中非常非常重要的一种思想,在spring项目中用的场景也非常广 2.面试问题 2.1.简单的面试问题 1.什么是aop,aop的作用是什么? 面向切面编程(AOP)提供另外一种角 ...

  5. 简明conda使用指南

    目录 区分conda, anaconda, miniconda conda版本 虚拟环境 分享环境 查看某个环境的位置 列出软件包 安装软件包 删除软件包 查找软件包 conda配置 conda实践: ...

  6. Codeforces G. The Brand New Function(枚举)

    题目描述: The Brand New Function time limit per test 2 seconds memory limit per test 256 megabytes input ...

  7. Python3 if 变量variable SQL where 语句拼接

    最近在写python3的项目,在实际中运用到了根据 if 判断变量variable ,然后去拼接where子句.但是在百度.BING搜索中未找到合适的答案,这是自己想出来的典型php写法,这里做一下记 ...

  8. 富文本编辑器 KindEditor 的基本使用 文件上传 图片上传

    富文本编辑器 KindEditor 富文本编辑器,Rich Text Editor , 简称 RTE , 它提供类似于 Microsoft Word 的编辑功能. 常用的富文本编辑器: KindEdi ...

  9. 开源许可证GPL、BSD、MIT、Mozilla、Apache和LGPL 介绍

    原文地址

  10. USACO Poker Hands

    洛谷 P3078 [USACO13MAR]扑克牌型Poker Hands 题目传送门 JDOJ 2359: USACO 2013 Mar Silver 1.Poker Hands JDOJ传送门 题目 ...