word2vec的Java源码【转】
一、核心代码 word2vec.java
package com.ansj.vec; import java.io.*;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet; import com.ansj.vec.domain.WordEntry;
import com.ansj.vec.util.WordKmeans;
import com.ansj.vec.util.WordKmeans.Classes; public class Word2VEC { public static void main(String[] args) throws IOException { //Learn learn = new Learn();
//learn.learnFile(new File("C:\\Users\\le\\Desktop\\0328-事件相关法律的算法进展\\Result_Country.txt"));
//learn.saveModel(new File("C:\\Users\\le\\Desktop\\0328-事件相关法律的算法进展\\javaSkip1")); Word2VEC vec = new Word2VEC();
vec.loadJavaModel("C:\\Users\\le\\Desktop\\0328-事件相关法律的算法进展\\javaSkip1");
System.out.println("中国" + "\t" +Arrays.toString(vec.getWordVector("中国")));
System.out.println("何润东" + "\t" +Arrays.toString(vec.getWordVector("何润东")));
System.out.println("足球" + "\t" + Arrays.toString(vec.getWordVector("足球"))); String str = "中国";
System.out.println(vec.distance(str));
WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 10);
Classes[] explain = wordKmeans.explain();
for (int i = 0; i < explain.length; i++) {
System.out.println("--------" + i + "---------");
System.out.println(explain[i].getTop(10));
}
} private HashMap<String, float[]> wordMap = new HashMap<String, float[]>(); private int words;
private int size;
private int topNSize = 40; /**
* 鍔犺浇妯″瀷
*
* @param path
* 妯″瀷鐨勮矾寰�
* @throws IOException
*/
public void loadGoogleModel(String path) throws IOException {
DataInputStream dis = null;
BufferedInputStream bis = null;
double len = 0;
float vector = 0;
try {
bis = new BufferedInputStream(new FileInputStream(path));
dis = new DataInputStream(bis);
// //璇诲彇璇嶆暟
words = Integer.parseInt(readString(dis));
// //澶у皬
size = Integer.parseInt(readString(dis));
String word;
float[] vectors = null;
for (int i = 0; i < words; i++) {
word = readString(dis);
vectors = new float[size];
len = 0;
for (int j = 0; j < size; j++) {
vector = readFloat(dis);
len += vector * vector;
vectors[j] = (float) vector;
}
len = Math.sqrt(len); for (int j = 0; j < size; j++) {
vectors[j] /= len;
} wordMap.put(word, vectors);
dis.read();
}
} finally {
bis.close();
dis.close();
}
} /**
* 鍔犺浇妯″瀷
*
* @param path
* 妯″瀷鐨勮矾寰�
* @throws IOException
*/
public void loadJavaModel(String path) throws IOException {
try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) {
words = dis.readInt();
size = dis.readInt(); float vector = 0; String key = null;
float[] value = null;
for (int i = 0; i < words; i++) {
double len = 0;
key = dis.readUTF();
value = new float[size];
for (int j = 0; j < size; j++) {
vector = dis.readFloat();
len += vector * vector;
value[j] = vector;
} len = Math.sqrt(len); for (int j = 0; j < size; j++) {
value[j] /= len;
}
wordMap.put(key, value);
} }
} private static final int MAX_SIZE = 50; /**
* 杩戜箟璇�
*
* @return
*/
public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
float[] wv0 = getWordVector(word0);
float[] wv1 = getWordVector(word1);
float[] wv2 = getWordVector(word2); if (wv1 == null || wv2 == null || wv0 == null) {
return null;
}
float[] wordVector = new float[size];
for (int i = 0; i < size; i++) {
wordVector[i] = wv1[i] - wv0[i] + wv2[i];
}
float[] tempVector;
String name;
List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize);
for (Entry<String, float[]> entry : wordMap.entrySet()) {
name = entry.getKey();
if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
continue;
}
float dist = 0;
tempVector = entry.getValue();
for (int i = 0; i < wordVector.length; i++) {
dist += wordVector[i] * tempVector[i];
}
insertTopN(name, dist, wordEntrys);
}
return new TreeSet<WordEntry>(wordEntrys);
} private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
// TODO Auto-generated method stub
if (wordsEntrys.size() < topNSize) {
wordsEntrys.add(new WordEntry(name, score));
return;
}
float min = Float.MAX_VALUE;
int minOffe = 0;
for (int i = 0; i < topNSize; i++) {
WordEntry wordEntry = wordsEntrys.get(i);
if (min > wordEntry.score) {
min = wordEntry.score;
minOffe = i;
}
} if (score > min) {
wordsEntrys.set(minOffe, new WordEntry(name, score));
} } public Set<WordEntry> distance(String queryWord) { float[] center = wordMap.get(queryWord);
if (center == null) {
return Collections.emptySet();
} int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
TreeSet<WordEntry> result = new TreeSet<WordEntry>(); double min = Float.MIN_VALUE;
for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
float[] vector = entry.getValue();
float dist = 0;
for (int i = 0; i < vector.length; i++) {
dist += center[i] * vector[i];
} if (dist > min) {
result.add(new WordEntry(entry.getKey(), dist));
if (resultSize < result.size()) {
result.pollLast();
}
min = result.last().score;
}
}
result.pollFirst(); return result;
} public Set<WordEntry> distance(List<String> words) { float[] center = null;
for (String word : words) {
center = sum(center, wordMap.get(word));
} if (center == null) {
return Collections.emptySet();
} int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
TreeSet<WordEntry> result = new TreeSet<WordEntry>(); double min = Float.MIN_VALUE;
for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
float[] vector = entry.getValue();
float dist = 0;
for (int i = 0; i < vector.length; i++) {
dist += center[i] * vector[i];
} if (dist > min) {
result.add(new WordEntry(entry.getKey(), dist));
if (resultSize < result.size()) {
result.pollLast();
}
min = result.last().score;
}
}
result.pollFirst(); return result;
} private float[] sum(float[] center, float[] fs) {
// TODO Auto-generated method stub if (center == null && fs == null) {
return null;
} if (fs == null) {
return center;
} if (center == null) {
return fs;
} for (int i = 0; i < fs.length; i++) {
center[i] += fs[i];
} return center;
} /**
* 寰楀埌璇嶅悜閲�
*
* @param word
* @return
*/
public float[] getWordVector(String word) {
return wordMap.get(word);
} public static float readFloat(InputStream is) throws IOException {
byte[] bytes = new byte[4];
is.read(bytes);
return getFloat(bytes);
} /**
* 璇诲彇涓�涓猣loat
*
* @param b
* @return
*/
public static float getFloat(byte[] b) {
int accum = 0;
accum = accum | (b[0] & 0xff) << 0;
accum = accum | (b[1] & 0xff) << 8;
accum = accum | (b[2] & 0xff) << 16;
accum = accum | (b[3] & 0xff) << 24;
return Float.intBitsToFloat(accum);
} /**
* 璇诲彇涓�涓瓧绗︿覆
*
* @param dis
* @return
* @throws IOException
*/
private static String readString(DataInputStream dis) throws IOException {
// TODO Auto-generated method stub
byte[] bytes = new byte[MAX_SIZE];
byte b = dis.readByte();
int i = -1;
StringBuilder sb = new StringBuilder();
while (b != 32 && b != 10) {
i++;
bytes[i] = b;
b = dis.readByte();
if (i == 49) {
sb.append(new String(bytes));
i = -1;
bytes = new byte[MAX_SIZE];
}
}
sb.append(new String(bytes, 0, i + 1));
return sb.toString();
} public int getTopNSize() {
return topNSize;
} public void setTopNSize(int topNSize) {
this.topNSize = topNSize;
} public HashMap<String, float[]> getWordMap() {
return wordMap;
} public int getWords() {
return words;
} public int getSize() {
return size;
} }
二、词向量-模型学习代码learn.java
package com.ansj.vec; import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry; import com.ansj.vec.util.MapCount;
import com.ansj.vec.domain.HiddenNeuron;
import com.ansj.vec.domain.Neuron;
import com.ansj.vec.domain.WordNeuron;
import com.ansj.vec.util.Haffman; public class Learn { private Map<String, Neuron> wordMap = new HashMap<>();
/**
* 训练多少个特征
*/
private int layerSize = 200; /**
* 上下文窗口大小
*/
private int window = 5; private double sample = 1e-3;
private double alpha = 0.025;
private double startingAlpha = alpha; public int EXP_TABLE_SIZE = 1000; private Boolean isCbow = false; private double[] expTable = new double[EXP_TABLE_SIZE]; private int trainWordsCount = 0; private int MAX_EXP = 6; public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha,
Double sample) {
createExpTable();
if (isCbow != null) {
this.isCbow = isCbow;
}
if (layerSize != null)
this.layerSize = layerSize;
if (window != null)
this.window = window;
if (alpha != null)
this.alpha = alpha;
if (sample != null)
this.sample = sample;
} public Learn() {
createExpTable();
} /**
* trainModel
*
* @throws IOException
*/
private void trainModel(File file) throws IOException {
try (BufferedReader br = new BufferedReader(new InputStreamReader(
new FileInputStream(file)))) {
String temp = null;
long nextRandom = 5;
int wordCount = 0;
int lastWordCount = 0;
int wordCountActual = 0;
while ((temp = br.readLine()) != null) {
if (wordCount - lastWordCount > 10000) {
System.out.println("alpha:" + alpha + "\tProgress: "
+ (int) (wordCountActual / (double) (trainWordsCount + 1) * 100)
+ "%");
wordCountActual += wordCount - lastWordCount;
lastWordCount = wordCount;
alpha = startingAlpha
* (1 - wordCountActual / (double) (trainWordsCount + 1));
if (alpha < startingAlpha * 0.0001) {
alpha = startingAlpha * 0.0001;
}
}
String[] strs = temp.split(" ");
wordCount += strs.length;
List<WordNeuron> sentence = new ArrayList<WordNeuron>();
for (int i = 0; i < strs.length; i++) {
Neuron entry = wordMap.get(strs[i]);
if (entry == null) {
continue;
}
// The subsampling randomly discards frequent words while keeping the
// ranking same
if (sample > 0) {
double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1)
* (sample * trainWordsCount) / entry.freq;
nextRandom = nextRandom * 25214903917L + 11;
if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
continue;
}
}
sentence.add((WordNeuron) entry);
} for (int index = 0; index < sentence.size(); index++) {
nextRandom = nextRandom * 25214903917L + 11;
if (isCbow) {
cbowGram(index, sentence, (int) nextRandom % window);
} else {
skipGram(index, sentence, (int) nextRandom % window);
}
} }
System.out.println("Vocab size: " + wordMap.size());
System.out.println("Words in train file: " + trainWordsCount);
System.out.println("sucess train over!");
}
} /**
* skip gram 模型训练
*
* @param sentence
* @param neu1
*/
private void skipGram(int index, List<WordNeuron> sentence, int b) {
// TODO Auto-generated method stub
WordNeuron word = sentence.get(index);
int a, c = 0;
for (a = b; a < window * 2 + 1 - b; a++) {
if (a == window) {
continue;
}
c = index - window + a;
if (c < 0 || c >= sentence.size()) {
continue;
} double[] neu1e = new double[layerSize];// 误差项
// HIERARCHICAL SOFTMAX
List<Neuron> neurons = word.neurons;
WordNeuron we = sentence.get(c);
for (int i = 0; i < neurons.size(); i++) {
HiddenNeuron out = (HiddenNeuron) neurons.get(i);
double f = 0;
// Propagate hidden -> output
for (int j = 0; j < layerSize; j++) {
f += we.syn0[j] * out.syn1[j];
}
if (f <= -MAX_EXP || f >= MAX_EXP) {
continue;
} else {
f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2);
f = expTable[(int) f];
}
// 'g' is the gradient multiplied by the learning rate
double g = (1 - word.codeArr[i] - f) * alpha;
// Propagate errors output -> hidden
for (c = 0; c < layerSize; c++) {
neu1e[c] += g * out.syn1[c];
}
// Learn weights hidden -> output
for (c = 0; c < layerSize; c++) {
out.syn1[c] += g * we.syn0[c];
}
} // Learn weights input -> hidden
for (int j = 0; j < layerSize; j++) {
we.syn0[j] += neu1e[j];
}
} } /**
* 词袋模型
*
* @param index
* @param sentence
* @param b
*/
private void cbowGram(int index, List<WordNeuron> sentence, int b) {
WordNeuron word = sentence.get(index);
int a, c = 0; List<Neuron> neurons = word.neurons;
double[] neu1e = new double[layerSize];// 误差项
double[] neu1 = new double[layerSize];// 误差项
WordNeuron last_word; for (a = b; a < window * 2 + 1 - b; a++)
if (a != window) {
c = index - window + a;
if (c < 0)
continue;
if (c >= sentence.size())
continue;
last_word = sentence.get(c);
if (last_word == null)
continue;
for (c = 0; c < layerSize; c++)
neu1[c] += last_word.syn0[c];
} // HIERARCHICAL SOFTMAX
for (int d = 0; d < neurons.size(); d++) {
HiddenNeuron out = (HiddenNeuron) neurons.get(d);
double f = 0;
// Propagate hidden -> output
for (c = 0; c < layerSize; c++)
f += neu1[c] * out.syn1[c];
if (f <= -MAX_EXP)
continue;
else if (f >= MAX_EXP)
continue;
else
f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
// 'g' is the gradient multiplied by the learning rate
// double g = (1 - word.codeArr[d] - f) * alpha;
// double g = f*(1-f)*( word.codeArr[i] - f) * alpha;
double g = f * (1 - f) * (word.codeArr[d] - f) * alpha;
//
for (c = 0; c < layerSize; c++) {
neu1e[c] += g * out.syn1[c];
}
// Learn weights hidden -> output
for (c = 0; c < layerSize; c++) {
out.syn1[c] += g * neu1[c];
}
}
for (a = b; a < window * 2 + 1 - b; a++) {
if (a != window) {
c = index - window + a;
if (c < 0)
continue;
if (c >= sentence.size())
continue;
last_word = sentence.get(c);
if (last_word == null)
continue;
for (c = 0; c < layerSize; c++)
last_word.syn0[c] += neu1e[c];
} }
} /**
* 统计词频
*
* @param file
* @throws IOException
*/
private void readVocab(File file) throws IOException {
MapCount<String> mc = new MapCount<>();
try (BufferedReader br = new BufferedReader(new InputStreamReader(
new FileInputStream(file)))) {
String temp = null;
while ((temp = br.readLine()) != null) {
String[] split = temp.split(" ");
trainWordsCount += split.length;
for (String string : split) {
mc.add(string);
}
}
}
for (Entry<String, Integer> element : mc.get().entrySet()) {
wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
(double) element.getValue() / mc.size(), layerSize));
}
} /**
* 对文本进行预分类
*
* @param files
* @throws IOException
* @throws FileNotFoundException
*/
private void readVocabWithSupervised(File[] files) throws IOException {
for (int category = 0; category < files.length; category++) {
// 对多个文件学习
MapCount<String> mc = new MapCount<>();
try (BufferedReader br = new BufferedReader(new InputStreamReader(
new FileInputStream(files[category])))) {
String temp = null;
while ((temp = br.readLine()) != null) {
String[] split = temp.split(" ");
trainWordsCount += split.length;
for (String string : split) {
mc.add(string);
}
}
}
for (Entry<String, Integer> element : mc.get().entrySet()) {
double tarFreq = (double) element.getValue() / mc.size();
if (wordMap.get(element.getKey()) != null) {
double srcFreq = wordMap.get(element.getKey()).freq;
if (srcFreq >= tarFreq) {
continue;
} else {
Neuron wordNeuron = wordMap.get(element.getKey());
wordNeuron.category = category;
wordNeuron.freq = tarFreq;
}
} else {
wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
tarFreq, category, layerSize));
}
}
}
} /**
* Precompute the exp() table f(x) = x / (x + 1)
*/
private void createExpTable() {
for (int i = 0; i < EXP_TABLE_SIZE; i++) {
expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP));
expTable[i] = expTable[i] / (expTable[i] + 1);
}
} /**
* 根据文件学习
*
* @param file
* @throws IOException
*/
public void learnFile(File file) throws IOException {
readVocab(file);
new Haffman(layerSize).make(wordMap.values()); // 查找每个神经元
for (Neuron neuron : wordMap.values()) {
((WordNeuron) neuron).makeNeurons();
} trainModel(file);
} /**
* 根据预分类的文件学习
*
* @param summaryFile
* 合并文件
* @param classifiedFiles
* 分类文件
* @throws IOException
*/
public void learnFile(File summaryFile, File[] classifiedFiles)
throws IOException {
readVocabWithSupervised(classifiedFiles);
new Haffman(layerSize).make(wordMap.values());
// 查找每个神经元
for (Neuron neuron : wordMap.values()) {
((WordNeuron) neuron).makeNeurons();
}
trainModel(summaryFile);
} /**
* 保存模型
*/
public void saveModel(File file) {
// TODO Auto-generated method stub try (DataOutputStream dataOutputStream = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(file)))) {
dataOutputStream.writeInt(wordMap.size());
dataOutputStream.writeInt(layerSize);
double[] syn0 = null;
for (Entry<String, Neuron> element : wordMap.entrySet()) {
dataOutputStream.writeUTF(element.getKey());
syn0 = ((WordNeuron) element.getValue()).syn0;
for (double d : syn0) {
dataOutputStream.writeFloat(((Double) d).floatValue());
}
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} public int getLayerSize() {
return layerSize;
} public void setLayerSize(int layerSize) {
this.layerSize = layerSize;
} public int getWindow() {
return window;
} public void setWindow(int window) {
this.window = window;
} public double getSample() {
return sample;
} public void setSample(double sample) {
this.sample = sample;
} public double getAlpha() {
return alpha;
} public void setAlpha(double alpha) {
this.alpha = alpha;
this.startingAlpha = alpha;
} public Boolean getIsCbow() {
return isCbow;
} public void setIsCbow(Boolean isCbow) {
this.isCbow = isCbow;
} public static void main(String[] args) throws IOException {
Learn learn = new Learn();
long start = System.currentTimeMillis();
learn.learnFile(new File("library/xh.txt"));
System.out.println("use time " + (System.currentTimeMillis() - start));
learn.saveModel(new File("library/javaVector")); }
}
三、词向量的kmeans聚类 util-----wordKmeans.java
package com.ansj.vec.util; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry; import com.ansj.vec.Word2VEC;
/*import com.ansj.vec.domain.WordEntry;
import com.ansj.vec.util.WordKmeans.Classes;*/
/**
* keanmeans聚类
*
* @author ansj
*
*/
public class WordKmeans { public static void main(String[] args) {
Word2VEC vec = new Word2VEC();
try { vec.loadJavaModel("C:\\Users\\le\\Desktop\\0328-事件相关法律的算法进展\\javaSkip1");
System.out.println("中国" + "\t" +Arrays.toString(vec.getWordVector("中国")));
System.out.println("何润东" + "\t" +Arrays.toString(vec.getWordVector("何润东")));
System.out.println("足球" + "\t" + Arrays.toString(vec.getWordVector("足球")));
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("load model ok!");
WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 50);
Classes[] explain = wordKmeans.explain(); for (int i = 0; i < explain.length; i++) {
System.out.println("--------" + i + "---------");
System.out.println(explain[i].getTop(10));
} } private HashMap<String, float[]> wordMap = null; private int iter; private Classes[] cArray = null; public WordKmeans(HashMap<String, float[]> wordMap, int clcn, int iter) {
this.wordMap = wordMap;
this.iter = iter;
cArray = new Classes[clcn];
} public Classes[] explain() {
//first 取前clcn个点
Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator();
for (int i = 0; i < cArray.length; i++) {
Entry<String, float[]> next = iterator.next();
cArray[i] = new Classes(i, next.getValue());
} for (int i = 0; i < iter; i++) {
for (Classes classes : cArray) {
classes.clean();
} iterator = wordMap.entrySet().iterator();
while (iterator.hasNext()) {
Entry<String, float[]> next = iterator.next();
double miniScore = Double.MAX_VALUE;
double tempScore;
int classesId = 0;
for (Classes classes : cArray) {
tempScore = classes.distance(next.getValue());
if (miniScore > tempScore) {
miniScore = tempScore;
classesId = classes.id;
}
}
cArray[classesId].putValue(next.getKey(), miniScore);
} for (Classes classes : cArray) {
classes.updateCenter(wordMap);
}
System.out.println("iter " + i + " ok!");
} return cArray;
} public static class Classes {
private int id; private float[] center; public Classes(int id, float[] center) {
this.id = id;
this.center = center.clone();
} Map<String, Double> values = new HashMap<>(); public double distance(float[] value) {
double sum = 0;
for (int i = 0; i < value.length; i++) {
sum += (center[i] - value[i])*(center[i] - value[i]) ;
}
return sum ;
} public void putValue(String word, double score) {
values.put(word, score);
} /**
* 重新计算中心点
* @param wordMap
*/
public void updateCenter(HashMap<String, float[]> wordMap) {
for (int i = 0; i < center.length; i++) {
center[i] = 0;
}
float[] value = null;
for (String keyWord : values.keySet()) {
value = wordMap.get(keyWord);
for (int i = 0; i < value.length; i++) {
center[i] += value[i];
}
}
for (int i = 0; i < center.length; i++) {
center[i] = center[i] / values.size();
}
} /**
* 清空历史结果
*/
public void clean() {
// TODO Auto-generated method stub
values.clear();
} /**
* 取得每个类别的前n个结果
* @param n
* @return
*/
public List<Entry<String, Double>> getTop(int n) {
List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(
values.entrySet());
Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Entry<String, Double> o1, Entry<String, Double> o2) {
// TODO Auto-generated method stub
return o1.getValue() > o2.getValue() ? 1 : -1;
}
});
int min = Math.min(n, arrayList.size() - 1);
if(min<=1)return Collections.emptyList() ;
return arrayList.subList(0, min);
} } }
四、词向量的 util-----huffman.java mapcount.java
package com.ansj.vec.util; import java.util.Collection;
import java.util.List;
import java.util.TreeSet; import com.ansj.vec.domain.HiddenNeuron;
import com.ansj.vec.domain.Neuron; /**
* 构建Haffman编码树
*
* @author ansj
*
*/
public class Haffman {
private int layerSize; public Haffman(int layerSize) {
this.layerSize = layerSize;
} private TreeSet<Neuron> set = new TreeSet<>(); public void make(Collection<Neuron> neurons) {
set.addAll(neurons);
while (set.size() > 1) {
merger();
}
} private void merger() {
HiddenNeuron hn = new HiddenNeuron(layerSize);
Neuron min1 = set.pollFirst();
Neuron min2 = set.pollFirst();
hn.category = min2.category;
hn.freq = min1.freq + min2.freq;
min1.parent = hn;
min2.parent = hn;
min1.code = 0;
min2.code = 1;
set.add(hn);
} }
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by Fernflower decompiler)
// package com.ansj.vec.util; import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry; public class MapCount<T> {
private HashMap<T, Integer> hm = null; public MapCount() {
this.hm = new HashMap();
} public MapCount(int initialCapacity) {
this.hm = new HashMap(initialCapacity);
} public void add(T t, int n) {
Integer integer = null;
if((integer = (Integer)this.hm.get(t)) != null) {
this.hm.put(t, Integer.valueOf(integer.intValue() + n));
} else {
this.hm.put(t, Integer.valueOf(n));
} } public void add(T t) {
this.add(t, 1);
} public int size() {
return this.hm.size();
} public void remove(T t) {
this.hm.remove(t);
} public HashMap<T, Integer> get() {
return this.hm;
} public String getDic() {
Iterator iterator = this.hm.entrySet().iterator();
StringBuilder sb = new StringBuilder();
Entry next = null; while(iterator.hasNext()) {
next = (Entry)iterator.next();
sb.append(next.getKey());
sb.append("\t");
sb.append(next.getValue());
sb.append("\n");
} return sb.toString();
} public static void main(String[] args) {
System.out.println(9223372036854775807L);
}
}
五、词向量的domain包
package com.ansj.vec.domain; public class HiddenNeuron extends Neuron{ public double[] syn1 ; //hidden->out public HiddenNeuron(int layerSize){
syn1 = new double[layerSize] ;
} }
package com.ansj.vec.domain; public abstract class Neuron implements Comparable<Neuron> {
public double freq;
public Neuron parent;
public int code;
// 语料预分类
public int category = -1; @Override
public int compareTo(Neuron neuron) {
if (this.category == neuron.category) {
if (this.freq > neuron.freq) {
return 1;
} else {
return -1;
}
} else if (this.category > neuron.category) {
return 1;
} else {
return -1;
}
}
}
package com.ansj.vec.domain; public class WordEntry implements Comparable<WordEntry> {
public String name;
public float score; public WordEntry(String name, float score) {
this.name = name;
this.score = score;
} @Override
public String toString() {
// TODO Auto-generated method stub
return this.name + "\t" + score;
} @Override
public int compareTo(WordEntry o) {
// TODO Auto-generated method stub
if (this.score < o.score) {
return 1;
} else {
return -1;
}
} }
package com.ansj.vec.domain; import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random; public class WordNeuron extends Neuron {
public String name;
public double[] syn0 = null; // input->hidden
public List<Neuron> neurons = null;// 路径神经元
public int[] codeArr = null; public List<Neuron> makeNeurons() {
if (neurons != null) {
return neurons;
}
Neuron neuron = this;
neurons = new LinkedList<>();
while ((neuron = neuron.parent) != null) {
neurons.add(neuron);
}
Collections.reverse(neurons);
codeArr = new int[neurons.size()]; for (int i = 1; i < neurons.size(); i++) {
codeArr[i - 1] = neurons.get(i).code;
}
codeArr[codeArr.length - 1] = this.code; return neurons;
} public WordNeuron(String name, double freq, int layerSize) {
this.name = name;
this.freq = freq;
this.syn0 = new double[layerSize];
Random random = new Random();
for (int i = 0; i < syn0.length; i++) {
syn0[i] = (random.nextDouble() - 0.5) / layerSize;
}
} /**
* 用于有监督的创造hoffman tree
*
* @param name
* @param freq
* @param layerSize
*/
public WordNeuron(String name, double freq, int category, int layerSize) {
this.name = name;
this.freq = freq;
this.syn0 = new double[layerSize];
this.category = category;
Random random = new Random();
for (int i = 0; i < syn0.length; i++) {
syn0[i] = (random.nextDouble() - 0.5) / layerSize;
}
} }
word2vec的Java源码【转】的更多相关文章
- 如何阅读Java源码 阅读java的真实体会
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比 ...
- Android反编译(一)之反编译JAVA源码
Android反编译(一) 之反编译JAVA源码 [目录] 1.工具 2.反编译步骤 3.实例 4.装X技巧 1.工具 1).dex反编译JAR工具 dex2jar http://code.go ...
- 如何阅读Java源码
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动.源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比方吧, ...
- Java 源码学习线路————_先JDK工具包集合_再core包,也就是String、StringBuffer等_Java IO类库
http://www.iteye.com/topic/1113732 原则网址 Java源码初接触 如果你进行过一年左右的开发,喜欢用eclipse的debug功能.好了,你现在就有阅读源码的技术基础 ...
- Programming a Spider in Java 源码帖
Programming a Spider in Java 源码帖 Listing 1: Finding the bad links (CheckLinks.java) import java.awt. ...
- 解密随机数生成器(二)——从java源码看线性同余算法
Random Java中的Random类生成的是伪随机数,使用的是48-bit的种子,然后调用一个linear congruential formula线性同余方程(Donald Knuth的编程艺术 ...
- Java--Eclipse关联Java源码
打开Eclipse,Window->Preferences->Java 点Edit按钮后弹出: 点Source Attachment后弹出: 选择Java安装路径下的src.zip文件即可 ...
- 使用JDT.AST解析java源码
在做java源码的静态代码审计时,最基础的就是对java文件进行解析,从而获取到此java文件的相关信息: 在java文件中所存在的东西很多,很复杂,难以用相关的正则表达式去一一匹配.但是,eclip ...
- [收藏] Java源码阅读的真实体会
收藏自http://www.iteye.com/topic/1113732 刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我 ...
随机推荐
- GET和POST解析
Http定义了与服务器交互的不同方法,最基本的方法有4种,分别是GET,POST,PUT,DELETE.URL全称是资源描述符,我们可以这样认为:一个URL地址,它用于描述一个网络上的资源,而HTTP ...
- cocos2d-x交叉编译到安卓
ccocos2d-x是一个基于MIT协议的开源框架,用于构建游戏.应用程序和其它图形界面交互应用. 它的最大特点就是跨平台性,支持IOS, Android.Windows, WindowsPhone等 ...
- .Net之路(十四)com组件、OLEDB导入EXCEL
版权声明:本文为博主原创文章,未经博主同意不得转载. https://blog.csdn.net/chenfanglincfl/article/details/30546777 .NET com组件 ...
- 百度语音识别开放平台SDK用法
版权声明:本文为博主原创文章,未经博主同意不得转载. https://blog.csdn.net/zpf8861/article/details/30229039 百度Android语音识别SDK分在 ...
- WIFI的通信知识整理
这两天在解决wifi芯片的一个底层问题,看了很多资料,下面做一个简要记录: 1.信号调制的基本原理 链接:http://wenku.baidu.com/link?url=3K6Z5fBIN20lPzB ...
- DedeCms如何调用Discuz论坛主题等数据方法总结
DedeCms如何调用Discuz论坛主题等数据方法总结 同时使用Dedecms和Discuz论坛的朋友,难免要在网站内调用论坛的内容.使用Discuz论坛的JS调用方式,对搜索引擎不够友好,下面我们 ...
- 一步一步学Silverlight 2系列文章
概述 由TerryLee编写的<Silverlight 2完美征程>一书,已经上市,在该系列文章的基础上补充了大量的内容,敬请关注.官方网站:http://www.dotneteye.cn ...
- UIWindow学习
写在前面 本文内容绝大部分都参考唐巧大神的<iOS开发进阶>,只是结合不是特别长的开发经验加以补充:最后基于UIWindow自定义了一个类似于微信的ActionSheet. UIWindo ...
- ImportError: No module named 'httplib'
原因:Python 2.x中的"httplib"模块在Python 3.x中变成了"http.client" 原代码: import httplib impor ...
- Android之APP模块编译
一,如何把app编进系统 a.源码编译,在packages/apps目录下有安卓原生的app,以Bluetooth为例,源码根目录下有Android.mk文件: packages\apps\Bluet ...