廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

Rcpp を使った最適化ができるパッケージ roptim の簡単な例題

今日の川柳

roptim (GitHub - ypan1988/roptim: General Purpose Optimization in R using C++) の使い方をメモ。

インストールはCRANからいけます。

install.packages("roptim")

線形回帰をやります。

二乗誤差を最小化。

reg.cpp というファイルを作る。

// [[Rcpp::plugins(cpp11)]]
#include <cmath>  // std::pow

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

#include <roptim.h>
// [[Rcpp::depends(roptim)]]

using namespace roptim;

class LM : public Functor {
private:
  arma::vec Y_;
  arma::mat X_;
  
public:
  LM(const arma::mat &Y, const arma::mat &X) : Y_(Y), X_(X) {
  }
  
  double operator()(const arma::vec &beta) override {
    return sum(0.5*pow(Y_-X_*beta,2));
  }
  
  void Gradient(const arma::vec &beta, arma::vec &gr) override {
    gr = -X_.t()*(Y_-X_*beta);
  }
};


// [[Rcpp::export]]
Rcpp::List lm_bfgs(arma::vec Y, arma::mat X, arma::vec beta) {
  LM lm(Y,X);
  Roptim<LM> opt("BFGS");
  
  opt.minimize(lm, beta);
  
  return Rcpp::List::create(
    Rcpp::Named("coefficients") = opt.par(),
    Rcpp::Named("value") = opt.value(),
    Rcpp::Named("convergence") = opt.convergence()
  );
}

R からはこうやって使う。

library(Rcpp)

sourceCpp("~/Documents/reg.cpp")
ctl <- c(4.17,5.58,5.18,6.11,4.50,4.61,5.17,4.53,5.33,5.14)
trt <- c(4.81,4.17,4.41,3.59,5.87,3.83,6.03,4.89,4.32,4.69)
group <- gl(2, 10, 20, labels = c("Ctl","Trt"))
weight <- c(ctl, trt)
X <- model.matrix(~ group)

lm(weight ~ group)
lm_bfgs(weight,X,c(1,1))

lm と同じ結果が求まる。

> lm(weight ~ group)

Call:
lm(formula = weight ~ group)

Coefficients:
(Intercept)     groupTrt  
      5.032       -0.371  

> lm_bfgs(weight,X,c(1,1))
$coefficients
       [,1]
[1,]  5.032
[2,] -0.371

$value
[1] 4.364625

$convergence
[1] 0

これで最尤推定やMAP推定が好きなだけできそう。