12.6. 実装例(GPU)

GPGPU実装は、CPU実装と基本的には同じものです。異なる点はデータ並列化が可能なコードは、カーネル関数内に移されます。

ここで使うカーネル関数は3つあります。まず一つ目が、fft_init関数です。

__kernel void fft_init(
                __global float2* data,
                __global float2* F,
                int N)

この関数は、Radix-2 FFTの導入部分となり、ストライドが1の時に使います。つまりCPU実装でいう以下の行に該当します。

F[offset] = data[offset] + data[offset+N2];
F[offset+1] = data[offset+1] + data[offset+N2+1];
F[offset+N2] = data[offset] - data[offset+N2];
F[offset+N2+1] = data[offset+1] - data[offset+N2+1];

この計算では三角関数が不要となります。またdata変数は生データですが、以後の計算では、F変数を用います。

Radix-2 FFTのGPU実装の該当するFFTカーネル関数は以下のようになります。CPU実装例と比べると、ほとんど同じコードであることがわかると思います。

float2 in0, in1;

in0 = F[index];
in1 = F[index+stride];

float angle = -2*M_PI_F*(index)/N;

float c,s;
float2 v;
float2 tmp0;

c = native_cos(angle);
s = sign*native_sin(angle);

v.x = c * (in1.x) - s * in1.y;
v.y = c * (in1.y) + s * in1.x;

tmp0 = in0;

in0 = tmp0 + v;
in1 = tmp0 - v;

F[index] = in0;
F[index + stride] = in1;

Javaの標準パッケージでは、ベクトル型を使用できませんが、OpenCLではfloat2型を使うことにより、xに実数部、yに虚数部とすることで、コード行数を少なくとも半分程度に抑制できます。

FFTのメインのアルゴリズムは以下のfftカーネル関数を使います。

__kernel void fft(
                __global float2* F,
                int N,
                int sign)

この関数はCooley-Tukeyアルゴリズムを実装しますが、該当するCPU実装は以下のようになります。

forward(N2,offset,data,F,sign,step);
forward(N2,offset+N2,data,F,sign,step);

for(int i = 0; i < N2; i+=2) {
    c = Math.cos(i*_PI/N2); // (_PI * 2 * k)
    s = sign*Math.sin(i*_PI/N2); // (_PI * 2 * k)

    real = F[i+N2+offset]*c + F[i+N2+1+offset]*s;
    imaginary = F[i+N2+1+offset]*c - F[i+N2+offset]*s;

    F[i+N2+offset] = F[i+offset] - real;
    F[i+N2+1+offset] = F[i+1+offset] - imaginary;
    F[i+offset] += real;
    F[i+1+offset] += imaginary;

}

ここでは、forward関数は、fft関数に該当し、最初の2行で再帰処理をしています。

GPU実装については、再帰処理ができないため、再帰部分はOpenCLホストAPIを使い、残りは記述をカーネル関数に移します。構成としては以下のようになります。

int fftSize = 1;
int ns = log2(N);
int stages = 0;
int[] fftSizePtr = new int[1];
for(int i = 0; i < ns; i++) {
    fftSize <<= 1;
    fftSizePtr[0] = fftSize;

    if(fftSize !=2) {
        // fftカーネル関数
    } else {
        // fft_initカーネル関数
    }
}

fft_initは、forループ内の反復の一番初めだけ実行され、残りはfftカーネル関数が代わりに実行されます。

FFTGPU1D.java. 

package com.book.jocl.fft;

import static org.jocl.CL.*;

import java.io.File;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Paths;
import java.util.Scanner;

import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;

public class FFTGPU1D {

        private static final String KERNEL_PATH = "fft1d.cl";
        private static final String KERNEL_INIT = "fft_init";
        private static final String KERNEL_BIT_REVERSAL = "bit_reversal";
        private static final String KERNEL_FFT = "fft";
        private static final String KERNEL_FFT_INVERSE = "fft_inverse";

    private static cl_context context;
    private static cl_command_queue queue;
    private static cl_program program;
    private static cl_kernel kernel_init;
    private static cl_kernel kernel_bit_reversal;
    private static cl_kernel kernel_fft;
    private static cl_kernel kernel_fft_inverse;

        private static final int DATA_SIZE = 256;
        private static final float[] data = new float[DATA_SIZE<<1];
        private static final float[] processed_data = new float[DATA_SIZE << 1];

        private static long[] global_work_size = new long[]{DATA_SIZE >>> 1,1,1};
        private static long[] local_work_size = new long[]{1,1,1};
        private static long[] global_work_size_full = new long[]{DATA_SIZE,1,1};


        private static int log2(int b) {
                int result = 0;
                if((b & 0xffff0000) != 0) {
                        b >>>= 16;
                        result = 16;
                }
                if(b >= 256) {
                        b >>>= 8;
                        result += 8;
                }
                if(b >= 16) {
                        b >>>= 4;
                        result += 4;
                }
                if(b >= 4) {
                        b >>>= 2;
                        result += 2;
                }
                return result + (b >>> 1);
        }

        public static void main(String[] args) throws Exception {

                CL.setExceptionsEnabled(true);

                cl_platform_id[] platform = new cl_platform_id[1];
                cl_device_id[] device = new cl_device_id[1];
                int[] num_devices = new int[1];

                clGetPlatformIDs(1, platform, null);
                clGetDeviceIDs(platform[0], CL_DEVICE_TYPE_GPU, 1, device, num_devices);

                cl_context_properties props = new cl_context_properties();
                props.addProperty(CL_CONTEXT_PLATFORM, platform[0]);
                context = clCreateContext(props, 1, device, null, null, null);

                queue = clCreateCommandQueue(context, device[0], 0, null);

                StringBuffer sb  = new StringBuffer();
                URL resource = FFTGPU1D.class.getResource(KERNEL_PATH) ;
                String path = Paths.get(resource.toURI()).toFile().getAbsolutePath();
                Scanner sc = new Scanner(new File(path));
                while(sc.hasNext()) {
                        sb.append(sc.nextLine() + "\n");
                }
                sc.close();
                program = clCreateProgramWithSource(context, 1, new String[] {sb.toString()}, null, null);
                String option = "-Werror";
                clBuildProgram(program, 0, null, option, null, null);

                cl_mem data_mem = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR,
                                Sizeof.cl_float2 * DATA_SIZE, Pointer.to(data), null);

                cl_mem processed_mem = clCreateBuffer(context, CL_MEM_USE_HOST_PTR,
                                Sizeof.cl_float2 * DATA_SIZE, Pointer.to(processed_data), null);

                kernel_init = clCreateKernel(program, KERNEL_INIT, null);
                kernel_bit_reversal = clCreateKernel(program, KERNEL_BIT_REVERSAL, null);
                kernel_fft = clCreateKernel(program, KERNEL_FFT, null);
                kernel_fft_inverse = clCreateKernel(program, KERNEL_FFT_INVERSE, null);

                generateSample();

                int N = DATA_SIZE;
                int fftSize = 1;
                int ns = log2(N);
                int stages = 0;
                int[] fftSizePtr = new int[1];

                int[] Ni = new int[1];
                Ni[0] = N;

                clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(data_mem));
                clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni));

                clEnqueueNDRangeKernel(queue,
                                kernel_bit_reversal, 1, null,
                                global_work_size_full,
                                local_work_size,
                                0, null, null);

                int[] signPtr = new int[1];
                signPtr[0] = 1;

                for(int i = 0; i < ns; i++) {
                        fftSize <<= 1;
                        fftSizePtr[0] = fftSize;

                        if(fftSize !=2) {

                                clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(processed_mem));
                                clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr));
                                clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_fft, 1, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        } else {
                                clSetKernelArg(kernel_init, 0, Sizeof.cl_mem, Pointer.to(data_mem));
                                clSetKernelArg(kernel_init, 1, Sizeof.cl_mem, Pointer.to(processed_mem));
                                clSetKernelArg(kernel_init, 2, Sizeof.cl_uint, Pointer.to(fftSizePtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_init, 1, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        }

                        stages++;
                }

                clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(processed_mem));
                clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni));

                clEnqueueNDRangeKernel(queue,
                                kernel_bit_reversal, 1, null,
                                global_work_size_full,
                                local_work_size,
                                0, null, null);

                signPtr[0] = -1;
                fftSize = 1;
                stages = 0;

                for(int i = 0; i < ns; i++) {
                        fftSize <<= 1;
                        fftSizePtr[0] = fftSize;

                        clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(processed_mem));
                        clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr));
                        clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr));

                        clEnqueueNDRangeKernel(queue,
                                        kernel_fft, 1, null,
                                        global_work_size,
                                        local_work_size,
                                        0, null, null);

                        stages++;
                }

                clSetKernelArg(kernel_fft_inverse, 0, Sizeof.cl_uint, Pointer.to(Ni));
                clSetKernelArg(kernel_fft_inverse, 1, Sizeof.cl_mem, Pointer.to(processed_mem));

                long[] global_work_size_scale = new long[]{DATA_SIZE,1,1};
                long[] local_work_size_scale = new long[]{1,1,1};
                clEnqueueNDRangeKernel(queue,
                                kernel_fft_inverse, 1, null,
                                global_work_size_scale,
                                local_work_size_scale,
                                0, null, null);

                ByteBuffer output = clEnqueueMapBuffer(queue,
                                processed_mem,
                                CL_TRUE,
                                CL_MAP_WRITE,
                                0,
                                Sizeof.cl_float2*DATA_SIZE,
                                0,
                                null,
                                null,
                                null);

                clEnqueueUnmapMemObject(queue, processed_mem, output, 0, null, null);
                clFinish(queue);

                output.order(ByteOrder.LITTLE_ENDIAN);

                for(int i = 0; i < DATA_SIZE*2; i++) {
                        System.out.println(output.getFloat());
                }

                clReleaseDevice(device[0]);
                clReleaseContext(context);
                clReleaseCommandQueue(queue);
                clReleaseKernel(kernel_fft);
                clReleaseKernel(kernel_init);
                clReleaseProgram(program);
        }

        private static void generateSample() {
                for(int i = 0; i < DATA_SIZE*2; i+=2) {
                        data[i] = i/2;
                        data[i+1] = 0.0f;
                }
        }
}

fft1d.cl. 

inline int reverseBit(int x, int stage) {
        int b = 0;
        int bits = stage;
        while (bits != 0){
                  b <<=1;
                  b |=( x &1 );
                  x >>=1;
                  bits>>=1;
        }
        return b;
}

__kernel void bit_reversal(__global float2* data, uint N) {
        size_t gid = get_global_id(0);
        uint rev = reverseBit(gid, N-1);
        float2 in1;
        float2 in2;
        if(gid < rev) {
                in1 = data[gid];
                in2 = data[rev];
                printf("pair: %d - %d, N = %d\n", gid, rev, N);
                data[rev] = in1;
                data[gid] = in2;
        }
}

__kernel void fft_init(
                __global float2* data,
                __global float2* F,
                int N)
{
        int gid = get_global_id(0);
        int stride = N/2;
        float floor_adjust = gid/stride;
        int index = ceil(floor_adjust)*stride + (gid);

        float2 in0, in1;

        in0 = data[index];
        in1 = data[index+stride];

        float2 v0;
        v0 = in0;
        in0 = v0 + in1;
        in1 = v0 - in1;

        F[index] = in0;
        F[index + stride] = in1;

        printf("gid=%d, pair: %d - %d, N = %d, s = %d, in0:in1 = %f:%f\n", gid, index, index+stride, N, stride, F[index].x, F[index+stride].x);

}


__kernel void fft(
                __global float2* F,
                int N,
                int sign)
{
        int gid = get_global_id(0);
        int stride = N/2;
        float floor_adjust = gid/stride;
        int index = ceil(floor_adjust)*stride + (gid);

        float2 in0, in1;

        in0 = F[index];
        in1 = F[index+stride];

        float angle = -2*M_PI_F*(index)/N;

        float c,s;
        float2 v;
        float2 tmp0;

        c = native_cos(angle);
        s = sign*native_sin(angle);

        v.x = c * (in1.x) - s * in1.y;
        v.y = c * (in1.y) + s * in1.x;

        tmp0 = in0;

        in0 = tmp0 + v;
        in1 = tmp0 - v;

        F[index] = in0;
        F[index + stride] = in1;

        printf("gid=%d, pair: %d - %d, N = %d, s = %d, sign = %d c:s = %f:%f\n in0:in1 = %f:%f\n", gid, index, index+stride, N, stride, sign, c, s, F[index].x, F[index+stride].x);

}

__kernel void fft_inverse(
    int N,
    __global float2* F)
{
        size_t gid = get_global_id(0);
        F[gid] /= N;
}

下記はFFTカーネル関数が出力したFFTの処理情報となります。処理点の数はN、処理点間の距離はs、pairが実行中の2つの処理点(in0、in1)、gidがグローバルIDとなっています。cはcos関数、sはsin関数の値です。

gid=2, pair: 4 - 5, N = 2, s = 1, in0:in1 = 6.000000:-4.000000
gid=0, pair: 0 - 1, N = 2, s = 1, in0:in1 = 4.000000:-4.000000
gid=3, pair: 6 - 7, N = 2, s = 1, in0:in1 = 10.000000:-4.000000
gid=1, pair: 2 - 3, N = 2, s = 1, in0:in1 = 8.000000:-4.000000
gid=2, pair: 4 - 6, N = 4, s = 2, sign = 1 c:s = 1.000000:-0.000000
 in0:in1 = 16.000000:-4.000000
gid=0, pair: 0 - 2, N = 4, s = 2, sign = 1 c:s = 1.000000:-0.000000
 in0:in1 = 12.000000:-4.000000
gid=3, pair: 5 - 7, N = 4, s = 2, sign = 1 c:s = -0.000000:-1.000000
 in0:in1 = -3.999999:-4.000001
gid=1, pair: 1 - 3, N = 4, s = 2, sign = 1 c:s = -0.000000:-1.000000
 in0:in1 = -4.000000:-4.000000
gid=2, pair: 2 - 6, N = 8, s = 4, sign = 1 c:s = -0.000000:-1.000000
 in0:in1 = -3.999999:-4.000001
gid=0, pair: 0 - 4, N = 8, s = 4, sign = 1 c:s = 1.000000:-0.000000
 in0:in1 = 28.000000:-4.000000
gid=3, pair: 3 - 7, N = 8, s = 4, sign = 1 c:s = -0.707110:-0.707110
 in0:in1 = -3.999999:-4.000001
gid=1, pair: 1 - 5, N = 8, s = 4, sign = 1 c:s = 0.707110:-0.707110
 in0:in1 = -3.999999:-4.000001

下記は上に同じく処理情報ですが、今度は逆(inverse)FFTの情報を採集しています。sign変数が「-1」となっていることに注目ください。

gid=0, pair: 0 - 1, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = 24.000000:32.000000
gid=2, pair: 4 - 5, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = -8.000000:0.000001
gid=3, pair: 6 - 7, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = -8.000000:0.000002
gid=1, pair: 2 - 3, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = -8.000000:0.000002
gid=2, pair: 4 - 6, N = 4, s = 2, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = -15.999999:-0.000001
gid=0, pair: 0 - 2, N = 4, s = 2, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = 16.000000:32.000000
gid=3, pair: 5 - 7, N = 4, s = 2, sign = -1 c:s = -0.000000:1.000000
 in0:in1 = -11.313758:11.313760
gid=1, pair: 1 - 3, N = 4, s = 2, sign = -1 c:s = -0.000000:1.000000
 in0:in1 = 24.000000:40.000000
gid=2, pair: 2 - 6, N = 8, s = 4, sign = -1 c:s = -0.000000:1.000000
 in0:in1 = 16.000000:48.000000
gid=0, pair: 0 - 4, N = 8, s = 4, sign = -1 c:s = 1.000000:0.000000
 in0:in1 = 0.000001:32.000000
gid=3, pair: 3 - 7, N = 8, s = 4, sign = -1 c:s = -0.707110:0.707110
 in0:in1 = 23.999861:56.000137
gid=1, pair: 1 - 5, N = 8, s = 4, sign = -1 c:s = 0.707110:0.707110
 in0:in1 = 7.999863:40.000137

プログラムが処理を終えた結果は以下のようになります。

1.1920929E-7
-3.5762787E-7
0.99998283
-2.526322E-7
2.0
-5.9604645E-8
2.9999826
-4.7709625E-7
4.0
2.3841858E-7
5.000017
1.9301845E-7
6.0
1.7881393E-7
7.000017
-6.554011E-7

元のデータがほとんど完全な形で復元に成功しています。

Copyright 2018-2019, by Masaki Komatsu