テンプレートで書くか継承で書くか

数値計算C++っぽく書く場合、速度的な事を考えるといちいち仮想関数テーブルを漁りにいかないのでテンプレートの方が速いんだろうな」と思っていたが、例のごとく億劫がってやってなかったので、やってみましたよというお話。

結論としてはテンプレートを使った方が2〜3倍程度速かったので積極的にテンプレートを使う方向でやっていきたい。

比較対象は

  • x^3の積分を中点公式で積分する操作を500,000回繰り返す

のに要した実行時間。

結果(/O2最適化)は

継承版(ミリ秒):24879 結果:2.0287e+009
テンプレート版(ミリ秒):9700 結果:2.0287e+009

となり、テンプレート版の方が2〜3倍程度速くなる。

以下、使用したコード。

#include <iostream>
#include <iomanip>
#include <vector>
#include <numeric>
#include <time.h>
//積分区間のグリッドサイズ
static const int GRID = 10000;
//継承使って書いた版
class MidpointRule1
{
public :
    virtual double operator()(double x) = 0;
    double Integate(double x_lower, double x_upper)
    {
        double dx = (x_upper - x_lower)/GRID;
        double result = 0.0;
        for(int i = 0; i < GRID; ++i){
            result += this->operator ()(x_lower + (i+0.5)*dx);
        }
        return (result * dx);
    }
};
class F1 : public MidpointRule1
{
    double operator()(double x){return x*x*x;}
};

//テンプレートを使って作った版
template<class Integration>
class F2
{
public:
    F2():integration_(Integration()){}
    double operator()(double x){return x*x*x;}
    double Integate(double x_lower, double x_upper)
    {
        return integration_.Integrate(x_lower, x_upper, *this);
    }
private: 
    Integration integration_;
};
class MidpointRule2
{
public:
    template<class Integrand>
    double Integrate(double x_lower, double x_upper, Integrand & integrand)
    {
        double dx = (x_upper - x_lower)/GRID;
        double result = 0.0;
        for(int i = 0; i < GRID; ++i){
            result += integrand(x_lower + (i+0.5)*dx);
        }
        return (result * dx);
    }
};

int main()
{
    //時間計測用変数
    clock_t start, finish;
    //積分の繰り返し回数
    const int N = 500000;
    //結果格納用ベクトル
    std::vector<double> result(N);

    //継承版
    F1 f1;
    start = clock();
    for(int i = 0; i < N; ++i){result[i] = f1.Integate(-i, i);}
    finish = clock();
    std::cout << "継承版(ミリ秒):" << (finish - start) << " 結果:" << std::accumulate(result.begin(), result.end(), 0.0) << std::endl;

    //template版
    F2<MidpointRule2> f2;
    start = clock();
    for(int i = 0; i < N; ++i){result[i] = f2.Integate(-i, i);}
    finish = clock();
    std::cout << "テンプレート版(ミリ秒):" << (finish - start) << " 結果:" << std::accumulate(result.begin(), result.end(), 0.0) << std::endl;

    return 0;
}