廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

『ガウス過程と機械学習』図3.20 (b) を R で再現する (p.91)

今日の川柳

ガウスカーネルのハイパーパラメータを最適化して2標準偏差の予測区間をプロットするところ。

f:id:abrahamcow:20190617232918p:plain

コードだけ貼ります。解説はしませんので、アマゾンアフィリエイト経由で本を買ってください。

ガウス過程と機械学習 (機械学習プロフェッショナルシリーズ)

ガウス過程と機械学習 (機械学習プロフェッショナルシリーズ)

library(ggplot2)
dat <- read.csv("http://chasen.org/~daiti-m/gpbook/data/gpr.dat",sep = "\t",header = FALSE)
colnames(dat) <- c("x","y")
RBF <- function(x1,x2,eta){
  theta <- exp(eta) 
  theta[1]*exp(-((x1-x2)^2)/theta[2]) + theta[3]*(x1==x2)
}
ll <- function(eta,y,X){
  cov <- outer(X,X,RBF,eta=eta)
  detK <- determinant(cov)
  -detK$sign*detK$modulus-t(y)%*%solve(cov)%*%y
}
dll <- function(eta,y,X){
  dRBF1 <- function(x1,x2,eta){
    theta <- exp(eta) 
    theta[1]*exp(-((x1-x2)^2)/theta[2])
  }
  dRBF2 <- function(x1,x2,eta){
    theta <- exp(eta)
    theta[1]*exp(-((x1-x2)^2)/theta[2])*((x1-x2)^2)/theta[2]
  }
  dRBF3 <- function(x1,x2,eta){
    theta <- exp(eta)
    theta[3]*(x1==x2)
  }
  cov <- outer(X,X,RBF,eta=eta)
  d1 <- outer(X,X,dRBF1,eta=eta)
  d2 <- outer(X,X,dRBF2,eta=eta)
  d3 <- outer(X,X,dRBF3,eta=eta)
  Kinv <- solve(cov)
  Kinvy <- Kinv%*%y
  c(-sum(diag(Kinv%*%d1))+t(Kinvy)%*%d1%*%Kinvy,
    -sum(diag(Kinv%*%d2))+t(Kinvy)%*%d2%*%Kinvy,
    -sum(diag(Kinv%*%d3))+t(Kinvy)%*%d3%*%Kinvy)
}
opt <- optim(c(0,0,0),ll,dll,y=dat$y,X=dat$x,control = list(fnscale=-1),method = "BFGS")
print(exp(opt$par))
#[1] 1.52450203 0.68928552 0.06701316
cov <- outer(dat$x,dat$x,RBF,eta=opt$par)
xv <- seq(-1,4,length.out = 100)
cov2 <- outer(dat$x,xv,RBF,eta=opt$par)
ybar <- as.vector(t(cov2)%*%solve(cov)%*%dat$y)
sd <- sqrt(diag(RBF(xv,xv,opt$par)-t(cov2)%*%solve(cov)%*%cov2))
pred <- data.frame(x=xv,y=ybar)
ggplot(dat,aes(x=x,y=y))+
  geom_line(data=pred,aes(x=xv,y=ybar),colour="royalblue")+
  geom_ribbon(data=pred,aes(x=xv,ymin=ybar-2*sd,ymax=ybar+2*sd),alpha=0.2,fill="royalblue")+
  geom_point(pch=4,size=5)+
  theme_bw()