廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

Dirichlet Process Mixtures の変分推論(混合二項分布)

今日の川柳

Blei & Jordan (2004) http://www.cs.columbia.edu/~blei/papers/BleiJordan2004.pdf をもとに実装してます。

Rです。

あってるか自信ない。

混合数3で乱数を作って、二項分布のパラメータを推定してます。

最大で20のクラスタがあるように設定していますが、全部のデータが3つのクラスタのいずれかに割り当てられて、残りの17個は使われませんでした。

> table(s,shat)
   shat
s     1   2   3
  1   2 172   0
  2 183   1   0
  3   0   0 142

f:id:abrahamcow:20190607184645p:plain

17個のほうのパラメータの推定値は事前分布の平均がそのまま入ってます。

青い点線が正解のパラメータです。

この例だとなんとなくいいような気がするけど乱数によってはうまくいかないこともあります。

VBbinomMix <- function(x,m,Kmax,a=1,b=1,alpha=1,maxit=200,seed=1){
  set.seed(seed)
  p_ini <- runif(Kmax)
  N <- length(x)
  m_x <- m-x
  s <- tmp <- matrix(NA,N,Kmax)
  logp <- log(p_ini)
  log1_p <- log(1-p_ini)
  V <- rgamma(Kmax,1,1)
  V <- V/sum(V)
  logV <- log(V)
  logVm1 <- log(1-V)
  tmp[,1] <- exp(x*logp[1]+m_x*log1_p[1]+logV[1])
  for(k in 2:Kmax){
    tmp[,k] <- exp(x*logp[k]+m_x*log1_p[k]+sum(logVm1[1:(k-1)]))
  }
  den <-apply(tmp,1,sum)
  s<-sweep(tmp,1,den,"/")
  for(i in 1:maxit){
    ahat <- apply(s*x,2,sum) + a
    bhat <- apply(s*m_x,2,sum) + b
    alphahat <- apply(s,2,sum) + alpha
    logp <- digamma(ahat)-digamma(ahat+bhat)
    log1_p <- digamma(bhat)-digamma(ahat+bhat)
    gamma1 <- 1+colSums(s)
    gamma2 <- c(alpha+sapply(2:Kmax,function(k)sum(s[,k:Kmax])),alpha)
    logV <- digamma(gamma1)-digamma(gamma1+gamma2)
    logVm1 <- digamma(gamma2)-digamma(gamma1+gamma2)
    tmp[,1] <- exp(x*logp[1]+m_x*log1_p[1]+logV[1])
    for(k in 2:Kmax){
      tmp[,k] <- exp(x*logp[k]+m_x*log1_p[k]+sum(logVm1[1:(k-1)]))
    }
    den <-apply(tmp,1,sum)
    s<-sweep(tmp,1,den,"/")
  }
  phat <- ahat/(ahat+bhat)
  ratio=colMeans(s)
  ord <- order(ratio,decreasing = TRUE)
  list(phat=phat[ord],ratio=ratio[ord],ahat=ahat[ord],bhat=bhat[ord],s=s[,ord],iter=i)
}
set.seed(1234)
m <- rpois(500,100)+1
p <- c(0.1,0.3,0.7)
s <- sample.int(3,500,replace = TRUE)
x <- rbinom(500,m,p[s])
hist(x/m)
out <- VBbinomMix(x,m,Kmax =20)
shat <- apply(out$s,1,which.max)
table(s,shat)
barplot(out$phat,ylim = c(0,1))
abline(h=p,lty=2,col="royalblue",lwd=2)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)