廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

崩壊型ギブスサンプリングによるトピックモデル(Dirichlet-Multinomial)のパラメータ推定

[math/0604410] Discrete Component Analysis を元に実装しています。

論文では

  • Dirichlet-Multinomial Model と呼ばれていますがやってることはLDAと変わりません。
  • Rao-Balckwellised Gibbs Sampling と呼ばれていますがやってることは崩壊型ギブスサンプリングと変わりません。

崩壊型ギブスサンプリングのほうが、変分推論より性能がいいというのが定説らしいのですが、ちょっと試した感じでは変分推論(変分ベイズによるトピックモデル(Dirichlet-multinomial Model)のパラメータ推定の高速化 - 廿TT)とあんまり変わらなかったです。

C++のコード:

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;

arma::irowvec rmult(int N, arma::vec probs) {
  int k = probs.n_rows;
  arma::irowvec ans(k);
  R::rmultinom(N, probs.begin(), k, ans.begin());
  return ans;
}

// [[Rcpp::export]]
arma::field<arma::icube> DirMultGibbs(arma::imat Y, int L, int iter, double alpha=1.0, double beta=1.0){
  int N = Y.n_rows;
  int K = Y.n_cols;
  arma::field<arma::icube> S(iter);
  arma::icube tmpS(N,L,K);
  arma::vec p(L);
  p.fill(1.0/L);
  for(int n=0; n<N; n++){
    for(int k=0;k<K;k++){
      tmpS.slice(k).row(n) = rmult(Y(n,k),p);
    }
  }
  S(0) = tmpS;
  arma::imat Shi =sum(tmpS,0);
  arma::mat Sh = arma::conv_to<arma::mat>::from(Shi) + alpha;
  arma::imat Swi = sum(tmpS,2);
  arma::mat Sw = arma::conv_to<arma::mat>::from(Swi) + beta;
  arma::irowvec Sli = sum(sum(tmpS,2),0);
  arma::rowvec Sl = arma::conv_to<arma::rowvec>::from(Sli) + K*beta;
  for(int i=1; i<iter; i++){
    for(int n=0; n<N; n++){
      for(int k=0;k<K;k++){
        arma::irowvec tmpi = S(i-1).slice(k).row(n);
        arma::rowvec tmp = arma::conv_to<arma::rowvec>::from(tmpi);
        Sw.row(n) -= tmp;
        Sh.col(k) -= tmp.t();
        Sl -= tmp;
        arma::vec prop =  (Sh.col(k) % Sw.row(n).t())/Sl.t();
        tmpi = rmult(Y(n,k),prop/sum(prop));
        tmp = arma::conv_to<arma::rowvec>::from(tmpi);
        Sw.row(n) += tmp;
        Sh.col(k) += tmp.t();
        Sl += tmp;
        tmpS.slice(k).row(n) = tmpi;
      }
    }
    S(i) = tmpS;
  }
  return S;
};

Rのコード:

library(gtools)
library(Rcpp)
sourceCpp("~/Documents/DirMultGibbs.cpp")

W <- rdirichlet(100,rep(1,3))
W <- W[,order(W[1,])]
H <- rdirichlet(3,rep(1,50))
Y <- t(apply(W%*%H,1,function(p){rmultinom(1,10000,p)}))

out <- DirMultGibbs(Y,3,500,1,1)
out <- simplify2array(out)
plot(out[1,1,1,],type="l")
Smean <-apply(out[,,,250:500], 1:3,mean)
Sw <- apply(Smean,1:2,sum)
Sh <- apply(Smean,2:3,sum)
What <-(Sw+1)/rowSums(Sw+1)
Hhat <-(Sh+1)/rowSums(Sh+1)
plot(What%*%Hhat,Y)

plot(What[,order(What[1,])],W)
abline(0,1)
plot(Hhat[order(What[1,]),],H)
abline(0,1)