GPGPU実装は、CPU実装と基本的には同じものです。異なる点はデータ並列化が可能なコードは、カーネル関数内に移されます。
ここで使うカーネル関数は3つあります。まず一つ目が、fft_init関数です。
__kernel void fft_init( __global float2* data, __global float2* F, int N)
この関数は、Radix-2 FFTの導入部分となり、ストライドが1の時に使います。つまりCPU実装でいう以下の行に該当します。
F[offset] = data[offset] + data[offset+N2]; F[offset+1] = data[offset+1] + data[offset+N2+1]; F[offset+N2] = data[offset] - data[offset+N2]; F[offset+N2+1] = data[offset+1] - data[offset+N2+1];
この計算では三角関数が不要となります。またdata変数は生データですが、以後の計算では、F変数を用います。
Radix-2 FFTのGPU実装の該当するFFTカーネル関数は以下のようになります。CPU実装例と比べると、ほとんど同じコードであることがわかると思います。
float2 in0, in1; in0 = F[index]; in1 = F[index+stride]; float angle = -2*M_PI_F*(index)/N; float c,s; float2 v; float2 tmp0; c = native_cos(angle); s = sign*native_sin(angle); v.x = c * (in1.x) - s * in1.y; v.y = c * (in1.y) + s * in1.x; tmp0 = in0; in0 = tmp0 + v; in1 = tmp0 - v; F[index] = in0; F[index + stride] = in1;
Javaの標準パッケージでは、ベクトル型を使用できませんが、OpenCLではfloat2型を使うことにより、xに実数部、yに虚数部とすることで、コード行数を少なくとも半分程度に抑制できます。
FFTのメインのアルゴリズムは以下のfftカーネル関数を使います。
__kernel void fft( __global float2* F, int N, int sign)
この関数はCooley-Tukeyアルゴリズムを実装しますが、該当するCPU実装は以下のようになります。
forward(N2,offset,data,F,sign,step); forward(N2,offset+N2,data,F,sign,step); for(int i = 0; i < N2; i+=2) { c = Math.cos(i*_PI/N2); // (_PI * 2 * k) s = sign*Math.sin(i*_PI/N2); // (_PI * 2 * k) real = F[i+N2+offset]*c + F[i+N2+1+offset]*s; imaginary = F[i+N2+1+offset]*c - F[i+N2+offset]*s; F[i+N2+offset] = F[i+offset] - real; F[i+N2+1+offset] = F[i+1+offset] - imaginary; F[i+offset] += real; F[i+1+offset] += imaginary; }
ここでは、forward関数は、fft関数に該当し、最初の2行で再帰処理をしています。
GPU実装については、再帰処理ができないため、再帰部分はOpenCLホストAPIを使い、残りは記述をカーネル関数に移します。構成としては以下のようになります。
int fftSize = 1; int ns = log2(N); int stages = 0; int[] fftSizePtr = new int[1]; for(int i = 0; i < ns; i++) { fftSize <<= 1; fftSizePtr[0] = fftSize; if(fftSize !=2) { // fftカーネル関数 } else { // fft_initカーネル関数 } }
fft_initは、forループ内の反復の一番初めだけ実行され、残りはfftカーネル関数が代わりに実行されます。
FFTGPU1D.java.
package com.book.jocl.fft; import static org.jocl.CL.*; import java.io.File; import java.net.URL; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.file.Paths; import java.util.Scanner; import org.jocl.CL; import org.jocl.Pointer; import org.jocl.Sizeof; import org.jocl.cl_command_queue; import org.jocl.cl_context; import org.jocl.cl_context_properties; import org.jocl.cl_device_id; import org.jocl.cl_kernel; import org.jocl.cl_mem; import org.jocl.cl_platform_id; import org.jocl.cl_program; public class FFTGPU1D { private static final String KERNEL_PATH = "fft1d.cl"; private static final String KERNEL_INIT = "fft_init"; private static final String KERNEL_BIT_REVERSAL = "bit_reversal"; private static final String KERNEL_FFT = "fft"; private static final String KERNEL_FFT_INVERSE = "fft_inverse"; private static cl_context context; private static cl_command_queue queue; private static cl_program program; private static cl_kernel kernel_init; private static cl_kernel kernel_bit_reversal; private static cl_kernel kernel_fft; private static cl_kernel kernel_fft_inverse; private static final int DATA_SIZE = 256; private static final float[] data = new float[DATA_SIZE<<1]; private static final float[] processed_data = new float[DATA_SIZE << 1]; private static long[] global_work_size = new long[]{DATA_SIZE >>> 1,1,1}; private static long[] local_work_size = new long[]{1,1,1}; private static long[] global_work_size_full = new long[]{DATA_SIZE,1,1}; private static int log2(int b) { int result = 0; if((b & 0xffff0000) != 0) { b >>>= 16; result = 16; } if(b >= 256) { b >>>= 8; result += 8; } if(b >= 16) { b >>>= 4; result += 4; } if(b >= 4) { b >>>= 2; result += 2; } return result + (b >>> 1); } public static void main(String[] args) throws Exception { CL.setExceptionsEnabled(true); cl_platform_id[] platform = new cl_platform_id[1]; cl_device_id[] device = new cl_device_id[1]; int[] num_devices = new int[1]; clGetPlatformIDs(1, platform, null); clGetDeviceIDs(platform[0], CL_DEVICE_TYPE_GPU, 1, device, num_devices); cl_context_properties props = new cl_context_properties(); props.addProperty(CL_CONTEXT_PLATFORM, platform[0]); context = clCreateContext(props, 1, device, null, null, null); queue = clCreateCommandQueue(context, device[0], 0, null); StringBuffer sb = new StringBuffer(); URL resource = FFTGPU1D.class.getResource(KERNEL_PATH) ; String path = Paths.get(resource.toURI()).toFile().getAbsolutePath(); Scanner sc = new Scanner(new File(path)); while(sc.hasNext()) { sb.append(sc.nextLine() + "\n"); } sc.close(); program = clCreateProgramWithSource(context, 1, new String[] {sb.toString()}, null, null); String option = "-Werror"; clBuildProgram(program, 0, null, option, null, null); cl_mem data_mem = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, Sizeof.cl_float2 * DATA_SIZE, Pointer.to(data), null); cl_mem processed_mem = clCreateBuffer(context, CL_MEM_USE_HOST_PTR, Sizeof.cl_float2 * DATA_SIZE, Pointer.to(processed_data), null); kernel_init = clCreateKernel(program, KERNEL_INIT, null); kernel_bit_reversal = clCreateKernel(program, KERNEL_BIT_REVERSAL, null); kernel_fft = clCreateKernel(program, KERNEL_FFT, null); kernel_fft_inverse = clCreateKernel(program, KERNEL_FFT_INVERSE, null); generateSample(); int N = DATA_SIZE; int fftSize = 1; int ns = log2(N); int stages = 0; int[] fftSizePtr = new int[1]; int[] Ni = new int[1]; Ni[0] = N; clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(data_mem)); clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni)); clEnqueueNDRangeKernel(queue, kernel_bit_reversal, 1, null, global_work_size_full, local_work_size, 0, null, null); int[] signPtr = new int[1]; signPtr[0] = 1; for(int i = 0; i < ns; i++) { fftSize <<= 1; fftSizePtr[0] = fftSize; if(fftSize !=2) { clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(processed_mem)); clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr)); clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr)); clEnqueueNDRangeKernel(queue, kernel_fft, 1, null, global_work_size, local_work_size, 0, null, null); } else { clSetKernelArg(kernel_init, 0, Sizeof.cl_mem, Pointer.to(data_mem)); clSetKernelArg(kernel_init, 1, Sizeof.cl_mem, Pointer.to(processed_mem)); clSetKernelArg(kernel_init, 2, Sizeof.cl_uint, Pointer.to(fftSizePtr)); clEnqueueNDRangeKernel(queue, kernel_init, 1, null, global_work_size, local_work_size, 0, null, null); } stages++; } clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(processed_mem)); clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni)); clEnqueueNDRangeKernel(queue, kernel_bit_reversal, 1, null, global_work_size_full, local_work_size, 0, null, null); signPtr[0] = -1; fftSize = 1; stages = 0; for(int i = 0; i < ns; i++) { fftSize <<= 1; fftSizePtr[0] = fftSize; clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(processed_mem)); clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr)); clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr)); clEnqueueNDRangeKernel(queue, kernel_fft, 1, null, global_work_size, local_work_size, 0, null, null); stages++; } clSetKernelArg(kernel_fft_inverse, 0, Sizeof.cl_uint, Pointer.to(Ni)); clSetKernelArg(kernel_fft_inverse, 1, Sizeof.cl_mem, Pointer.to(processed_mem)); long[] global_work_size_scale = new long[]{DATA_SIZE,1,1}; long[] local_work_size_scale = new long[]{1,1,1}; clEnqueueNDRangeKernel(queue, kernel_fft_inverse, 1, null, global_work_size_scale, local_work_size_scale, 0, null, null); ByteBuffer output = clEnqueueMapBuffer(queue, processed_mem, CL_TRUE, CL_MAP_WRITE, 0, Sizeof.cl_float2*DATA_SIZE, 0, null, null, null); clEnqueueUnmapMemObject(queue, processed_mem, output, 0, null, null); clFinish(queue); output.order(ByteOrder.LITTLE_ENDIAN); for(int i = 0; i < DATA_SIZE*2; i++) { System.out.println(output.getFloat()); } clReleaseDevice(device[0]); clReleaseContext(context); clReleaseCommandQueue(queue); clReleaseKernel(kernel_fft); clReleaseKernel(kernel_init); clReleaseProgram(program); } private static void generateSample() { for(int i = 0; i < DATA_SIZE*2; i+=2) { data[i] = i/2; data[i+1] = 0.0f; } } }
fft1d.cl.
inline int reverseBit(int x, int stage) { int b = 0; int bits = stage; while (bits != 0){ b <<=1; b |=( x &1 ); x >>=1; bits>>=1; } return b; } __kernel void bit_reversal(__global float2* data, uint N) { size_t gid = get_global_id(0); uint rev = reverseBit(gid, N-1); float2 in1; float2 in2; if(gid < rev) { in1 = data[gid]; in2 = data[rev]; printf("pair: %d - %d, N = %d\n", gid, rev, N); data[rev] = in1; data[gid] = in2; } } __kernel void fft_init( __global float2* data, __global float2* F, int N) { int gid = get_global_id(0); int stride = N/2; float floor_adjust = gid/stride; int index = ceil(floor_adjust)*stride + (gid); float2 in0, in1; in0 = data[index]; in1 = data[index+stride]; float2 v0; v0 = in0; in0 = v0 + in1; in1 = v0 - in1; F[index] = in0; F[index + stride] = in1; printf("gid=%d, pair: %d - %d, N = %d, s = %d, in0:in1 = %f:%f\n", gid, index, index+stride, N, stride, F[index].x, F[index+stride].x); } __kernel void fft( __global float2* F, int N, int sign) { int gid = get_global_id(0); int stride = N/2; float floor_adjust = gid/stride; int index = ceil(floor_adjust)*stride + (gid); float2 in0, in1; in0 = F[index]; in1 = F[index+stride]; float angle = -2*M_PI_F*(index)/N; float c,s; float2 v; float2 tmp0; c = native_cos(angle); s = sign*native_sin(angle); v.x = c * (in1.x) - s * in1.y; v.y = c * (in1.y) + s * in1.x; tmp0 = in0; in0 = tmp0 + v; in1 = tmp0 - v; F[index] = in0; F[index + stride] = in1; printf("gid=%d, pair: %d - %d, N = %d, s = %d, sign = %d c:s = %f:%f\n in0:in1 = %f:%f\n", gid, index, index+stride, N, stride, sign, c, s, F[index].x, F[index+stride].x); } __kernel void fft_inverse( int N, __global float2* F) { size_t gid = get_global_id(0); F[gid] /= N; }
下記はFFTカーネル関数が出力したFFTの処理情報となります。処理点の数はN、処理点間の距離はs、pairが実行中の2つの処理点(in0、in1)、gidがグローバルIDとなっています。cはcos関数、sはsin関数の値です。
gid=2, pair: 4 - 5, N = 2, s = 1, in0:in1 = 6.000000:-4.000000 gid=0, pair: 0 - 1, N = 2, s = 1, in0:in1 = 4.000000:-4.000000 gid=3, pair: 6 - 7, N = 2, s = 1, in0:in1 = 10.000000:-4.000000 gid=1, pair: 2 - 3, N = 2, s = 1, in0:in1 = 8.000000:-4.000000 gid=2, pair: 4 - 6, N = 4, s = 2, sign = 1 c:s = 1.000000:-0.000000 in0:in1 = 16.000000:-4.000000 gid=0, pair: 0 - 2, N = 4, s = 2, sign = 1 c:s = 1.000000:-0.000000 in0:in1 = 12.000000:-4.000000 gid=3, pair: 5 - 7, N = 4, s = 2, sign = 1 c:s = -0.000000:-1.000000 in0:in1 = -3.999999:-4.000001 gid=1, pair: 1 - 3, N = 4, s = 2, sign = 1 c:s = -0.000000:-1.000000 in0:in1 = -4.000000:-4.000000 gid=2, pair: 2 - 6, N = 8, s = 4, sign = 1 c:s = -0.000000:-1.000000 in0:in1 = -3.999999:-4.000001 gid=0, pair: 0 - 4, N = 8, s = 4, sign = 1 c:s = 1.000000:-0.000000 in0:in1 = 28.000000:-4.000000 gid=3, pair: 3 - 7, N = 8, s = 4, sign = 1 c:s = -0.707110:-0.707110 in0:in1 = -3.999999:-4.000001 gid=1, pair: 1 - 5, N = 8, s = 4, sign = 1 c:s = 0.707110:-0.707110 in0:in1 = -3.999999:-4.000001
下記は上に同じく処理情報ですが、今度は逆(inverse)FFTの情報を採集しています。sign変数が「-1」となっていることに注目ください。
gid=0, pair: 0 - 1, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000 in0:in1 = 24.000000:32.000000 gid=2, pair: 4 - 5, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000 in0:in1 = -8.000000:0.000001 gid=3, pair: 6 - 7, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000 in0:in1 = -8.000000:0.000002 gid=1, pair: 2 - 3, N = 2, s = 1, sign = -1 c:s = 1.000000:0.000000 in0:in1 = -8.000000:0.000002 gid=2, pair: 4 - 6, N = 4, s = 2, sign = -1 c:s = 1.000000:0.000000 in0:in1 = -15.999999:-0.000001 gid=0, pair: 0 - 2, N = 4, s = 2, sign = -1 c:s = 1.000000:0.000000 in0:in1 = 16.000000:32.000000 gid=3, pair: 5 - 7, N = 4, s = 2, sign = -1 c:s = -0.000000:1.000000 in0:in1 = -11.313758:11.313760 gid=1, pair: 1 - 3, N = 4, s = 2, sign = -1 c:s = -0.000000:1.000000 in0:in1 = 24.000000:40.000000 gid=2, pair: 2 - 6, N = 8, s = 4, sign = -1 c:s = -0.000000:1.000000 in0:in1 = 16.000000:48.000000 gid=0, pair: 0 - 4, N = 8, s = 4, sign = -1 c:s = 1.000000:0.000000 in0:in1 = 0.000001:32.000000 gid=3, pair: 3 - 7, N = 8, s = 4, sign = -1 c:s = -0.707110:0.707110 in0:in1 = 23.999861:56.000137 gid=1, pair: 1 - 5, N = 8, s = 4, sign = -1 c:s = 0.707110:0.707110 in0:in1 = 7.999863:40.000137
プログラムが処理を終えた結果は以下のようになります。
1.1920929E-7 -3.5762787E-7 0.99998283 -2.526322E-7 2.0 -5.9604645E-8 2.9999826 -4.7709625E-7 4.0 2.3841858E-7 5.000017 1.9301845E-7 6.0 1.7881393E-7 7.000017 -6.554011E-7
元のデータがほとんど完全な形で復元に成功しています。
Copyright 2018-2019, by Masaki Komatsu