13.5. 2次元カーネルの実装

2次元FFTの実装は、各行・各列のスライス(配列)を独立して処理することになります。

もちろん前の項目で紹介した1次元FFTカーネルを各行、各列にそのまま適用することも可能です。ただこの項目では、OpenCLのカーネルのインデックス空間を2次元に拡大して、1つのカーネルで全行、全列を処理する方法を紹介させていただきます。

まずワークサイズについては以下のように2次元として設定します。

private static long[] global_work_size = new long[]{width/2,height,1};
private static long[] local_work_size = new long[]{1,1,1};

一つ目の次元のワークサイズがwidth/2に設定していますが、FFTでは処理点が全データの要素の半分となるため、ワークサイズ(インデックス空間)についても半分としています。

__kernel void fft2d(
                __global float2* F,
                int N,
                int sign)
{
        int x = get_global_id(0);
        int row_size = get_global_size(0)*2;
        int y = get_global_id(1);
        int stride = N/2;
        float floor_adjust = x/stride;
        int index = ceil(floor_adjust)*stride + (x%row_size);

        float2 in0, in1;

        in0 = F[y*row_size + index];
        in1 = F[y*row_size + index+stride];

        float angle = -2*M_PI_F*(index)/N;

        float c,s;
        float2 v;
        float2 tmp0;

        c = native_cos(angle);
        s = sign*native_sin(angle);

        v.x = c * (in1.x) - s * in1.y;
        v.y = c * (in1.y) + s * in1.x;

        tmp0 = in0;

        in0 = tmp0 + v;
        in1 = tmp0 - v;

        F[y*row_size + index] = in0;
        F[y*row_size + index + stride] = in1;

}

1次元からの大きな変更点としては各行を一つ一つ処理するために、yとrow_sizeの2つの変数が追加されたことです。

in0 = F[y*row_size + index];
in1 = F[y*row_size + index+stride];

この記述では、Fは2次元行列でindexは行の中の要素に対応します。またy*row_sizeは、y行を指定します。

13.5.1. 全体の流れ

擬似フロー(図13.2「図:2次元FFTの手順」)にある通り、2D-FFTの手順には、4回のビット反転、2回の行列転置を行う必要があります。

実装例では以下のような流れで処理をしています。

bit_reverse(data_mem); //(1)
fft_2d(width, data_mem,processed_mem,1,true); //(2)

transpose(processed_mem, transpose_mem); //(3)

bit_reverse(transpose_mem); //(4)
fft_2d(height,null,transpose_mem,1,false); //(5)

bit_reverse(transpose_mem); //(6)
fft_2d(height,null,transpose_mem,-1,false); //(7)
fft_inverse2d(height,transpose_mem); //(8)

transpose(transpose_mem, processed_mem); //(9)

bit_reverse(processed_mem); //(10)
fft_2d(width,null,processed_mem,-1,false); //(11)
fft_inverse2d(width,processed_mem);     //(12)

(1)

元データ(data_mem)のビット反転。

(2)

元データを変換して、処理済みデータ(processed_mem)に書き込み。

(3)

処理済みデータ(processed_mem)を転置して、転置データ(transpose_mem)に書き込み。

(4)

転置データ(transpose_mem)のビット反転。

(5)

転置データ(transpose_mem)のFFTをして転置データ(transpose_mem)を更新。

(6)

転置データ(transpose_mem)のビット反転。

(7)

転置データ(transpose_mem)の逆FFTをして転置データ(transpose_mem)を更新。

(8)

転置データ(transpose_mem)にスケールファクターをかけて値を調整します。

(9)

転置データ(transpose_mem)を転置して、処理済みデータ(processed_mem)に書き込み。

(10)

処理済みデータ(processed_mem)のビット反転。

(11)

処理済みデータ(processed_mem)のFFTをしてデータを更新。

(12)

処理済みデータ(processed_mem)にスケールファクターをかけて値を調整します。

bit_reverse関数がビット反転、transpose関数が転置、fft_2dが高速フーリエ変換(第4引数を-1にすると逆変換)、fft_inverse2dが逆フーリエ変換のスケーリングに対応します。

fft_inverse2dは以下のカーネル関数に対応します。

__kernel void fft_inverse2d(int N, __global float2* F) {
        size_t x = get_global_id(0);
        size_t row_size = get_global_size(0);
        size_t y = get_global_id(1);
        F[y*row_size + x] /= N;
}

この関数は単純に全ての要素に処理するデータサイズ(N)で割り算をしています。

FFTGPU2D.java. 

package com.book.jocl.fft;

import static org.jocl.CL.*;

import java.io.File;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Paths;
import java.util.Scanner;

import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;

public class FFTGPU2D {
        private static final String KERNEL_PATH = "fft2d.cl";
        private static final String KERNEL_INIT = "fft_init";
        private static final String KERNEL_BIT_REVERSAL = "bit_reversal";
        private static final String KERNEL_FFT = "fft2d";
        private static final String KERNEL_FFT_INVERSE = "fft_inverse2d";
        //private static final String KERNEL_FFT_TRANSPOSE = "async_transpose";
        private static final String KERNEL_TRANSPOSE = "transpose";

    private static cl_context context;
    private static cl_command_queue queue;
    private static cl_program program;
    private static cl_kernel kernel_init;
    private static cl_kernel kernel_bit_reversal;
    private static cl_kernel kernel_fft;
    private static cl_kernel kernel_fft_inverse;
    //private static cl_kernel kernel_fft_transpose;
    private static cl_kernel kernel_transpose;

        private static final int width = 16;
        private static final int height = 16;

        private static final int DATA_SIZE = width*height;
        private static final int LOCAL_SIZE = 4;

        private static final float[] data = new float[DATA_SIZE << 1];
        private static final float[] processed_data = new float[DATA_SIZE << 1];

        private static long[] global_work_size = new long[]{width/2,height,1};
        private static long[] local_work_size = new long[]{1,1,1};

        private static long[] global_work_size_rect = new long[]{width,height,1};
        private static long[] local_work_size_rect = new long[]{LOCAL_SIZE,LOCAL_SIZE,1};

        private static int log2(int b) {
                int result = 0;
                if((b & 0xffff0000) != 0) {
                        b >>>= 16;
                        result = 16;
                }
                if(b >= 256) {
                        b >>>= 8;
                        result += 8;
                }
                if(b >= 16) {
                        b >>>= 4;
                        result += 4;
                }
                if(b >= 4) {
                        b >>>= 2;
                        result += 2;
                }
                return result + (b >>> 1);
        }
        public static void main(String[] args) throws Exception {

                CL.setExceptionsEnabled(true);

                cl_platform_id[] platform = new cl_platform_id[1];
                cl_device_id[] device = new cl_device_id[1];
                int[] num_devices = new int[1];

                clGetPlatformIDs(1, platform, null);
                clGetDeviceIDs(platform[0], CL_DEVICE_TYPE_GPU, 1, device, num_devices);

                cl_context_properties props = new cl_context_properties();
                props.addProperty(CL_CONTEXT_PLATFORM, platform[0]);
                context = clCreateContext(props, 1, device, null, null, null);

                queue = clCreateCommandQueue(context, device[0], 0, null);

                StringBuffer sb  = new StringBuffer();
                URL resource = FFTGPU1D.class.getResource(KERNEL_PATH) ;
                String path = Paths.get(resource.toURI()).toFile().getAbsolutePath();
                Scanner sc = new Scanner(new File(path));
                while(sc.hasNext()) {
                        sb.append(sc.nextLine() + "\n");
                }
                sc.close();
                program = clCreateProgramWithSource(context, 1, new String[] {sb.toString()}, null, null);
                StringBuffer op = new StringBuffer();
                op.append("-Werror -Dsize=");
                op.append(LOCAL_SIZE);
                String option = op.toString();
                clBuildProgram(program, 0, null, option, null, null);

            cl_mem data_mem;
            cl_mem processed_mem;
            cl_mem transpose_mem;

                data_mem = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR,
                                Sizeof.cl_float2 * DATA_SIZE, Pointer.to(data), null);

                processed_mem = clCreateBuffer(context, CL_MEM_USE_HOST_PTR,
                                Sizeof.cl_float2 * DATA_SIZE, Pointer.to(processed_data), null);

                transpose_mem = clCreateBuffer(context, CL_MEM_ALLOC_HOST_PTR,
                                Sizeof.cl_float2 * DATA_SIZE, null, null);

                kernel_init = clCreateKernel(program, KERNEL_INIT, null);
                kernel_bit_reversal = clCreateKernel(program, KERNEL_BIT_REVERSAL, null);
                kernel_fft = clCreateKernel(program, KERNEL_FFT, null);
                kernel_fft_inverse = clCreateKernel(program, KERNEL_FFT_INVERSE, null);
                kernel_transpose = clCreateKernel(program, KERNEL_TRANSPOSE, null);

                generateSample();

                bit_reverse(data_mem);
                fft_2d(width, data_mem,processed_mem,1,true);

                transpose(processed_mem, transpose_mem);

                bit_reverse(transpose_mem);
                fft_2d(height,null,transpose_mem,1,false);

                bit_reverse(transpose_mem);
                fft_2d(height,null,transpose_mem,-1,false);
                fft_inverse2d(height,transpose_mem);

                transpose(transpose_mem, processed_mem);

                bit_reverse(processed_mem);
                fft_2d(width,null,processed_mem,-1,false);
                fft_inverse2d(width,processed_mem);

                System.out.println("COMPLETE");

                ByteBuffer outBuffer = clEnqueueMapBuffer(queue,
                                processed_mem,
                                CL_TRUE,
                                CL_MAP_WRITE,
                                0,
                                Sizeof.cl_float2*width*height,
                                0,
                                null,
                                null,
                                null);

                clEnqueueUnmapMemObject(queue, processed_mem, outBuffer, 0, null, null);
                clFinish(queue);

                outBuffer.order(ByteOrder.LITTLE_ENDIAN);

                for(int i = 0; i < width*height*2; i+=2) {
                        System.out.println(outBuffer.getFloat());
                        outBuffer.getFloat();
                }

                clReleaseDevice(device[0]);
                clReleaseContext(context);
                clReleaseCommandQueue(queue);
                clReleaseKernel(kernel_fft);
                clReleaseKernel(kernel_init);
                clReleaseProgram(program);
        }

        private static void generateSample() {
                for(int i = 0; i < width*height*2; i+=2) {
                        data[i] = i;
                        data[i+1] = 0.0f;
                }
        }

        private static void transpose(cl_mem input_mem, cl_mem output_mem) {

                int[] widthPtr = new int[1];
                widthPtr[0]=width;
                int[] heightPtr = new int[1];
                heightPtr[0]=height;

                clSetKernelArg(kernel_transpose, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                clSetKernelArg(kernel_transpose, 1, Sizeof.cl_mem, Pointer.to(output_mem));
                clSetKernelArg(kernel_transpose, 2, Sizeof.cl_int, Pointer.to(widthPtr));
                clSetKernelArg(kernel_transpose, 3, Sizeof.cl_int, Pointer.to(heightPtr));

                clEnqueueNDRangeKernel(queue,
                                kernel_transpose, 2, null,
                                global_work_size_rect,
                                local_work_size_rect,
                                0, null, null);
        }

        private static void fft_inverse2d(int N, cl_mem inverse_mem) {
                int[] Ni = new int[1];
                Ni[0] = N;
                clSetKernelArg(kernel_fft_inverse, 0, Sizeof.cl_int, Pointer.to(Ni));
                clSetKernelArg(kernel_fft_inverse, 1, Sizeof.cl_mem, Pointer.to(inverse_mem));
                clEnqueueNDRangeKernel(queue,
                                kernel_fft_inverse, 2, null,
                                global_work_size_rect,
                                local_work_size,
                                0, null, null);
        }

        private static void bit_reverse(cl_mem input_mem) {
                int N = DATA_SIZE;
                int[] Ni = new int[1];
                Ni[0] = N;

                clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni));

                clEnqueueNDRangeKernel(queue,
                                kernel_bit_reversal, 2, null,
                                global_work_size_rect,
                                local_work_size,
                                0, null, null);
        }

        private static void fft_2d(int N, cl_mem input_mem, cl_mem output_mem, int sign, boolean is_init){
                int fftSize = 1;
                int ns = log2(N);
                int[] fftSizePtr = new int[1];

                int[] signPtr = new int[1];
                signPtr[0] = sign;

                for(int i = 0; i < ns; i++) {
                        fftSize <<= 1;
                        fftSizePtr[0] = fftSize;

                        if(is_init ==  false || fftSize !=2) {

                                clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(output_mem));
                                clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr));
                                clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_fft, 2, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        } else {
                                clSetKernelArg(kernel_init, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                                clSetKernelArg(kernel_init, 1, Sizeof.cl_mem, Pointer.to(output_mem));
                                clSetKernelArg(kernel_init, 2, Sizeof.cl_uint, Pointer.to(fftSizePtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_init, 2, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        }

                }
        }

}

fft2d.cl. 

inline int reverseBit(int x, int stage) {
        int b = 0;
        int bits = stage;
        while (bits != 0){
                  b <<=1;
                  b |=( x &1 );
                  x >>=1;
                  bits>>=1;
        }
        return b;
}

__kernel void bit_reversal(__global float2* data, uint N) {
        size_t x = get_global_id(0);
        size_t row_size = get_global_size(0);
        size_t y = get_global_id(1);
        uint rev = reverseBit(x, row_size-1);
        float2 in1;
        float2 in2;
        if(x < rev) {
                in1 = data[y*row_size + x];
                in2 = data[y*row_size + rev];
                printf("x=%d, y=%d, pair: %d - %d, rs = %d, di=%d\n", x, y, x, rev, row_size,y*row_size + x);
                data[y*row_size + rev] = in1;
                data[y*row_size + x] = in2;
        }
}

__kernel void fft_init(
                __global float2* data,
                __global float2* F,
                int N)
{
        int x = get_global_id(0);
        int row_size = get_global_size(0)*2;
        int y = get_global_id(1);
        int stride = N/2;
        float floor_adjust = x/stride;
        int index = ceil(floor_adjust)*stride + (x%row_size);

        float2 in0, in1;

        in0 = data[y*row_size + index];
        in1 = data[y*row_size + index+stride];

        float2 v0;
        v0 = in0;
        in0 = v0 + in1;
        in1 = v0 - in1;

        F[y*row_size + index] = in0;
        F[y*row_size + index + stride] = in1;

        //printf("x=%d, y=%d, rs=%d, di=%d, pair: %d - %d, N = %d, s = %d, in0:in1 = %f:%f\n", x,y,row_size,y*row_size + index, index, index+stride, N, stride, F[y*row_size + index].x, F[y*row_size + index+stride].x);

}


__kernel void fft2d(
                __global float2* F,
                int N,
                int sign)
{
        int x = get_global_id(0);
        int row_size = get_global_size(0)*2;
        int y = get_global_id(1);
        int stride = N/2;
        float floor_adjust = x/stride;
        int index = ceil(floor_adjust)*stride + (x%row_size);

        float2 in0, in1;

        in0 = F[y*row_size + index];
        in1 = F[y*row_size + index+stride];

        float angle = -2*M_PI_F*(index)/N;

        float c,s;
        float2 v;
        float2 tmp0;

        c = native_cos(angle);
        s = sign*native_sin(angle);

        v.x = c * (in1.x) - s * in1.y;
        v.y = c * (in1.y) + s * in1.x;

        tmp0 = in0;

        in0 = tmp0 + v;
        in1 = tmp0 - v;

        F[y*row_size + index] = in0;
        F[y*row_size + index + stride] = in1;

        //printf("x=%d, y=%d, pair: %d - %d, N = %d, s = %d, sign = %d c:s = %f:%f\n in0:in1 = %f:%f\n", x,y, y*row_size + index, y*row_size + index+stride, N, stride, sign, c, s, F[y*row_size + index].x, F[y*row_size + index+stride].x);

}

__kernel void fft_inverse2d(int N, __global float2* F) {
        size_t x = get_global_id(0);
        size_t row_size = get_global_size(0);
        size_t y = get_global_id(1);
        F[y*row_size + x] /= N;
        //printf("di=%d, F[]=%f\n",y*row_size + x,F[y*row_size + x].x);
}

__kernel void transpose(
                         __global float2* input,
                         __global float2* output,
                         size_t width,
                         size_t height)
{

        __local float2 tile[size * (size+1)];
        size_t x = get_global_id(0);
        size_t y = get_global_id(1);

        size_t lx = get_local_id(0);
        size_t ly = get_local_id(1);

        size_t gx = get_group_id(0);
        size_t gy = get_group_id(1);

        size_t index_input = y * width + x;
    size_t index_tile = ly * (size+1) + lx;
    tile[index_tile]=  input[index_input];
        barrier(CLK_LOCAL_MEM_FENCE);

        size_t ox = gy * size + lx;
        size_t oy = gx * size + ly;

        size_t index_output = oy * height + ox;
        index_tile = lx * (size+1) + ly;
    output[index_output] = tile[index_tile];
    //printf("output[%d %d %d %d]=%f, index_tile=%d, index_out=%d\n",x,y,gx,gy,output[index_output].x, index_tile, index_output);

}

Copyright 2018-2019, by Masaki Komatsu