diff options
author | Junyan He <junyan.he@linux.intel.com> | 2014-06-10 12:52:54 +0800 |
---|---|---|
committer | Zhigang Gong <zhigang.gong@intel.com> | 2014-06-11 11:04:04 +0800 |
commit | 42cdacda1b5217c593d23621775651a7143a353c (patch) | |
tree | 20e60920e80d69c9424a0b449df6c7f2f1450763 | |
parent | 91afb05c0f6440ada79a0e6896b64aa0bc3b2ca2 (diff) |
Add the PrintfParser llvm parser into the llvm backend.
The PrintfParser will work before the llvm gen backend.
It will filter out all the printf function call. When
the printf call found, we will analyse the print format
and % place holder here. Replace the print call with
STORE or CONV+STORE instruction if needed.
Signed-off-by: Junyan He <junyan.he@linux.intel.com>
Reviewed-by: Zhigang Gong <zhigang.gong@linux.intel.com>
-rw-r--r-- | backend/src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | backend/src/ir/function.cpp | 1 | ||||
-rw-r--r-- | backend/src/ir/function.hpp | 4 | ||||
-rw-r--r-- | backend/src/llvm/llvm_gen_backend.cpp | 5 | ||||
-rw-r--r-- | backend/src/llvm/llvm_gen_backend.hpp | 4 | ||||
-rw-r--r-- | backend/src/llvm/llvm_printf_parser.cpp | 677 | ||||
-rw-r--r-- | backend/src/llvm/llvm_to_gen.cpp | 1 |
7 files changed, 693 insertions, 0 deletions
diff --git a/backend/src/CMakeLists.txt b/backend/src/CMakeLists.txt index 0716d357..60901747 100644 --- a/backend/src/CMakeLists.txt +++ b/backend/src/CMakeLists.txt @@ -147,6 +147,7 @@ else (GBE_USE_BLOB) llvm/llvm_scalarize.cpp llvm/llvm_intrinsic_lowering.cpp llvm/llvm_barrier_nodup.cpp + llvm/llvm_printf_parser.cpp llvm/llvm_to_gen.cpp llvm/llvm_loadstore_optimization.cpp llvm/llvm_gen_backend.hpp diff --git a/backend/src/ir/function.cpp b/backend/src/ir/function.cpp index b0df412d..a46108ea 100644 --- a/backend/src/ir/function.cpp +++ b/backend/src/ir/function.cpp @@ -48,6 +48,7 @@ namespace ir { initProfile(*this); samplerSet = GBE_NEW(SamplerSet); imageSet = GBE_NEW(ImageSet); + printfSet = GBE_NEW(PrintfSet); } Function::~Function(void) { diff --git a/backend/src/ir/function.hpp b/backend/src/ir/function.hpp index 266e6526..63bb6ea5 100644 --- a/backend/src/ir/function.hpp +++ b/backend/src/ir/function.hpp @@ -29,6 +29,7 @@ #include "ir/instruction.hpp" #include "ir/profile.hpp" #include "ir/sampler.hpp" +#include "ir/printf.hpp" #include "ir/image.hpp" #include "sys/vector.hpp" #include "sys/set.hpp" @@ -329,6 +330,8 @@ namespace ir { SamplerSet* getSamplerSet(void) const {return samplerSet; } /*! Get image set in this function */ ImageSet* getImageSet(void) const {return imageSet; } + /*! Get printf set in this function */ + PrintfSet* getPrintfSet(void) const {return printfSet; } /*! Set required work group size. */ void setCompileWorkGroupSize(size_t x, size_t y, size_t z) { compileWgSize[0] = x; compileWgSize[1] = y; compileWgSize[2] = z; } /*! Get required work group size. */ @@ -360,6 +363,7 @@ namespace ir { uint32_t stackSize; //!< stack size for private memory. SamplerSet *samplerSet; //!< samplers used in this function. ImageSet* imageSet; //!< Image set in this function's arguments.. + PrintfSet *printfSet; //!< printfSet store the printf info. size_t compileWgSize[3]; //!< required work group size specified by // __attribute__((reqd_work_group_size(X, Y, Z))). GBE_CLASS(Function); //!< Use custom allocator diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp index 78028182..4bb90393 100644 --- a/backend/src/llvm/llvm_gen_backend.cpp +++ b/backend/src/llvm/llvm_gen_backend.cpp @@ -2970,7 +2970,12 @@ namespace gbe #undef DEF case GEN_OCL_PRINTF: + { + ir::PrintfSet::PrintfFmt* fmt = (ir::PrintfSet::PrintfFmt*)getPrintfInfo(&I); + ctx.getFunction().getPrintfSet()->append(fmt, unit); + assert(fmt); break; + } case GEN_OCL_PRINTF_BUF_ADDR: case GEN_OCL_PRINTF_INDEX_BUF_ADDR: default: break; diff --git a/backend/src/llvm/llvm_gen_backend.hpp b/backend/src/llvm/llvm_gen_backend.hpp index 26323a3e..cc5cdad6 100644 --- a/backend/src/llvm/llvm_gen_backend.hpp +++ b/backend/src/llvm/llvm_gen_backend.hpp @@ -95,6 +95,10 @@ namespace gbe /*! Convert the Intrinsic call to gen function */ llvm::BasicBlockPass *createIntrinsicLoweringPass(); + /*! Passer the printf function call. */ + llvm::FunctionPass* createPrintfParserPass(); + + void* getPrintfInfo(llvm::CallInst* inst); } /* namespace gbe */ #endif /* __GBE_LLVM_GEN_BACKEND_HPP__ */ diff --git a/backend/src/llvm/llvm_printf_parser.cpp b/backend/src/llvm/llvm_printf_parser.cpp new file mode 100644 index 00000000..ec8e76d8 --- /dev/null +++ b/backend/src/llvm/llvm_printf_parser.cpp @@ -0,0 +1,677 @@ +/* + * Copyright © 2012 Intel Corporation + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library. If not, see <http://www.gnu.org/licenses/>. + */ + +/** + * \file llvm_printf.cpp + * + * When there are printf functions existing, we have something to do here. + * Because the GPU's feature, it is relatively hard to parse and caculate the + * printf's format string. OpenCL 1.2 restrict the format string to be a + * constant string and can be decided at compiling time. So we add a pass here + * to parse the format string and check whether the parameters is valid. + * If all are valid, we will generate the according instruction to store the + * parameter content into the printf buffer. And if something is invalid, a + * warning is generated and the printf instruction is skipped in order to avoid + * GPU error. We also keep the relationship between the printf format and printf + * content in GPU's printf buffer here, and use the system's C standard printf to + * print the content after kernel executed. + */ +#include <stdio.h> +#include <stdlib.h> + +#include "llvm/Config/config.h" +#if LLVM_VERSION_MINOR <= 2 +#include "llvm/Function.h" +#include "llvm/InstrTypes.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Module.h" +#else +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#endif /* LLVM_VERSION_MINOR <= 2 */ +#include "llvm/Pass.h" +#if LLVM_VERSION_MINOR <= 1 +#include "llvm/Support/IRBuilder.h" +#elif LLVM_VERSION_MINOR == 2 +#include "llvm/IRBuilder.h" +#else +#include "llvm/IR/IRBuilder.h" +#endif /* LLVM_VERSION_MINOR <= 1 */ +#include "llvm/Support/CallSite.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Attributes.h" + +#include "llvm/llvm_gen_backend.hpp" +#include "sys/map.hpp" +#include "ir/printf.hpp" + +using namespace llvm; + +namespace gbe +{ + using namespace ir; + + /* Return the conversion_specifier if succeed, -1 if failed. */ + static char __parse_printf_state(char *begin, char *end, char** rend, PrintfState * state) + { + const char *fmt; + state->left_justified = 0; + state->sign_symbol = 0; //0 for nothing, 1 for sign, 2 for space. + state->alter_form = 0; + state->zero_padding = 0; + state->vector_n = 0; + state->min_width = 0; + state->precision = 0; + state->length_modifier = 0; + state->conversion_specifier = PRINTF_CONVERSION_INVALID; + state->out_buf_sizeof_offset = -1; + + fmt = begin; + + if (*fmt != '%') + return -1; + +#define FMT_PLUS_PLUS do { \ + if (fmt + 1 < end) fmt++; \ + else { \ + printf("Error, line: %d, fmt > end\n", __LINE__); \ + return -1; \ + } \ + } while(0) + + FMT_PLUS_PLUS; + + // parse the flags. + switch (*fmt) { + case '-': + /* The result of the conversion is left-justified within the field. */ + state->left_justified = 1; + FMT_PLUS_PLUS; + break; + case '+': + /* The result of a signed conversion always begins with a plus or minus sign. */ + state->sign_symbol = 1; + FMT_PLUS_PLUS; + break; + case ' ': + /* If the first character of a signed conversion is not a sign, or if a signed + conversion results in no characters, a space is prefixed to the result. + If the space and + flags both appear,the space flag is ignored. */ + if (state->sign_symbol == 0) state->sign_symbol = 2; + FMT_PLUS_PLUS; + break; + case '#': + /*The result is converted to an alternative form. */ + state->alter_form = 1; + FMT_PLUS_PLUS; + break; + case '0': + if (!state->left_justified) state->zero_padding = 1; + FMT_PLUS_PLUS; + break; + default: + break; + } + + // The minimum field width + while ((*fmt >= '0') && (*fmt <= '9')) { + state->min_width = state->min_width * 10 + (*fmt - '0'); + FMT_PLUS_PLUS; + } + + // The precision + if (*fmt == '.') { + FMT_PLUS_PLUS; + while (*fmt >= '0' && *fmt <= '9') { + state->precision = state->precision * 10 + (*fmt - '0'); + FMT_PLUS_PLUS; + } + } + + // handle the vector specifier. + if (*fmt == 'v') { + FMT_PLUS_PLUS; + switch (*fmt) { + case '2': + state->vector_n = 2; + FMT_PLUS_PLUS; + break; + case '3': + state->vector_n = 3; + FMT_PLUS_PLUS; + break; + case '4': + state->vector_n = 4; + FMT_PLUS_PLUS; + break; + case '8': + state->vector_n = 8; + FMT_PLUS_PLUS; + break; + case '1': + FMT_PLUS_PLUS; + if (*fmt == '6') { + state->vector_n = 16; + FMT_PLUS_PLUS; + } else + return -1; + break; + default: + //Wrong vector, error. + return -1; + } + } + + // length modifiers + if (*fmt == 'h') { + FMT_PLUS_PLUS; + if (*fmt == 'h') { //hh + state->length_modifier = PRINTF_LM_HH; + FMT_PLUS_PLUS; + } else if (*fmt == 'l') { //hl + state->length_modifier = PRINTF_LM_HL; + FMT_PLUS_PLUS; + } else { //h + state->length_modifier = PRINTF_LM_H; + } + } else if (*fmt == 'l') { + state->length_modifier = PRINTF_LM_L; + FMT_PLUS_PLUS; + } + +#define CONVERSION_SPEC_AND_RET(XXX, xxx) \ + case XXX: \ + state->conversion_specifier = PRINTF_CONVERSION_##xxx; \ + FMT_PLUS_PLUS; \ + *rend = (char *)fmt; \ + return XXX; \ + break; + + // conversion specifiers + switch (*fmt) { + CONVERSION_SPEC_AND_RET('d', D) + CONVERSION_SPEC_AND_RET('i', I) + CONVERSION_SPEC_AND_RET('o', O) + CONVERSION_SPEC_AND_RET('u', U) + CONVERSION_SPEC_AND_RET('x', x) + CONVERSION_SPEC_AND_RET('X', X) + CONVERSION_SPEC_AND_RET('f', f) + CONVERSION_SPEC_AND_RET('F', F) + CONVERSION_SPEC_AND_RET('e', e) + CONVERSION_SPEC_AND_RET('E', E) + CONVERSION_SPEC_AND_RET('g', g) + CONVERSION_SPEC_AND_RET('G', G) + CONVERSION_SPEC_AND_RET('a', a) + CONVERSION_SPEC_AND_RET('A', A) + CONVERSION_SPEC_AND_RET('c', C) + CONVERSION_SPEC_AND_RET('s', A) + CONVERSION_SPEC_AND_RET('p', P) + + // %% has been handled + + default: + return -1; + } + } + + static PrintfSet::PrintfFmt* parser_printf_fmt(char* format, int& num) + { + char* begin; + char* end; + char* p; + char ret_char; + char* rend; + PrintfState state; + PrintfSet::PrintfFmt* printf_fmt = new PrintfSet::PrintfFmt(); + + p = format; + begin = format; + end = format + strlen(format); + + /* Now parse it. */ + while (*begin) { + p = begin; + +again: + while (p < end && *p != '%') { + p++; + } + if (p < end && p + 1 == end) { // String with % at end. + printf("string end with %%\n"); + goto error; + } + if (*(p + 1) == '%') { // %% + p += 2; + goto again; + } + + if (p != begin) { + std::string s = std::string(begin, size_t(p - begin)); + printf_fmt->push_back(PrintfSlot(s.c_str())); + } + + if (p == end) // finish + break; + + /* Now parse the % start conversion_specifier. */ + ret_char = __parse_printf_state(p, end, &rend, &state); + if (ret_char < 0) + goto error; + + printf_fmt->push_back(&state); + + if (rend == end) + break; + + begin = rend; + } + + for (auto &s : *printf_fmt) { + if (s.type == PRINTF_SLOT_TYPE_STATE) { + num++; +#if 0 + printf("---- %d ---: state : \n", j); + printf(" left_justified : %d\n", s.state->left_justified); + printf(" sign_symbol: %d\n", s.state->sign_symbol); + printf(" alter_form : %d\n", s.state->alter_form); + printf(" zero_padding : %d\n", s.state->zero_padding); + printf(" vector_n : %d\n", s.state->vector_n); + printf(" min_width : %d\n", s.state->min_width); + printf(" precision : %d\n", s.state->precision); + printf(" length_modifier : %d\n", s.state->length_modifier); + printf(" conversion_specifier : %d\n", s.state->conversion_specifier); +#endif + } else if (s.type == PRINTF_SLOT_TYPE_STRING) { + //printf("---- %d ---: string : %s\n", j, s.str); + } + } + + return printf_fmt; + +error: + printf("error format string.\n"); + delete printf_fmt; + return NULL; + } + + class PrintfParser : public FunctionPass + { + public: + static char ID; + typedef std::pair<Instruction*, bool> PrintfInst; + std::vector<PrintfInst> deadprintfs; + Module* module; + IRBuilder<>* builder; + Type* intTy; + Value* pbuf_ptr; + Value* index_buf_ptr; + int out_buf_sizeof_offset; + static map<CallInst*, PrintfSet::PrintfFmt*> printfs; + int printf_num; + + PrintfParser(void) : FunctionPass(ID) { + module = NULL; + builder = NULL; + intTy = NULL; + out_buf_sizeof_offset = 0; + printfs.clear(); + pbuf_ptr = NULL; + index_buf_ptr = NULL; + printf_num = 0; + } + + ~PrintfParser(void) { + for (auto &s : printfs) { + delete s.second; + s.second = NULL; + } + printfs.clear(); + } + + + bool parseOnePrintfInstruction(CallInst *& call); + int generateOneParameterInst(PrintfSlot& slot, Value& arg); + + virtual const char *getPassName() const { + return "Printf Parser"; + } + + virtual bool runOnFunction(llvm::Function &F); + }; + + bool PrintfParser::parseOnePrintfInstruction(CallInst *& call) + { + CallSite CS(call); + CallSite::arg_iterator CI_FMT = CS.arg_begin(); + int param_num = 0; + + llvm::Constant* arg0 = dyn_cast<llvm::ConstantExpr>(*CI_FMT); + llvm::Constant* arg0_ptr = dyn_cast<llvm::Constant>(arg0->getOperand(0)); + if (!arg0_ptr) { + return false; + } + + ConstantDataSequential* fmt_arg = dyn_cast<ConstantDataSequential>(arg0_ptr->getOperand(0)); + if (!fmt_arg || !fmt_arg->isCString()) { + return false; + } + + std::string fmt = fmt_arg->getAsCString(); + + PrintfSet::PrintfFmt* printf_fmt = NULL; + + if (!(printf_fmt = parser_printf_fmt((char *)fmt.c_str(), param_num))) {//at lease print something + return false; + } + + /* iff parameter more than %, error. */ + /* str_fmt arg0 arg1 ... NULL */ + if (param_num + 2 < static_cast<int>(call->getNumOperands())) { + delete printf_fmt; + return false; + } + + /* FIXME: Because the OpenCL language do not support va macro, and we do not want + to introduce the va_list, va_start and va_end into our code, we just simulate + the function calls to caculate the offset caculation here. */ + CallInst* group_id_2 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_group_id2", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* group_id_1 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_group_id1", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* group_id_0 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_group_id0", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + + CallInst* global_size_2 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_global_size2", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* global_size_1 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_global_size1", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* global_size_0 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_global_size0", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + + CallInst* local_id_2 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_id2", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* local_id_1 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_id1", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* local_id_0 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_id0", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + + CallInst* local_size_2 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_size2", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* local_size_1 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_size1", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + CallInst* local_size_0 = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_get_local_size0", + IntegerType::getInt32Ty(module->getContext()), + NULL))); + Value* op0 = NULL; + Value* val = NULL; + /* offset = ((local_id_2 + local_size_2 * group_id_2) * (global_size_1 * global_size_0) + + (local_id_1 + local_size_1 * group_id_1) * global_size_0 + + (local_id_0 + local_size_0 * group_id_0)) * sizeof(type) */ + + // local_size_2 * group_id_2 + val = builder->CreateMul(local_size_2, group_id_2); + // local_id_2 + local_size_2 * group_id_2 + val = builder->CreateAdd(local_id_2, val); + // global_size_1 * global_size_0 + op0 = builder->CreateMul(global_size_1, global_size_0); + // (local_id_2 + local_size_2 * group_id_2) * (global_size_1 * global_size_0) + Value* offset1 = builder->CreateMul(val, op0); + // local_size_1 * group_id_1 + val = builder->CreateMul(local_size_1, group_id_1); + // local_id_1 + local_size_1 * group_id_1 + val = builder->CreateAdd(local_id_1, val); + // (local_id_1 + local_size_1 * group_id_1) * global_size_0 + Value* offset2 = builder->CreateMul(val, global_size_0); + // local_size_0 * group_id_0 + val = builder->CreateMul(local_size_0, group_id_0); + // local_id_0 + local_size_0 * group_id_0 + val = builder->CreateAdd(local_id_0, val); + // The total sum + val = builder->CreateAdd(val, offset1); + Value* offset = builder->CreateAdd(val, offset2); + + ///////////////////////////////////////////////////// + /* calculate index address. + index_addr = (index_offset + offset )* sizeof(int) + index_buf_ptr + index_offset = global_size_2 * global_size_1 * global_size_0 * printf_num */ + + // global_size_2 * global_size_1 + op0 = builder->CreateMul(global_size_2, global_size_1); + // global_size_2 * global_size_1 * global_size_0 + Value* glXg2Xg3 = builder->CreateMul(op0, global_size_0); + Value* index_offset = builder->CreateMul(glXg2Xg3, ConstantInt::get(intTy, printf_num)); + // index_offset + offset + op0 = builder->CreateAdd(index_offset, offset); + // (index_offset + offset)* sizeof(int) + op0 = builder->CreateMul(op0, ConstantInt::get(intTy, sizeof(int))); + // Final index address = index_buf_ptr + (index_offset + offset)* sizeof(int) + op0 = builder->CreateAdd(op0, index_buf_ptr); + Value* index_addr = builder->CreateIntToPtr(op0, Type::getInt32PtrTy(module->getContext(), 1)); + builder->CreateStore(ConstantInt::get(intTy, 1), index_addr);// The flag + + int i = 1; + Value* data_addr = NULL; + for (auto &s : *printf_fmt) { + if (s.type == PRINTF_SLOT_TYPE_STRING) + continue; + + assert(i < static_cast<int>(call->getNumOperands()) - 1); + + int sizeof_size = generateOneParameterInst(s, *call->getOperand(i)); + if (!sizeof_size) { + printf("Printf: %d, parameter %d may have no result because some error\n", + printf_num, i - 1); + continue; + } + + ///////////////////////////////////////////////////// + /* Calculate the data address. + data_addr = data_offset + pbuf_ptr + offset * sizeof(specify) + data_offset = global_size_2 * global_size_1 * global_size_0 * out_buf_sizeof_offset + + //global_size_2 * global_size_1 * global_size_0 * out_buf_sizeof_offset */ + op0 = builder->CreateMul(glXg2Xg3, ConstantInt::get(intTy, out_buf_sizeof_offset)); + //offset * sizeof(specify) + val = builder->CreateMul(offset, ConstantInt::get(intTy, sizeof_size)); + //data_offset + pbuf_ptr + op0 = builder->CreateAdd(op0, pbuf_ptr); + op0 = builder->CreateAdd(op0, val); + data_addr = builder->CreateIntToPtr(op0, Type::getInt32PtrTy(module->getContext(), 1)); + builder->CreateStore(call->getOperand(i), data_addr); + s.state->out_buf_sizeof_offset = out_buf_sizeof_offset; + out_buf_sizeof_offset += sizeof_size; + i++; + } + + CallInst* printf_inst = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_printf", Type::getVoidTy(module->getContext()), + NULL))); + assert(printfs[printf_inst] == NULL); + printfs[printf_inst] = printf_fmt; + printf_num++; + return true; + } + + bool PrintfParser::runOnFunction(llvm::Function &F) + { + bool changed = false; + switch (F.getCallingConv()) { +#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR <= 2 + case CallingConv::PTX_Device: + return false; + case CallingConv::PTX_Kernel: +#else + case CallingConv::C: +#endif + break; + default: + GBE_ASSERTM(false, "Unsupported calling convention"); + } + + module = F.getParent(); + intTy = IntegerType::get(module->getContext(), 32); + builder = new IRBuilder<>(module->getContext()); + + // As we inline all function calls, so skip non-kernel functions + bool bKernel = isKernelFunction(F); + if(!bKernel) return false; + + /* Iter the function and find printf. */ + for (llvm::Function::iterator B = F.begin(), BE = F.end(); B != BE; B++) { + for (BasicBlock::iterator instI = B->begin(), + instE = B->end(); instI != instE; ++instI) { + + llvm::CallInst* call = dyn_cast<llvm::CallInst>(instI); + if (!call) { + continue; + } + + if (call->getCalledFunction()->getIntrinsicID() != 0) + continue; + + Value *Callee = call->getCalledValue(); + const std::string fnName = Callee->getName(); + + if (fnName != "__gen_ocl_printf_stub") + continue; + + changed = true; + + builder->SetInsertPoint(call); + + if (!pbuf_ptr) { + /* alloc a new buffer ptr to collect the print output. */ + pbuf_ptr = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_printf_get_buf_addr", Type::getInt32Ty(module->getContext()), + NULL))); + } + if (!index_buf_ptr) { + /* alloc a new buffer ptr to collect the print valid index. */ + index_buf_ptr = builder->CreateCall(cast<llvm::Function>(module->getOrInsertFunction( + "__gen_ocl_printf_get_index_buf_addr", Type::getInt32Ty(module->getContext()), + NULL))); + } + + deadprintfs.push_back(PrintfInst(cast<Instruction>(call),parseOnePrintfInstruction(call))); + } + } + + /* Replace the instruction's operand if using printf's return value. */ + for (llvm::Function::iterator B = F.begin(), BE = F.end(); B != BE; B++) { + for (BasicBlock::iterator instI = B->begin(), + instE = B->end(); instI != instE; ++instI) { + + for (unsigned i = 0; i < instI->getNumOperands(); i++) { + for (auto &prf : deadprintfs) { + if (instI->getOperand(i) == prf.first) { + + if (prf.second == true) { + instI->setOperand(i, ConstantInt::get(intTy, 0)); + } else { + instI->setOperand(i, ConstantInt::get(intTy, -1)); + } + } + } + } + } + } + + /* Kill the dead printf instructions. */ + for (auto &prf : deadprintfs) { + prf.first->dropAllReferences(); + if (prf.first->use_empty()) + prf.first->eraseFromParent(); + } + + delete builder; + + return changed; + } + + int PrintfParser::generateOneParameterInst(PrintfSlot& slot, Value& arg) + { + assert(slot.type == PRINTF_SLOT_TYPE_STATE); + assert(builder); + + /* Check whether the arg match the format specifer. If needed, some + conversion need to be applied. */ + switch (arg.getType()->getTypeID()) { + case Type::IntegerTyID: { + switch (slot.state->conversion_specifier) { + case PRINTF_CONVERSION_I: + case PRINTF_CONVERSION_D: + /* Int to Int, just store. */ + return sizeof(int); + break; + + default: + return 0; + } + } + break; + + default: + return 0; + } + + return 0; + } + + map<CallInst*, PrintfSet::PrintfFmt*> PrintfParser::printfs; + + void* getPrintfInfo(CallInst* inst) + { + if (PrintfParser::printfs[inst]) + return (void*)PrintfParser::printfs[inst]; + return NULL; + } + + FunctionPass* createPrintfParserPass() + { + return new PrintfParser(); + } + char PrintfParser::ID = 0; + +} // end namespace diff --git a/backend/src/llvm/llvm_to_gen.cpp b/backend/src/llvm/llvm_to_gen.cpp index 2bb9da11..ee5eb88d 100644 --- a/backend/src/llvm/llvm_to_gen.cpp +++ b/backend/src/llvm/llvm_to_gen.cpp @@ -208,6 +208,7 @@ namespace gbe passes.add(createLowerSwitchPass()); passes.add(createPromoteMemoryToRegisterPass()); passes.add(createGVNPass()); // Remove redundancies + passes.add(createPrintfParserPass()); passes.add(createScalarizePass()); // Expand all vector ops passes.add(createDeadInstEliminationPass()); // Remove simplified instructions passes.add(createCFGSimplificationPass()); // Merge & remove BBs |