復元抽出のアルゴリズム
もくてき
粒子フィルタ(パーティクル・フィルタ)を実行する際には、粒子のウェイト(weight)に比例する確率でリサンプリングを実行する必要がある。そのためのアルゴリズムとコードを考えたい。ここでは手元にある各粒子のウェイトはK個の要素からなるベクトルだと仮定して、さらにそれを復元を許してN個リサンプリングするという状況で考える。
こんな面倒な状況を考えなくても、要するにこれは
- いろんな色の球が入ってる壺から、1個適当に球を取り出して、その色をメモって、球を戻す
を複数回繰り返すことと同じで、高校生で習う確率の範囲で理解できる計算なわけだ。
アルゴリズム1(逆変換法)
逆変換法のアイディアを使って以下のようにするのが素朴なアイディアでコードも短い(後述)。
1: weightを確率に直し、その累積確率を計算し、これを{Qk, k=1,2,...K}とする
2: [0, 1]の実数値乱数rを一個生成する
3: rを、k=1,2,..., Kまで順番に累積確率Qkと比較していき、r
アルゴリズム2(ソートを使う方法)
ちょいと頭を使うといかのようなアルゴリズムもできる。
1: weightを確率に直し、その累積確率を計算し、これを{Qk, k=1,2,...K}とする
2: [0, 1]の実数値乱数rをN個生成し、{rn, n=1,2,...N}とする
3: {rn}∪{Qk}として、これをソートし、配列unionとする
4: ソートした配列unionの間に挟まってる"乱数rn"の個数を数える
5: 4で数えた個数ぶんだけ、該当するweightの要素を返却する
速度比較
上述のアルゴリズムのオーダーは、
- アルゴリズム1 : (K-1)*N(乱数との比較がボトルネック。NとKの二重ループが走る。)
- アルゴリズム2 : (N+K)log(N+K)(ソートがボトルネック。ソートにはn*log(n)オーダーの計算アルゴリズムを使う)
なので、データ(K) or リサンプリング(N)数が多い場合には、アルゴリズム2の方が効率的であると考えられるので、それを確かめてみる。
Rで書くと以下のような感じか。乱数は一度に生成するなどの計算時間節約は入れてしまっている。
# アルゴリズム1 resampling1 <- function(weight, size){ prob_cum <- cumsum(weight)/sum(weight) index <- sapply(runif(size), function(x)which(x < prob_cum)[1]) weight[index] } # アルゴリズム2 resampling2 <- function(weight, size){ prob_cum <- cumsum(weight)/sum(weight) size_weight <- length(weight) union <- c(prob_cum, runif(size)) union_indexes <- order(union, decreasing=TRUE) index <- numeric(size) value <- union_indexes[1] counter <- 1 for(union_index in union_indexes[-1]){ if(union_index > size_weight){ index[counter] <- value counter <- counter + 1 }else{ value <- union_index } } weight[index] }
速度の比較結果は
> x <- sample(1:10^4) > system.time(resampling1(x, 10^5)) ユーザ システム 経過 9.35 0.00 9.36 > system.time(resampling2(x, 10^5)) ユーザ システム 経過 0.25 0.00 0.25
というわけで、コードの長さとは裏腹にアルゴリズム2の方が遥かに速い。
・・・で、ここまで書いてから、Rに組み込みのsample関数のprob引数を与えれば、同じ処理ができることに気がついた。なので、アルゴリズム2とsample関数の計算速度を比べてみる。
> system.time(resampling2(x, 10^5)) ユーザ システム 経過 0.25 0.00 0.24 > system.time(sample(x, 10^5, replace=TRUE, prob=1:10^4)) ユーザ システム 経過 0 0 0
・・・Rの実装の方が圧倒的に速い!?!?!?!?!?
Rcppで書きなおす
きぃぃい、悔しい!!!負けるのは悔しいのでRcppでアルゴリズム2を書きなおしてみた。RcppとRcppArmadilloパッケージは
install.packages("Rcpp") install.packages("RcppArmadillo")
として突っ込んでおいておく。
library(Rcpp) sourceCpp(code=' #include <RcppArmadillo.h> using namespace Rcpp; // [[Rcpp::depends(RcppArmadillo)]] // [[Rcpp::export]] NumericVector resamplingCpp(arma::vec weight, int size){ RNGScope scope; const unsigned int size_weight = weight.size(); const arma::vec prob_cum = cumsum(weight)/sum(weight); arma::vec unionset = arma::join_cols(prob_cum, as<arma::vec>(runif(size))); arma::uvec union_indexes = arma::sort_index(unionset, "descend"); arma::uvec index = arma::uvec(size); double value = union_indexes[0]; int counter = 0; const arma::uvec::iterator union_indexes_begin = union_indexes.begin()+1; const arma::uvec::iterator union_indexes_end = union_indexes.end(); for(arma::uvec::iterator union_iterator=union_indexes_begin; union_iterator!=union_indexes_end; ++union_iterator){ if(*union_iterator > (size_weight-1)){ index[counter] = value; counter++; }else{ value = *union_iterator; } } arma::vec result = weight.elem(index); return NumericVector(result.begin(), result.end()); }')
そして、このresamplingCpp関数と、アルゴリズム2(R実装)・sample関数の速度比較してみる。
> system.time(resampling2(x, 10^6)) ユーザ システム 経過 2.90 0.00 2.93 > system.time(sample(x, 10^6, replace=TRUE, prob=1:10^4)) ユーザ システム 経過 0.03 0.02 0.04 > system.time(resamplingCpp(x, 10^6)) ユーザ システム 経過 0.22 0.00 0.22
うおお、なんてこった。元のコード(アルゴリズム2)よりも計算速度が10倍以上速くなったRcppのコードでも負けた…R恐るべし。俺のC++実装の問題な気もするのが、生のCっぽく書かくのはしんどいので、ここでおしまい。sample関数使えばいいや、もう。
答えの確認
ここで俺俺実装した結果が、ちゃんと同じ結果を返すことを確認しておく。
> weight <- sample(1:3) > set.seed(100) > x1 <- resampling1(weight, 10^2) > set.seed(100) > x2 <- resampling2(weight, 10^2) > set.seed(100) > x3 <- resamplingCpp(weight, 10^2) > all(x1[order(x1)]==x2[order(x2)]) [1] TRUE > all(x1[order(x1)]==x3[order(x3)]) [1] TRUE
上のように、乱数のシードを揃えれば確かに全ての結果が一致する。
sample関数については、実際に復元抽出した結果から確認しておく。
> val1 <- table(sample(1:5, 10^5, replace=TRUE, prob=1:5)) > val1/min(val1) 1 2 3 4 5 1.000000 2.026021 3.031165 4.059304 5.012103
確かに、モンテカルロ法の誤差のレベルで、指定したウェイト(1,2,3,4,5)に比例した答えになる。