LFDMM源码剖析(融入词向量的概率图模型)

标签:#源码# 时间:2018/03/27 14:48:37 作者:十七岁的雨季

本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流。
未经本人允许禁止转载

原文地址:https://blog.csdn.net/qy20115549/article/details/79675572

论文来源

Nguyen D Q, Billingsley R, Du L, et al. Improving topic models with latent feature word representations[J]. Transactions of the Association for Computational Linguistics, 2015, 3: 299-313.

发表在15年的ACL会议对应的一个期刊上,应该说还是很不错的,后面有一些文章在这个基础上进行了改进。下面,对作者提供的源码进行解析。

源码解读

  1. package models;
  2. import java.io.BufferedReader;
  3. import java.io.BufferedWriter;
  4. import java.io.FileReader;
  5. import java.io.FileWriter;
  6. import java.io.IOException;
  7. import java.util.ArrayList;
  8. import java.util.HashMap;
  9. import java.util.List;
  10. import java.util.Map;
  11. import java.util.Set;
  12. import java.util.TreeMap;
  13. import utility.FuncUtils;
  14. import utility.LBFGS;
  15. import utility.Parallel;
  16. import cc.mallet.optimize.InvalidOptimizableException;
  17. import cc.mallet.optimize.Optimizer;
  18. import cc.mallet.types.MatrixOps;
  19. import cc.mallet.util.Randoms;
  20. /**
  21. * Implementation of the LF-DMM latent feature topic model, using collapsed Gibbs sampling, as
  22. * described in:
  23. *
  24. * Dat Quoc Nguyen, Richard Billingsley, Lan Du and Mark Johnson. 2015. Improving Topic Models with
  25. * Latent Feature Word Representations. Transactions of the Association for Computational
  26. * Linguistics, vol. 3, pp. 299-313.
  27. *
  28. * @author Dat Quoc Nguyen
  29. */
  30. public class LFDMM
  31. {
  32. public double alpha; // Hyper-parameter alpha 超参数
  33. public double beta; // Hyper-parameter beta 超参数
  34. // public double alphaSum; // alpha * numTopics
  35. public double betaSum; // beta * vocabularySize V*beta
  36. public int numTopics; // Number of topics 主题数目
  37. public int topWords; // Number of most probable words for each topic 每个主题取多少个靠前的单词
  38. public double lambda; // Mixture weight value 混合权重值
  39. public int numInitIterations; //
  40. public int numIterations; // Number of EM-style sampling iterations 迭代次数
  41. public List<List<Integer>> corpus; // Word ID-based corpus 语料单词的id
  42. public List<List<Integer>> topicAssignments; // Topics assignments for words 单词的主题分配
  43. // in the corpus
  44. public int numDocuments; // Number of documents in the corpus 文档的数量
  45. public int numWordsInCorpus; // Number of words in the corpus 整个语料单词的数量
  46. public HashMap<String, Integer> word2IdVocabulary; // Vocabulary to get ID 单词的编号
  47. // given a word
  48. public HashMap<Integer, String> id2WordVocabulary; // Vocabulary to get word 将编号转化为单词 用于输出
  49. // given an ID
  50. public int vocabularySize; // The number of word types in the corpus 整个语料中单词的总数
  51. // Number of documents assigned to a topic 分配到一个主题文档的数量
  52. public int[] docTopicCount;
  53. // numTopics * vocabularySize matrix
  54. // Given a topic: number of times a word type generated from the topic by
  55. // the Dirichlet multinomial component 主题对应的单词数量 该单词是由多项式分布产生
  56. public int[][] topicWordCountDMM;
  57. // Total number of words generated from each topic by the Dirichlet
  58. // multinomial component 主题对应的总的单词数量 这些单词是由多项式分布产生
  59. public int[] sumTopicWordCountDMM;
  60. // numTopics * vocabularySize matrix
  61. // Given a topic: number of times a word type generated from the topic by
  62. // the latent feature component 单词是由隐特征部分产生 统计一个主题对应的单词数量
  63. public int[][] topicWordCountLF;
  64. // Total number of words generated from each topic by the latent feature
  65. // component 单词由隐特征产生 统计一个主题对应的总单词数量
  66. public int[] sumTopicWordCountLF;
  67. // Double array used to sample a topic 概率 用于抽样
  68. public double[] multiPros;
  69. // Path to the directory containing the corpus
  70. public String folderPath;
  71. // Path to the topic modeling corpus
  72. public String corpusPath;
  73. public String vectorFilePath;
  74. public double[][] wordVectors; // Vector representations for words 词向量表示
  75. public double[][] topicVectors;// Vector representations for topics 主题向量表示
  76. public int vectorSize; // Number of vector dimensions 向量的维度
  77. public double[][] dotProductValues; //点乘法的值
  78. public double[][] expDotProductValues; //指数变化后的值
  79. public double[] sumExpValues; // Partition function values 求和的值
  80. public final double l2Regularizer = 0.01; // L2 regularizer value for learning topic vectors L2正则化
  81. public final double tolerance = 0.05; // Tolerance value for LBFGS convergence LBFGS收敛
  82. public String expName = "LFDMM";
  83. public String orgExpName = "LFDMM";
  84. public String tAssignsFilePath = "";
  85. public int savestep = 0;
  86. public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
  87. double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
  88. int inNumIterations, int inTopWords)
  89. throws Exception
  90. {
  91. this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
  92. inNumInitIterations, inNumIterations, inTopWords, "LFDMM");
  93. }
  94. public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
  95. double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
  96. int inNumIterations, int inTopWords, String inExpName)
  97. throws Exception
  98. {
  99. this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
  100. inNumInitIterations, inNumIterations, inTopWords, inExpName, "", 0);
  101. }
  102. public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
  103. double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
  104. int inNumIterations, int inTopWords, String inExpName, String pathToTAfile)
  105. throws Exception
  106. {
  107. this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
  108. inNumInitIterations, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
  109. }
  110. public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
  111. double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
  112. int inNumIterations, int inTopWords, String inExpName, int inSaveStep)
  113. throws Exception
  114. {
  115. this(pathToCorpus, pathToWordVectorsFile, inNumTopics, inAlpha, inBeta, inLambda,
  116. inNumInitIterations, inNumIterations, inTopWords, inExpName, "", inSaveStep);
  117. }
  118. public LFDMM(String pathToCorpus, String pathToWordVectorsFile, int inNumTopics,
  119. double inAlpha, double inBeta, double inLambda, int inNumInitIterations,
  120. int inNumIterations, int inTopWords, String inExpName, String pathToTAfile,
  121. int inSaveStep)
  122. throws Exception
  123. {
  124. alpha = inAlpha;
  125. beta = inBeta;
  126. lambda = inLambda;
  127. numTopics = inNumTopics;
  128. numIterations = inNumIterations;
  129. numInitIterations = inNumInitIterations;
  130. topWords = inTopWords;
  131. savestep = inSaveStep;
  132. expName = inExpName;
  133. orgExpName = expName;
  134. //word2vec语料
  135. vectorFilePath = pathToWordVectorsFile;
  136. //语料的路径
  137. corpusPath = pathToCorpus;
  138. folderPath = pathToCorpus.substring(0,
  139. Math.max(pathToCorpus.lastIndexOf("/"), pathToCorpus.lastIndexOf("\\")) + 1);
  140. //输入语料的路径
  141. System.out.println("Reading topic modeling corpus: " + pathToCorpus);
  142. //词转化为编号
  143. word2IdVocabulary = new HashMap<String, Integer>();
  144. //编号转化为词
  145. id2WordVocabulary = new HashMap<Integer, String>();
  146. //语料
  147. corpus = new ArrayList<List<Integer>>();
  148. //文档数目
  149. numDocuments = 0;
  150. //语料中单词的数目
  151. numWordsInCorpus = 0;
  152. //读取语料
  153. BufferedReader br = null;
  154. try {
  155. int indexWord = -1;
  156. br = new BufferedReader(new FileReader(pathToCorpus));
  157. //每一行表示一个文档
  158. for (String doc; (doc = br.readLine()) != null;) {
  159. if (doc.trim().length() == 0)
  160. continue;
  161. //文档单词拆分
  162. String[] words = doc.trim().split("\\s+");
  163. //文档表示成集合
  164. List<Integer> document = new ArrayList<Integer>();
  165. //对文档的所有单词进行循环
  166. for (String word : words) {
  167. //文档中单词编号-----编号是全局而言
  168. if (word2IdVocabulary.containsKey(word)) {
  169. //如果包含了该单词,将该单词直接添加到文档集合中
  170. document.add(word2IdVocabulary.get(word));
  171. }
  172. else {
  173. //加1表示从0开始对单词进行编号,并将编号对应的单词加入到id2WordVocabulary
  174. indexWord += 1;
  175. word2IdVocabulary.put(word, indexWord);
  176. id2WordVocabulary.put(indexWord, word);
  177. //文档添加该单词
  178. document.add(indexWord);
  179. }
  180. }
  181. //文档数目++
  182. numDocuments++;
  183. //语料中所有单词的数量
  184. numWordsInCorpus += document.size();
  185. //将所有文档添加到集合中
  186. corpus.add(document);
  187. }
  188. }
  189. catch (Exception e) {
  190. e.printStackTrace();
  191. }
  192. //语料不重复单词的总量
  193. vocabularySize = word2IdVocabulary.size();
  194. //主题对应的文档统计
  195. docTopicCount = new int[numTopics];
  196. //主题-单词统计 来自多项式分布
  197. topicWordCountDMM = new int[numTopics][vocabularySize];
  198. //主题对应的单词总数目统计 来自多项式分布
  199. sumTopicWordCountDMM = new int[numTopics];
  200. //主题-单词统计 来自于隐特征分布
  201. topicWordCountLF = new int[numTopics][vocabularySize];
  202. //主题对应的单词数目总计 来自于隐特征
  203. sumTopicWordCountLF = new int[numTopics];
  204. //多项式分布的先验
  205. multiPros = new double[numTopics];
  206. //先验为1/K,后面要轮盘赌的,在初始化的时候
  207. for (int i = 0; i < numTopics; i++) {
  208. multiPros[i] = 1.0 / numTopics;
  209. }
  210. // alphaSum = numTopics * alpha;
  211. betaSum = vocabularySize * beta;
  212. //读取词向量 word2vec文件
  213. readWordVectorsFile(vectorFilePath);
  214. topicVectors = new double[numTopics][vectorSize];
  215. dotProductValues = new double[numTopics][vocabularySize];
  216. expDotProductValues = new double[numTopics][vocabularySize];
  217. sumExpValues = new double[numTopics];
  218. System.out
  219. .println("Corpus size: " + numDocuments + " docs, " + numWordsInCorpus + " words");
  220. System.out.println("Vocabuary size: " + vocabularySize);
  221. System.out.println("Number of topics: " + numTopics);
  222. System.out.println("alpha: " + alpha);
  223. System.out.println("beta: " + beta);
  224. System.out.println("lambda: " + lambda);
  225. System.out.println("Number of initial sampling iterations: " + numInitIterations);
  226. System.out.println("Number of EM-style sampling iterations for the LF-DMM model: "
  227. + numIterations);
  228. System.out.println("Number of top topical words: " + topWords);
  229. tAssignsFilePath = pathToTAfile;
  230. if (tAssignsFilePath.length() > 0)
  231. initialize(tAssignsFilePath);
  232. else
  233. initialize();
  234. }
  235. //读取词向量文件
  236. public void readWordVectorsFile(String pathToWordVectorsFile)
  237. throws Exception
  238. {
  239. //输出需要读取词向量文件的相对地址
  240. System.out.println("Reading word vectors from word-vectors file " + pathToWordVectorsFile
  241. + "...");
  242. BufferedReader br = null;
  243. try {
  244. br = new BufferedReader(new FileReader(pathToWordVectorsFile));
  245. //以空格分开
  246. String[] elements = br.readLine().trim().split("\\s+");
  247. //词向量的长度,这里减1是因为第一维度是词
  248. vectorSize = elements.length - 1;
  249. //word2vec向量的维度,只去语料中有的词vocabularySize
  250. wordVectors = new double[vocabularySize][vectorSize];
  251. //单词为第一维度
  252. String word = elements[0];
  253. //如果这个词语在语料中的话,将该词的词向量存入数组wordVectors
  254. if (word2IdVocabulary.containsKey(word)) {
  255. for (int j = 0; j < vectorSize; j++) {
  256. wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
  257. }
  258. }
  259. //继续读文本,上面之所以要先读一行是为了初始化,获取词向量的维度
  260. for (String line; (line = br.readLine()) != null;) {
  261. elements = line.trim().split("\\s+");
  262. word = elements[0];
  263. //语料中出现的每个单词的词向量
  264. if (word2IdVocabulary.containsKey(word)) {
  265. for (int j = 0; j < vectorSize; j++) {
  266. wordVectors[word2IdVocabulary.get(word)][j] = new Double(elements[j + 1]);
  267. }
  268. }
  269. }
  270. }
  271. catch (Exception e) {
  272. e.printStackTrace();
  273. }
  274. //防止语料中的词在word2vec文件中不存在
  275. for (int i = 0; i < vocabularySize; i++) {
  276. if (MatrixOps.absNorm(wordVectors[i]) == 0.0) {
  277. System.out.println("The word \"" + id2WordVocabulary.get(i)
  278. + "\" doesn't have a corresponding vector!!!");
  279. throw new Exception();
  280. }
  281. }
  282. }
  283. //初始化方法
  284. public void initialize()
  285. throws IOException
  286. {
  287. //随机对文档进行主题分配
  288. System.out.println("Randomly initialzing topic assignments ...");
  289. topicAssignments = new ArrayList<List<Integer>>();
  290. //循环每篇文档
  291. for (int docId = 0; docId < numDocuments; docId++) {
  292. List<Integer> topics = new ArrayList<Integer>();
  293. //基于轮盘赌获取主题编号(前面已经初始化了),这里multiPros必须有值,否则则会报错
  294. int topic = FuncUtils.nextDiscrete(multiPros);
  295. //分配到该主题的文档数量+1
  296. docTopicCount[topic] += 1;
  297. //文档的单词个数
  298. int docSize = corpus.get(docId).size();
  299. //循环每个单词
  300. for (int j = 0; j < docSize; j++) {
  301. //获取单词编号
  302. int wordId = corpus.get(docId).get(j);
  303. //随机产生false or true,用来初始化该文档是来自于隐特征还是多项式分布
  304. boolean component = new Randoms().nextBoolean();
  305. int subtopic = topic; //这里是什么意思呢
  306. if (!component) { // Generated from the latent feature component
  307. //主题-单词数量增加1 由隐特征主题生成
  308. topicWordCountLF[topic][wordId] += 1;
  309. // 该主题生成的单词总数增加1 由隐特征主题生成
  310. sumTopicWordCountLF[topic] += 1;
  311. }
  312. else {// Generated from the Dirichlet multinomial component
  313. //主题-单词数量增加1 由多项式分布生成
  314. topicWordCountDMM[topic][wordId] += 1;
  315. //主题生成的单词总数增加1
  316. sumTopicWordCountDMM[topic] += 1;
  317. subtopic = subtopic + numTopics;
  318. }
  319. topics.add(subtopic);
  320. }
  321. topicAssignments.add(topics);
  322. }
  323. }
  324. //输入参数初始化
  325. public void initialize(String pathToTopicAssignmentFile)
  326. throws Exception
  327. {
  328. System.out.println("Reading topic-assignment file: " + pathToTopicAssignmentFile);
  329. topicAssignments = new ArrayList<List<Integer>>();
  330. BufferedReader br = null;
  331. try {
  332. br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
  333. int docId = 0;
  334. int numWords = 0;
  335. for (String line; (line = br.readLine()) != null;) {
  336. String[] strTopics = line.trim().split("\\s+");
  337. List<Integer> topics = new ArrayList<Integer>();
  338. int topic = new Integer(strTopics[0]) % numTopics;
  339. docTopicCount[topic] += 1;
  340. for (int j = 0; j < strTopics.length; j++) {
  341. int wordId = corpus.get(docId).get(j);
  342. int subtopic = new Integer(strTopics[j]);
  343. if (subtopic == topic) {
  344. topicWordCountLF[topic][wordId] += 1;
  345. sumTopicWordCountLF[topic] += 1;
  346. }
  347. else {
  348. topicWordCountDMM[topic][wordId] += 1;
  349. sumTopicWordCountDMM[topic] += 1;
  350. }
  351. topics.add(subtopic);
  352. numWords++;
  353. }
  354. topicAssignments.add(topics);
  355. docId++;
  356. }
  357. if ((docId != numDocuments) || (numWords != numWordsInCorpus)) {
  358. System.out
  359. .println("The topic modeling corpus and topic assignment file are not consistent!!!");
  360. throw new Exception();
  361. }
  362. }
  363. catch (Exception e) {
  364. e.printStackTrace();
  365. }
  366. }
  367. //模型推断
  368. public void inference()
  369. throws IOException
  370. {
  371. System.out.println("Running Gibbs sampling inference: ");
  372. //初始化迭代
  373. for (int iter = 1; iter <= numInitIterations; iter++) {
  374. System.out.println("\tInitial sampling iteration: " + (iter));
  375. //单词初始化迭代
  376. sampleSingleInitialIteration();
  377. }
  378. for (int iter = 1; iter <= numIterations; iter++) {
  379. System.out.println("\tLFDMM sampling iteration: " + (iter));
  380. //优化主题向量
  381. optimizeTopicVectors();
  382. //迭代
  383. sampleSingleIteration();
  384. if ((savestep > 0) && (iter % savestep == 0) && (iter < numIterations)) {
  385. System.out.println("\t\tSaving the output from the " + iter + "^{th} sample");
  386. expName = orgExpName + "-" + iter;
  387. write();
  388. }
  389. }
  390. expName = orgExpName;
  391. //保存模型相关参数
  392. writeParameters();
  393. System.out.println("Writing output from the last sample ...");
  394. //保存信息
  395. write();
  396. System.out.println("Sampling completed!");
  397. }
  398. //优化主题向量
  399. public void optimizeTopicVectors()
  400. {
  401. System.out.println("\t\tEstimating topic vectors ...");
  402. sumExpValues = new double[numTopics];
  403. dotProductValues = new double[numTopics][vocabularySize];
  404. expDotProductValues = new double[numTopics][vocabularySize];
  405. Parallel.loop(numTopics, new Parallel.LoopInt()
  406. {
  407. @Override
  408. public void compute(int topic)
  409. {
  410. int rate = 1;
  411. boolean check = true;
  412. while (check) {
  413. double l2Value = l2Regularizer * rate;
  414. try {
  415. //主题向量表示 主题包含的单词个数 词向量 正则化值(这里是传入参数-----以便执行TopicVectorOptimizer)
  416. TopicVectorOptimizer optimizer = new TopicVectorOptimizer(
  417. topicVectors[topic], topicWordCountLF[topic], wordVectors, l2Value);
  418. //通过LBFGS优化
  419. Optimizer gd = new LBFGS(optimizer, tolerance);
  420. gd.optimize(600);
  421. //需要优化的参数
  422. optimizer.getParameters(topicVectors[topic]);
  423. //输入的是两个特征的乘积以及其加和-----针对每个主题计算一个向量值(为了更新主题使用)
  424. sumExpValues[topic] = optimizer.computePartitionFunction(
  425. dotProductValues[topic], expDotProductValues[topic]);
  426. check = false;
  427. if (sumExpValues[topic] == 0 || Double.isInfinite(sumExpValues[topic])) {
  428. double max = -1000000000.0;
  429. for (int index = 0; index < vocabularySize; index++) {
  430. if (dotProductValues[topic][index] > max)
  431. max = dotProductValues[topic][index];
  432. }
  433. for (int index = 0; index < vocabularySize; index++) {
  434. expDotProductValues[topic][index] = Math
  435. .exp(dotProductValues[topic][index] - max);
  436. sumExpValues[topic] += expDotProductValues[topic][index];
  437. }
  438. }
  439. }
  440. catch (InvalidOptimizableException e) {
  441. e.printStackTrace();
  442. check = true;
  443. }
  444. rate = rate * 10;
  445. }
  446. }
  447. });
  448. }
  449. //每一代分配主题
  450. public void sampleSingleIteration()
  451. {
  452. //对每一篇文档进行循环
  453. for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
  454. //获取文档的所有单词
  455. List<Integer> document = corpus.get(dIndex);
  456. //文档的长度
  457. int docSize = document.size();
  458. //文档的初始主题分布,接下来是移除该单词
  459. int topic = topicAssignments.get(dIndex).get(0) % numTopics;
  460. //主题对应的文档数目减1
  461. docTopicCount[topic] = docTopicCount[topic] - 1;
  462. //接下来对对个单词进行循环,做相关单词的统计工作
  463. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  464. //获取单词的id
  465. int word = document.get(wIndex);// wordId
  466. int subtopic = topicAssignments.get(dIndex).get(wIndex);
  467. if (subtopic == topic) {
  468. topicWordCountLF[topic][word] -= 1;
  469. sumTopicWordCountLF[topic] -= 1;
  470. }
  471. else {
  472. topicWordCountDMM[topic][word] -= 1;
  473. sumTopicWordCountDMM[topic] -= 1;
  474. }
  475. }
  476. // 对文档单词的主题进行抽样
  477. for (int tIndex = 0; tIndex < numTopics; tIndex++) {
  478. multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
  479. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  480. int word = document.get(wIndex);
  481. //依据公式进行计算,不过论文公式有问题 N_{d,w}+K_{d,w}的次方有问题,推理的公式应该是这样的
  482. multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word]
  483. / sumExpValues[tIndex] + (1 - lambda)
  484. * (topicWordCountDMM[tIndex][word] + beta)
  485. / (sumTopicWordCountDMM[tIndex] + betaSum));
  486. }
  487. }
  488. //基于轮盘赌抽样
  489. topic = FuncUtils.nextDiscrete(multiPros);
  490. //开始做相关统计
  491. docTopicCount[topic] += 1;
  492. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  493. int word = document.get(wIndex);
  494. int subtopic = topic;
  495. //这里是对s_{di}的抽样,采用的是直接计算,并没有使用轮盘赌
  496. if (lambda * expDotProductValues[topic][word] / sumExpValues[topic] > (1 - lambda)
  497. * (topicWordCountDMM[topic][word] + beta)
  498. / (sumTopicWordCountDMM[topic] + betaSum)) {
  499. //来自隐特征的相关统计
  500. topicWordCountLF[topic][word] += 1;
  501. sumTopicWordCountLF[topic] += 1;
  502. }
  503. else {
  504. //来自多项式分布的相关统计
  505. topicWordCountDMM[topic][word] += 1;
  506. sumTopicWordCountDMM[topic] += 1;
  507. subtopic += numTopics;
  508. }
  509. // 更新主题分配
  510. topicAssignments.get(dIndex).set(wIndex, subtopic);
  511. }
  512. }
  513. }
  514. //初始化迭代
  515. public void sampleSingleInitialIteration()
  516. {
  517. //对每篇文档循环
  518. for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
  519. //获取文档
  520. List<Integer> document = corpus.get(dIndex);
  521. //文档的长度,即文档包含的所有单词数
  522. int docSize = document.size();
  523. //文档主题分配,这里需要理解一下
  524. int topic = topicAssignments.get(dIndex).get(0) % numTopics;
  525. //主题生成的文档统计,移除该文档
  526. docTopicCount[topic] = docTopicCount[topic] - 1;
  527. //循环文档的每一个单词
  528. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  529. //获取单词的编号
  530. int word = document.get(wIndex);
  531. //获取subtopic
  532. int subtopic = topicAssignments.get(dIndex).get(wIndex);
  533. //如果subtopic和topic相同,来自于隐变量,否则来自于多项式分布
  534. if (topic == subtopic) {
  535. //主题-单词 数目减1
  536. topicWordCountLF[topic][word] -= 1;
  537. //主题对应的总的单词数-1
  538. sumTopicWordCountLF[topic] -= 1;
  539. }
  540. else {
  541. //主题-单词 数目减1
  542. topicWordCountDMM[topic][word] -= 1;
  543. //主题对应的总的单词数-1
  544. sumTopicWordCountDMM[topic] -= 1;
  545. }
  546. }
  547. // 抽取文档所属的主题,计算该篇文档属于每个主题的概率,然后基于轮盘赌进行选择
  548. for (int tIndex = 0; tIndex < numTopics; tIndex++) {
  549. //这里这个公式是哪里来的呢,这里作者弄得词都是来自于多项式分布
  550. multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
  551. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  552. int word = document.get(wIndex);
  553. multiPros[tIndex] *= (lambda * (topicWordCountLF[tIndex][word] + beta)
  554. / (sumTopicWordCountLF[tIndex] + betaSum) + (1 - lambda)
  555. * (topicWordCountDMM[tIndex][word] + beta)
  556. / (sumTopicWordCountDMM[tIndex] + betaSum));
  557. }
  558. }
  559. //基于轮盘赌进行选择
  560. topic = FuncUtils.nextDiscrete(multiPros);
  561. //新主题对应的文档数量加1
  562. docTopicCount[topic] += 1;
  563. //判断该主题是来自于隐特征还是多项式分布
  564. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  565. int word = document.get(wIndex);// wordID
  566. int subtopic = topic;
  567. //这里是对s_{di}的抽样,采用的是直接计算,并没有使用轮盘赌
  568. if (lambda * (topicWordCountLF[topic][word] + beta)
  569. / (sumTopicWordCountLF[topic] + betaSum) > (1 - lambda)
  570. * (topicWordCountDMM[topic][word] + beta)
  571. / (sumTopicWordCountDMM[topic] + betaSum)) {
  572. topicWordCountLF[topic][word] += 1;
  573. sumTopicWordCountLF[topic] += 1;
  574. }
  575. else {
  576. topicWordCountDMM[topic][word] += 1;
  577. sumTopicWordCountDMM[topic] += 1;
  578. subtopic += numTopics;
  579. }
  580. // Update topic assignments
  581. topicAssignments.get(dIndex).set(wIndex, subtopic);
  582. }
  583. }
  584. }
  585. public void writeParameters()
  586. throws IOException
  587. {
  588. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".paras"));
  589. writer.write("-model" + "\t" + "LFDMM");
  590. writer.write("\n-corpus" + "\t" + corpusPath);
  591. writer.write("\n-vectors" + "\t" + vectorFilePath);
  592. writer.write("\n-ntopics" + "\t" + numTopics);
  593. writer.write("\n-alpha" + "\t" + alpha);
  594. writer.write("\n-beta" + "\t" + beta);
  595. writer.write("\n-lambda" + "\t" + lambda);
  596. writer.write("\n-initers" + "\t" + numInitIterations);
  597. writer.write("\n-niters" + "\t" + numIterations);
  598. writer.write("\n-twords" + "\t" + topWords);
  599. writer.write("\n-name" + "\t" + expName);
  600. if (tAssignsFilePath.length() > 0)
  601. writer.write("\n-initFile" + "\t" + tAssignsFilePath);
  602. if (savestep > 0)
  603. writer.write("\n-sstep" + "\t" + savestep);
  604. writer.close();
  605. }
  606. public void writeDictionary()
  607. throws IOException
  608. {
  609. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
  610. + ".vocabulary"));
  611. for (String word : word2IdVocabulary.keySet()) {
  612. writer.write(word + " " + word2IdVocabulary.get(word) + "\n");
  613. }
  614. writer.close();
  615. }
  616. public void writeIDbasedCorpus()
  617. throws IOException
  618. {
  619. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
  620. + ".IDcorpus"));
  621. for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
  622. int docSize = corpus.get(dIndex).size();
  623. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  624. writer.write(corpus.get(dIndex).get(wIndex) + " ");
  625. }
  626. writer.write("\n");
  627. }
  628. writer.close();
  629. }
  630. public void writeTopicAssignments()
  631. throws IOException
  632. {
  633. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
  634. + ".topicAssignments"));
  635. for (int dIndex = 0; dIndex < numDocuments; dIndex++) {
  636. int docSize = corpus.get(dIndex).size();
  637. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  638. writer.write(topicAssignments.get(dIndex).get(wIndex) + " ");
  639. }
  640. writer.write("\n");
  641. }
  642. writer.close();
  643. }
  644. public void writeTopicVectors()
  645. throws IOException
  646. {
  647. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
  648. + ".topicVectors"));
  649. for (int i = 0; i < numTopics; i++) {
  650. for (int j = 0; j < vectorSize; j++)
  651. writer.write(topicVectors[i][j] + " ");
  652. writer.write("\n");
  653. }
  654. writer.close();
  655. }
  656. public void writeTopTopicalWords()
  657. throws IOException
  658. {
  659. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName
  660. + ".topWords"));
  661. for (int tIndex = 0; tIndex < numTopics; tIndex++) {
  662. writer.write("Topic" + new Integer(tIndex) + ":");
  663. Map<Integer, Double> topicWordProbs = new TreeMap<Integer, Double>();
  664. for (int wIndex = 0; wIndex < vocabularySize; wIndex++) {
  665. //获取概率值,这里可以看出包含两部分的内容,将两部分信息进行融合了
  666. double pro = lambda * expDotProductValues[tIndex][wIndex] / sumExpValues[tIndex]
  667. + (1 - lambda) * (topicWordCountDMM[tIndex][wIndex] + beta)
  668. / (sumTopicWordCountDMM[tIndex] + betaSum);
  669. topicWordProbs.put(wIndex, pro);
  670. }
  671. //主题词分布降序排序
  672. topicWordProbs = FuncUtils.sortByValueDescending(topicWordProbs);
  673. Set<Integer> mostLikelyWords = topicWordProbs.keySet();
  674. int count = 0;
  675. for (Integer index : mostLikelyWords) {
  676. if (count < topWords) {
  677. writer.write(" " + id2WordVocabulary.get(index));
  678. count += 1;
  679. }
  680. else {
  681. writer.write("\n\n");
  682. break;
  683. }
  684. }
  685. }
  686. writer.close();
  687. }
  688. public void writeTopicWordPros()
  689. throws IOException
  690. {
  691. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".phi"));
  692. for (int t = 0; t < numTopics; t++) {
  693. for (int w = 0; w < vocabularySize; w++) {
  694. double pro = lambda * expDotProductValues[t][w] / sumExpValues[t] + (1 - lambda)
  695. * (topicWordCountDMM[t][w] + beta) / (sumTopicWordCountDMM[t] + betaSum);
  696. writer.write(pro + " ");
  697. }
  698. writer.write("\n");
  699. }
  700. writer.close();
  701. }
  702. public void writeDocTopicPros()
  703. throws IOException
  704. {
  705. BufferedWriter writer = new BufferedWriter(new FileWriter(folderPath + expName + ".theta"));
  706. for (int i = 0; i < numDocuments; i++) {
  707. int docSize = corpus.get(i).size();
  708. double sum = 0.0;
  709. for (int tIndex = 0; tIndex < numTopics; tIndex++) {
  710. multiPros[tIndex] = (docTopicCount[tIndex] + alpha);
  711. for (int wIndex = 0; wIndex < docSize; wIndex++) {
  712. int word = corpus.get(i).get(wIndex);
  713. multiPros[tIndex] *= (lambda * expDotProductValues[tIndex][word]
  714. / sumExpValues[tIndex] + (1 - lambda)
  715. * (topicWordCountDMM[tIndex][word] + beta)
  716. / (sumTopicWordCountDMM[tIndex] + betaSum));
  717. }
  718. sum += multiPros[tIndex];
  719. }
  720. for (int tIndex = 0; tIndex < numTopics; tIndex++) {
  721. writer.write((multiPros[tIndex] / sum) + " ");
  722. }
  723. writer.write("\n");
  724. }
  725. writer.close();
  726. }
  727. public void write()
  728. throws IOException
  729. {
  730. //主题词分布
  731. writeTopTopicalWords();
  732. writeDocTopicPros();
  733. writeTopicAssignments();
  734. writeTopicWordPros();
  735. }
  736. public static void main(String args[])
  737. throws Exception
  738. {
  739. //初始化迭代次数----模型迭代次数
  740. LFDMM lfdmm = new LFDMM("", "", 40, 0.1, 0.01, 0.6, 20,
  741. 20, 20, "LFDMM");
  742. lfdmm.writeParameters();
  743. lfdmm.inference();
  744. }
  745. }
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
Back to Top