### train_model.py ###

#!/usr/bin/env python
# coding=utf-8 import codecs
import simplejson as json
import numpy as np
import pandas as pd
from keras.models import Sequential, load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.preprocessing import sequence
from keras.utils import to_categorical
from keras.layers import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.externals import joblib
import logging
import re
import pickle as pkl logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s: %(message)s', datefmt='%Y-%m-%d %H:%M', filename='log/train_model.log', filemode='a+') ngram_range = 1
max_features = 6500
maxlen = 120 fw = open('error_line_test.txt', 'wb') DIRTY_LABEL = re.compile('\W+')
# set([u'业务',u'代销',u'施工',u'策划',u'设计',u'销售',u'除外',u'零售',u'食品'])
STOP_WORDS = pkl.load(open('./data/stopwords.pkl')) def load_data(fname='data/12315_industry_business_train.csv', nrows=None):
"""
载入训练数据
"""
data, labels = [], []
char2idx = json.load(open('data/char2idx.json'))
used_keys = set(['name', 'business'])
df = pd.read_csv(fname, encoding='utf-8', nrows=nrows)
for idx, item in df.iterrows():
item = item.to_dict()
line = ''
for key, value in item.iteritems():
if key in used_keys:
line += key+value data.append([char2idx[char] for char in line if char in char2idx])
labels.append(item['label']) le = LabelEncoder()
logging.info('%d nb_class: %s' % (len(np.unique(labels)), str(np.unique(labels))))
onehot_label = to_categorical(le.fit_transform(labels))
joblib.dump(le, 'model/tgind_labelencoder.h5')
x_train, x_test, y_train, y_test = train_test_split(data, onehot_label, test_size=0.1)
return (x_train, y_train), (x_test, y_test) def create_ngram_set(input_list, ngram_value=2):
return set(zip(*[input_list[i:] for i in range(ngram_value)])) def add_ngram(sequences, token_indice, ngram_range=2):
"""
Augment the input list of sequences by appending n-grams values """
new_sequences = []
for input_list in sequences:
new_list = input_list[:]
for i in range(len(new_list) - ngram_range + 1):
for ngram_value in range(2, ngram_range+1):
ngram = tuple(new_list[i:i+ngram_value])
if ngram in token_indice:
new_list.append(token_indice[ngram])
new_sequences.append(new_list) return new_sequences (x_train, y_train), (x_test, y_test) = load_data()
nb_class = y_train.shape[1] logging.info('x_train size: %d' % (len(x_train)))
logging.info('x_test size: %d' % (len(x_test)))
logging.info('x_train sent average len: %.2f' % (np.mean(list(map(len, x_train)))))
print 'x_train sent avg length: %.2f' % (np.mean(list(map(len, x_train)))) if ngram_range>1:
print 'add {}-gram features'.format(ngram_range)
ngram_set = set()
for input_list in x_train:
for i in range(2, ngram_range+1):
set_of_ngram = create_ngram_set(input_list, ngram_value=i)
ngram_set.update(set_of_ngram) start_index = max_features + 1
token_indice = {v: k+start_index for k,v in enumerate(ngram_set)}
indice_token = {token_indice[k]: k for k in token_indice} max_features = np.max(list(indice_token.keys()))+1 x_train = add_ngram(x_train, token_indice, ngram_range)
x_test = add_ngram(x_test, token_indice, ngram_range) print 'pad sequences (samples x time)'
x_train = sequence.pad_sequences(x_train, maxlen=maxlen, padding='post', truncating='post')
x_test = sequence.pad_sequences(x_test, maxlen=maxlen, padding='post', truncating='post') logging.info('x_train.shape: %s' % (str(x_train.shape))) print 'build model...' def cal_accuracy(x_test, y_test):
"""
准确率统计
"""
y_test = np.argmax(y_test, axis=1)
y_pred = model.predict_classes(x_test)
correct_cnt = np.sum(y_pred==y_test)
return float(correct_cnt)/len(y_test) DEBUG = False
if DEBUG:
model = Sequential()
model.add(Embedding(max_features, 200, input_length=maxlen))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.3))
model.add(Dense(nb_class, activation='softmax'))
else:
model = load_model('./model/tgind_dalei.h5') #model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
earlystop = EarlyStopping(monitor='val_loss', patience=8)
checkpoint = ModelCheckpoint(filepath='./model/tgind_dalei.h5', monitor='val_loss', save_best_only=True, save_weights_only=False) model.fit(x_train, y_train, shuffle=True, batch_size=64, epochs=80, validation_split=0.1, callbacks=[checkpoint, earlystop]) loss, acc = model.evaluate(x_test, y_test)
print '\n\nlast model: loss', loss
print 'acc', acc model = load_model('model/tgind_dalei.h5')
loss, acc = model.evaluate(x_test, y_test)
print '\n\n cur best model: loss', loss
print 'accuracy', acc
logging.info('loss: %.4f ;accuracy: %.4f' % (loss, acc)) logging.info('\nmodel acc: %.4f' % acc)
logging.info('\nmodel config:\n %s' % model.get_config())

### test_model.py ###

#!/usr/bin/env python
# coding=utf-8 import matplotlib.pyplot as plt
from api_tgind import TgIndustry
import pandas as pd
import codecs
import json
from collections import OrderedDict ########### 根据阈值计算准确率 ########### def cal_model_acc(model, fname='./data/industry_dalei_test_sample2k.txt', nrows=None):
"""
载入数据, 并计算前5的准确率
"""
res = {}
res['y_pred'] = []
res['y_true'] = []
with codecs.open(fname, encoding='utf-8') as fr:
for idx, line in enumerate(fr):
tokens = line.strip().split()
if len(tokens)>3:
tokens, label = tokens[:-1], tokens[-1].replace('__labe__', '')
tmp = {}
tmp['business'] = ''.join(tokens)
res['y_pred'].append(model.predict(tmp))
res['y_true'].append(label)
if nrows and idx>nrows:
break
json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
return res def cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv', nrows=None):
"""
直接根据csv预测结果
"""
res = {}
res['y_pred'] = []
res['y_true'] = []
df = pd.read_csv(fname, encoding='utf-8')
for idx, item in df.iterrows():
try:
res['y_pred'].append(model.predict(item.to_dict()))
except Exception as e:
print e
print idx
print item['name']
continue
res['y_true'].append(item['label']) if nrows and idx>nrows:
break
json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
return res def get_model_acc_menlei(res, topk=5, threhold=0.8):
"""
根据阈值计算模型准确率
"""
correct_cnt, total_cnt = 0, 0
for idx, y_pred in enumerate(res['y_pred']):
y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # 概率排序
y_pred = OrderedDict()
for c, s in y_pred_tuple:
y_pred[c] = float(s) if y_pred.values()[0] > threhold: # 最大类别概率大于阈值threhold
if res['y_true'][idx][0] in map(lambda x:x[0], y_pred.keys()[:topk]):
correct_cnt += 1
total_cnt += 1
acc = float(correct_cnt)/total_cnt
recall = float(total_cnt)/len(res['y_true'])
return acc, recall def get_model_acc_dalei(res, topk=5, threhold=0.8):
"""
根据阈值计算模型准确率
"""
correct_cnt, total_cnt = 0, 0
for idx, y_pred in enumerate(res['y_pred']):
y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # 概率排序
y_pred = OrderedDict()
for c, s in y_pred_tuple:
y_pred[c] = float(s) if y_pred.values()[0] >= threhold: # 最大类别概率大于阈值threhold
if res['y_true'][idx] in y_pred.keys()[:topk]:
correct_cnt += 1
total_cnt += 1 acc = float(correct_cnt)/total_cnt
recall = float(total_cnt)/len(res['y_true'])
return acc, recall def plot_accuracy(title, df, number):
"""
准确率绘图
"""
for topk in range(1, 5):
tmpdf = df[df.topk==topk]
fig = plt.figure()
ax1 = fig.add_subplot(111)
plt.subplots_adjust(top=0.85)
ax1.plot(tmpdf['threhold'], tmpdf['accuracy'], 'ro-', label='accuracy')
# ax2 = ax1.twinx()
ax1.plot(tmpdf['threhold'], tmpdf['recall'], 'g^-', label='recall')
ax1.set_ylim(0.3, 1.0)
ax1.legend(loc=3)
ax1.set_xlabel('threhold')
plt.grid(True)
plt.title('%s Industry Classify Result\n topk=%d, number=%d\n' % (title, topk, number))
plt.savefig('log/test_%s_acc_topk%d.png' % (title, topk))
print topk, 'done!' def gen_plot_data(model_acc, ctype='2nd'):
"""
生成图数据
"""
res = {}
res['accuracy'] = []
res['threhold'] = []
res['topk'] = []
res['recall'] = []
for topk in range(1,5):
for threhold in range(0, 10):
threhold = 0.1*threhold
if ctype == '1st':
acc, recall = get_model_acc_menlei(model_acc, topk, threhold)
else:
acc, recall = get_model_acc_dalei(model_acc, topk, threhold)
res['accuracy'].append(acc)
res['recall'].append(recall)
res['threhold'].append(threhold)
res['topk'].append(topk)
print ctype, topk, acc
json.dump(res, open('log/test_model_threshold_%s.log' % ctype, 'wb'))
df = pd.DataFrame(res)
df.to_csv('log/test_model_result_%s.csv' % ctype, index=False)
plot_accuracy(ctype, df, len(model_acc['y_true']))
return df if __name__=='__main__': model = TgIndustry()
# model_acc = cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv')
model_acc = json.load(codecs.open('log/total_acc_output_12315.json', encoding='utf-8'))
gen_plot_data(model_acc, '1st')
gen_plot_data(model_acc, '2nd')

### api_tgind.py ###

#!/usr/bin/env python
# coding=utf-8 import numpy as np
import codecs
import simplejson as json
from keras.models import load_model
from keras.preprocessing import sequence
from sklearn.externals import joblib
from collections import OrderedDict
import pickle as pkl
import re, os
import jieba
import time """
行业分类调用Api __author__: jkmiao
__date__: 2017-07-05 """ class TgIndustry(object): def __init__(self, model_path='model/tgind_dalei_acc76.h5'): base_path = os.path.dirname(__file__)
model_path = os.path.join(base_path, model_path) # 载入预训练好的模型
self.model = load_model(model_path)
# 载入labelEncoder
self.le = joblib.load(os.path.join(base_path, './model/tgind_labelencoder.h5'))
# 载入字符映射表
self.char2idx = json.load(open(os.path.join(base_path, 'data/char2idx.json')))
# 载入停用词表
# self.stop_words = set([line.strip() for line in codecs.open('./data/stopwords.txt', encoding='utf-8')])
self.stop_words = pkl.load(open(os.path.join(base_path, './data/stopwords.pkl')))
# 载入类别最终的编号和名称映射
self.menlei_label2name = json.load(open(os.path.join(base_path, 'data/menlei_label2name.json'))) # 一级分类
self.dalei_label2name = json.load(open(os.path.join(base_path, 'data/dalei_label2name.json'))) # 二级分类 def predict(self, company_info, topk=2, firstIndustry=False, final_name=False):
"""
:type company_info: 公司相关信息
:rtype business: str: 对应 label
"""
line = ''
for key, value in company_info.iteritems():
if key in ['name', 'business']: # 公司信息, 目前取公司名和经营范围
line += company_info[key] if not isinstance(line, unicode):
line = line.decode('utf-8') # 去除停用词后的句子
line = ''.join([token for token in jieba.cut(line) if token not in self.stop_words])
data = [self.char2idx[char] for char in line if char in self.char2idx]
data = sequence.pad_sequences([data], maxlen=100, padding='post', truncating='post')
y_pred_proba = self.model.predict(data, verbose=0)
y_pred_idx_list = [c[-topk:][::-1] for c in np.argsort(y_pred_proba, axis=-1)][0]
res = OrderedDict()
for y_pred_idx in y_pred_idx_list:
y_pred_label = self.le.inverse_transform(y_pred_idx)
if final_name:
y_pred_label = self.dalei_label2name[y_pred_label]
if firstIndustry:
res[y_pred_label[0]] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数
res[y_pred_label] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数
return res if __name__ == '__main__': DIRTY_LABEL = re.compile('\W+')
test = TgIndustry()
cnt, total_cnt = 0, 0
start_time = time.time()
fw2 = codecs.open('./output/industry_dalei_test_sample2k_error.txt', 'wb', encoding='utf-8')
with codecs.open('./data/industry_dalei_test_sample2k.txt', encoding='utf-8') as fr:
for idx, line in enumerate(fr):
tokens = line.strip().split()
if len(tokens)>3:
tokens, label = tokens[:-1], tokens[-1].replace('__label__', '')
if len(label) not in [2, 3] or DIRTY_LABEL.search(label):
print 'error line:'
print idx, line, label
continue
tmp = {}
tmp['business'] = ''.join(tokens)
y_pred = test.predict(tmp, topk=1)
if label in y_pred:
cnt += 1
elif y_pred.values()[0] < 0.3:
print 'error: ', ''.join(tokens), y_pred, 'y_true:', label
fw2.write(''.join(tokens))
total_cnt +=1
print label
print json.dumps(y_pred, ensure_ascii=False)
print idx, '=='*20, float(cnt)/total_cnt
if idx>200:
break
print 'avg cost time:', float(time.time()-start_time)/idx

基于keras的fasttext短文本分类的更多相关文章

  1. 基于keras中IMDB的文本分类 demo

      本次demo主题是使用keras对IMDB影评进行文本分类: import tensorflow as tf from tensorflow import keras import numpy a ...

  2. [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)

    [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...

  3. (转!)利用Keras实现图像分类与颜色分类

    2018-07-19 全部谷歌渣翻加略微修改 大家将就的看哈 建议大佬们还是看看原文 点击收获原文 其中用到的示例文件 multi-output-classification 大家可以点击 下载 . ...

  4. 基于Text-CNN模型的中文文本分类实战 流川枫 发表于AI星球订阅

    Text-CNN 1.文本分类 转眼学生生涯就结束了,在家待就业期间正好有一段空闲期,可以对曾经感兴趣的一些知识点进行总结. 本文介绍NLP中文本分类任务中核心流程进行了系统的介绍,文末给出一个基于T ...

  5. 基于Text-CNN模型的中文文本分类实战

    Text-CNN 1.文本分类 转眼学生生涯就结束了,在家待就业期间正好有一段空闲期,可以对曾经感兴趣的一些知识点进行总结. 本文介绍NLP中文本分类任务中核心流程进行了系统的介绍,文末给出一个基于T ...

  6. 基于keras实现的中文实体识别

    1.简介 NER(Named Entity Recognition,命名实体识别)又称作专名识别,是自然语言处理中常见的一项任务,使用的范围非常广.命名实体通常指的是文本中具有特别意义或者指代性非常强 ...

  7. 万字总结Keras深度学习中文文本分类

    摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

  8. 一款基于jQuery仿淘宝红色分类导航

    今天给大家分享一款基于jQuery仿淘宝红色分类导航.这款分类导航适用浏览器:IE8.360.FireFox.Chrome.Safari.Opera.傲游.搜狗.世界之窗.效果图如下: 在线预览    ...

  9. Chinese-Text-Classification,用卷积神经网络基于 Tensorflow 实现的中文文本分类。

    用卷积神经网络基于 Tensorflow 实现的中文文本分类 项目地址: https://github.com/fendouai/Chinese-Text-Classification 欢迎提问:ht ...

随机推荐

  1. FileReader字符的读出

    1.fileReader是字符的读出,只能读文件. 2.在读取文件的之前,该文件必须存在. 3.int reader();一次读取一个字符,返回的该字符的码值,如果想要返回字符,直接进行强转char ...

  2. java基础---->Serializable的使用

    本次讲解中我们建立一个Java的项目去体会一下序列化Serializable的使用,序列化的原理以及序列化的自定义请参见我的另外一篇博客(java高级---->Serializable序列化的源 ...

  3. redis 版的 hello world

    为 redis 添加一个命令,效果如下图: 在 Server.h 中加入命令处理函数的声明: void meCommand(client *c); 在 Server.c 的命令表中加入: struct ...

  4. kali-rolling安装使用sonarqube教程(docker方式)

    一.说明 最近要找一款代码审计工具,Fortify SCA太贵,VisualCodeGrepper不太好用.在freebuf上看到可用sonarqube来建代码自动化扫描系统所以也来试一试. 直接安装 ...

  5. easyui 如何为标签动态追加属性实现渲染效果

    简述一下在项目遇到的问题,这边有一个需求,选择不同类型,加载不同的div标签(其中属性是否必填是区分类型的关键) html界面是这样的 <div class="grid_1 lbl&q ...

  6. Win10下安装zio

    0x00 报错:capstone.dll缺失,就算用pip安装也不行. 推荐:kali下安装pwn,pwntools,zio

  7. Python测试框架之Unittest梳理

    1. 2.

  8. 20175224 2018-2019-2 《Java程序设计》第二周学习总结

    教材学习内容总结 本周对教材的第二第三章进行了学习,通过阅读教材,我发现java和c语言在相似的基础上还是有很多不同的地方,以下是我对这周学习知识的一些总结. 2.1 java标识符中的字母是区分大小 ...

  9. [Data Structure] Tree - relative

    Segment Tree First, try to build the segment tree. lintcode suggest code: Currently recursion recomm ...

  10. thinkphp获取后台所有控制器和action

    <?phpnamespace Admin\Controller;use Think\Controller;class AuthorController extends PublicControl ...