廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

0-1データのNMF(非負値行列因子分解)を使った項目反応理論っぽい分析

今日の川柳

分析の対象

「たのしいベイズモデリング」第4章のデータです。

北大路書房:『たのしいベイズモデリング-事例で拓く研究のフロンティア-』ダウンロード から入手できます。

no sex age p1 p2 p3 p4 p5 p6 p7
1 2 37 0 0 1 0 0 1 1
2 2 32 1 1 1 1 0 1 1
3 1 59 1 0 1 0 1 1 1
4 2 33 0 1 0 0 1 1 1
5 2 55 1 1 1 1 1 1 1
6 2 33 0 0 1 0 0 1 0

no が被験者、p1からp7が質問を表しています。

質問の内容は省略します。本を買ってください。

yesと答えた場合が1で、noが0だと思ってください。

「たのしいベイズモデリング」の中では項目の特性や被験者の属性情報を加味した分析は行われておらず、今後の課題とされています。

ここではそれに 0-1データのNMF(非負値行列因子分解) - 廿TT を使って挑戦します。

モデル

Y=(Y_{n,k}) を被験者 n の項目 k に対する回答(0 または 1)とします。
X_n を被験者 n の属性を表す説明変数(年齢と性別)とします。

トピック数 L はユーザーが定めます。

A_{l,k} \sim \mathrm{Beta}(a,b)
\alpha_{n,1:L} = \exp(X_n V)
S_{n,1:L} \sim \mathrm{Dirichlet}(\alpha_{n, 1:L})
 z_{n,k} \sim \mathrm{Categorical}(S_{n,1:L})
 Y_{n,k} \sim \mathrm{Bernoulli}(A_{k,z_{n,k}})

パラメータ A_{l,k} はトピック数×項目数個の成功確率パラメータ、パラメータ S_{n,1:L} はサンプルごとのトピックの構成割合です。

潜在変数  z_{n,k} はトピックのインデックスを表すインジケータ変数です。

変分推論

R のコードを貼ります。

berNMF <- function(Y,X,L,a=1,b=1,maxit=5000,tol=1e-4,seed=1) {
  set.seed(seed)
  lldir <- function(par){
    V <-matrix(par,D,L)
    beta <- exp(X%*%V)
    logD <- sum(rowSums(lgamma(beta)) - lgamma(rowSums(beta)))
    -(sum(beta*logS)-logD)
  }
  lldir_grad <- function(par){
    V <-matrix(par,D,L)
    beta <- exp(X%*%V)
    lastterm <- V
    for(l in 1:L){
      lastterm[,l] <- t(X)%*%(beta[,l]*digamma(rowSums(beta)))
    }
    -c(t(X)%*%(beta*logS)-t(X)%*%(beta*digamma(beta))+lastterm)
  }
  N <- nrow(Y)
  K <- ncol(Y)
  D <- ncol(X)
  S <- gtools::rdirichlet(N,rep(1,L))
  A <- matrix(rbeta(L*K,1,1),L,K)
  Abar <- 1-A
  Ybar <- 1-Y
  M <- is.na(Y)
  Y[M] <- 0
  Ybar[M] <- 0
  V <- matrix(0,D,L)
  beta <- exp(X%*%V)
  for (i in 1:maxit) {
    Z_s = S*t((A)%*%t(Y/(S%*%A))+(Abar)%*%t(Ybar/(S%*%Abar))) + beta
    logS = digamma(Z_s)-digamma(rowSums(Z_s))
    opt <- optim(V, lldir, lldir_grad, method = "BFGS")
    if(all(abs(V-opt$par)<tol)){
      break
    }
    V <- opt$par
    beta <- exp(X%*%V)
    S <- exp(logS)
    Z_a = A*(t(S)%*%(Y/(S%*%A))) + a
    Z_b = Abar*(t(S)%*%(Ybar/(S%*%Abar))) + b
    A = exp(digamma(Z_a)-digamma(Z_a+Z_b))
    Abar = exp(digamma(Z_b)-digamma(Z_a+Z_b))
  }
  hess <- optimHess(V, lldir, lldir_grad)
  S <- Z_s/rowSums(Z_s)
  A <- Z_a/(Z_a+Z_b)
  ll <- sum((!M)*dbinom(Y,1,(S%*%A),log = TRUE))
  list(S=S, A=A, V=V, hess=hess, ll=ll, iter=i)
}

結果

トピック数を2から10まで変えて試したところ対数尤度はどんどん下がっていく感じになっちゃったので、トピック数は2にしたいと思います。

f:id:abrahamcow:20190806074256p:plain

ためしにトピック数2で行列Yの10%をランダムに欠損させてフィッティングをみたところ、AUCは0.77くらいでした。

まあまあ妥当な予測が行えていると判断します。

f:id:abrahamcow:20190806074628p:plain

ここからはパラメータの解釈です。

下図は推定された成功確率パラメータAの棒グラフです。デンドログラムは普通のユークリッド距離を使ってウォード法で書きました。

f:id:abrahamcow:20190806074729p:plain

p3、p6は常にyesが選ばれやすい質問のようです。

p1、p4は比較的yesが選ばれにくい質問のようです。

p1、p4で比較的yesが選ばれやすくなるトピック2では、対照的にp2、p7がnoになりやすくなるという傾向が見て取れます。

これはp1、p4とp2、p7の特性の違いを反映していると考えれます。

次に回答者の属性の影響をみていきます。

下図は黒丸が係数V、エラーバーが95%信用区間です(ラプラス近似による荒い信用区間です。実際にどの程度の被覆確率になっているかはシミュレーションしてみないとわかりませんので、95という数字は目安です。)

f:id:abrahamcow:20190806075449p:plain

性別によるちがいはほとんどないようです。

年齢(100で割ってスケーリングしています)は高くなるほどトピック1の割合が増えるようです。

他に、推定されたSを使って回答者をクラスタリングしたりすることもできます。

R のコード

library(tidyverse)
library(ROCR)
library(ggdendro)
library(patchwork)
library(parallel)
dat <- read_csv("~/Downloads/chapter01_19/chapter04平川/data_hirakawa.csv")

age <- dat$age/100
sex <- dat$sex-1
X <- model.matrix(~sex+age)
Y <- as.matrix(select(dat,starts_with("p")))
out <- mclapply(2:10,function(L)berNMF(Y,X,L=L),mc.cores = detectCores())
plot(2:10,sapply(out, function(x)x$ll),type="b", xlab="number of topics", ylab="log-likelihood")

Y2 <- Y
set.seed(1)
ind <- sample.int(length(Y),280)
Y2[ind] <- NA
out_test <- berNMF(Y2,X,2)
pred <- prediction(as.vector((out_test$S%*%out_test$A)[ind]),as.vector(Y[ind]))
perf <- performance(pred, "tpr", "fpr")
plot(perf)
perf <- performance(pred, "auc")
perf@y.values


hc <- hclust(dist(t(out[[1]]$A)),method = "ward.D")
hcdata <- dendro_data(hc)

dfA <- as.data.frame(out[[1]]$A) %>% 
  mutate(topic=row_number()) %>% 
  gather(Q,value,-topic) %>% 
  mutate(Q=factor(Q,levels = levels(hcdata$labels$label)))
p2 <- ggplot(dfA,aes(x=Q,y=value))+
  geom_col()+
  geom_hline(yintercept = 0.5,linetype=2)+
  facet_grid(topic~.)+
  theme_bw()

p1 <- ggdendrogram(hcdata)
p1 + p2 + 
  plot_layout(ncol = 1, heights = c(1, 3))

se <- matrix(sqrt(diag(solve(out[[1]]$hess))),3,2)
dfV <- data.frame(out[[1]]$V) %>% 
  set_names(1:2) %>% 
  mutate(name=colnames(X)) %>% 
  gather(topic,coef,-name)
dfse <- data.frame(se) %>% 
  set_names(1:2) %>% 
  mutate(name=colnames(X)) %>% 
  gather(topic,se,-name)
dfVse <- left_join(dfV,dfse)
ggplot(dfVse,aes(x=name,y=coef,
                 ymin=coef+qnorm(0.25)*se,
                 ymax=coef+qnorm(0.975)*se))+
  geom_pointrange()+
  geom_hline(yintercept = 0, linetype=2)+
  facet_wrap(~topic)+
  theme_bw(20)