R言語でスライスサンプリング(Slice Sampling)を実装してみた
スライスサンプリング(Slice Sampling)というサンプリング手法
Slice sampling, Radford M. Neal, Source: Ann. Statist. Volume 31, Number 3 (2003), 705-767.
[physics/0009028] Slice Sampling
についてお勉強していたのでまとめる。以下の文章で「原論文」としてこの論文を参照する。
ちなみに1P程度ではあるものの皆大好きPRML・下の11章にも記載がある。
概要&アルゴリズム
手法としてはMCMCの一種と考えられるものであるので、まずサンプリング対象として欲しい確率分布(の規格化定数を除いた部分)を設定する。
この時、に従うサンプル系列を得るために、スライスサンプリングでは以下のようなアルゴリズムを考える。
- 初期時点として、その時の初期値を適当に設定
- となるような一様分布からをサンプリングし、"スライス"を決定
- 区間をなんらかの方法で決める
- を領域 から一様にサンプリングして持ってくる
というものである。もう少し細かい実装のお話は次の節に書く。
ざっくりでいうと
[tex: p(x,y) = \frac{1_{\{0
R言語での実装
基本的に原論文、特にFigure3〜5あたりに合わせて実装してある。
注意すべき点としては
- アンダーフローしないようにのlogを取って計算するための方法で実装(原論文P8半ば参照)
- 区間の選び方として、"doubling"手順ではなく"stepping out"手順を使用している
- 区間を無限大まで伸ばして良いと仮定したので、Figure3のJとKに関する処理は無視(原論文P10半ば参照)
- 区間を決める際の幅は、過去の点間の距離の平均値で決定(原論文P16最終段落参照)*1
である。
これを踏まえた上で以下のようにスライスサンプリングを以下のように実装した。
#The "stepping out" procedure for finding an interval around x stepping.out <- function(x, w, is.in.S) { #initial range L <- x - w * runif(1) R <- L + w #find inerval around x while(is.in.S(L)){ L <- L - w } while(is.in.S(R)){ R <- R + w } list(L=L, R=R) } #The "shrinkage" procedure for sampling from the interval shrinkage <- function(x0, I, is.in.S) { L.bar <- I$L R.bar <- I$R repeat { #select new point from the interval between L.bar and R.bar x1 <- L.bar + runif(1) * (R.bar - L.bar) if(is.in.S(x1)){break} #shrinkage the interval if(x1 < x0){ L.bar <- x1 } else{ R.bar <- x1 } } x1 } #Slice sampling function slice.sample <- function(n, x0, f, w=1.0) { g <- function(x){log(f(x))} make.is.in.S <- function(z, g){function(x){z < g(x)}} sum.dist <- 0 result <- rep(x0,n) for(i in 2:n) { z <- g(x0) - rexp(1) is.in.S <- make.is.in.S(z,g) #calc interval I <- stepping.out(x0, w, is.in.S) #sample next point x1 <- shrinkage(x0, I, is.in.S) #update results sum.dist <- sum.dist + abs(x1 - x0) w <- sum.dist / (i - 1) result[i] <- x1 x0 <- x1 } result }
実際に動かしてみる
まずは平均0・標準偏差1の正規分布を生成してみる。
MCMC同様(というかこの手法自体がその一部なので)、確率密度関数の定数項は無視して良い。
SIZE <- 10^4 points <- slice.sample(SIZE, 0, function(x)exp(-0.5*x^2), w=1)
作成したサンプリングポイント(点列)の平均と標準偏差は
> mean(points) [1] 0.009999361 > sd(points) [1] 1.018237
のようにそれぞれ0と1に近い値を取っている。更にサンプリングした点列のヒストグラムと密度関数を重ねてPLOTすると
hist(points, SIZE^0.5, freq=FALSE) x<-seq(-3 ,3 ,0.01) lines(x, dnorm(x), col=2, lwd=3)
次に混合正規分布を作成してみる。
まずは平均が−3と3の位置にある5:5の混合正規分布を作成し、ヒストグラムと密度関数を重ねてPLOTしてみる。
mu <- 3 SIZE <- 10^4 points <- slice.sample(SIZE, 0, function(x)exp(-0.5*(x+mu)^2)*0.8 + exp(-0.5*(x-mu)^2)*0.2, w=1) hist(points, SIZE^0.5, freq=FALSE,xlim=c(-(mu+3), mu+3)) x<-seq(-(mu+3), mu+3 ,0.01) lines(x, 0.2*dnorm(x-mu)+0.8*dnorm(x+mu), col=2, lwd=3)
確かに混合正規分布もうまくできてそうだ。
ここからもう少し分布間の幅を広げ、平均を−6と6へ変更してみると・・・
mu <- 6 SIZE <- 10^4 points <- slice.sample(SIZE, 0, function(x)exp(-0.5*(x+mu)^2)*0.8 + exp(-0.5*(x-mu)^2)*0.2, w=1) hist(points, SIZE^0.5, freq=FALSE,xlim=c(-(mu+3), mu+3)) x<-seq(-(mu+3), mu+3 ,0.01) lines(x, 0.2*dnorm(x-mu)+0.8*dnorm(x+mu), col=2, lwd=3)
やはり幅が広い(≒多峰性の系)だとこの手法でもうまくいかないことがわかる。
汎用的なサンプリング手法への道は険しい。。。
参考
*1:これは単峰系のみという注釈が原論文にあったので、やらない方がいいかも