条件随机场之CRF++源码详解-特征
我在学习条件随机场的时候经常有这样的疑问,crf预测当前节点label如何利用其他节点的信息、crf的训练样本与其他的分类器有什么不同、crf的公式中特征函数是什么以及这些特征函数是如何表示的。在这一章中,我将在CRF++源码中寻找答案。
输入过程
CRF++训练的入口在crf_learn.cpp文件的main函数中,在该函数中调用了encoder.cpp的crfpp_learn(int argc, char **argv)函数。在CRF++中,训练被称为encoder,显然预测就称为decoder。crfpp_learn的源码如下:
int crfpp_learn(int argc, char **argv) {
CRFPP::Param param; //存放输入的参数
param.open(argc, argv, CRFPP::long_options); //处理命令行输入的参数,存在param对象中
return CRFPP::crfpp_learn(param);
}
Param对象主要存放输入的参数,调用open方法处理命令行输入的参数并存储。最后调用crfpp_learn(const Param ¶m)函数,在该函数中将初始化Encoder对象encoder,并调用encoder的learn方法。
样本的处理以及特征的构造
本章的重点便是这个learn方法,该方法主要是根据输入的样本和特征模板构造特征。阅读该函数源码之前可以去CRF++官网了解一下CRF++输入的参数,以及模板文件和训练文件的格式。
bool Encoder::learn(const char *templfile, //模板文件
const char *trainfile, //训练样本
const char *modelfile, //模型输出文件
bool textmodelfile,
size_t maxitr,
size_t freq,
double eta,
double C,
unsigned short thread_num,
unsigned short shrinking_size,
int algorithm) {
std::cout << COPYRIGHT << std::endl; CHECK_FALSE(eta > 0.0) << "eta must be > 0.0"; //CHECK_FALSE是宏定义,如果传入的条件是false,则输出异常信息
CHECK_FALSE(C >= 0.0) << "C must be >= 0.0";
CHECK_FALSE(shrinking_size >= ) << "shrinking-size must be >= 1";
CHECK_FALSE(thread_num > ) << "thread must be > 0"; #ifndef CRFPP_USE_THREAD
CHECK_FALSE(thread_num == )
<< "This architecture doesn't support multi-thrading";
#endif if (algorithm == MIRA && thread_num > ) {//MIRAS算法无法启用多线程
std::cerr << "MIRA doesn't support multi-thrading. use thread_num=1"
<< std::endl;
} EncoderFeatureIndex feature_index; //所有的特征将存储在feature_index中
Allocator allocator(thread_num); //allocator对象主要用来做资源分配以及回收
std::vector<TaggerImpl* > x; //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子 std::cout.setf(std::ios::fixed, std::ios::floatfield);
std::cout.precision(); #define WHAT_ERROR(msg) do { \
for (std::vector<TaggerImpl *>::iterator it = x.begin(); \
it != x.end(); ++it) \
delete *it; \
std::cerr << msg << std::endl; \
return false; } while () CHECK_FALSE(feature_index.open(templfile, trainfile)) //打开“模板文件”和“训练文件”
<< feature_index.what(); {
progress_timer pg; std::ifstream ifs(WPATH(trainfile));
CHECK_FALSE(ifs) << "cannot open: " << trainfile; std::cout << "reading training data: " << std::flush;
size_t line = ;
while (ifs) { //开始读取训练样本
TaggerImpl *_x = new TaggerImpl(); //_x存放的是一句话的内容,CRF++官网中提到,用一个空白行将每个sentence隔开
_x->open(&feature_index, &allocator); //做一些属性赋值,所有的句子都对应相同的feature_index和allocator对象
if (!_x->read(&ifs) || !_x->shrink()) {
WHAT_ERROR(_x->what());
} if (!_x->empty()) {
x.push_back(_x);
} else {
delete _x;
continue;
} _x->set_thread_id(line % thread_num); //每个句子都会分配一个线程id,可以多线程并发处理不同的句子 if (++line % == ) {
std::cout << line << ".. " << std::flush;
}
} ifs.close();
std::cout << "\nDone!";
} feature_index.shrink(freq, &allocator); // 根据训练是指定的-f参数,将特征出现的频率小于freq的过滤掉 std::vector <double> alpha(feature_index.size()); // feature_index.size()返回的是maxid_,即:特征函数的个数,alpha是每个特征函数的权重,便是CRF中要学习的参数
std::fill(alpha.begin(), alpha.end(), 0.0);
feature_index.set_alpha(&alpha[]); std::cout << "Number of sentences: " << x.size() << std::endl;
std::cout << "Number of features: " << feature_index.size() << std::endl;
std::cout << "Number of thread(s): " << thread_num << std::endl;
std::cout << "Freq: " << freq << std::endl;
std::cout << "eta: " << eta << std::endl;
std::cout << "C: " << C << std::endl;
std::cout << "shrinking size: " << shrinking_size
<< std::endl; ... //省略后续代码
95 }
我阅读源码是按照深度优先遍历的方式,遇到一个函数会不断地深入进去,直到理解了该函数的功能再返回。上述源码需要重点介绍的部分,我也按照深度优先的方式记录。对于比较容易理解的部分则直接在源码中添加注释。首先看下源码第43行feature_index.open(templfile, trainfile),表面是理解是打开模板文件和训练集文件,但具体做了什么事儿呢,进入这个函数发现分别调用了两个函数。一个是EncoderFeatureIndex::openTemplate(const char *filename),这个函数主要是读取模板文件中的unigram特征和bigram特征分别存储,从官网文章中也可以知道,crf的特征分为两种特征,unigram对应的是状态特征,bigram对应的是转移特征。另一个函数是EncoderFeatureIndex::openTagSet(const char *filename),该函数读取训练集文件,获得训练集特征的数量(feature_index.xsize_属性)以及训练集中label的集合(feature_index.y_属性),以后可以用集合中label值得的索引代替label。
learn函数的第57行,有两个函数调用。一个是_x->read(&ifs),这个函数是对输入的样本做处理。解释该函数之前,我先做一个约定,以词性标注为例。我们输入的训练样本每一行代表一个词,每一列代表词的特征,多个词(多行)代表一个句子,句子与句子之间用空白行分隔。这个规则从CRF++文档中也能看出,我们就统一用句子和词表示,方便表达。那么,该函数会读取一个句子。经过层层调用,会对_x对象中几个重要的数据结构进行赋值,由于这个函数的处理逻辑不复杂,因此我直接给出最终赋值的结果。如下:
class TaggerImpl : public Tagger {
FeatureIndex *feature_index_;
Allocator *allocator_;
std::vector<std::vector <const char *> > x_; //代表一个句子,外部vector代表多行(多个词),内部vector代表每行的多列,具体的列用char*表示
std::vector<std::vector <Node *> > node_; //相当于二位数组,node_[i][j]表示一个节点,即:第i个词是第j个label的点。如:“我”这个词是“代词”
std::vector<unsigned short int> answer_; //每个词对应的label
std::vector<unsigned short int> result_;
};
另一个调用是_x->shrink(),该函数的主要功能就是构造特征,具体来说是调用了feature_index的FeatureIndex::buildFeatures(TaggerImpl *tagger)方法,源码如下:
#define ADD { const int id = getID(os.c_str()); \
if (id != -1) feature.push_back(id); } while (0)
bool FeatureIndex::buildFeatures(TaggerImpl *tagger) const {
string_buffer os;
std::vector<int> feature; FeatureCache *feature_cache = tagger->allocator()->feature_cache(); //存放是每个节点或者边对应的特征向量,节点便是node[i][j],边的概念后续会接触,暂时可以忽略
tagger->set_feature_id(feature_cache->size()); //做个标记,以后要取该句子的特征,可以从该id的位置取 for (size_t cur = ; cur < tagger->size(); ++cur) {//遍历每个词,计算每个词的特征
for (std::vector<std::string>::const_iterator it
= unigram_templs_.begin();
it != unigram_templs_.end(); ++it) { //遍历每个unigram特征
if (!applyRule(&os, it->c_str(), cur, *tagger)) {applyRule函数根据当前词(cur)以及当前的特征(如: %x[-2,0]),生成一个特征,存放在os中
return false;
}
ADD; //将根据特征os,获取该特征的id,如果不存在该特征,生成新的id,将该id添加到feature变量中
}
feature_cache->add(feature); //将该词的特征添加到feature_cache中,add方法会将feature拷贝一份并将最后添加-1,方便后续读取
feature.clear();
} for (size_t cur = ; cur < tagger->size(); ++cur) {//遍历每条边,计算每条边的特征
for (std::vector<std::string>::const_iterator
it = bigram_templs_.begin();
it != bigram_templs_.end(); ++it) {//遍历每个bigram特征
if (!applyRule(&os, it->c_str(), cur, *tagger)) {//处理同上
return false;
}
ADD;
}
feature_cache->add(feature);
feature.clear();
} return true;
}
经过上面处理,最终会存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中,由于在该函数中调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表。这里需要关注一下applyRule函数和ADD宏定义中的getID函数。下面我将举个例子,来直观感受下这两个函数的功能。
tempfile:
# Unigram
U00:%x[-1,0]
U01:%x[0,0]
trainfile:
0 - -1 -1 -1 -1 O
0 submit 7 0 0 0 B
1 submit 3 4 0 0 E
先看下CRF++中的特征模板,模板文件比较简单,只有unigram特征,特征的表示形如 U00:%x[a,b],开头的'U'代表unigram特征还是bigram特征,b代表的是哪列特征,a代表的是当前词的行偏移量。样本集文件更简单,只有一个句子,该句子有3个单词,每个单词有6个特征。
1) 当cur=0,遍历第一个unigram特征U00:%x[-1,0], 0代表第0个特征(第0列),-1代表前一个词的第0个特征。由于第一个词没有前一个词,所以CRF++中用_B-1代替,这部分可在源码中找到。调用applyRule将会生成"U00:_B-1"特征,调用getID函数返回的maxid_并存储在feature_index的dic_属性中,maxid_初始值为0,如果当前特征是新的则返回maxid_并更新maxid_为新值,maxid更新代码为maxid_ += (key[0] == 'U' ? y_.size() : y_.size() * y_.size()); 由于unigram是状态特征label与当前节点有关,所以加y_.size()表示y_.size()个特征函数,而bigram表示转移特征(边),与当前状态和前一个状态有关,有y_.size() * y_.size()种情况,因此加上y_.size() * y_.size(),代表y_.size()*y_.size()个特征函数。以上述例子unigram来说,对于某个词的特征,该词的label可能有y_.size()种情况,最终生成的特征函数是 f(特征='U00:_B-1', y='O')=1,f(特征='U00:_B-1', y='B')=1,f(特征='U00:_B-1', y='E')=1。总结一下,对于这个例子来说,一个unigram特征对应3状态特征函数,一个bigram特征对应9个转移特征函数。
2) 当cur=0,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3,feature变量为[0,3]
3) 当cur=1,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6
4) 当cur=1,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3, feature变量为[6,3]
5) 当cur=2,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6
6) 当cur=2,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:1",调用getID函数,返回特征id为9, 此时maxid_更新为12,feature变量为[6,9]
因此,特征一共有4个,状态特征有12个,转移特征为0个,因此feature_index的maxid_为12,feature_cache的大小为5(3个节点+2条边)。本例子中只有1句话并且只有一个特征的unigram特征函数,对于多句话和多个特征函数,计算逻辑是一样的,并且都会更新到公共的变量feature_index中。
至此,就_x->shrink()的核心逻辑便梳理完毕, 同时也是整个learn函数的核心逻辑,回到learn函数的源码继续往下看,while循环会对每个句子重复进行上述操作,并将表示句子的变量x_存储到变量x中,代表整个训练集。还有需要注意的是我们平时一般用w表示待学习的参数,但在CRF++中使用变量alpha表示w。
总结
本章主要结合源码和实际的例子,了解了CRF++如何处理输入的样本,如何生成特征以及特征函数。首先,通过本章可以清晰的找到开头提到的几个问题。其次,可以学习CRF++如何定义数据结构表示条件随机场各个元素及其之间的关系,如果再仔细体会一下,就能发现CRF++里设计的数据结构和代码实现还是非常巧妙的,值得学习。如对本章内容有疑问的欢迎在留言区交流,我会及时回复,同时如有表述不对的地方,烦请指正。
条件随机场之CRF++源码详解-特征的更多相关文章
- 条件随机场之CRF++源码详解-预测
这篇文章主要讲解CRF++实现预测的过程,预测的算法以及代码实现相对来说比较简单,所以这篇文章理解起来也会比上一篇条件随机场训练的内容要容易. 预测 上一篇条件随机场训练的源码详解中,有一个地方并没有 ...
- 条件随机场之CRF++源码详解-训练
上篇的CRF++源码阅读中, 我们看到CRF++如何处理样本以及如何构造特征.本篇文章将继续探讨CRF++的源码,并且本篇文章将是整个系列的重点,会介绍条件随机场中如何构造无向图.前向后向算法.如何计 ...
- 条件随机场之CRF++源码详解-开篇
介绍 最近在用条件随机场做切分标注相关的工作,系统学习了下条件随机场模型.能够理解推导过程,但还是比较抽象.因此想研究下模型实现的具体过程,比如:1) 状态特征和转移特征具体是什么以及如何构造 2)前 ...
- Spark Streaming揭秘 Day25 StreamingContext和JobScheduler启动源码详解
Spark Streaming揭秘 Day25 StreamingContext和JobScheduler启动源码详解 今天主要理一下StreamingContext的启动过程,其中最为重要的就是Jo ...
- [转]Linux内核源码详解--iostat
Linux内核源码详解——命令篇之iostat 转自:http://www.cnblogs.com/york-hust/p/4846497.html 本文主要分析了Linux的iostat命令的源码, ...
- saltstack源码详解一
目录 初识源码流程 入口 1.grains.items 2.pillar.items 2/3: 是否可以用python脚本实现 总结pillar源码分析: @(python之路)[saltstack源 ...
- Activiti架构分析及源码详解
目录 Activiti架构分析及源码详解 引言 一.Activiti设计解析-架构&领域模型 1.1 架构 1.2 领域模型 二.Activiti设计解析-PVM执行树 2.1 核心理念 2. ...
- 源码详解系列(六) ------ 全面讲解druid的使用和源码
简介 druid是用于创建和管理连接,利用"池"的方式复用连接减少资源开销,和其他数据源一样,也具有连接数控制.连接可靠性测试.连接泄露控制.缓存语句等功能,另外,druid还扩展 ...
- Mybatis源码详解系列(四)--你不知道的Mybatis用法和细节
简介 这是 Mybatis 系列博客的第四篇,我本来打算详细讲解 mybatis 的配置.映射器.动态 sql 等,但Mybatis官方中文文档对这部分内容的介绍已经足够详细了,有需要的可以直接参考. ...
随机推荐
- C++学习3--编程基础(vector、string、三种传参)
知识点学习 Vector容器 vector是C++标准程序库中的一个类,其定义于头文件中,与其他STL组件一样,ventor属于STD名称空间: ventor是C++标准程序库里最基本的容器,设计之初 ...
- 分布式系列 - dubbo服务telnet命令【转】
dubbo服务发布之后,我们可以利用telnet命令进行调试.管理.Dubbo2.0.5以上版本服务提供端口支持telnet命令,下面我以通过实例抛砖引玉一下: 1.连接服务 测试对应IP和端口下的d ...
- TCP/IP指纹鉴别 fingerprint
http://www.freebuf.com/articles/system/30037.html使用TCP/IP协议栈指纹进行远程操作系统辨识 Fyodor <fyodor@insecure. ...
- nodejs前端开发环境安装
1. nodejs安装 要求:node版本6.2.0及以上,npm版本3.8.9及以上 Nodejs安装包地址: 2. 在rTools上下载并安装git 3. 在rTools上 ...
- android项目结构
- 输入一个数,求1到他 的和(for循环)
- Object Detection
这篇博客对目标检测做了总结:https://handong1587.github.io/deep_learning/2015/10/09/object-detection.html
- 使用@font-family时各浏览器对字体格式(format)的支持情况
说到浏览器对@font-face的兼容问题,这里涉及到一个字体format的问题,因为不同的浏览器对字体格式支持是不一致的,这样大家有必要了解一下,各种版本的浏览器支持什么样的字体,前面也简单带到了有 ...
- SeaJS入门教程系列之SeaJS介绍(一)
前言SeaJS是一个遵循CommonJS规范的JavaScript模块加载框架,可以实现JavaScript的模块化开发及加载机制.与jQuery等JavaScript框架不同,SeaJS不会扩展封装 ...
- tensorflow-安装
1.pip安装(最好在虚拟环境中安装) →更新pip:pip install --upgrade pip →安装最新版tensorflow(GPU):pip install tensorflow-gp ...