Dirichlet Multinomial Mixture Model做短文本聚类(包括代码)

标签:#DPMM# 时间:2018/03/07 20:22:06 作者:十七岁的雨季

原文地址

http://blog.csdn.net/qy20115549/article/details/79429127

论文来源

Yin J, Wang J. A dirichlet multinomial mixture model-based approach for short text clustering[C]//Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2014: 233-242.

论文理解及公式推理


这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

核心源码

`//模型初始化
public void intialize(DocumentSet documentSet)
{

    D = documentSet.D;    //获取文档总数目
    z = new int[D]; //文档对应的主题数目
    for(int d = 0; d < D; d++){
        //获取每一篇文档的内容
        Document document = documentSet.documents.get(d);
        //针对每一篇文档随机初始化一个主题
        int cluster = (int) (K * Math.random());
        z[d] = cluster;
        //每个主题对应的文档数目统计
        m_z[cluster]++;
        //对文档的每个单词进行循环
        for(int w = 0; w < document.wordNum; w++){
            //获取文档每个单词的编号
            int wordNo = document.wordIdArray[w];
            //获取文档每个单词出现的数目
            int wordFre = document.wordFreArray[w];
            //统计每个主题下,每个单词出现的数目
            n_zv[cluster][wordNo] += wordFre;
            //统计每个主题下所有单词的数目
            n_z[cluster] += wordFre; 
        }
    }
}
//gibbs采样
public void gibbsSampling(DocumentSet documentSet)
{
    for(int i = 0; i < iterNum; i++){
        //每篇文档循环
        for(int d = 0; d < D; d++){
            Document document = documentSet.documents.get(d);
            //获取文档对应的主题
            int cluster = z[d];
            //移除该文档,该主题对应的文档数目减去1
            m_z[cluster]--;
            for(int w = 0; w < document.wordNum; w++){
                int wordNo = document.wordIdArray[w];
                int wordFre = document.wordFreArray[w];
                //该主题对应的文档中的单词的数目,减少文档该单词出现的数目
                n_zv[cluster][wordNo] -= wordFre;
                //该主题对应的单词总数减去了该文档对应单词的总数
                n_z[cluster] -= wordFre;
            }
            //抽取该文档所属的新主题
            cluster = sampleCluster(d, document);
            //分配该文档对应的新主题后,重新统计相关词频
            z[d] = cluster;
            m_z[cluster]++;
            for(int w = 0; w < document.wordNum; w++){
                int wordNo = document.wordIdArray[w];
                int wordFre = document.wordFreArray[w];
                n_zv[cluster][wordNo] += wordFre; 
                n_z[cluster] += wordFre; 
            }
        }
    }
}

private int sampleCluster(int d, Document document)
{ 
    double[] prob = new double[K];
    //统计是哪个主题
    int[] overflowCount = new int[K];
    //对所有主题进行循环,计算该文档属于每个主题的概率
    for(int k = 0; k < K; k++){
        //依照计算公式计算文档d属于每个单词的概率
        prob[k] = (m_z[k] + alpha) / (D - 1 + alpha0);
        double valueOfRule2 = 1.0;
        int i = 0;
        for(int w=0; w < document.wordNum; w++){
            //获取该文档中某一单词的编号及出现的频率
            int wordNo = document.wordIdArray[w];
            int wordFre = document.wordFreArray[w];
            //文档的每个单词进行计算,这里有防止连乘积概率过小的判断及处理
            for(int j = 0; j < wordFre; j++){
                if(valueOfRule2 < smallDouble){
                    overflowCount[k]--;
                    valueOfRule2 *= largeDouble;
                }

                valueOfRule2 *= (n_zv[k][wordNo] + beta + j) 
                         / (n_z[k] + beta0 + i);
                i++;
            }
        }
        prob[k] *= valueOfRule2;            
    }
    //重新计算概率
    reComputeProbs(prob, overflowCount, K);
    //使用轮盘赌,分配新的主题
    for(int k = 1; k < K; k++){
        prob[k] += prob[k - 1];
    }
    double thred = Math.random() * prob[K - 1];
    int kChoosed;
    for(kChoosed = 0; kChoosed < K; kChoosed++){
        if(thred < prob[kChoosed]){
            break;
        }
    }

    return kChoosed;
}
//重新计算概率值
private void reComputeProbs(double[] prob, int[] overflowCount, int K)
{
    int max = Integer.MIN_VALUE;
    for(int k = 0; k < K; k++){
        if(overflowCount[k] > max && prob[k] > 0){
            max = overflowCount[k];
            System.out.println("max:"+max);
        }
    }
    //概率值统一扩展
    for(int k = 0; k < K; k++){            
        if(prob[k] > 0){
            prob[k] = prob[k] * Math.pow(largeDouble, overflowCount[k] - max);
        }
    }        
}

`

欢迎大家关注DataLearner官方微信,接受最新的AI技术推送