廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

R と Stan で Factorization Machines

今日の川柳

Factorization Machines の解説はこの記事がわかりやすかった:

一歩Matrix Factorization、二歩Factorization Machines、三歩Field-aware Factorization Machines…『分解、三段突き!!』 - F@N Ad-Tech Blog

Factorization Machines は傾向線に以下の式を仮定した回帰型のモデルである。

 \hat y(x) = w_0 + \sum_{i=1}^nw_ix_i + \sum_{i=1}\sum_{j=i+1} \langle\mathbf{v}_i ,\mathbf{v}_j \rangle x_ix_j

ただし \langle\mathbf{v}_i ,\mathbf{v}_j \rangle はドット積、

 \langle\mathbf{v}_i ,\mathbf{v}_j \rangle = \sum_{k=1}^{K} v_{ik} v_{jk}

を表す。

例えばこんな感じの表が与えられたとき、

deviceCategory userGender userAgeBracket userType landingPagePath sessions
desktop female 18-24 New Visitor /entry/2013/03/22/004017 11
desktop female 18-24 New Visitor /entry/2014/08/13/031519 35
desktop female 18-24 New Visitor /entry/2014/11/06/060634 11
desktop female 18-24 New Visitor /entry/2014/12/19/221230 50
desktop female 18-24 New Visitor /entry/2015/01/17/064522 15
desktop female 18-24 New Visitor /entry/2018/02/21/230602 23

deviceCategory から landingPagePath の列をダミー変数に変えたものを説明変数 x、sessions の列を目的変数 y にできる。

Stan でやる場合には観測モデルを正規分布に限定する必要はないので、今回はポアソン分布を仮定してみる。

y\sim \mathrm{Poisson}(\exp( \hat y(x)))

です。

ところで、

 \sum_{i=1}\sum_{j=i+1} \langle\mathbf{v}_i ,\mathbf{v}_j \rangle x_ix_j

の項はちょっと変形してやると、

 \frac{1}{2} \sum_k \left( (\sum_i v_{ik}x_i)^2 - \sum_i v_{jk}^2x^2 \right)

と簡単に書ける。

よって、Stan のコードはこうなりました。

data{
  int N;
  int D;
  int K;
  int y[N];
  matrix[N,D] X;
}
parameters{
  vector[D] W;
  matrix[D,K] V;
  real<lower=0> sigma;
}
transformed parameters{
  matrix[N,K] tmp;
  tmp = ( (X*V) .* (X*V) - (X .* X) * (V .* V) )/2;
}
model{
  for(n in 1:N){
    y[n] ~ poisson_log(X[n,]*W+sum(tmp[n,]));
  }
  W ~ normal(0,sigma);
  to_vector(V) ~ normal(0,sigma);
}
generated quantities{
  real pred[N];
  for(n in 1:N){
    pred[n] = poisson_log_rng(X[n,]*W+sum(tmp[n,]));
  }
}

切片項 w_0 を入れるの忘れちゃった。まあいいや。

K はとりあえず2でやってみたけど当てはまりはいい感じ。

f:id:abrahamcow:20190220223117p:plain

推定された v をプロットしたのが以下の図です。

f:id:abrahamcow:20190220223247p:plain

パラメータの解釈もある程度できる。

f:id:abrahamcow:20190220223315p:plain

真ん中の方で目立っているlandingPagePath/entry/2014/12/19/221230はv1が負の値で小さく、v2が正の値で大きいので、たとえば符号が逆になってるdeviceCategorytablet とは相性が悪いことがわかる。

今回 x は186行61列の行列でしたが、4000回回して646.587 secondsで終わりました。

スモールデータならレコメンドとかに使えるかもしれません。

R のコードを貼ります。

library(rstan)
library(googleAnalyticsR)
library(tidyverse)
library(caret)
rstan_options(auto_write = TRUE)
FM <- stan_model("~/Documents/FM.stan")
ga_auth()
account_list <- ga_account_list()
ga_id <- account_list$viewId[3]

gadata <-google_analytics(ga_id,
                   date_range = c("2019-01-01","2019-01-31"),
                   metrics = c("sessions"),
                   dimensions = c("deviceCategory","userGender","userAgeBracket","userType","landingPagePath"))

dmy <- dummyVars(~deviceCategory+userGender+userAgeBracket+userType+landingPagePath,data=gadata)
X <- predict(dmy,newdata = gadata)
dim(X)
#[1] 186  61
dat4stan <- list(X=X,y=gadata$sessions,K=2,N=nrow(gadata),D=ncol(X))
fit <- sampling(FM,dat4stan,chain=1,iter=4000,seed=893)
#646.587 seconds (Total)
all(summary(fit)$summary[,"Rhat"]<1.1)
traceplot(fit,pars="V[1,1]")
plot(gadata$sessions,get_posterior_mean(fit,par="pred"),
     xlab="observed",ylab="fitted")
abline(0,1,lty=2)
Vdf <- matrix(get_posterior_mean(fit,pars="V"),dat4stan$D,dat4stan$K,byrow = TRUE) %>% 
  as.data.frame() %>% 
  mutate(variables=colnames(X)) %>% 
  gather(key,value,-variables)

ggplot(Vdf_res,aes(x=variables,y=value,fill=key))+
  geom_col(position = "dodge")+
  geom_hline(yintercept = 0, linetype=2)+
  theme_bw()+
  theme(axis.text.x = element_text(angle = 90, hjust = 1, colour="black",size = 12))

https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf

https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf