Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和上文将带领大家来分析Alink中 Word2Vec 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

0x01 前文回顾

从前文 Alink漫谈(十六) :Word2Vec之建立霍夫曼树 我们了解了Word2Vec的概念、在Alink中的整体架构以及完成对输入的处理,以及词典、二叉树的建立。

此时我们已经有了一个已经构造好的Huffman树,以及初始化完毕的各个向量,可以开始输入文本来进行训练了。

1.1 上文总体流程图

先给出一个上文总体流程图:

1.2 回顾霍夫曼树

1.2.1 变量定义

现在定义变量如下:

  • n : 一个词的上下文包含的词数,与n-gram中n的含义相同
  • m : 词向量的长度,通常在10~100
  • h : 隐藏层的规模,一般在100量级
  • N :词典的规模,通常在1W~10W
  • T : 训练文本中单词个数

1.2.2 为何要引入霍夫曼树

word2vec也使用了CBOW与Skip-Gram来训练模型与得到词向量,但是并没有使用传统的DNN模型。最先优化使用的数据结构是用霍夫曼树来代替隐藏层和输出层的神经元,霍夫曼树的叶子节点起到输出层神经元的作用,叶子节点的个数即为词汇表的小大。 而内部节点则起到隐藏层神经元的作用

以CBOW为例,输入层为n-1个单词的词向量,长度为m(n-1),隐藏层的规模为h,输出层的规模为N。那么前向的时间复杂度就是o(m(n-1)h+hN) = o(hN) 这还是处理一个词所需要的复杂度。如果要处理所有文本,则需要o(hNT)的时间复杂度。这个是不可接受的。

同时我们也注意到,o(hNT)之中,h和T的值相对固定,想要对其进行优化,主要还是应该从N入手。而输出层的规模之所以为N,是因为这个神经网络要完成的是N选1的任务。那么可不可以减小N的值呢?答案是可以的。解决的思路就是将一次分类分解为多次分类,这也是Hierarchical Softmax的核心思想

举个栗子,有[1,2,3,4,5,6,7,8]这8个分类,想要判断词A属于哪个分类,我们可以一步步来,首先判断A是属于[1,2,3,4]还是属于[5,6,7,8]。如果判断出属于[1,2,3,4],那么就进一步分析是属于[1,2]还是[3,4],以此类推。这样一来,就把单个词的时间复杂度从o(hN)降为o(hlogN),更重要的减少了内存的开销。

从输入到输出,中间是一个树形结构,其中的每一个节点都完成一个二分类(logistic分类)问题。那么就存在一个如何构建树的问题。这里采用huffman树,因为这样构建的话,出现频率越高的词所经过的路径越短,从而使得所有单词的平均路径长度达到最短。

假设 足球 的路径 1001,那么当要输出 足球 这个词的时候,这个模型其实并不是直接输出 "1001" 这条路径,而是在每一个节点都进行一次二分类。这样相当于将最后输出的二叉树变成多个二分类的任务。而路径中的每个根节点都是一个待求的向量。 也就是说这个模型不仅需要求每个输入参数的变量,还需要求这棵二叉树中每个非叶子节点的向量,当然这些向量都只是临时用的向量。

0x02 训练

2.1 训练流程

Alink 实现的是 基于Hierarchical Softmax的Skip-Gram模型

现在我们先看看基于Skip-Gram模型时, Hierarchical Softmax如何使用。此时输入的只有一个词w,输出的为2c个词向量context(w)。

训练的过程主要有输入层(input),映射层(projection)和输出层(output)三个阶段。

  • 我们对于训练样本中的每一个词,该词本身作为样本的输入,其前面的c个词和后面的c个词作为了Skip-Gram模型的输出,期望这些词的softmax概率比其他的词大。
  • 我们需要先将词汇表建立成一颗霍夫曼树(此步骤在上文已经完成)。
  • 对于从输入层到隐藏层(映射层),这一步比CBOW简单,由于只有一个词,所以,即x_w就是词w对应的词向量。
  • 通过梯度上升法来更新我们的θwj−1和x_w,注意这里的x_w周围有2c个词向量,我们期望P(xi|xw),i=1,2...2c最大。此时我们注意到由于上下文是相互的,在期望P(xi|xw),i=1,2...2c最大化的同时,反过来我们也期望P(xw|xi),i=1,2...2c最大。那么是使用P(xi|xw)好还是P(xw|xi)好呢,word2vec使用了后者,这样做的好处就是在一个迭代窗口内,我们不是只更新xw一个词,而是xi,i=1,2...2c共2c个词。这样整体的迭代会更加的均衡。因为这个原因,Skip-Gram模型并没有和CBOW模型一样对输入进行迭代更新,而是对2c个输出进行迭代更新。
    • 从根节点开始,映射层的值需要沿着Huffman树不断的进行logistic分类,并且不断的修正各中间向量和词向量。
    • 假设映射层输入为 pro(t),单词为“足球”,即w(t)=“足球”,假设其Huffman码可知为d(t)=”1001,那么根据Huffman码可知,从根节点到叶节点的路径为“左右右左”,即从根节点开始,先往左拐,再往右拐2次,最后再左拐。
    • 既然知道了路径,那么就按照路径从上往下依次修正路径上各节点的中间向量。在第一个节点,根据节点的中间向量Θ(t,1)和pro(t)进行Logistic分类。如果分类结果显示为0,则表示分类错误 (应该向左拐,即分类到1),则要对Θ(t,1)进行修正,并记录误差量。
    • 接下来,处理完第一个节点之后,开始处理第二个节点。方法类似,修正Θ(t,2),并累加误差量。接下来的节点都以此类推。
    • 在处理完所有节点,达到叶节点之后,根据之前累计的误差来修正词向量v(w(t))。这里引入学习率概念,η 表示学习率。学习率越大,则判断错误的惩罚也越大,对中间向量的修正跨度也越大。

这样,一个词w(t)的处理流程就结束了。如果一个文本中有N个词,则需要将上述过程在重复N遍,从w(0)~w(N-1)。

这里总结下基于Hierarchical Softmax的Skip-Gram模型算法流程,梯度迭代使用了随机梯度上升法:

  • 输入:基于Skip-Gram的语料训练样本,词向量的维度大小M,Skip-Gram的上下文大小2c,步长η
  • 输出:霍夫曼树的内部节点模型参数θ,所有的词向量w

2.2 生成训练模型

Huffman树中非叶节点存储的中间向量的初始化值是零向量,而叶节点对应的单词的词向量是随机初始化的。我们可以看到,对于 input 和 output,是会进行 AllReduce 的,就是聚合每个task的计算结果

  1. DataSet <Row> model = new IterativeComQueue()
  2. .initWithPartitionedData("trainData", trainData)
  3. .initWithBroadcastData("vocSize", vocSize)
  4. .initWithBroadcastData("initialModel", initialModel)
  5. .initWithBroadcastData("vocabWithoutWordStr", vocabWithoutWordStr)
  6. .initWithBroadcastData("syncNum", syncNum)
  7. .add(new InitialVocabAndBuffer(getParams()))
  8. .add(new UpdateModel(getParams()))
  9. .add(new AllReduce("input"))
  10. .add(new AllReduce("output"))
  11. .add(new AvgInputOutput())
  12. .setCompareCriterionOfNode0(new Criterion(getParams()))
  13. .closeWith(new SerializeModel(getParams()))
  14. .exec();

2.3 初始化词典&缓冲

InitialVocabAndBuffer类完成此功能,主要是初始化参数,把词典加载到模型内存中。这里只有迭代第一次才会运行

  • input 数组存着Vocab的全部词向量,就是Huffman树所有叶子节点的词向量。大小|V|∗|M|,初始化范围[−0.5M,0.5M],经验规则。
  • output 数组存着Hierarchical Softmax的参数,就是Huffman树所有 非叶子 节点的参数向量(映射层到输出层之间的权重)。大小|V|∗|M|,初始化全为0,经验规则。实际使用|V−1|组。
  1. private static class InitialVocabAndBuffer extends ComputeFunction {
  2. Params params;
  3. public InitialVocabAndBuffer(Params params) {
  4. this.params = params;
  5. }
  6. @Override
  7. public void calc(ComContext context) {
  8. if (context.getStepNo() == 1) { // 只有迭代第一次才会运行
  9. int vectorSize = params.get(Word2VecTrainParams.VECTOR_SIZE);
  10. List <Long> vocSizeList = context.getObj("vocSize");
  11. List <Tuple2 <Integer, double[]>> initialModel = context.getObj("initialModel");
  12. List <Tuple2 <Integer, Word>> vocabWithoutWordStr = context.getObj("vocabWithoutWordStr");
  13. int vocSize = vocSizeList.get(0).intValue();
  14. // 生成一个 100 x 12 的input,这个迭代之后就是最终的词向量
  15. double[] input = new double[vectorSize * vocSize];
  16. Word[] vocab = new Word[vocSize];
  17. for (int i = 0; i < vocSize; ++i) {
  18. Tuple2 <Integer, double[]> item = initialModel.get(i);
  19. System.arraycopy(item.f1, 0, input,
  20. item.f0 * vectorSize, vectorSize); //初始化词向量
  21. Tuple2 <Integer, Word> vocabItem = vocabWithoutWordStr.get(i);
  22. vocab[vocabItem.f0] = vocabItem.f1;
  23. }
  24. context.putObj("input", input); // 把词向量放入系统上下文
  25. // 生成一个 100 x 11 的output,就是Hierarchical Softmax的参数
  26. context.putObj("output", new double[vectorSize * (vocSize - 1)]);
  27. context.putObj("vocab", vocab);
  28. context.removeObj("initialModel");
  29. context.removeObj("vocabWithoutWordStr");
  30. }
  31. }
  32. }

2.4 更新模型UpdateModel

这里进行“分布式计算”的分配。其中,如何计算给哪个task发送多少/发送起始位置,是在DefaultDistributedInfo完成的。这里需要结合 pieces 函数进行分析。具体在 [ Alink漫谈之三] AllReduce通信模型 有详细介绍。具体计算则是在 CalcModel.update 中完成。

  1. private static class UpdateModel extends ComputeFunction {
  2. @Override
  3. public void calc(ComContext context) {
  4. List <int[]> trainData = context.getObj("trainData");
  5. int syncNum = ((List <Integer>) context.getObj("syncNum")).get(0);
  6. DistributedInfo distributedInfo = new DefaultDistributedInfo();
  7. long startPos = distributedInfo.startPos(
  8. (context.getStepNo() - 1) % syncNum,
  9. syncNum,
  10. trainData.size()
  11. );
  12. long localRowCnt = distributedInfo.localRowCnt( //计算本分区信息
  13. (context.getStepNo() - 1) % syncNum,
  14. syncNum,
  15. trainData.size()
  16. );
  17. new CalcModel( //更新模型
  18. params.get(Word2VecTrainParams.VECTOR_SIZE),
  19. System.currentTimeMillis(),
  20. Boolean.parseBoolean(params.get(Word2VecTrainParams.RANDOM_WINDOW)),
  21. params.get(Word2VecTrainParams.WINDOW),
  22. params.get(Word2VecTrainParams.ALPHA),
  23. context.getTaskId(),
  24. context.getObj("vocab"),
  25. context.getObj("input"),
  26. context.getObj("output")
  27. ).update(trainData.subList((int) startPos, (int) (startPos + localRowCnt)));
  28. }
  29. }

2.5 计算更新

CalcModel.update 中完成计算更新。

2.5.1 sigmoid函数值近似计算

在利用神经网络模型对样本进行预測的过程中。须要对其进行预測,此时,须要使用到sigmoid函数,sigmoid函数的具体形式为:σ(x)=1 / (1+e^x)

σ(x) 在 x = 0 附近变化剧烈,往两边逐渐趋于平缓,当 x > 6 或者 x < -6 时候,函数值就基本不变了,前者趋近于0,后者趋近于1.

假设每一次都请求计算sigmoid值,对性能将会有一定的影响,当sigmoid的值对精度的要求并非非常严格时。能够採用近似计算。在word2vec中。将区间[−6,6](设置的參数MAX_EXP为6)等距离划分成EXP_TABLE_SIZE等份,并将每个区间中的sigmoid值计算好存入到数组expTable中。须要使用时,直接从数组中查找。

Alink中实现如下:

  1. public class ExpTableArray {
  2. public final static float[] sigmoidTable = {
  3. 0.002473f, 0.002502f, 0.002532f, 0.002562f, 0.002593f, 0.002624f, 0.002655f, 0.002687f, 0.002719f, 0.002751f ......
  4. }
  5. }

2.5.2 窗口及上下文

Context(w) 就是在词 w 的前后各取 C 个词,Alink是事先设置一个窗口预置参数window(默认为5),每次构造Context(w)时候,首先生成一个[1, window] 上的一个随机整数 C~ ,于是 w 前后各取 C~ 个词就构成了 Context(w)。

  1. if (randomWindow) {
  2. b = random.nextInt(window);
  3. } else {
  4. b = 0;
  5. }
  6. int bound = window * 2 + 1 - b;
  7. for (int a = b; a < bound; ++a) {
  8. .....
  9. }

2.5.3 训练

2.5.3.1 数据结构

c语言代码 中:

  • syn0数组存着Vocab的全部词向量,就是Huffman树所有叶子节点的词向量,即input -> hidden 的 weights 。大小|V|∗|M|,初始化范围[−0.5M,0.5M],经验规则。在code中是一个1维数组,但是应该按照二维数组来理解。访问时实际上可以看成 syn0[i, j] i为第i个单词,j为第j个隐含单元。
  • syn1数组存着Hierarchical Softmax的参数,就是Huffman树所有 非叶子 节点的参数向量,即 hidden----> output 的 weights。大小|V|∗|M|,初始化全为0,经验规则。实际使用|V−1|组。

原本的Softmax问题,被近似退化成了近似log(K)个Logistic回归组合成决策树。

Softmax的K组θ,现在变成了K-1组,代表着二叉树的K-1个非叶结点。在Word2Vec中,由syn1数组存放,。

Alink代码 中:

  • input 就对应了 syn0,就是上图的 v。
  • output 就对应了 syn1,就是上图的 θ。
2.5.3.2 具体代码

具体代码如下(我们使用最大似然法来寻找所有节点的词向量和所有内部节点θ):

  1. private static class CalcModel {
  2. public void update(List <int[]> values) {
  3. double[] neu1e = new double[vectorSize];
  4. double f, g;
  5. int b, c, lastWord, l1, l2;
  6. for (int[] val : values) {
  7. for (int i = 0; i < val.length; ++i) {
  8. if (randomWindow) {
  9. b = random.nextInt(window);
  10. } else {
  11. b = 0;
  12. }
  13. // 在Skip-gram模型中。须要使用当前词分别预測窗体中的词,因此。这是一个循环的过程
  14. // 因为需要预测Context(w)中的每个词,因此需要循环2window - 2b + 1次遍历整个窗口
  15. int bound = window * 2 + 1 - b;
  16. for (int a = b; a < bound; ++a) {
  17. if (a != window) { //遍历时跳过中心单词
  18. c = i - window + a;
  19. if (c < 0 || c >= val.length) {
  20. continue;
  21. }
  22. lastWord = val[c]; //last_word为当前待预测的上下文单词
  23. l1 = lastWord * vectorSize; //l1为当前单词的词向量在syn0中的起始位置
  24. Arrays.fill(neu1e, 0.f); //初始化累计误差
  25. Word w = vocab[val[i]];
  26. int codeLen = w.code.length;
  27. //根据Haffman树上从根节点到当前词的叶节点的路径,遍历所有经过的中间节点
  28. for (int d = 0; d < codeLen; ++d) {
  29. f = 0.f;
  30. //l2为当前遍历到的中间节点的向量在syn1中的起始位置
  31. l2 = w.point[d] * vectorSize;
  32. // 正向传播,得到该编码单元对应的output 值f
  33. //注意!这里用到了模型对称:p(u|w) = p(w|u),其中w为中心词,u为context(w)中每个词, 也就是skip-gram虽然是给中心词预测上下文,真正训练的时候还是用上下文预测中心词, 与CBOW不同的是这里的u是单个词的词向量,而不是窗口向量之和
  34. // 将路径上所有Node连锁起来,累积得到 输入向量与中间结点向量的内积
  35. // f=σ(W.θi)
  36. for (int t = 0; t < vectorSize; ++t) {
  37. // 这里就是 X * Y
  38. // 映射层即为输入层
  39. f += input[l1 + t] * output[l2 + t];
  40. }
  41. if (f > -6.0f && f < 6.0f) {
  42. // 从 ExpTableArray 中查询到相应的值。
  43. f = ExpTableArray.sigmoidTable[(int) ((f + 6.0) * 84.0)];
  44. //@brief此处最核心,loss是交叉熵 Loss=xlogp(x)+(1-x)*log(1-p(x))
  45. //其中p(x)=exp(neu1[c] * syn1[c + l2])/(1+exp(neu1[c] * syn1[c + l2]))
  46. //x=1-code#作者才此处定义label为1-code,实际上也可以是code
  47. //log(L) = (1-x) * neu1[c] * syn1[c + l2] -x*log(1 + exp(neu1[c] * syn1[c + l2]))
  48. //对log(L)中的syn1进行偏导,g=(1 -code - p(x))*syn1
  49. //因此会有
  50. //g = (1 - vocab[word].code[d] - f) * alpha;alpha学习速率
  51. // 'g' is the gradient multiplied by the learning rate
  52. // g是梯度和学习率的乘积
  53. //注意!word2vec中将Haffman编码为1的节点定义为负类,而将编码为0的节点定义为正类,即一个节点的label = 1 - d
  54. g = (1.f - w.code[d] - f) * alpha;
  55. // Propagate errors output -> hidden
  56. // 根据计算得到的修正量g和中间节点的向量更新累计误差
  57. for (int t = 0; t < vectorSize; ++t) {
  58. neu1e[t] += g * output[l2 + t]; // 修改映射后的结果
  59. }
  60. // Learn weights hidden -> output
  61. for (int t = 0; t < vectorSize; ++t) {
  62. output[l2 + t] += g * input[l1 + t]; // 改动映射层到输出层之间的权重
  63. }
  64. }
  65. }
  66. for (int t = 0; t < vectorSize; ++t) {
  67. input[l1 + t] += neu1e[t]; // 返回改动每个词向量
  68. }
  69. }
  70. }
  71. }
  72. }
  73. }
  74. }

2.6 平均化

AvgInputOutput 类会对Input,output做平均化。

  1. .add(new AllReduce("input"))
  2. .add(new AllReduce("output"))
  3. .add(new AvgInputOutput())

原因在于做AllReduce时候,会简单的累积,如果有 context.getNumTask() 个task在同时进行,就容易简单粗暴的相加,这样数值就会扩大 context.getNumTask() 倍。

  1. private static class AvgInputOutput extends ComputeFunction {
  2. @Override
  3. public void calc(ComContext context) {
  4. double[] input = context.getObj("input");
  5. for (int i = 0; i < input.length; ++i) {
  6. input[i] /= context.getNumTask(); //平均化
  7. }
  8. double[] output = context.getObj("output");
  9. for (int i = 0; i < output.length; ++i) {
  10. output[i] /= context.getNumTask(); //平均化
  11. }
  12. }
  13. }

2.7 判断收敛

这里能够看到,收敛就是判断是否达到迭代次数

  1. private static class Criterion extends CompareCriterionFunction {
  2. @Override
  3. public boolean calc(ComContext context) {
  4. return (context.getStepNo() - 1)
  5. == ((List <Integer>) context.getObj("syncNum")).get(0)
  6. * params.get(Word2VecTrainParams.NUM_ITER);
  7. }
  8. }

2.8 序列化模型

这是在 context.getTaskId() 为 0 的task中完成序列化操作,其他task直接返回。这里收集了所有task的计算结果。

  1. private static class SerializeModel extends CompleteResultFunction {
  2. @Override
  3. public List <Row> calc(ComContext context) {
  4. // 在 context.getTaskId() 为 0 的task中完成序列化操作,其他task直接返回
  5. if (context.getTaskId() != 0) {
  6. return null; //其他task直接返回
  7. }
  8. int vocSize = ((List <Long>) context.getObj("vocSize")).get(0).intValue();
  9. int vectorSize = params.get(Word2VecTrainParams.VECTOR_SIZE);
  10. List <Row> ret = new ArrayList <>(vocSize);
  11. double[] input = context.getObj("input");
  12. for (int i = 0; i < vocSize; ++i) {
  13. // 完成序列化操作
  14. DenseVector dv = new DenseVector(vectorSize);
  15. System.arraycopy(input, i * vectorSize, dv.getData(), 0, vectorSize);
  16. ret.add(Row.of(i, dv));
  17. }
  18. return ret;
  19. }
  20. }

0x03 输出模型

输出模型的代码如下,功能分别是:

  • 把词典和计算出来的向量联系起来
  • 按分区分割模型成row
  • 发送模型
  1. model = model
  2. .map(new MapFunction <Row, Tuple2 <Integer, DenseVector>>() {
  3. @Override
  4. public Tuple2 <Integer, DenseVector> map(Row value) throws Exception {
  5. return Tuple2.of((Integer) value.getField(0), (DenseVector) value.getField(1));
  6. }
  7. })
  8. .join(vocab)
  9. .where(0)
  10. .equalTo(0) //把词典和计算出来的向量联系起来
  11. .with(new JoinFunction <Tuple2 <Integer, DenseVector>, Tuple3 <Integer, String, Word>, Row>() {
  12. @Override
  13. public Row join(Tuple2 <Integer, DenseVector> first, Tuple3 <Integer, String, Word> second)
  14. throws Exception {
  15. return Row.of(second.f1, first.f1);
  16. }
  17. })
  18. .mapPartition(new MapPartitionFunction <Row, Row>() {
  19. @Override
  20. public void mapPartition(Iterable <Row> values, Collector <Row> out) throws Exception {
  21. Word2VecModelDataConverter model = new Word2VecModelDataConverter();
  22. model.modelRows = StreamSupport
  23. .stream(values.spliterator(), false)
  24. .collect(Collectors.toList());
  25. model.save(model, out);
  26. }
  27. });
  28. setOutput(model, new Word2VecModelDataConverter().getModelSchema());

3.1 联系词典和向量

.join(vocab).where(0).equalTo(0)就是把词典和计算出来的向量联系起来。两个Join来源分别如下:

  1. // 来源1,计算出来的向量
  2. first = {Tuple2@11501}
  3. f0 = {Integer@11509} 9
  4. f1 = {DenseVector@11502} "0.9371751984171548 0.33341686580829943 0.6472255126130384 0.36692156358000316 0.1187895685629788 0.9223451469664975 0.763874142430857 0.1330720374498615 0.9631811135902764 0.9283700030050634......"
  5. // 来源2,词典
  6. second = {Tuple3@11499} "(9,我们,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@1ffa469)"
  7. f0 = {Integer@11509} 9
  8. f1 = "我们"
  9. f2 = {Word2VecTrainBatchOp$Word@11510}

3.2 按分区分割模型成row

首先按照分区计算,分割模型成row。这里用到了java 8的新特性 StreamSupport,spliterator。

但是这里只是使用到了Stream的方式,没有使用其并行功能(可能过后会有文章进行研究)。

  1. model.modelRows = StreamSupport
  2. .stream(values.spliterator(), false)
  3. .collect(Collectors.toList());

比如某个分区得到:

  1. model = {Word2VecModelDataConverter@11561}
  2. modelRows = {ArrayList@11562} size = 3
  3. 0 = {Row@11567} "胖,0.4345151137723066 0.4923534386513069 0.49497589358976174 0.10917632806760409 0.7007392318076214 0.6468149904858065 0.3804865818632239 0.4348997489483902 0.03362685646645655 0.29769437681180916 0.04287936035337748..."
  4. 1 = {Row@11568} "的,0.4347763498886036 0.6852891840621573 0.9862851622413142 0.7061202166493431 0.9896492612656784 0.46525497532250026 0.03379287230189395 0.809333161215095 0.9230387687661015 0.5100444513892355 0.02436724648194081..."
  5. 2 = {Row@11569} "老王,0.4337285110643647 0.7605192699353084 0.6638406386520266 0.909594031681524 0.26995654043189604 0.3732722125930673 0.16171135697228312 0.9759668223869069 0.40331291071231623 0.22651841541002585 0.7150087001048662...."
  6. ......

3.3 发送数据

然后发送数据

  1. public class Word2VecModelDataConverter implements ModelDataConverter<Word2VecModelDataConverter, Word2VecModelDataConverter> {
  2. public List <Row> modelRows;
  3. @Override
  4. public void save(Word2VecModelDataConverter modelData, Collector<Row> collector) {
  5. modelData.modelRows.forEach(collector::collect); //发送数据
  6. }
  7. @Override
  8. public TableSchema getModelSchema() {
  9. return new TableSchema( //返回schema
  10. new String[] {"word", "vec"},
  11. new TypeInformation[] {Types.STRING, VectorTypes.VECTOR}
  12. );
  13. }
  14. }

0x04 问题答案

我们上文提到了一些问题,现在逐一回答:

  • 哪些模块用到了Alink的分布式处理能力?答案是:

    • 分割单词,计数(为了剔除低频词,排序);
    • 单词排序;
    • 训练;
  • Alink实现了Word2vec的哪个模型?是CBOW模型还是skip-gram模型?答案是:
    • skip-gram模型
  • Alink用到了哪个优化方法?是Hierarchical Softmax?还是Negative Sampling?答案是:
    • Hierarchical Softmax
  • 是否在本算法内去除停词?所谓停用词,就是出现频率太高的词,如逗号,句号等等,以至于没有区分度。答案是:
    • 本实现中没有去处停词
  • 是否使用了自适应学习率?答案是:
    • 没有

0xFF 参考

word2vec原理推导与代码分析

文本深度表示模型Word2Vec

word2vec原理(二) 基于Hierarchical Softmax的模型

word2vec原理(一) CBOW与Skip-Gram模型基础

word2vec原理(三) 基于Negative Sampling的模型

word2vec概述

对Word2Vec的理解

自己动手写word2vec (一):主要概念和流程

自己动手写word2vec (二):统计词频

自己动手写word2vec (三):构建Huffman树

自己动手写word2vec (四):CBOW和skip-gram模型

word2vec 中的数学原理详解(一)目录和前言

基于 Hierarchical Softmax 的模型

基于 Negative Sampling 的模型

机器学习算法实现解析——word2vec源代码解析

Word2Vec源码解析

word2vec源码思路和关键变量

Word2Vec源码最详细解析(下)

word2vec源码思路和关键变量

Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练的更多相关文章

  1. Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树

    Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树 目录 Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树 0x00 摘要 0x01 背景概念 1.1 词向量基础 ...

  2. Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构

    Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 目录 Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 0x00 摘要 0x01 Alink设计原则 0x02 A ...

  3. [源码分析] Facebook如何训练超大模型---(1)

    [源码分析] Facebook如何训练超大模型---(1) 目录 [源码分析] Facebook如何训练超大模型---(1) 0x00 摘要 0x01 简介 1.1 FAIR & FSDP 1 ...

  4. [源码分析] Facebook如何训练超大模型 --- (2)

    [源码分析] Facebook如何训练超大模型 --- (2) 目录 [源码分析] Facebook如何训练超大模型 --- (2) 0x00 摘要 0x01 回顾 1.1 ZeRO 1.1.1 Ze ...

  5. [源码分析] Facebook如何训练超大模型 --- (3)

    [源码分析] Facebook如何训练超大模型 --- (3) 目录 [源码分析] Facebook如何训练超大模型 --- (3) 0x00 摘要 0x01 ZeRO-Offload 1.1 设计原 ...

  6. Alink漫谈(二十二) :源码分析之聚类评估

    Alink漫谈(二十二) :源码分析之聚类评估 目录 Alink漫谈(二十二) :源码分析之聚类评估 0x00 摘要 0x01 背景概念 1.1 什么是聚类 1.2 聚类分析的方法 1.3 聚类评估 ...

  7. 手机自动化测试:appium源码分析之bootstrap十七

    手机自动化测试:appium源码分析之bootstrap十七   poptest是国内唯一一家培养测试开发工程师的培训机构,以学员能胜任自动化测试,性能测试,测试工具开发等工作为目标.如果对课程感兴趣 ...

  8. ABP源码分析十七:DTO 自动校验的实现

    对传给Application service对象中的方法的DTO参数,ABP都会在方法真正执行前自动完成validation(根据标注到DTO对象中的validate规则). ABP是如何做到的? 思 ...

  9. ABP源码分析二十七:ABP.Entity Framework

    IRepository:接口定义了Repository常见的方法 AbpRepositoryBase:实现了IRepository接口的常见方法 EfRepositoryBase:实现了AbpRepo ...

随机推荐

  1. CRM【第二篇】: stark组件

    介绍: stark组件,是一个帮助开发者快速实现数据库表的增删改查+的组件.目标: 10s 中完成一张表的增删改查. 前戏: django项目启动时,自定义执行某个py文件. django启动时,且在 ...

  2. Django框架10 /sweetalert插件、django事务和锁、中间件、django请求生命周期

    Django框架10 /sweetalert插件.django事务和锁.中间件.django请求生命周期 目录 Django框架10 /sweetalert插件.django事务和锁.中间件.djan ...

  3. Docker搭建部署Node项目

    前段时间做了个node全栈项目,服务端技术栈是 nginx + koa + postgresql.其中在centos上搭建环境和部署都挺费周折,部署测试服务器,接着上线的时候又部署生产环境服务器.这中 ...

  4. JavaScript图形实例:阿基米德螺线

    1.阿基米德螺线 阿基米德螺线亦称“等速螺线”.当一点P沿动射线OP以等速率运动的同时,该射线又以等角速度绕点O旋转,点P的轨迹称为“阿基米德螺线”. 阿基米德螺线的笛卡尔坐标方程式为: r=10*( ...

  5. 【翻译】Scriban README 文本模板语言和.NET引擎

    scriban Scriban是一种快速.强大.安全和轻量级的文本模板语言和.NET引擎,具有解析liquid模板的兼容模式 Github https://github.com/lunet-io/sc ...

  6. C++语法小记---函数对象

    函数对象 用于替代函数指针 优势:函数对象内部可以保存状态,而不必使用全局变量或静态局部变量 关键:重载"()"操作符 #include<iostream> #incl ...

  7. [spring] -- MVC篇

    流程: 客户端(浏览器)发送请求,直接请求到 DispatcherServlet. DispatcherServlet 根据请求信息调用 HandlerMapping,解析请求对应的 Handler. ...

  8. 云原生时代高性能Java框架—Quarkus(二)

    --- *构建Quarkus本地镜像.容器化部署Quarkus项目* Quarkus系列博文 Quarkus&GraalVM介绍.创建并启动第一个项目 构建Quarkus本地镜像.容器化部署Q ...

  9. C#结合SMTP实现邮件报警通知

    写在前面 C#是微软推出的一门面向对象的通用型编程语言,它除了可以开发PC软件.网站(借助 http://ASP.NET)和APP(基于 Windows Phone),还能作为游戏脚本,编写游戏逻辑. ...

  10. PPT如何转换为Word文档?

    首先,打开你要转换的PPT,按F12键,此时会跳出另存为窗口,如图: 然后点击保存类型,选择RTF文件,保存到指定路径即可. 找到保存好的RTF文件,用word打开即可.