廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

グループドデータの非負値行列因子分解

今日の川柳

モチベーション

たとえばこういう表がある。

gist.github.com

表の左のほうにユーザー層の情報、右の方にユーザー層ごとのブログへのアクセス経路が書かれている。

どのユーザー層がどの経路を好むか知りたいとする。

そこでトピックモデルとしてポアソン分布を使った非負値行列因子分解を考える。
トピックモデルシリーズ 6 GaP (Gamma-Poisson Model) - StatModeling Memorandum などを参照。)

ユーザー層が文書、アクセス経路が単語に対応する。

ユーザー層の情報を捨てて、行列を分解してしまうのはおもしろくない。

ユーザー層の情報を説明変数として、ユーザー層ごとにトピックの構成が変わるようなモデルにしたい。

モデル

観測値を行列の積 XVH で近似することを目指します。

Y: 観測された分解したい行列(N行K列)
X: 観測された説明変数(N行J列)
V: パラメータ(J行L列)
H: パラメータ(L行K列)

X の中身は0または1の変数でユーザー層を表したものです。

また、W = XVH とします。

L はトピック数です。分析者が決めます。

XV がユーザーの持つトピック、H がトピックが与えられたときの単語の出現しやすさと解釈します。

行列W、Y、X、V、Hの要素は小文字のw、y、x、v、xに下付きの添字を付けて表します。

観測モデルは以下です。

 y_{n,k} \sim {\rm Poisson}(w_{n,k})

事前分布として以下を仮定します。
v_{n,l} \sim {\rm Gamma}(\alpha_v,\beta_v)
h_{l,j} \sim {\rm Gamma}(\alpha_h,\beta_h)

事前分布を入れないと0の多い行列は途中計算でNaNになっちゃうことが多いです。

以下の潜在変数を考えます。
 y_{n,k} = \sum_{l} \sum_{j} s_{n,j,l,k}
 s_{n,j,l,k} \sim {\rm Poisson}(x_{n,j}v_{j,l}h_{l,k})

ポアソン分布に従う確率変数の和はポアソン分布に従うという性質を使いました。

補足

上記の説明ではわかりにくいと思うので補足します。

たとえば男を (1,0) 女を (0,1) とコード化したものが X です。

f:id:abrahamcow:20180710022200p:plain

Vの要素がもし0か1かの値を取るとしたら、Xによってトピックが一つに決まります。

トピックごとの単語の共起する重みを表すのが H です。図の場合、単語が1対2対3で出現します。

でも入力によってトピックがたった一つに決まってしまうというのは不自由です。

一つの文書に複数のトピックが混ざっていると考えたい。

単語が1対2対3で出現するトピックと3対2対1で出現するトピックがあるとしてその二つのトピックが混ざって出てくることにします。

f:id:abrahamcow:20180710022313p:plain

XVでユーザー層の持つトピックが決まり、それに重みHをかけたものがアウトプットになります。

これが行列 W = XVH の意味です。

変分ベイズ

ベイズ推論による機械学習入門』に出てくる非負値行列因子分解とほぼ同じ計算なので詳しくはそちらを参照してください。

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

本を買うのがいやな人は論文をみてください。

Bayesian Inference for Nonnegative Matrix Factorisation Models

変分ベイズの更新式を導出するのはじめてなので、間違ってるかもしれません。

事後分布が以下のように分解できるとして近似します。

q(S,V,H)=q(S)q(V)q(H)

a の期待値を取る操作を E(a) で表します。

変分ベイズの更新式は以下のようになります。

 q(v_{j,l}) =  {\rm Gamma}(\hat \alpha_v^{(j,l)}, \hat \beta_v^{(j,l)})

\hat \alpha_v^{(j,l)} = \sum_n\sum_k E(s_{n,l,k}) + \alpha_v
\hat \beta_v^{(j,l)} = \sum_n\sum_k x_{n,j} E(h_{l,k}) + \beta_v

  q(h_{l,k}) = {\rm Gamma}(\hat \alpha_v^{(l,k)}, \hat \beta_v^{(l,k)})

\hat \alpha_h^{(l,k)} = \sum_n\sum_j E(s_{n,l,k}) + \alpha_h
\hat \beta_h^{(l,k)} = \sum_n\sum_j x_{n,j} E(v_{j,l}) + \beta_h

 q(s_{n,:,:,k}) = {\rm Multinomial}(y_{n,k},\hat p_{n,:,:,k})

\hat p_{n,j,l,k} \propto \exp E(\log(x_{n,j}v_{j,l}h_{l,k}))

確率密度関数も確率分布とおなじ記号 Gamma で書きましたが、文脈で区別できると思います。

近似分布の更新に必要なガンマ分布の対数の期待値はディガンマ関数 \psi を用いて以下のように表わせます。

 X \sim Gamma(a,b)
ならば
 E(\log X) = \psi (a) - \log b

R による実装

s を四次元配列として保存しないで、V や H の更新式に代入してしまうのがこつです。

NMFVB <-function(Y,X,L=2,alpha=1,beta=1,maxit=5000,seed=1234){
  set.seed(seed)
  N <- nrow(Y)
  K <- ncol(Y)
  J <- ncol(X)
  EV <- EelV <- matrix(rgamma(J*L,shape=alpha,rate=beta),J,L)
  EH <- EelH <- matrix(rgamma(L*K,shape=alpha,rate=beta),L,K)
  for(iter in 2:maxit){
    den <- ((X %*% EelV)) %*% EelH
    Sh <- EelH * (t((X %*% EelV)) %*% (Y/den))
    Sv <- EelV * t(X) %*% (((Y/den)%*%t(EelH)))
    beta_H <- colSums(X%*%EV) + beta
    alpha_H <- alpha + Sh
    EH <- alpha_H/beta_H
    beta_V <- outer(colSums(X),rowSums(EH)) + beta
    alpha_V <- alpha + Sv
    EV <- alpha_V/beta_V
    EelH <-exp(digamma(alpha_H))/beta_H
    EelV <-exp(digamma(alpha_V))/beta_V
  }
  return(list(V=EV,H=EH,
              alpha_V=alpha_V,beta_V=beta_V,
              alpha_H=alpha_H,beta_H=beta_H))
}

結果

勘によりトピック数は L = 3 としました。

勘できめるのがいやな人は論文をみてください。

Bayesian Inference for Nonnegative Matrix Factorisation Models

ELBOを比べることでモデルを選択する方法が記載されています。
(ぼくはELBOの計算自信ない。だれかできたら教えてください。)

もとめた W = XVH(上)ともとの行列 Y(下)を並べてみます。

f:id:abrahamcow:20180705225935p:plain

まあまあ雰囲気を再現できているんじゃないでしょうか。

行列 V の値をみてみます。V1、V2、V3はそれぞれ V の1列目、2列目、3列目で潜在的なトピックを表します。

f:id:abrahamcow:20180705021620p:plain

18-24歳はV1が支配的です。歳を取るにつれてV2が多くなってきます。また女性は男性よりV2が多く他が少ないです。
V3は18-24歳にはほとんどなく、25歳以上になると増えますが歳を取るにつれて減っていきます。

V1、V2、V3に解釈を与えるために H の中身を見ます。H はトピックごとの単語の出現しやすさです。

f:id:abrahamcow:20180705022544p:plain

一応ガンマ分布の95%区間をエラーバーで重ねていますが、幅が狭くてほとんど見えません。

V1はGoogle成分が多いですがBingも多めです。調べものをしていてこのブログにたどり着く層でしょうか。若年層は具体的に知りたいことがあってこのブログに来る人が多いようです。
V2はGoogleが少なくYahooが多いのが特徴です。歳を取るにつれてYahooユーザーが多くなるみたいです。また、女性のほうがYahooユーザーが多そうです。
V3のV1、V2との違いはt.co(ツイッター)成分がそこそこある点です。SNS経由で来る人は18-24歳にはほとんどなく、25歳以上になると増え歳を取るにつれて減っていくみたいです。

最後にRのコードをまとめて貼ります。

library(googleAnalyticsR)
library(tidyverse)

ga_auth()
account_list <- ga_account_list()
ga_id <- account_list$viewId[3]

gadata <-
  google_analytics(ga_id,
                   date_range = c("2018-01-01","2018-06-30"),
                   metrics = c("sessions"),
                   dimensions = c("source","userGender","userAgeBracket"))

gadata_w <-spread(gadata,source,sessions,fill=0)

gamat <-as.matrix(gadata_w[,-c(1:2)])

X1 <-model.matrix(~userGender-1,data=gadata_w)
X2 <-model.matrix(~userAgeBracket-1,data=gadata_w)
gaX <-cbind(X1,X2)

out <-NMFVB(Y=gamat,X=gaX,L=3)

obsdf <-as.data.frame(gamat) %>% 
  set_names(1:ncol(gamat)) %>% 
  mutate(row=row_number()) %>% 
  gather(col,sessions,-row) %>% 
  mutate(col=as.integer(col),type="obs")

fitdf <- as.data.frame(gaX %*% out$V %*% out$H) %>% 
  set_names(1:ncol(gamat)) %>% 
  mutate(row=row_number()) %>% 
  gather(col,sessions,-row) %>% 
  mutate(col=as.integer(col),type="fit")

outdf <-bind_rows(obsdf,fitdf)

ggplot(outdf,aes(x=col,y=row,fill=sessions))+
  geom_tile(colour="black")+
  scale_fill_continuous(low="white",high="cornflowerblue")+
  facet_wrap(~type,nrow=2)

dfV <-as.data.frame(out$V) %>%
  rownames_to_column() %>%
  gather(key,value,-rowname)

ggplot(dfV,aes(x=rowname,y=value,fill=key))+
  geom_col(colour="black",position = "fill")+
  coord_flip()


CIHlower <- as.data.frame(qgamma(0.025,shape=out$alpha_H,rate=out$beta_H)) %>% 
  mutate(l=row_number()) %>%
  gather(source,lower,-l)


CIHupper <- as.data.frame(qgamma(0.975,shape=out$alpha_H,rate=out$beta_H)) %>% 
  mutate(l=row_number()) %>%
  gather(source,upper,-l)

Hdf <- as.data.frame(out$H) %>% 
  mutate(l=row_number()) %>% 
  gather(source,value,-l) %>% 
  left_join(CIHlower) %>% 
  left_join(CIHupper)

ggplot(Hdf,aes(x=source,y=value,ymin=lower,ymax=upper))+
  geom_col(fill="white",colour="black")+
  geom_errorbar(width=0.5)+
  facet_wrap(~l,scales="free_x")+
  coord_flip()