diff options
author | Denis Steckelmacher <steckdenis@yahoo.fr> | 2011-07-10 16:32:08 +0200 |
---|---|---|
committer | Denis Steckelmacher <steckdenis@yahoo.fr> | 2011-07-10 16:32:08 +0200 |
commit | e1a734b1064bda51ebd910208fc5e4f6a1bef5f8 (patch) | |
tree | f04f120fc2f4b9720c1eaf76ed2f03661c2ed6d5 | |
parent | b812a5c65a720e3bfe49a9e44b57deef656ce2ef (diff) |
Implement clSetKernelArg.
-rw-r--r-- | src/api/api_kernel.cpp | 5 | ||||
-rw-r--r-- | src/core/kernel.cpp | 167 | ||||
-rw-r--r-- | src/core/kernel.h | 15 | ||||
-rw-r--r-- | src/runtime/stdlib.h | 2 | ||||
-rw-r--r-- | tests/test_kernel.cpp | 19 |
5 files changed, 165 insertions, 43 deletions
diff --git a/src/api/api_kernel.cpp b/src/api/api_kernel.cpp index e30ad01..94f6ea3 100644 --- a/src/api/api_kernel.cpp +++ b/src/api/api_kernel.cpp @@ -135,7 +135,10 @@ clSetKernelArg(cl_kernel kernel, size_t arg_size, const void * arg_value) { - return 0; + if (!kernel) + return CL_INVALID_KERNEL; + + return kernel->setArg(arg_indx, arg_size, arg_value); } cl_int diff --git a/src/core/kernel.cpp b/src/core/kernel.cpp index a6e6dc6..f9fcf35 100644 --- a/src/core/kernel.cpp +++ b/src/core/kernel.cpp @@ -2,6 +2,7 @@ #include <string> #include <iostream> +#include <cstring> #include <llvm/Support/Casting.h> #include <llvm/Module.h> @@ -78,72 +79,79 @@ cl_int Kernel::addFunction(DeviceInterface *device, llvm::Function *function, a.kind = Arg::Invalid; a.vec_dim = 1; + a.file = Arg::Private; + a.kernel_alloc_size = 0; + a.set = false; if (arg_type->isPointerTy()) { // It's a pointer, dereference it const llvm::PointerType *p_type = llvm::cast<llvm::PointerType>(arg_type); - a.kind = Arg::Buffer; // Buffer by default, can be refined + a.file = (Arg::File)p_type->getAddressSpace(); arg_type = p_type->getElementType(); - } - - if (arg_type->isVectorTy()) - { - // It's a vector, we need its element's type - const llvm::VectorType *v_type = llvm::cast<llvm::VectorType>(arg_type); - - a.vec_dim = v_type->getNumElements(); - arg_type = v_type->getElementType(); - } - // Get type kind - if (arg_type->isFloatTy()) - { - a.kind = Arg::Float; - } - else if (arg_type->isDoubleTy()) - { - a.kind = Arg::Double; - } - else if (arg_type->isIntegerTy()) - { - const llvm::IntegerType *i_type = llvm::cast<llvm::IntegerType>(arg_type); + // Get the name of the type to see if it's something like image2d, etc + std::string name = module->getTypeName(arg_type); - if (i_type->getBitWidth() == 8) + if (name == "image2d") { - a.kind = Arg::Int8; + // TODO: Address space qualifiers for image types, and read_only + a.kind = Arg::Image2D; } - else if (i_type->getBitWidth() == 16) + else if (name == "image3d") { - a.kind = Arg::Int16; + a.kind = Arg::Image3D; } - else if (i_type->getBitWidth() == 32) + else if (name == "sampler") { - a.kind = Arg::Int32; + // TODO: Sampler } - else if (i_type->getBitWidth() == 64) + else { - a.kind = Arg::Int64; + a.kind = Arg::Buffer; } } else { - // Get the name of the type to see if it's something like image2d, etc - std::string name = module->getTypeName(arg_type); + if (arg_type->isVectorTy()) + { + // It's a vector, we need its element's type + const llvm::VectorType *v_type = llvm::cast<llvm::VectorType>(arg_type); - if (name == "image2d") + a.vec_dim = v_type->getNumElements(); + arg_type = v_type->getElementType(); + } + + // Get type kind + if (arg_type->isFloatTy()) { - // TODO: Address space qualifiers for image types, and read_only - a.kind = Arg::Image2D; + a.kind = Arg::Float; } - else if (name == "image3d") + else if (arg_type->isDoubleTy()) { - a.kind = Arg::Image3D; + a.kind = Arg::Double; } - else if (name == "sampler") + else if (arg_type->isIntegerTy()) { - // TODO: Sampler + const llvm::IntegerType *i_type = llvm::cast<llvm::IntegerType>(arg_type); + + if (i_type->getBitWidth() == 8) + { + a.kind = Arg::Int8; + } + else if (i_type->getBitWidth() == 16) + { + a.kind = Arg::Int16; + } + else if (i_type->getBitWidth() == 32) + { + a.kind = Arg::Int32; + } + else if (i_type->getBitWidth() == 64) + { + a.kind = Arg::Int64; + } } } @@ -172,6 +180,85 @@ llvm::Function *Kernel::function(DeviceInterface *device) const return dep.function; } +size_t Kernel::Arg::valueSize() const +{ + switch (kind) + { + case Invalid: + return 0; + case Int8: + return 1; + case Int16: + return 2; + case Int32: + return 4; + case Int64: + return 8; + case Float: + return sizeof(cl_float); + case Double: + return sizeof(double); + case Buffer: + case Image2D: + case Image3D: + return sizeof(cl_mem); + } +} + +cl_int Kernel::setArg(cl_uint index, size_t size, const void *value) +{ + if (index > p_args.size()) + return CL_INVALID_ARG_INDEX; + + Arg &arg = p_args[index]; + + // Special case for __local pointers + if (arg.file == Arg::Local) + { + if (size == 0) + return CL_INVALID_ARG_SIZE; + + if (value != 0) + return CL_INVALID_ARG_VALUE; + + arg.kernel_alloc_size = size; + + return CL_SUCCESS; + } + + // Check that size corresponds to the arg type + size_t arg_size = arg.valueSize(); + + if (size != arg_size) + return CL_INVALID_ARG_SIZE; + + // Check for null values + if (!value) + { + switch (arg.kind) + { + case Arg::Buffer: + case Arg::Image2D: + case Arg::Image3D: + // Special case buffers : value can be 0 (or point to 0) + arg.value.cl_mem_val = 0; + arg.set = true; + return CL_SUCCESS; + + // TODO samplers + default: + return CL_INVALID_ARG_VALUE; + } + } + + // Copy the data + std::memcpy(&arg.value, value, arg_size); + + arg.set = true; + + return CL_SUCCESS; +} + Program *Kernel::program() const { return p_program; diff --git a/src/core/kernel.h b/src/core/kernel.h index 3f30d53..accc8a5 100644 --- a/src/core/kernel.h +++ b/src/core/kernel.h @@ -29,12 +29,23 @@ class Kernel cl_int addFunction(DeviceInterface *device, llvm::Function *function, llvm::Module *module); llvm::Function *function(DeviceInterface *device) const; + cl_int setArg(cl_uint index, size_t size, const void *value); Program *program() const; struct Arg { unsigned short vec_dim; + bool set; + size_t kernel_alloc_size; /*!< Size of the memory that must be allocated at kernel execution */ + + enum File + { + Private = 0, + Global = 1, + Local = 2, + Constant = 3 + } file; enum Kind { @@ -50,6 +61,7 @@ class Kernel Image3D // TODO: Sampler } kind; + union { #define TYPE_VAL(type) type type##_val @@ -61,12 +73,13 @@ class Kernel TYPE_VAL(double); TYPE_VAL(cl_mem); #undef TYPE_VAL - }; + } value; inline bool operator !=(const Arg &b) { return (kind != b.kind) || (vec_dim != b.vec_dim); } + size_t valueSize() const; }; private: diff --git a/src/runtime/stdlib.h b/src/runtime/stdlib.h index f673010..f1c7749 100644 --- a/src/runtime/stdlib.h +++ b/src/runtime/stdlib.h @@ -39,10 +39,10 @@ COAL_VECTOR_SET(float); #undef COAL_VECTOR /* Address spaces */ +#define __private __attribute__((address_space(0))) #define __global __attribute__((address_space(1))) #define __local __attribute__((address_space(2))) #define __constant __attribute__((address_space(3))) -#define __private __attribute__((address_space(4))) #define global __global #define local __local diff --git a/tests/test_kernel.cpp b/tests/test_kernel.cpp index 94f4e0d..bce4e03 100644 --- a/tests/test_kernel.cpp +++ b/tests/test_kernel.cpp @@ -1,3 +1,5 @@ +#include <iostream> + #include "test_kernel.h" #include "CL/cl.h" @@ -45,10 +47,13 @@ START_TEST (test_compiled_kernel) cl_int result; cl_kernel kernels[2]; cl_uint num_kernels; + cl_mem buf; const char *src = source; size_t program_len = sizeof(source); + int buffer[64]; + result = clGetDeviceIDs(platform, CL_DEVICE_TYPE_DEFAULT, 1, &device, 0); fail_if( result != CL_SUCCESS, @@ -99,6 +104,20 @@ START_TEST (test_compiled_kernel) "unable to get the two kernels of the program" ); + // Try to run kernel2 + buf = clCreateBuffer(ctx, CL_MEM_READ_WRITE | CL_MEM_USE_HOST_PTR, + sizeof(buffer), buffer, &result); + fail_if( + result != CL_SUCCESS, + "cannot create a valid CL_MEM_COPY_HOST_PTR read-write buffer" + ); + + result = clSetKernelArg(kernels[1], 0, sizeof(cl_mem), &buf); + fail_if( + result != CL_SUCCESS, + "cannot set kernel argument" + ); + clReleaseKernel(kernels[0]); clReleaseKernel(kernels[1]); clReleaseProgram(program); |