数据集为玻森命名实体数据。

目前代码流程跑通了,后续再进行优化。

项目地址:https://github.com/cyandn/practice/tree/master/NER

步骤:

数据预处理:

def data_process():
zh_punctuation = [',', '。', '?', ';', '!', '……'] with open('data/BosonNLP_NER_6C_process.txt', 'w', encoding='utf8') as fw:
with open('data/BosonNLP_NER_6C.txt', encoding='utf8') as fr:
for line in fr.readlines():
line = ''.join(line.split()).replace('\\n', '') # 去除文本中的空字符 i = 0
while i < len(line):
word = line[i] if word in zh_punctuation:
fw.write(word + '/O')
fw.write('\n')
i += 1
continue if word == '{':
i += 2
temp = ''
while line[i] != '}':
temp += line[i]
i += 1
i += 2 type_ne = temp.split(':')
etype = type_ne[0]
entity = type_ne[1]
fw.write(entity[0] + '/B_' + etype + ' ')
for item in entity[1:]:
fw.write(item + '/I_' + etype + ' ')
else:
fw.write(word + '/O ')
i += 1

加载数据:

def load_data(self):
maxlen = 0 with open('data/BosonNLP_NER_6C_process.txt', encoding='utf8') as f:
for line in f.readlines():
word_list = line.strip().split()
one_sample, one_label = zip(
*[word.rsplit('/', 1) for word in word_list])
one_sample_len = len(one_sample)
if one_sample_len > maxlen:
maxlen = one_sample_len
one_sample = ' '.join(one_sample)
one_label = [config.classes[label] for label in one_label]
self.total_sample.append(one_sample)
self.total_label.append(one_label) tok = Tokenizer()
tok.fit_on_texts(self.total_sample)
self.vocabulary = len(tok.word_index) + 1
self.total_sample = tok.texts_to_sequences(self.total_sample) self.total_sample = np.array(pad_sequences(
self.total_sample, maxlen=maxlen, padding='post', truncating='post'))
self.total_label = np.array(pad_sequences(
self.total_label, maxlen=maxlen, padding='post', truncating='post'))[:, :, None] print('total_sample shape:', self.total_sample.shape)
print('total_label shape:', self.total_label.shape) X_train, self.X_test, y_train, self.y_test = train_test_split(
self.total_sample, self.total_label, test_size=config.proportion['test'], random_state=666)
self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
X_train, y_train, test_size=config.proportion['val'], random_state=666) print('X_train shape:', self.X_train.shape)
print('y_train shape:', self.y_train.shape)
print('X_val shape:', self.X_val.shape)
print('y_val shape:', self.y_val.shape)
print('X_test shape:', self.X_test.shape)
print('y_test shape:', self.y_test.shape) del self.total_sample
del self.total_label

构建模型:

def build_model(self):
model = Sequential() model.add(Embedding(self.vocabulary, 100, mask_zero=True))
model.add(Bidirectional(LSTM(64, return_sequences=True)))
model.add(CRF(len(config.classes), sparse_target=True))
model.summary() opt = Adam(lr=config.hyperparameter['learning_rate'])
model.compile(opt, loss=crf_loss, metrics=[crf_viterbi_accuracy]) self.model = model

训练:

def train(self):
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = '{epoch:03d}_{val_crf_viterbi_accuracy:.4f}.h5'
if not os.path.isdir(save_dir):
os.makedirs(save_dir) tensorboard = TensorBoard()
checkpoint = ModelCheckpoint(os.path.join(save_dir, model_name),
monitor='val_crf_viterbi_accuracy',
save_best_only=True)
lr_reduce = ReduceLROnPlateau(
monitor='val_crf_viterbi_accuracy', factor=0.2, patience=10) self.model.fit(self.X_train, self.y_train,
batch_size=config.hyperparameter['batch_size'],
epochs=config.hyperparameter['epochs'],
callbacks=[tensorboard, checkpoint, lr_reduce],
validation_data=[self.X_val, self.y_val])

预测:

def evaluate(self):
best_model_name = sorted(os.listdir('saved_models')).pop()
self.best_model = load_model(os.path.join('saved_models', best_model_name),
custom_objects={'CRF': CRF,
'crf_loss': crf_loss,
'crf_viterbi_accuracy': crf_viterbi_accuracy})
scores = self.best_model.evaluate(self.X_test, self.y_test)
print('test loss:', scores[0])
print('test accuracy:', scores[1])

参考:

https://zhuanlan.zhihu.com/p/44042528

https://blog.csdn.net/buppt/article/details/81180361

https://github.com/stephen-v/zh-NER-keras

http://www.voidcn.com/article/p-pykfinyn-bro.html

NER(BiLSTM+CRF,Keras)的更多相关文章

  1. 百度坐标(BD09)、国测局坐标(火星坐标,GCJ02)、和WGS84坐标系之间的转换(JS版代码)

    /** * Created by Wandergis on 2015/7/8. * 提供了百度坐标(BD09).国测局坐标(火星坐标,GCJ02).和WGS84坐标系之间的转换 */ //定义一些常量 ...

  2. Slider插件(滑动条,拉链)

    Slider插件(滑动条,拉链) 下载地址:http://files.cnblogs.com/elves/Slider.rar 提示:微软AJAX插件中也带此效果!

  3. NGUI系列教程四(自定义Atlas,Font)

    今天我们来看一下怎么自定义NGUIAtlas,制作属于自己风格的UI.第一部分:自定义 Atlas1 . 首先我们要准备一些图标素材,也就是我们的UI素材,将其导入到unity工程中.2. 全选我们需 ...

  4. Java基础知识强化之集合框架笔记60:Map集合之TreeMap(TreeMap<Student,String>)的案例

    1. TreeMap(TreeMap<Student,String>)的案例 2. 案例代码: (1)Student.java: package cn.itcast_04; public ...

  5. Java基础知识强化之集合框架笔记57:Map集合之HashMap集合(HashMap<Student,String>)的案例

    1. HashMap集合(HashMap<Student,String>)的案例 HashMap<Student,String>键:Student      要求:如果两个对象 ...

  6. Java基础知识强化之集合框架笔记56:Map集合之HashMap集合(HashMap<String,Student>)的案例

    1. HashMap集合(HashMap<String,Student>)的案例 HashMap是最常用的Map集合,它的键值对在存储时要根据键的哈希码来确定值放在哪里. HashMap的 ...

  7. Java基础知识强化之集合框架笔记54:Map集合之HashMap集合(HashMap<String,String>)的案例

    1. HashMap集合 HashMap集合(HashMap<String,String>)的案例 2. 代码示例: package cn.itcast_02; import java.u ...

  8. pearl(二分查找,stl)

    最近大概把有关二分的题目都看了一遍... 嗯..这题是二分查找...二分查找的代码都类似,所以打起来会水很多 但是刚开始打二分还是很容易写挂..所以依旧需要注意 题2 天堂的珍珠 [题目描述] 我有很 ...

  9. KMP算法(研究总结,字符串)

    KMP算法(研究总结,字符串) 前段时间学习KMP算法,感觉有些复杂,不过好歹是弄懂啦,简单地记录一下,方便以后自己回忆. 引入 首先我们来看一个例子,现在有两个字符串A和B,问你在A中是否有B,有几 ...

随机推荐

  1. 关于spring中请求返回值的json序列化/反序列化问题

    https://cloud.tencent.com/developer/article/1381083 https://www.jianshu.com/p/db07543ffe0a 先留个坑

  2. 不要在 MySQL 中使用“utf8”,请使用“utf8mb4”

    不要在 MySQL 中使用“utf8”,请使用“utf8mb4” 最近我遇到了一个bug,我试着通过Rails在以“utf8”编码的MariaDB中保存一个UTF-8字符串,然后出现了一个离奇的错误: ...

  3. Cron Expressions——Cron 表达式(QuartZ调度时间配置)

    如果你需要像日历那样按日程来触发任务,而不是像SimpleTrigger 那样每隔特定的间隔时间触发,CronTriggers通常比SimpleTrigger更有用. 使用CronTrigger,你可 ...

  4. h5表单亲测

    Document 下载进度: 标签. 牛奶 面包 男 女 one two three 按钮 搜索 请输入搜索内容 加密强度 用户名 Email 密码 年龄 身高 生日 这一系列是很酷的一个类型,完全解 ...

  5. MySQL使用alter修改表的结构

    SQL语句     DLL        数据定义语言         create,drop     DML     数据操纵语言         insert,delete,select,upda ...

  6. Linux服务管理之ntp

    NTP是网络时间协议(Network Time Protocol),它是用来同步网络中各个计算机的时间的协议. 在计算机的世界里,时间非常地重要,例如对于火箭发射这种科研活动,对时间的统一性和准确性要 ...

  7. js的函数三角恋

    原创,转载请标明来源https://www.cnblogs.com/sogeisetsu/ js的函数三角恋 1.什么是构造函数 是专门用于创建对象的 对象就是object **** 1.什么是函数? ...

  8. django cookie,session,auth

    一.最完美的auth auth_user 是用来存储的用户注册的username,password auth 首先需要引入模块 from django.contrib import auth 用户认证 ...

  9. Mybatis-plus中如何排除非表字段的三种方式

    1.transient关键字 2.使用静态变量(static) 3.TableField(exit=false) 这三种方式可以在使用的过程中,是这个对象中的属性不被序列化.(直接被忽略)

  10. AcWing 38. 二叉树的镜像

    习题地址 https://www.acwing.com/solution/acwing/content/2922/ 题目描述输入一个二叉树,将它变换为它的镜像. 样例 输入树: / \ / \ / \ ...