廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

正規分布で事後分布を近似する変分推論のアルゴリズム

今日の川柳

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

ベイズ推論による機械学習入門』ではロジスティック回帰とニューラルネットのところで近似事後分布として1次元の正規分布を仮定して、変分パラメータを勾配法で推定するアルゴリズムが出てくる。

勾配を評価するときには再パラメータ化トリック(re-parameterization trick)というアイデアを使う。

これはニューラルネットワークに限らずどんな分布でも使えるアルゴリズムで、共役事前分布を使ったモデリングができないときに便利だと思ったので、アルゴリズムをまとめ直してみる。

まずすべてのパラメータに対して、平均0、分散 \lambda^{-1} の正規事前分布を仮定する。

すべてのパラメータの近似事後分布として、平均\mu、分散 \sigma^{2}正規分布を仮定する。ここでは添字は省略した。

標準偏差は正の数なので \sigma = \log(1+\exp(\rho)) とおいて、\rho を最適化する。

真の事後分布とのカルバック・ライブラ距離が近い正規分布を求めるのがアルゴリズムの目的。

最適化の目的関数を g(\eta) と置く。

モデルの対数尤度関数を \log p(w) とすると

  1. 学習率 \alpha と 事前分布の精度パラメータ  \lambda を設定。
  2. \mu\rho を適当に初期化し以下を繰り返す。
  3. すべてのパラメータに対して、標準正規乱数を使ってサンプル  \tilde{w} = \mu + \sigma \tilde{\varepsilon} を得る。\tilde{\varepsilon} は標準正規乱数
  4.  \frac{d}{dw} \log p(w) を計算。
  5.  \frac{d}{d\mu}g(\tilde{w}) =  -\frac{d}{d\tilde{w}} \log p(\tilde{w})+ \lambda\tilde{w} を計算。
  6.  \frac{d}{d\rho}g(\tilde{w}) = (\{\frac{d}{dw} \log p(w)\}\tilde{\varepsilon} - 1/\log(\exp(\rho)) + \lambda \tilde{W} \tilde{\varepsilon})\frac{1}{\log(1+\exp(-\rho))} を計算。
  7.  \mu \leftarrow \mu + \alpha \frac{d}{d\mu}g(\tilde{w}) で更新
  8.  \rho \leftarrow \rho + \alpha \frac{d}{d\rho}g(\tilde{w}) で更新

例題:ポアソン回帰

 y\sim \mathrm{Poisson}(X\beta)

R です。

x <- rnorm(100)
X <- cbind(1,x)
beta <- c(2,-1)
set.seed(1)
y <- rpois(100,exp(X%*%beta))

VIpoisreg <- function(y,X,lambda,alpha,iter){
  dlogll <- function(beta){
    lambda <- exp(X%*%beta)
    drop(t(y-lambda)%*%X)
  }
compute_obj <- function(Y,X,W,mu,rho,lambda){
  M <- length(W)
  term1 <- sum(dnorm(W,mu,log1p(exp(rho)),log = TRUE))
  term2 <- sum(dpois(y,exp(X%*%W),log = TRUE))
  term3 <- sum(dnorm(W,numeric(M),1/sqrt(lambda),log=TRUE))
  return(term1-term2-term3)
}

N <- length(y)
M <- ncol(X)
mu = rnorm(M)
rho = rnorm(M)
KL <- numeric(iter)
pb <- txtProgressBar(min = 1, max = iter, style = 3)
for(i in 1:iter){
  ep = rnorm(M)
  W_tmp = mu + log1p(exp(rho)) * ep
  
  # calculate gradient
  d_tmp <- -dlogll(W_tmp)
  d_mu = d_tmp + lambda*W_tmp
  d_rho = (d_tmp*ep - 1/log1p(exp(rho)) + lambda*W_tmp*ep)*plogis(rho)
  
  # update variational parameters
  mu = mu - alpha * d_mu / N
  rho = rho - alpha * d_rho / N
  KL[i] <- compute_obj(y,X,W_tmp,mu,rho,lambda)
  setTxtProgressBar(pb, i)
}
return(list(mu=mu,sigma=log1p(exp(rho)),KL=KL))
}

fit <- VIpoisreg(y,X,lambda = 0.01,alpha = 0.001,iter=5000)
plot(fit$KL,type="l")

fit$mu
fit$sigma

summary(glm(y~x,family = "poisson"))

目的関数が小さくなっていく様子です。

f:id:abrahamcow:20190303024931p:plain

事後分布の平均はけっこう正確です。

事後分布の標準偏差に関してはちょっと広めに求まりました。

> fit$mu
                    x 
 1.9752896 -0.9736318 
> fit$sigma
                    x 
0.09651999 0.08934027 
> summary(glm(y~x,family = "poisson"))

Call:
glm(formula = y ~ x, family = "poisson")

Deviance Residuals: 
     Min        1Q    Median        3Q       Max  
-2.46625  -0.41209  -0.03585   0.44706   2.41542  

Coefficients:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept)  1.99564    0.04068   49.05   <2e-16 ***
x           -0.98920    0.03469  -28.51   <2e-16 ***
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

(Dispersion parameter for poisson family taken to be 1)

    Null deviance: 919.679  on 99  degrees of freedom
Residual deviance:  80.006  on 98  degrees of freedom
AIC: 455.38

Number of Fisher Scoring iterations: 4