13.7. FFT-2Dで画像処理

2次元高速フーリエ変換については、一通りカバーしましたが、最後にCPU実装例と比較するために、画像処理に応用する場合のコードを掲載します。

画像データの場合はRGBチャネルの操作のために、コードする手間は(各色彩チャンネル分)3倍になりますが、CPU実装と比べると読みやすく感じられる方も入らっしゃるかもしれません。

両方の実装例ともオブジェクト指向やデザインパターンを使用していないため、可読性は大きく変わることはないでしょうが、ベクトル型を使用することで、虚数部と実数部をパッキングでき、コード数を減らすことが可能となっています。

複素数のパッキングについては以下の「image_to_complex」カーネルで行っています。

__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE |
             CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;

__kernel void image_to_complex(
                read_only image2d_t src_image,
        __global float2* dst_blue,
                __global float2* dst_green,
                __global float2* dst_red) {

        size_t gidx = get_global_id(0);
        size_t gidy = get_global_id(1);
        size_t group_size = get_global_size(0);

        int2 coord = (int2)(gidx, gidy);
        float4 pixel = read_imagef(
                                        src_image,
                                        sampler,
                                        (int2)(coord.x,coord.y)); //(1)

        dst_red[gidy*group_size + gidx].x = pixel.x; //(2)
        dst_red[gidy*group_size + gidx].y = 0.0f; //(3)
        dst_green[gidy*group_size + gidx].x = pixel.y;
        dst_green[gidy*group_size + gidx].y =  0.0f;
        dst_blue[gidy*group_size + gidx].x = pixel.z;
        dst_blue[gidy*group_size + gidx].y = 0.0f;

}

(1)

clEnqueueReadImageで読み込んだ画像のピクセルデータをfloat4型として読み込みます。ここではRGBAの順序で読み込まれると想定します。

(2)

実数部(ベクトルのx)に値を設定します。

(3)

虚数部(ベクトルのy)に値を設定します。FFTの初期値は通常0です。

虚数部のパッキングを除けば、CPU実装例と先行する項目の2次元FFTのアルゴリズムを活用しています。

FFTGPU2Dpng.java. 

package com.book.jocl.fft;

import static org.jocl.CL.*;

import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferInt;
import java.io.File;
import java.net.URL;
import java.nio.file.Paths;
import java.util.Scanner;

import javax.imageio.ImageIO;
import javax.swing.ImageIcon;

import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_image_desc;
import org.jocl.cl_image_format;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;

public class FFTGPU2Dpng {
        private static final String KERNEL_PATH = "fft2dpng.cl";
        private static final String KERNEL_INIT = "fft_init";
        private static final String KERNEL_BIT_REVERSAL = "bit_reversal";
        private static final String KERNEL_FFT = "fft2d";
        private static final String KERNEL_FFT_INVERSE = "fft_inverse2d";
        private static final String KERNEL_TRANSPOSE = "transpose";
        private static final String KERNEL_IMAGE2COMPLEX = "image_to_complex";

    private static cl_context context;
    private static cl_command_queue queue;
    private static cl_program program;
    private static cl_kernel kernel_init;
    private static cl_kernel kernel_bit_reversal;
    private static cl_kernel kernel_fft;
    private static cl_kernel kernel_fft_inverse;
    private static cl_kernel kernel_transpose;
    private static cl_kernel kernel_image_to_complex;

        private static int width;
        private static int height;

        private static final int LOCAL_SIZE = 4;

        private static final String IMAGE_PATH = "SAMPLE.png";
    private static ImageIcon img;

        private static final int _RowSize = 256;
        private static final int _ColSize = 256;

        private static long[] global_work_size = new long[]{_RowSize/2,_ColSize,1};
        private static long[] local_work_size = new long[]{1,1,1};

        private static long[] global_work_size_rect = new long[]{_RowSize,_ColSize,1};
        private static long[] local_work_size_rect = new long[]{LOCAL_SIZE,LOCAL_SIZE,1};

        private static int log2(int b) {
                int result = 0;
                if((b & 0xffff0000) != 0) {
                        b >>>= 16;
                        result = 16;
                }
                if(b >= 256) {
                        b >>>= 8;
                        result += 8;
                }
                if(b >= 16) {
                        b >>>= 4;
                        result += 4;
                }
                if(b >= 4) {
                        b >>>= 2;
                        result += 2;
                }
                return result + (b >>> 1);
        }
        public static void main(String[] args) throws Exception {

                CL.setExceptionsEnabled(true);

                cl_platform_id[] platform = new cl_platform_id[1];
                cl_device_id[] device = new cl_device_id[1];
                int[] num_devices = new int[1];

                clGetPlatformIDs(1, platform, null);
                clGetDeviceIDs(platform[0], CL_DEVICE_TYPE_GPU, 1, device, num_devices);

                cl_context_properties props = new cl_context_properties();
                props.addProperty(CL_CONTEXT_PLATFORM, platform[0]);
                context = clCreateContext(props, 1, device, null, null, null);

                queue = clCreateCommandQueue(context, device[0], 0, null);

                StringBuffer sb  = new StringBuffer();
                URL resource = FFTGPU1D.class.getResource(KERNEL_PATH) ;
                String path = Paths.get(resource.toURI()).toFile().getAbsolutePath();
                Scanner sc = new Scanner(new File(path));
                while(sc.hasNext()) {
                        sb.append(sc.nextLine() + "\n");
                }
                sc.close();
                program = clCreateProgramWithSource(context, 1, new String[] {sb.toString()}, null, null);
                StringBuffer op = new StringBuffer();
                op.append("-Werror -Dsize=");
                op.append(LOCAL_SIZE);
                String option = op.toString();
                clBuildProgram(program, 0, null, option, null, null);

                URL imgResource = FFTCPU2D.class.getResource(IMAGE_PATH);
                String imgPath = Paths.get(imgResource.toURI()).toFile().getAbsolutePath();
                System.out.println(imgPath);
        img = new ImageIcon(imgResource);

        width = img.getIconWidth();
        height = img.getIconHeight();

        BufferedImage bimg = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        Graphics g = bimg.createGraphics();
        img.paintIcon(null, g, 0,0);
        g.dispose();
        DataBufferInt dbb = (DataBufferInt)bimg.getRaster().getDataBuffer();

        cl_image_format format = new cl_image_format();
        format.image_channel_data_type = CL_UNORM_INT8;
        format.image_channel_order = CL_RGBA;
        cl_image_desc desc = new cl_image_desc();
        desc.image_height = img.getIconHeight();
        desc.image_width = img.getIconWidth();
        desc.num_mip_levels = 0;
        desc.num_samples = 0;
        desc.image_type = CL_MEM_OBJECT_IMAGE2D;

        cl_mem img_mem = clCreateImage(context,
                        CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR,
                        format,
                        desc,
                        Pointer.to(dbb.getData()),
                        null);


        cl_mem raw_blue_mem;
        cl_mem raw_red_mem;
        cl_mem raw_green_mem;

        raw_blue_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_RowSize*_ColSize,
                        null, null);
        raw_red_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_RowSize*_ColSize,
                        null, null);
        raw_green_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_RowSize*_ColSize,
                        null, null);

        cl_mem processed_blue_mem;
        cl_mem processed_red_mem;
        cl_mem processed_green_mem;

        processed_blue_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_ColSize*_RowSize,
                        null, null);
        processed_red_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_ColSize*_RowSize,
                        null, null);
        processed_green_mem = clCreateBuffer(context,
                        CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
                        Sizeof.cl_float2 *_ColSize*_RowSize,
                        null, null);

            cl_mem transpose_blue_mem;
            cl_mem transpose_green_mem;
            cl_mem transpose_red_mem;

                transpose_blue_mem = clCreateBuffer(context, CL_MEM_ALLOC_HOST_PTR,
                                Sizeof.cl_float2 * _ColSize* _RowSize, null, null);
                transpose_green_mem = clCreateBuffer(context, CL_MEM_ALLOC_HOST_PTR,
                                Sizeof.cl_float2 * _ColSize * _RowSize, null, null);
                transpose_red_mem = clCreateBuffer(context, CL_MEM_ALLOC_HOST_PTR,
                                Sizeof.cl_float2 * _ColSize * _RowSize, null, null);

                kernel_init = clCreateKernel(program, KERNEL_INIT, null);
                kernel_bit_reversal = clCreateKernel(program, KERNEL_BIT_REVERSAL, null);
                kernel_fft = clCreateKernel(program, KERNEL_FFT, null);
                kernel_fft_inverse = clCreateKernel(program, KERNEL_FFT_INVERSE, null);
                kernel_transpose = clCreateKernel(program, KERNEL_TRANSPOSE, null);
                kernel_image_to_complex = clCreateKernel(program, KERNEL_IMAGE2COMPLEX, null);

        clSetKernelArg(kernel_image_to_complex, 0, Sizeof.cl_mem, Pointer.to(img_mem));
        clSetKernelArg(kernel_image_to_complex, 1, Sizeof.cl_mem, Pointer.to(raw_blue_mem));
        clSetKernelArg(kernel_image_to_complex, 2, Sizeof.cl_mem, Pointer.to(raw_green_mem));
        clSetKernelArg(kernel_image_to_complex, 3, Sizeof.cl_mem, Pointer.to(raw_red_mem));

        clEnqueueNDRangeKernel(queue, kernel_image_to_complex,
                        2, null, global_work_size_rect,
                        local_work_size, 0, null, null);

        bit_reverse(_RowSize*_ColSize,raw_blue_mem);
                bit_reverse(_RowSize*_ColSize,raw_green_mem);
                bit_reverse(_RowSize*_ColSize,raw_red_mem);

                fft_2d(_RowSize, raw_blue_mem, processed_blue_mem,1,true);
                fft_2d(_RowSize, raw_green_mem, processed_green_mem,1,true);
                fft_2d(_RowSize, raw_red_mem, processed_red_mem,1,true);

                transpose(processed_blue_mem, transpose_blue_mem);
                transpose(processed_green_mem, transpose_green_mem);
                transpose(processed_red_mem, transpose_red_mem);

                bit_reverse(_RowSize*_ColSize,transpose_blue_mem);
                bit_reverse(_RowSize*_ColSize,transpose_green_mem);
                bit_reverse(_RowSize*_ColSize,transpose_red_mem);

                fft_2d(_ColSize,null,transpose_blue_mem,1,false);
                fft_2d(_ColSize,null,transpose_green_mem,1,false);
                fft_2d(_ColSize,null,transpose_red_mem,1,false);

                // APPLY FILTER

                bit_reverse(_RowSize*_ColSize,transpose_blue_mem);
                bit_reverse(_RowSize*_ColSize,transpose_green_mem);
                bit_reverse(_RowSize*_ColSize,transpose_red_mem);

                fft_2d(_ColSize,null,transpose_blue_mem,-1,false);
                fft_inverse2d(_ColSize,transpose_blue_mem);
                fft_2d(_ColSize,null,transpose_green_mem,-1,false);
                fft_inverse2d(_ColSize,transpose_green_mem);
                fft_2d(_ColSize,null,transpose_red_mem,-1,false);
                fft_inverse2d(_ColSize,transpose_red_mem);

                transpose(transpose_blue_mem, processed_blue_mem);
                transpose(transpose_green_mem, processed_green_mem);
                transpose(transpose_red_mem, processed_red_mem);

                bit_reverse(_RowSize*_ColSize,processed_blue_mem);
                bit_reverse(_RowSize*_ColSize,processed_green_mem);
                bit_reverse(_RowSize*_ColSize,processed_red_mem);

                fft_2d(_RowSize,null,processed_blue_mem,-1,false);
                fft_inverse2d(_RowSize,processed_blue_mem);
                fft_2d(_RowSize,null,processed_green_mem,-1,false);
                fft_inverse2d(_RowSize,processed_green_mem);
                fft_2d(_RowSize,null,processed_red_mem,-1,false);
                fft_inverse2d(_RowSize,processed_red_mem);

                float[] outBlue = new float[_RowSize*_ColSize*2];
                float[] outGreen = new float[_RowSize*_ColSize*2];
                float[] outRed = new float[_RowSize*_ColSize*2];

        clEnqueueReadBuffer(queue, processed_blue_mem,
                        CL_TRUE, 0, Sizeof.cl_float2*_RowSize*_ColSize,
                        Pointer.to(outBlue), 0, null, null);
        clEnqueueReadBuffer(queue, processed_green_mem,
                        CL_TRUE, 0, Sizeof.cl_float2*_RowSize*_ColSize,
                        Pointer.to(outGreen), 0, null, null);
        clEnqueueReadBuffer(queue, processed_red_mem,
                        CL_TRUE, 0, Sizeof.cl_float2*_RowSize*_ColSize,
                        Pointer.to(outRed), 0, null, null);

                clFlush(queue);

                int[] output = new int[_RowSize*_ColSize];

                // RGB色彩を整数型に集約
        for( int i = 0; i < _RowSize*2; i+=2 ) {
            for( int j = 0; j < _ColSize; j++ ) {

                output[j*_RowSize+i/2] = 0xFF000000
                                | ((int) ((byte)(outBlue[j*_RowSize*2 + i]*255) & 0xff) << 16)
                                | ((int) ((byte)(outGreen[j*_RowSize*2 + i]*255) & 0xff) << 8)
                                | ((int) ((byte)(outRed[j*_RowSize*2 + i]*255) & 0xff));

            }
                }

                BufferedImage image = new BufferedImage(_RowSize, _ColSize, BufferedImage.TYPE_INT_RGB);
                image.setRGB(0, 0, _RowSize, _ColSize, output, 0, _RowSize);

                resource = FFTGPU2Dpng.class.getResource("./");
                path = Paths.get(resource.toURI()).toFile().getAbsolutePath();
                System.out.println(path+"/fft2d_image.png");
        ImageIO.write(image, "png", new File(path+"/fft2d_image.png"));

                clReleaseDevice(device[0]);
                clReleaseContext(context);
                clReleaseCommandQueue(queue);
                clReleaseKernel(kernel_fft);
                clReleaseKernel(kernel_init);
                clReleaseProgram(program);
        }

        private static void transpose(cl_mem input_mem, cl_mem output_mem) {

                int[] widthPtr = new int[1];
                widthPtr[0]=_RowSize;
                int[] heightPtr = new int[1];
                heightPtr[0]=_ColSize;

                clSetKernelArg(kernel_transpose, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                clSetKernelArg(kernel_transpose, 1, Sizeof.cl_mem, Pointer.to(output_mem));
                clSetKernelArg(kernel_transpose, 2, Sizeof.cl_int, Pointer.to(widthPtr));
                clSetKernelArg(kernel_transpose, 3, Sizeof.cl_int, Pointer.to(heightPtr));

                clEnqueueNDRangeKernel(queue,
                                kernel_transpose, 2, null,
                                global_work_size_rect,
                                local_work_size_rect,
                                0, null, null);
        }

        private static void fft_inverse2d(int N, cl_mem inverse_mem) {
                int[] Ni = new int[1];
                Ni[0] = N;
                clSetKernelArg(kernel_fft_inverse, 0, Sizeof.cl_int, Pointer.to(Ni));
                clSetKernelArg(kernel_fft_inverse, 1, Sizeof.cl_mem, Pointer.to(inverse_mem));
                clEnqueueNDRangeKernel(queue,
                                kernel_fft_inverse, 2, null,
                                global_work_size_rect,
                                local_work_size,
                                0, null, null);
        }

        private static void bit_reverse(int N,cl_mem input_mem) {
                int[] Ni = new int[1];
                Ni[0] = N;

                clSetKernelArg(kernel_bit_reversal, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                clSetKernelArg(kernel_bit_reversal, 1, Sizeof.cl_uint, Pointer.to(Ni));

                clEnqueueNDRangeKernel(queue,
                                kernel_bit_reversal, 2, null,
                                global_work_size_rect,
                                local_work_size,
                                0, null, null);
        }

        private static void fft_2d(int N, cl_mem input_mem, cl_mem output_mem, int sign, boolean is_init){
                int fftSize = 1;
                int ns = log2(N);
                int[] fftSizePtr = new int[1];

                int[] signPtr = new int[1];
                signPtr[0] = sign;

                for(int i = 0; i < ns; i++) {
                        fftSize <<= 1;
                        fftSizePtr[0] = fftSize;

                        if(is_init ==  false || fftSize !=2) {

                                clSetKernelArg(kernel_fft, 0, Sizeof.cl_mem, Pointer.to(output_mem));
                                clSetKernelArg(kernel_fft, 1, Sizeof.cl_uint, Pointer.to(fftSizePtr));
                                clSetKernelArg(kernel_fft, 2, Sizeof.cl_int, Pointer.to(signPtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_fft, 2, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        } else {
                                clSetKernelArg(kernel_init, 0, Sizeof.cl_mem, Pointer.to(input_mem));
                                clSetKernelArg(kernel_init, 1, Sizeof.cl_mem, Pointer.to(output_mem));
                                clSetKernelArg(kernel_init, 2, Sizeof.cl_uint, Pointer.to(fftSizePtr));

                                clEnqueueNDRangeKernel(queue,
                                                kernel_init, 2, null,
                                                global_work_size,
                                                local_work_size,
                                                0, null, null);
                        }

                }
        }

}

fft2dpng.cl. 

inline int reverseBit(int x, int stage) {
        int b = 0;
        int bits = stage;
        while (bits != 0){
                  b <<=1;
                  b |=( x &1 );
                  x >>=1;
                  bits>>=1;
        }
        return b;
}


__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE |
             CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;

__kernel void image_to_complex(
                read_only image2d_t src_image,
        __global float2* dst_blue,
                __global float2* dst_green,
                __global float2* dst_red) {

        size_t gidx = get_global_id(0);
        size_t gidy = get_global_id(1);
        size_t group_size = get_global_size(0);

        int2 coord = (int2)(gidx, gidy);
        float4 pixel = read_imagef(
                                        src_image,
                                        sampler,
                                        (int2)(coord.x,coord.y));

        dst_red[gidy*group_size + gidx].x = pixel.x;
        dst_red[gidy*group_size + gidx].y = 0.0f;
        dst_green[gidy*group_size + gidx].x = pixel.y;
        dst_green[gidy*group_size + gidx].y =  0.0f;
        dst_blue[gidy*group_size + gidx].x = pixel.z;
        dst_blue[gidy*group_size + gidx].y = 0.0f;

}

__kernel void bit_reversal(__global float2* data, uint N) {
        size_t x = get_global_id(0);
        size_t row_size = get_global_size(0);
        size_t y = get_global_id(1);
        uint rev = reverseBit(x, row_size-1);
        float2 in1;
        float2 in2;
        if(x < rev) {
                in1 = data[y*row_size + x];
                in2 = data[y*row_size + rev];
                data[y*row_size + rev] = in1;
                data[y*row_size + x] = in2;
        }
}

__kernel void fft_init(
                __global float2* data,
                __global float2* F,
                int N)
{
        int x = get_global_id(0);
        int row_size = get_global_size(0)*2;
        int y = get_global_id(1);
        int stride = N/2;
        float floor_adjust = x/stride;
        int index = ceil(floor_adjust)*stride + (x%row_size);

        float2 in0, in1;

        in0 = data[y*row_size + index];
        in1 = data[y*row_size + index+stride];

        float2 v0;
        v0 = in0;
        in0 = v0 + in1;
        in1 = v0 - in1;

        F[y*row_size + index] = in0;
        F[y*row_size + index + stride] = in1;

}


__kernel void fft2d(
                __global float2* F,
                int N,
                int sign)
{
        int x = get_global_id(0);
        int row_size = get_global_size(0)*2;
        int y = get_global_id(1);
        int stride = N/2;
        float floor_adjust = x/stride;
        int index = ceil(floor_adjust)*stride + (x%row_size);

        float2 in0, in1;

        in0 = F[y*row_size + index];
        in1 = F[y*row_size + index+stride];

        float angle = -2*M_PI_F*(index)/N;

        float c,s;
        float2 v;
        float2 tmp0;

        c = native_cos(angle);
        s = sign*native_sin(angle);

        v.x = c * (in1.x) - s * in1.y;
        v.y = c * (in1.y) + s * in1.x;

        tmp0 = in0;

        in0 = tmp0 + v;
        in1 = tmp0 - v;

        F[y*row_size + index] = in0;
        F[y*row_size + index + stride] = in1;

}

__kernel void fft_inverse2d(int N, __global float2* F) {
        size_t x = get_global_id(0);
        size_t row_size = get_global_size(0);
        size_t y = get_global_id(1);
        F[y*row_size + x] /= N;
        //printf("di=%d, F[]=%f\n",y*row_size + x,F[y*row_size + x].x);
}

__kernel void transpose(
                         __global float2* input,
                         __global float2* output,
                         size_t width,
                         size_t height)
{

        __local float2 tile[size * (size+1)];
        size_t x = get_global_id(0);
        size_t y = get_global_id(1);

        size_t lx = get_local_id(0);
        size_t ly = get_local_id(1);

        size_t gx = get_group_id(0);
        size_t gy = get_group_id(1);

        size_t index_input = y * width + x;
    size_t index_tile = ly * (size+1) + lx;
    tile[index_tile]=  input[index_input];
        barrier(CLK_LOCAL_MEM_FENCE);

        size_t ox = gy * size + lx;
        size_t oy = gx * size + ly;

        size_t index_output = oy * height + ox;
        index_tile = lx * (size+1) + ly;
    output[index_output] = tile[index_tile];

}

Copyright 2018-2019, by Masaki Komatsu