廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

多項式回帰モデルの予測分布(『ベイズ推論による機械学習入門』をRで)

今日の川柳

ベイズ推論による機械学習入門』p.109 の図3.8をRで再現してみました。

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

まずサインカーブ正規分布でノイズを加えたデータを生成します。

f:id:abrahamcow:20180329011549p:plain

それをもとに多項式回帰で学習させた予測分布の平均と1標準偏差の範囲をプロットしたのが以下の図です。

f:id:abrahamcow:20180329011533p:plain

パラメータ等の設定はテキストと同じになっているはずです。

多項式の字数 M=1 や M=2 のときはうまくデータの構造をつかめていないようです。

M=4くらいがデータの上がり下がりをうまくつかめていそうです。

M=6 のときは点線の範囲が広めになっています。これはモデルが予測結果に「自信をもっていない」ことを示しています。

以下にRのコードを貼ります。

library(tidyverse)
library(cowplot)
set.seed(4321)
x <- runif(10,0,6)
y <- sin(x)+rnorm(10,0,sqrt(1/10))
dat <- data_frame(y,x)

p_true <- ggplot(dat,aes(x=x,y=y))+
  geom_point()+
  stat_function(fun = sin,colour="red")

ggsave("p_true.png",p_true)

lmBayes <- function(formula,data,lambda,Lambda=NULL,m=NULL){
  mf <- model.frame(formula,data)
  X <- model.matrix(formula,data)
  y <- model.response(mf)
  p <- ncol(X)
  if(is.null(Lambda)){
    Lambda <- diag(1,p)
  }
  if(is.null(m)){
    m <- numeric(p)
  }
  Lambdahat <- lambda * t(X) %*% X + Lambda
  mhat <- drop(solve(Lambdahat) %*% (lambda * colSums(y*X) + Lambda %*% m))
  return(list(formula=formula,lambda=lambda,mhat=mhat,Lambdahat=Lambdahat))
}

predlmBayes <- function(posterior, newdata){
  X <- model.matrix(posterior$formula[-2],newdata)
  n <- nrow(X)
  mu_ast <- drop(X %*% posterior$mhat)
  inv_lambda_ast <- unname(1/posterior$lambda +diag(X %*% solve(posterior$Lambdahat)  %*% t(X)))
  return(list(mu_ast=mu_ast,inv_lambda_ast=inv_lambda_ast))
}

newdata <- data_frame(x=seq(0,6,by=0.1))
post0 <-lmBayes(y~1,data = dat,lambda = 10)
pred0 <-predlmBayes(post0,newdata)

p0 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred0$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred0$mu_ast-sqrt(pred0$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred0$mu_ast+sqrt(pred0$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=1")

post1 <-lmBayes(y~x,data = dat,lambda = 10)
pred1 <-predlmBayes(post1,newdata)

p1 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred1$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred1$mu_ast-sqrt(pred1$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred1$mu_ast+sqrt(pred1$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=2")

post2 <-lmBayes(y~x+I(x^2),data = dat,lambda = 10)
pred2 <-predlmBayes(post2,newdata)

p2 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred2$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred2$mu_ast-sqrt(pred2$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred2$mu_ast+sqrt(pred2$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=3")

post3 <-lmBayes(y~x+I(x^2)+I(x^3),data = dat,lambda = 10)
pred3 <-predlmBayes(post3,newdata)

p3 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred3$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred3$mu_ast-sqrt(pred3$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred3$mu_ast+sqrt(pred3$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=4")

post4 <-lmBayes(y~x+I(x^2)+I(x^3)+I(x^4),data = dat,lambda = 10)
pred4 <-predlmBayes(post4,newdata)

p4 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred4$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred4$mu_ast-sqrt(pred4$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred4$mu_ast+sqrt(pred4$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=5")

post5 <-lmBayes(y~x+I(x^2)+I(x^3)+I(x^4)+I(x^5),data = dat,lambda = 10)
pred5 <-predlmBayes(post5,newdata)

p5 <-ggplot(dat)+
  geom_point(aes(x=x,y=y))+
  geom_line(data=newdata,aes(x=x,y=pred5$mu_ast),colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred5$mu_ast-sqrt(pred5$inv_lambda_ast)),linetype=2,colour="royalblue")+
  geom_line(data=newdata,aes(x=x,y=pred5$mu_ast+sqrt(pred5$inv_lambda_ast)),linetype=2,colour="royalblue")+
  ggtitle("M=6")

p_pred <-plot_grid(p0,p1,p2,p3,p4,p5)
ggsave("pred.png",p_pred)

おなじようなことを黒木玄氏がより詳細に(Juliaで)やっています:Polynomial regression · GitHub