廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

多項ロジスティック回帰(ソフトマックス回帰)の変分推論をRで(Adagrad)

今日の川柳

abrahamcow.hatenablog.com

のおまけです。

AdaGradのすすめ - Qiita を参考にAdagradというのを使ってみたら、若干収束が速くなりました。

row_softmax <- function(x){
  ex <- exp(x)
  ex/rowSums(ex)
}

VIlogistic <- function(Y, X, lambda=1, lr=1e-5, max_iter=50000){
  D <- ncol(Y)
  M <- ncol(X)
  N <- nrow(X)
  
  dw <- function(Y,X,W){
    -t(X)%*%(Y-row_softmax(X%*%W))
  }
  
  compute_obj <- function(Y,X,W,mu,rho,lambda){
    term1 <- sum(dnorm(W,mu,log1p(exp(rho)),log = TRUE))
    term2 <- sum(Y*log(row_softmax(X%*%W)))
    term3 <- sum(dnorm(W,0,1/sqrt(lambda),log=TRUE))
    return(term1-term2-term3)
  }
  
  mu = matrix(rnorm(M*D),M,D)
  rho = matrix(rnorm(M*D),M,D)
  
  d2_mu = matrix(0,M,D)
  d2_rho = matrix(0,M,D)
  
  KL <- numeric(max_iter)
  pb <- txtProgressBar(min = 1, max = max_iter, style = 3)
  for(i in 1:max_iter){
    # sample epsilon
    ep = matrix(rnorm(M*D),M,D)
    W_tmp = mu + log1p(exp(rho)) * ep
    
    # calculate gradient
    tmp_dw = dw(Y, X, W_tmp)
    d_mu =  tmp_dw - lambda*W_tmp
    d_rho = (tmp_dw*ep - 1/log1p(exp(rho)) + lambda*W_tmp*ep) * (1 / (1+exp(-rho)))
    
    d2_mu = d2_mu + d_mu^2
    d2_rho = d2_rho + d_rho^2
    
    # update variational parameters
    mu = mu - lr * d_mu  / sqrt(d2_mu)
    rho = rho - lr * d_rho  / sqrt(d2_rho)
      
    
    KL[i] <- compute_obj(Y,X,W_tmp,mu,rho,lambda)
    setTxtProgressBar(pb, i)
  }
  return(list(mu = mu, rho = rho, KL = KL))
}

set.seed(1)
test <- sample.int(nrow(iris),10)

Y <- model.matrix(~Species-1, data=iris[-test,])
X <- model.matrix(Species~., data=iris[-test,])
outVI <- VIlogistic(Y,X, lr=0.1,max_iter = 5000)
plot(outVI$KL,type="l")

pred_VI_logis <- function(i,mu,rho,predX){
  set.seed(i)
  M <- nrow(mu)
  D <- ncol(mu)
  ep = matrix(rnorm(M*D),M,D)
  W_tmp = mu + log1p(exp(rho)) * ep
  t(apply(row_softmax(predX%*%W_tmp),1,function(p)rmultinom(1,1,p)))
}

testX <- model.matrix(Species~., data=iris[test,])
testY <- model.matrix(~Species-1, data=iris[test,])

outpred <- lapply(X=1:100, FUN = pred_VI_logis,mu=outVI$mu,rho=outVI$rho,predX=testX)
predmean <- apply(simplify2array(outpred),1:2,mean)

predint <- apply(predmean, 1, which.max)
testYint <- apply(testY, 1, which.max)
table(predint,testYint)

tf <- predint==testYint

library(tidyverse)
dfpred <- as.data.frame(predmean) %>% 
  setNames(attr(Y,"dimnames")[[2]]) %>% 
  mutate(answer = tf, rowname=row.names(testY)) %>% 
  gather(Species,prob,-answer,-rowname)

ggplot(dfpred,aes(x=Species,y=prob,fill=answer))+
  geom_col()+
  facet_wrap(~rowname)+
  coord_flip()+
  theme_bw(12)