読者です 読者をやめる 読者になる 読者になる

廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

混合正規分布のパラメータ推定(あるいは EM アルゴリズム練習問題)

前置き

R では mixtools パッケージなどを使えば混合正規分布のパラメータを推定できる.
ExploringDataBlog: Fitting mixture distributions with the R package mixtools などを参照して欲しい.

ここではパッケージはさておき, 愚直に計算して更新式を導出する.

EM アルゴリズムの勉強には役立つかもしれない.

計算

一般の混合分布

g 個の成分からなる混合分布
\displaystyle f_\boldsymbol{\theta}(y) = \sum_{j=1}^{g} \xi _j f_j(y;\phi_j)
を考える.
ここで ξ は混合比率
 \displaystyle\sum _{j=1}^{g} \xi _j = 1.

混合分布の対数尤度 l (\boldsymbol{\theta,y})
\displaystyle l (\boldsymbol{\theta ,y}) =\sum _{i=1}^{n} \log f_\boldsymbol{\theta}(y_i) =\sum _{i=1}^{n}\log\left\{ \sum_{j=1}^{g} \xi _j f_j(y_i;\phi_j) \right\}.

したがって, 尤度方程式は
\displaystyle \frac{\partial}{\partial \xi _j} l(\boldsymbol{\theta ,y}) = \sum _{i=1}^{n} \left\{ \frac{f_j(y_i|\phi _j)}{f(y_i|\theta)} - \frac{f_g(y_i|\phi _j)}{f(y_i|\theta)} \right\} =0 \quad j=1,\ldots , g-1,

\displaystyle \frac{\partial}{\partial \phi _j} l(\boldsymbol{\theta ,y}) = \xi _j \sum _{i=1}^{n} \frac{\frac{\partial}{\partial \phi _j}f _j (y_i|\phi _j)}{f(y_i|\boldsymbol{\theta})} = \boldsymbol{0} \quad j=1,\ldots , g
となる.

ある観測値 y がどの部分母集団から得られたかを示す観測 \boldsymbol{z}=(z_1,\ldots,z_g)^T
 \displaystyle z_j =\left\{ \begin{array}{ll}1 & \text{観測が $\Omega _j$ から得られた}\\0& \text{それ以外} \end{array}\right.
があるとする.

単一観測 \boldsymbol{X}^{\ast T}=(y, \boldsymbol{Z}^T)^T に関する同時分布は, 条件付き密度を用いて,
\displaystyle f(\boldsymbol{x}^{\ast}|\boldsymbol{\theta})=\prod_{j=1}^{g}f^{z_j}_{j}(y|\phi _j) \prod_{j=1}^{g} \xi ^{z_j}_{j}
と書ける.

完全観測 \boldsymbol{X}=(\boldsymbol{Y}^T,\boldsymbol{Z}^T)^T の密度関数は,
\displaystyle f(\boldsymbol{x}|\boldsymbol{\theta})=\prod^{n}_{i=1}\left(\prod^{g}_{j=1} f^{z_{ij}}_{j}(y_i|\phi _j) \prod^{g}_{j=1}\xi ^{z_ij}_{j}\right)
となる.

ここから完全観測に基づく対数尤度 l^C (\boldsymbol{\theta ,x})
\displaystyle l^C (\boldsymbol{\theta ,x}) =\sum^{n}_{i=1} \log f(\boldsymbol {x}_i|\boldsymbol{\theta})= \sum^{n}_{i=1}\sum^{g}_{j=1} z_{ij} \log f_j (y_i|\phi _j) + \sum^{n}_{i=1}\sum^{g}_{j=1} z_{ij} \log \xi _j
となる.

EM アルゴリズムの E ステップでは, 現時点でのパラメータ値 \boldsymbol{\theta}^{(k)} が得られ, 観測 \boldsymbol{Y=y} が与えられた条件のもとでの l^C (\boldsymbol{\theta ,x}) に関する条件付き期待値 Q (\boldsymbol{\theta ,\theta}^{(k)}) を計算する.

\displaystyle Q (\boldsymbol{\theta ,\theta}^{(k)}) 
= E _{ \boldsymbol{\theta}^{(k)} } \left\{ l^C (\boldsymbol{\theta ,x}) |\boldsymbol{Y=y}  \right\} \\
=  \sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij} \log f_j (y_i|\phi _j) + \sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij} \log \xi _j

M ステップでは, E ステップで得られた Q を \boldsymbol{\theta} に関して最大化する.

混合比率 ξ に関しては, \xi _g = 1- \xi _1 - \cdots -\xi _{g-1} であることから,
\displaystyle \frac{\partial}{\partial \xi _j} Q (\boldsymbol{\theta ,\theta}^{(k)}) = \sum^{n}_{i=1} \left( \frac{z^{(k)}_{ij}}{\xi _j} -\frac{z^{(k)}_{ig}}{\xi _g}\right) = 0 \quad j=1,\ldots, g-1,
なる方程式を解けば良い.

ここで,
\displaystyle \sum^{g}_{j=1} z^{(k)}_{ij}=1, \quad \sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij}=n
に着目すると, \xi _j に関するパラメータ更新式
\displaystyle \xi^{(k+1)}_{j} = \frac{1}{n} \sum^{n}_{i=1}z^{(k)}_{ij} \quad j=1,\ldots, g-1 \tag{1}
を得る.

さらに, 各部分母集団でのパラメータ  \phi _l については, Q の第二項目が  \phi _l に依存しないことから,
\displaystyle \frac{\partial}{\partial \phi _l} Q (\boldsymbol{\theta ,\theta}^{(k)}) = \frac{\partial}{\partial \phi _l}\sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij} \log f_j (y_i|\phi _j) \quad i=1,\ldots, g,
を解けば良い.

混合正規分布

これより各部分母集団が正規分布の場合を考える.

第 j 母集団の観測の密度関数は,
\displaystyle f _j (y| \phi _j) = f(y|(\mu _j , \sigma ^2) ) = \frac{1}{\sqrt{s \pi \sigma ^2}}\exp \left\{-\frac{1}{\sigma ^2}(y-\mu _j)^2 \right\}
であるから,
\displaystyle \log f_j(y_i |\phi _j ) = - \frac{1}{2} \log 2\pi \sigma ^2 - \frac{1}{2 \sigma ^2}(y_i -\mu _j) ^2.

よって, 尤度方程式は,
\displaystyle 0 = \frac{\partial}{\partial \mu _ll} \sum^{n}_{i=1} \sum^{g}_{j=1} z^{(k)}_{ij} \log f_j(y_i |\phi _j )
= \sum^{n}_{i=1}z^{(k)}_{il} \left\{ \sigma^2 (y_i - \mu _l)\right\} ,
\displaystyle 0 = \frac{\partial}{\partial  \sigma^2} \sum^{n}_{i=1} \sum^{g}_{j=1} z^{(k)}_{ij} \log f_j(y_i |\phi _j )
= \sum^{n}_{i=1} \sum^{g}_{j=1} z^{(k)}_{ij} \left\{ -\frac{1}{\sigma^2} + \frac{1}{2 (\sigma^2)^2} (y_i - \mu _l)\right\}
となる.

これを解いて(\sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij}=n に留意)パラメータ更新式,
 \displaystyle \mu^{k+1}_l =\frac{\sum^{n}_{i=1} z^{(k)}_{il} y_i}{\sum^{n}_{i=1} z^{(k)}_{il}}\tag{2}
 \displaystyle (\sigma^2)^{(k+1)} = \frac{1}{n} \sum^{n}_{i=1}\sum^{g}_{j=1} z^{(k)}_{ij}(y_i - \mu^{(k+1)}_j)^2 \tag{3}
を得る.

R による計算例

パラメータが増えると収束しなくなる. 最尤法の限界なのか, ぼくのコードのまずさなのかはわからない.

estmixnorm <- function(x,mu,Sigma,pi,maxit=1000){
  messe <-  "収束しませんでした"
  N <- length(x)
  LL_1 <- numeric(0)
  n_k <- numeric(0)
  K <- length(mu)
  f_k <- function(i){pi[i]*dnorm(x,mu[i],sqrt(Sigma[i]))}

  LL <- function(x, mu, sigma, pi){
    pL <- sapply(1:K,f_k)
    sum(log(apply(pL,1,sum)))   
  }

  LL_1[1] <- LL(x,mu,sigma,pi)
  for(i in 1:maxit){
    tmp <- sapply(1:K,f_k)
    den <- apply(tmp,1,sum)
    w_k <- matrix(,nrow=K,ncol=N)  
    for(j in 1:K){
      w_k[j,] <- pi[j]*dnorm(x,mu[j],sqrt(Sigma[j]))/den  
    }
    n_k <- apply(w_k,1,sum)
    if(any(n_k==0)){messe = "重み係数が0になりました"; break }
    for(j in 1:K){
      mu[j] <- sum((w_k[j,] * x))/n_k[j]
    #  gamma2 * (x-mu[2])^2
      Sigma[j] <- sum(w_k[j,]* (x-mu[j])^2)/n_k[j]
    }
    pi <- n_k/N
    LL_1[i+1] <-LL(x,mu,sigma,pi)
    if(abs(LL_1[i+1]-LL_1[i]) < 1e-6){messe = "収束しました"; break }
  }
  list(mu=mu,Sigma= Sigma, pi= pi, it =i, LL= LL_1, convergence =messe)
}

混合正規乱数を発生させて, 上記の関数でパラメータを推定してみよう.

rmixnorm <- function(n,mu, sigma,weight){
  stopifnot( length(mu) == length(sigma) & length(sigma)==length(weight))
  Y <-sample(length(weight), n, replace=TRUE, prob=weight)
  X <-rnorm(n, mu[Y], sigma[Y])
}

mu <- c(1, 10)
sigma <- sqrt(c(1, 9))
weight <- c(0.4, 0.6)

X <- rmixnorm(1000,mu,sigma,weight)


res <-estmixnorm(X,c(1,3),sqrt(c(1, 9)),c(0.5,0.5))



dmixnorm <- function(x,mu,sigma,weight){
  weight[1]*dnorm(x,mu[1],sigma[1]) +
    weight[2]*dnorm(x,mu[2],sigma[2])
}
hist(X,breaks="FD",freq=FALSE,ylim=c(0,0.17))
curve(dmixnorm(x,mu,sigma,weight),add=TRUE,col="royalblue",lwd=2)
curve(dmixnorm(x,res$mu,sqrt(res$Sigma),res$pi),add=TRUE,col="orange2",lwd=2)
legend("topright",legend=c("true","estimate"),col=c("royalblue","orange2"),lwd=2)

ヒストグラムを見た感じでは, うまく推定できたようだ.

f:id:abrahamcow:20150801044858p:plain

> res$mu
[1] 0.9873231 9.9251457
> res$Sigma
[1] 1.071319 9.368002
> res$pi
[1] 0.4045819 0.5954181

続いて対数尤度の変化を見て収束の様子を調べる.

plot(res$LL,type="b")

f:id:abrahamcow:20150801045019p:plain

一回目の更新で急激に大きくなっている. 収束は早いようだ.