diff options
author | Junyan He <junyan.he@linux.intel.com> | 2015-12-01 16:10:36 +0800 |
---|---|---|
committer | Yang Rong <rong.r.yang@intel.com> | 2015-12-14 15:11:48 +0800 |
commit | eb22b9895c97504c78c5338a6a0354b130cb6d81 (patch) | |
tree | 7626d54e0669d47ea94fac32f50cc1e0470f7b0c | |
parent | 1f030e70ae8123b388076974aa2501bd98eb6b4b (diff) |
Backend: Implement reduce min and max in gen_context
Signed-off-by: Junyan He <junyan.he@linux.intel.com>
Reviewed-by: Yang Rong <rong.r.yang@intel.com>
-rw-r--r-- | backend/src/backend/gen_context.cpp | 284 |
1 files changed, 279 insertions, 5 deletions
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp index 1bb36d5f..880d8b72 100644 --- a/backend/src/backend/gen_context.cpp +++ b/backend/src/backend/gen_context.cpp @@ -2814,9 +2814,7 @@ namespace gbe p->push(); { GenRegister ffid = GenRegister::toUniform(data, GEN_TYPE_UD); GenRegister tmp = GenRegister::toUniform(profilingReg[3], GEN_TYPE_UD); - GenRegister stateReg = GenRegister(GEN_ARCHITECTURE_REGISTER_FILE, GEN_ARF_STATE, 0, - GEN_TYPE_UD, GEN_VERTICAL_STRIDE_0, GEN_WIDTH_1, GEN_HORIZONTAL_STRIDE_1); - p->curr.predicate = GEN_PREDICATE_NONE; + GenRegister stateReg = GenRegister::sr(0, 0); p->curr.noMask = 1; p->curr.execWidth = 1; p->MOV(ffid, stateReg); @@ -2828,8 +2826,7 @@ namespace gbe p->MOV(genInfo, stateReg); p->AND(genInfo, genInfo, GenRegister::immud(0x0ff07)); //The dispatch mask - stateReg = GenRegister(GEN_ARCHITECTURE_REGISTER_FILE, GEN_ARF_STATE, 2, - GEN_TYPE_UD, GEN_VERTICAL_STRIDE_0, GEN_WIDTH_1, GEN_HORIZONTAL_STRIDE_1); + stateReg = GenRegister::sr(0, 2); p->MOV(tmp, stateReg); p->AND(tmp, tmp, GenRegister::immud(0x0000ffff)); p->SHL(tmp, tmp, GenRegister::immud(16)); @@ -2851,7 +2848,284 @@ namespace gbe } p->pop(); } + static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData, + uint32_t simd, uint32_t wg_op, GenEncoder *p) { + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + uint32_t cond; + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) + cond = GEN_CONDITIONAL_LE; + else + cond = GEN_CONDITIONAL_GE; + + p->SEL_CMP(cond, msgData, threadData, msgData); + } + p->pop(); + } + + static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) { + if (dataReg.type == GEN_TYPE_UD) { + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) { + p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF)); + } else { + GBE_ASSERT(wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX); + p->MOV(dataReg, GenRegister::immud(0)); + } + } else if (dataReg.type == GEN_TYPE_F) { + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) { + p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000)); // inf + } else if (wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) { + p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000)); // -inf + } + } else { + GBE_ASSERT(0); + } + } + + static void workgroupOpInThread(GenRegister msgData, GenRegister theVal, GenRegister threadData, + GenRegister tmp, uint32_t simd, uint32_t wg_op, GenEncoder *p) { + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + + /* Setting the init value here. */ + threadData = GenRegister::retype(threadData, theVal.type); + initValue(p, threadData, wg_op); + + if (theVal.hstride != GEN_HORIZONTAL_STRIDE_0) { + /* We need to set the value out of dispatch mask to MAX. */ + tmp = GenRegister::retype(tmp, theVal.type); + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = simd; + initValue(p, tmp, wg_op); + p->curr.noMask = 0; + p->MOV(tmp, theVal); + p->pop(); + } + + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + uint32_t cond; + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) + cond = GEN_CONDITIONAL_LE; + else + cond = GEN_CONDITIONAL_GE; + + if (theVal.hstride == GEN_HORIZONTAL_STRIDE_0) { // an uniform value. + p->SEL_CMP(cond, threadData, threadData, theVal); + } else { + GBE_ASSERT(tmp.type == theVal.type); + GenRegister v = GenRegister::toUniform(tmp, theVal.type); + for (uint32_t i = 0; i < simd; i++) { + p->SEL_CMP(cond, threadData, threadData, v); + v.subnr += typeSize(theVal.type); + if (v.subnr == 32) { + v.subnr = 0; + v.nr++; + } + } + } + } + + p->pop(); + } + +#define SEND_RESULT_MSG() \ +do { \ + p->push(); { /* then send msg. */ \ + p->curr.noMask = 1; \ + p->curr.predicate = GEN_PREDICATE_NONE; \ + p->curr.execWidth = 1; \ + GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \ + offLen.vstride = GEN_VERTICAL_STRIDE_0; \ + offLen.width = GEN_WIDTH_1; \ + offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \ + uint32_t szEnc = typeSize(theVal.type) >> 1; \ + if (szEnc == 4) { \ + szEnc = 3; \ + } \ + p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \ + \ + GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \ + tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \ + tidEuid.width = GEN_WIDTH_1; \ + tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \ + p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \ + \ + p->curr.execWidth = 8; \ + p->FWD_GATEWAY_MSG(nextThreadID, 2); \ + } p->pop(); \ +} while(0) + + + /* The basic idea is like this: + 1. All the threads firstly calculate the max/min/add value within their own thread, that is finding + the max/min/add value within their 16 work items when SIMD == 16. + 2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated + result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding. + 3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it + compares the result in the message and the result within its thread, then forward the correct result to + the next thread by sending a message again. If it is the last thread, send it to thread 0. + 4. Thread 0 finally get the message from the last one and broadcast the final result. */ void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) { + const GenRegister dst = ra->genReg(insn.dst(0)); + const GenRegister tmp = ra->genReg(insn.dst(2)); + GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag); + GenRegister nextThreadID = ra->genReg(insn.src(1)); + const GenRegister theVal = ra->genReg(insn.src(0)); + GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid)); + GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn)); + GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward. + msgData.vstride = GEN_VERTICAL_STRIDE_0; + msgData.width = GEN_WIDTH_1; + msgData.hstride = GEN_HORIZONTAL_STRIDE_0; + GenRegister threadData = + GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread. + threadData.vstride = GEN_VERTICAL_STRIDE_0; + threadData.width = GEN_WIDTH_1; + threadData.hstride = GEN_HORIZONTAL_STRIDE_0; + uint32_t wg_op = insn.extra.workgroupOp; + uint32_t simd = p->curr.execWidth; + GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW); + flag_save.vstride = GEN_VERTICAL_STRIDE_0; + flag_save.width = GEN_WIDTH_1; + flag_save.hstride = GEN_HORIZONTAL_STRIDE_0; + int32_t jip; + int32_t oneThreadJip = -1; + + p->push(); { /* First, so something within thread. */ + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + /* Do some calculation within each thread. */ + workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p); + } p->pop(); + + /* If we are the only one thread, no need to send msg, just broadcast the result.*/ + p->push(); { + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1)); + + /* Broadcast result. */ + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) { + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 1; + p->MOV(flag_save, GenRegister::immuw(0x0)); + p->curr.inversePredicate = 0; + p->MOV(flag_save, GenRegister::immuw(0xffff)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.execWidth = simd; + p->MOV(dst, threadData); + } + + /* Bail out. */ + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 0; + p->curr.execWidth = 1; + oneThreadJip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + } p->pop(); + + p->push(); { + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0)); + + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 1; + p->MOV(flag_save, GenRegister::immuw(0x0)); + p->curr.inversePredicate = 0; + p->MOV(flag_save, GenRegister::immuw(0xffff)); + + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + } p->pop(); + + p->push(); { + p->curr.noMask = 1; + p->curr.execWidth = 1; + + /* threadid 0, send the msg and wait */ + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->curr.inversePredicate = 1; + p->curr.predicate = GEN_PREDICATE_NORMAL; + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(msgData, threadData); + SEND_RESULT_MSG(); + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Others wait and send msg, and do something when we get the msg. */ + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 0; + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p); + SEND_RESULT_MSG(); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Restore the flag. */ + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + } p->pop(); + + /* Broadcast the result. */ + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + p->push(); { + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->curr.inversePredicate = 0; + + /* Not the first thread, wait for msg first. */ + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Do something when get the msg. */ + p->curr.execWidth = simd; + p->MOV(dst, msgData); + + p->curr.execWidth = 8; + p->FWD_GATEWAY_MSG(nextThreadID, 2); + + p->curr.execWidth = 1; + p->curr.inversePredicate = 1; + p->curr.predicate = GEN_PREDICATE_NORMAL; + + /* The first thread, the last one will notify us. */ + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + } p->pop(); + } + + if (oneThreadJip >=0) + p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0); } void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) { |