廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

ラプラス近似による非共役モデルの変分推論

今日の川柳

[1209.4360] Variational Inference in Nonconjugate Models に出ている例題をやります。

変分推論は便利ですが、モデルが複雑になってくると近似事後分布が解析的に求まらない場合が多いです。

上記の論文、Wang & Blei (2013) では解析的に求まる部分は解析的に求め、求まらないところはラプラス近似で推定する方法を提案しています。

大域的なパラメータと個体差を同時に推定したいときなどに便利だと思います。

アルゴリズム

例として、以下のモデルを考えます。

 \theta \sim \mathrm{Norma}(0,I)
 z_n \sim \mathrm{Dirichlet}(\exp(\theta))
 y_n \sim \mathrm{Multinomial}(M_n,z_n)

ここで I単位行列です。

イメージとしては y_n が文書 n のワードカウントで  z_n が文章ごとの単語の出現頻度を決めるパラメータだと思ってください。

\theta が与えられたとき、z_n の事後分布はパラメータ y_n+\exp(\theta) のディリクレ分布です。

一方で z_n が与えられたとき、\theta の事後分布は解析的には求まりません。

そこでラプラス近似です。

\theta の事後分布は次の式に比例します。

 p(z_n|\theta)p(\theta)

この式の対数をとって、期待値を計算すると、ディリクレ分布の密度関数と正規分布の密度関数より、

f(\theta)= \sum_n\sum_i \{ (\exp(\theta_i)-1)E[\log z_{ni}]-\log C_D\} - \sum_i\theta_i/2+\mathrm{Const.}

ここで C_D はディリクレ分布の正規化定数です。

\theta は関数 f(\theta) を最大化する \hat \theta を平均とする正規分布で近似できます。

こうして、\theta に関する近似事後分布と z_n に関する近似事後分布を繰り返し更新していくことでパラメータを推定します。

R による実装

f(\theta) の最大化は optim に丸投げできます。

本当は E[\exp(\theta)] \neq \exp(E[\theta]) に注意する必要があります。

しかしまあ、サンプルサイズが十分に大きければ E[\exp(\theta)] \approx \exp(E[\theta]) だと思って計算してもたぶん大過ないでしょう。

それがいやなら optim にはヘッセ行列を返す機能もあるので、より正確に E[\exp(\theta)] を求めることもできます。

ただしヘッセ行列の逆行列が常に求まる保証はありません。

library(gtools)
set.seed(1)
theta <- rnorm(10,0,1)
z <- rdirichlet(100,exp(theta))
y <- t(apply(z,1,function(z)rmultinom(1,1000,z)))

VIunigram <- function(y,theta=numeric(ncol(y))){
  f <- function(theta,logphi){
    N <- nrow(logphi)
    pd <- 0
    logD <- sum(lgamma(exp(theta)))-lgamma(sum(exp(theta)))
    for(n in 1:N){
      pd <- pd + sum((exp(theta)-1)*logphi[n,])-logD
    }
    pd <- pd-sum(theta^2)/2
    return(pd)
  }
  for(i in 1:100){
    num <- (y+exp(theta))
    den <- rowSums(y+exp(theta))
    Elphi <- digamma(num)-digamma(den)
    opt <- optim(theta,f,logphi=Elphi,control = list(fnscale=-1,maxit=1000),method = "BFGS") 
    if(all(abs(theta-opt$par)<1e-8)){
      break
    }
    theta <- opt$par
  }
  return(list(z=num/den,theta=opt$par,iter=i,opt=opt))
}
out <- VIunigram(y)
print(out$iter)
plot(out$theta,theta)
abline(0,1,lty=2)
plot(out$z,z)
abline(0,1,lty=2)

36回の繰り返しで収束し、推定されたパラメータと真のパラメータを比較するとこんな感じです。

f:id:abrahamcow:20190103023944p:plain

f:id:abrahamcow:20190103023924p:plain

けっこううまくいってるのではないでしょうか。

応用例としては次のようなものがあります:
ガンマ・ポアソン分布回帰による margarine データの分析(特に根拠のない推定法) - 廿TT

変分推論自体の説明としては次のようなものがあります:
変分法を使わずに変分ベイズの導出をする - 廿TT

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)