読者です 読者をやめる 読者になる 読者になる

廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

R によるすごくかんたんなパーティクルフィルタの実装例

R 時系列

これであってるのかあんまり自信ない。

主に http://daweb.ism.ac.jp/koza/koza2008/PF_Nakano20081030.pdf を参考にしました。

パーティクルフィルタのアルゴリズムは、

  1. 一期先の予測を乱数でばらまく
  2. 得られた乱数に尤度の重みをつけてまき直す

だと理解しました。

そこですごくかんたんなモデルで試してみることにしました。

モデル:「未観測の状態変数は標準偏差 1 の正規分布したがって推移する。得られる観測は状態変数に標準偏差 3 の正規分布に従うノイズが乗ったものである。」

システムモデル: x_t \sim N(x_{t-1},1)
観測モデル: y_t \sim N(x_{t},3)

 y_t が観測値です。ただし x_0=0 としました。

particle_filter <- function(x0,y,N) {
  tmax <- length(y)
  xx <- matrix(,N,tmax+1)
  xx[,1] <-rep(x0,N)
  for(t in 1:tmax){
    x <- xx[,t] + rnorm(N,0,1) #一期先予測
    w <- dnorm(rep(y[t],N),x,3) #尤度
    xx[,t+1] <- sample(x,N,replace=TRUE,prob=w/sum(w)) #リサンプリング
  }
  return(xx)
}
set.seed(1)
tmax <- 50
x0 <- 0
x <- cumsum(rnorm(tmax))
y <- rnorm(tmax,x,3)

xx <- particle_filter(x0,y,N=5000)
plot(y,type="b")
lines(x,col="red",lwd=2)
lines(apply(xx[,-1],2,mean),col="blue",lwd=2)

f:id:abrahamcow:20170329200749p:plain

赤い線が真の状態、青い線がパーティクルフィルタによる状態の推定値です。

あってる、のか……?

自信がないので dlm パッケージ(岩波データサイエンス Vol.1 の年輪の例題を dlm でやる - 廿TT)をつかってカルマンフィルタでも状態の推定をやってみました。

パーティクルフィルタによる状態の推定値とカルマンフィルタによる状態の推定値は一致するはずです。

library(dlm)
Filt1 <- dlmFilter(y,dlmModPoly(order=1,dV=3^2,C0=0))
lines(dropFirst(Filt1$m),col="orange",lwd=2)

f:id:abrahamcow:20170329200758p:plain

オレンジの線がカルマンフィルタによる状態の推定値です。

青い線とほぼ一致しています。

なんか、あっていそうな気がする。

今後はパーティクルフィルタで状態を推定しつつ MCMC などで未知パラメータもあわせて推定したいです。

いい資料があったら教えて下さい。

以下はアマゾンアフィリエイトのコーナーです。

カルマンフィルタの基礎

カルマンフィルタの基礎

Rcpp で独立メトロポリスヘイスティングス

R

独立メトロポリス・ヘイスティングス法を用いたベイズ推測の簡単な例題 - 廿TT でやったのと同じことを Rcpp で書いてみた。
ハローワールド。

C++ のコードはこう。

#include <Rcpp.h>
using namespace Rcpp;
double lik(double lambda, NumericVector dat) {
  double sumdat;
  sumdat = sum(dat);
  return(dat.length()*log(lambda)-sumdat*lambda);
}
// [[Rcpp::export]]
NumericVector MHexp(int N, NumericVector dat) {
  NumericVector chain(N);
  double a;
  double r;
  RNGScope scope;
  chain[0] =R::runif(0, 10);
  for(int i = 0; i < N; i++) {
    a = R::runif(0, 10);
    r = lik(a,dat)-lik(chain[i-1],dat);
    if(R::runif(0,1) < exp(r)){
      chain[i] = a;
    }else{
      chain[i] = chain[i-1];
    }
  }
  return(chain);
}

R からはこうやって使う。

library(Rcpp)
sourceCpp('MHexp.cpp')
dat <- rexp(100,2)
system.time({out <-MHexp(100000,dat)})
#   ユーザ   システム       経過  
#      0.03       0.00       0.03 
plot(out,type="l")

f:id:abrahamcow:20170329065503p:plain