14.4. カーネルの実装(bitonic split)

バイトニックソートの実装については、大きく分けて3つの点を考慮する必要があります。

  1. データサイズの2分の1の要素を処理(各カーネルで2つの要素のスワップを行うため)
  2. スワップ距離によって処理する点が異なる。
  3. 上昇配列と下降配列の2つの種類の配列が存在する。

例えばデータの要素数が32の場合にスワップ距離(distance)が8の場合を考えてみましょう。この場合、上昇配列のスワップ・ペアは

となり、下降配列のスワップペアは

となります。この場合にカーネルが処理する点は、

の16の添字(インデックス)となります。つまり32のデータ数の半分を処理することになります。9-15、24-31のデータ点はスワップする点のため、該当するカーネルで処理されるため、OpenCLカーネルが処理するインデックス空間から除外します。

9と24については、上昇・下降配列内を分ける「分岐点」とします。

注目すべき点としては、スワップ距離(値は8)の2倍となる16を境に分けられることです。それが上昇と下降の方向性を切り替える「中間点」であることです。

次に同じ32の要素を持つデータでスワップ距離が2の場合を考えてみましょう。上昇配列のスワップ・ペアは

となります。下降配列のスワップペアは

となります。処理範囲は「0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30」となり、奇数の添字はカーネルがスワップするため、処理点から除外します。

「1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31」については、処理点から除外されますが、上昇・下降配列内の「分岐点」とします。

距離が8の場合に比べると分かりやすいようにも見えますが、上昇配列が8個、下降配列が8個となり、上昇・下降配列の中間点の数も8つとなります。

中間点(2、6、10、14、18、22、26、30)については、スワップ距離の2倍の数(4)を足していった数となります。

バイトニックソートの実装のためには、これら全ての条件を満たすロジックを考案する必要があります。これらの要件を満たすのは複雑に見えるかもしれませんが、難易度は高くありません。あまり小綺麗なコードにしようとすると難しくなるので、見栄えが悪くてもコードが正確に動くことを心がければ、実装は可能かと思います。

実装についての注意点としては、スパゲティ方式で記述すると分岐が増えてパフォーマンスが落ちるので、分岐をできるだけ減らすようにすると良いかと思います。例えば、select関数、比較関数を使えばプロセッサが負担する分岐を減少または緩和することができるので、実装にこれらの関数を用いるのもオプションの一つです。

実装例ではOpenCLのインデックス空間を活用して読者に慣れていただきたいので、ストレートな実装に若干に変更を加えました。

まず処理範囲、処理点の算出については、グローバルワークサイズをスワップ距離(distance)の2倍(つまり上昇・下降配列内の分岐点)として、グローバルIDが上昇配列、下降配列のいずれかの範囲と同一とします。

つまりグローバルIDのインデックス空間をdistanceの2倍の要素の範囲で分割して、それがdistanceより小さい場合に処理を行います。これは以下のようにカーネル内で実装することができます。

int in_range = isless(gid % (distance << 1),distance);
if(in_range) {

}

gid % (distance << 1)がdistanceより小さい場合に処理をするということになります。

それではカーネルのコードをすべて見てみましょう。

uint gid = get_global_id(0);
uint lid = get_local_id(0);

int in_range = isless(gid % (distance << 1),distance); //(1)

if(in_range) { //(2)

    uint middle = distance << 1; //(3)
    int dir_mask = isgreaterequal(gid%(middle*2),middle); //(4)

    uint left = data[gid]; //(5)
    uint right = data[gid+distance]; //(6)
    uint cmp_mask;

    if(dir_mask) { //(7)
        cmp_mask = left < right ? 1 : 0; //(8)
    } else {
        cmp_mask = left > right ? 1 : 0; //(9)
    }
    data[gid] = select(left,right,cmp_mask); //(10)
    data[gid+distance] = select(right,left,cmp_mask); //(11)

}

(1)

グローバルIDのインデックス空間をdistanceの2倍の要素の範囲で分割して、それがdistanceより小さい場合に真を返します。処理すべき添字は、distanceより小さい値を持つ、グローバルIDの場合とします。

(2)

カーネルは、①でdistanceの2倍で除したグローバルIDの値がdistanceより小さい場合は、処理点から除外。if分岐のカッコ内の処理は効率的でないため、前のステップで計算したフラグをif分岐のカッコに配置します。

(3)

スワップ距離の2倍を中間点の倍数として求めます。

(4)

グローバルIDを中間点の2倍(つまり上昇配列と下降配列を足したデータ要素数)で除した剰余が中間点よりも大きいか同じ場合に、「下降」方向フラグ(dir_mask)をオンにします。

(5)

処理点の値をleft変数に保存。

(6)

処理点+distanceの値をright変数に保存

(7)

下降フラグがオンか判定します。

(8)

下降フラグがオンの場合で、leftがrightより小さい場合はスワップフラグ(cmp_mask)をオンにします。

(9)

下降フラグがオンの場合で、leftがrightより大きい場合はスワップフラグ(cmp_mask)をオンにします。

(10)

スワップフラグ(cmp_mask)がオンの場合にleftとrightをスワップ(入れ替え)します。

(11)

スワップフラグ(cmp_mask)がオンの場合にrightとleftをスワップ(入れ替え)します。

上記のコードをデータサイズが16とした場合は、以下のような処理が行われます。dがスワップ距離、gdivが「gid % (distance << 1)」、pairがスワップペア、dataが処理点とスワップ点の値、gmodはグローバルIDを中間点の2倍で除した剰余です。

d = 1, gdiv = 0, pair: 0-1, data: 10, 16, gmod = 0
d = 1, gdiv = 0, pair: 4-5, data: 10, 12, gmod = 0
d = 1, gdiv = 0, pair: 2-3, data: 14, 10, gmod = 2
d = 1, gdiv = 0, pair: 12-13, data: 4, 10, gmod = 0
d = 1, gdiv = 0, pair: 6-7, data: 10, 10, gmod = 2
d = 1, gdiv = 0, pair: 8-9, data: 8, 10, gmod = 0
d = 1, gdiv = 0, pair: 14-15, data: 10, 2, gmod = 2
d = 1, gdiv = 0, pair: 10-11, data: 10, 6, gmod = 2
d = 2, gdiv = 0, pair: 8-10, data: 8, 10, gmod = 0
d = 2, gdiv = 1, pair: 9-11, data: 6, 10, gmod = 1
d = 2, gdiv = 0, pair: 0-2, data: 10, 14, gmod = 0
d = 2, gdiv = 1, pair: 1-3, data: 10, 16, gmod = 1
d = 2, gdiv = 0, pair: 4-6, data: 10, 10, gmod = 4
d = 2, gdiv = 1, pair: 5-7, data: 12, 10, gmod = 5
d = 2, gdiv = 0, pair: 12-14, data: 10, 4, gmod = 4
d = 2, gdiv = 1, pair: 13-15, data: 10, 2, gmod = 5
d = 1, gdiv = 0, pair: 0-1, data: 10, 10, gmod = 0
d = 1, gdiv = 0, pair: 4-5, data: 12, 10, gmod = 4
d = 1, gdiv = 0, pair: 12-13, data: 10, 10, gmod = 4
d = 1, gdiv = 0, pair: 2-3, data: 14, 16, gmod = 2
d = 1, gdiv = 0, pair: 8-9, data: 6, 8, gmod = 0
d = 1, gdiv = 0, pair: 6-7, data: 10, 10, gmod = 6
d = 1, gdiv = 0, pair: 14-15, data: 4, 2, gmod = 6
d = 1, gdiv = 0, pair: 10-11, data: 10, 10, gmod = 2
d = 4, gdiv = 0, pair: 0-4, data: 10, 12, gmod = 0
d = 4, gdiv = 1, pair: 1-5, data: 10, 10, gmod = 1
d = 4, gdiv = 2, pair: 2-6, data: 10, 14, gmod = 2
d = 4, gdiv = 3, pair: 3-7, data: 10, 16, gmod = 3
d = 4, gdiv = 0, pair: 8-12, data: 10, 6, gmod = 8
d = 4, gdiv = 1, pair: 9-13, data: 10, 8, gmod = 9
d = 4, gdiv = 2, pair: 10-14, data: 10, 4, gmod = 10
d = 4, gdiv = 3, pair: 11-15, data: 10, 2, gmod = 11
d = 2, gdiv = 0, pair: 8-10, data: 10, 10, gmod = 8
d = 2, gdiv = 1, pair: 9-11, data: 10, 10, gmod = 9
d = 2, gdiv = 0, pair: 0-2, data: 10, 10, gmod = 0
d = 2, gdiv = 1, pair: 1-3, data: 10, 10, gmod = 1
d = 2, gdiv = 0, pair: 4-6, data: 12, 14, gmod = 4
d = 2, gdiv = 1, pair: 5-7, data: 10, 16, gmod = 5
d = 2, gdiv = 0, pair: 12-14, data: 6, 4, gmod = 12
d = 2, gdiv = 1, pair: 13-15, data: 8, 2, gmod = 13
d = 1, gdiv = 0, pair: 0-1, data: 10, 10, gmod = 0
d = 1, gdiv = 0, pair: 4-5, data: 10, 12, gmod = 4
d = 1, gdiv = 0, pair: 12-13, data: 8, 6, gmod = 12
d = 1, gdiv = 0, pair: 8-9, data: 10, 10, gmod = 8
d = 1, gdiv = 0, pair: 2-3, data: 10, 10, gmod = 2
d = 1, gdiv = 0, pair: 6-7, data: 14, 16, gmod = 6
d = 1, gdiv = 0, pair: 14-15, data: 4, 2, gmod = 14
d = 1, gdiv = 0, pair: 10-11, data: 10, 10, gmod = 10

dが1(distance)の場合に、gdivが常に0(gid % (distance << 1))、dが2の場合では、gdivが0と1、dが4の場合ではgdivが0,1,2,3のいずれかとなっています。

Copyright 2018-2019, by Masaki Komatsu