summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJunyan He <junyan.he@linux.intel.com>2015-12-01 16:10:36 +0800
committerYang Rong <rong.r.yang@intel.com>2015-12-14 15:11:48 +0800
commiteb22b9895c97504c78c5338a6a0354b130cb6d81 (patch)
tree7626d54e0669d47ea94fac32f50cc1e0470f7b0c
parent1f030e70ae8123b388076974aa2501bd98eb6b4b (diff)
Backend: Implement reduce min and max in gen_context
Signed-off-by: Junyan He <junyan.he@linux.intel.com> Reviewed-by: Yang Rong <rong.r.yang@intel.com>
-rw-r--r--backend/src/backend/gen_context.cpp284
1 files changed, 279 insertions, 5 deletions
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 1bb36d5f..880d8b72 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2814,9 +2814,7 @@ namespace gbe
p->push(); {
GenRegister ffid = GenRegister::toUniform(data, GEN_TYPE_UD);
GenRegister tmp = GenRegister::toUniform(profilingReg[3], GEN_TYPE_UD);
- GenRegister stateReg = GenRegister(GEN_ARCHITECTURE_REGISTER_FILE, GEN_ARF_STATE, 0,
- GEN_TYPE_UD, GEN_VERTICAL_STRIDE_0, GEN_WIDTH_1, GEN_HORIZONTAL_STRIDE_1);
- p->curr.predicate = GEN_PREDICATE_NONE;
+ GenRegister stateReg = GenRegister::sr(0, 0);
p->curr.noMask = 1;
p->curr.execWidth = 1;
p->MOV(ffid, stateReg);
@@ -2828,8 +2826,7 @@ namespace gbe
p->MOV(genInfo, stateReg);
p->AND(genInfo, genInfo, GenRegister::immud(0x0ff07));
//The dispatch mask
- stateReg = GenRegister(GEN_ARCHITECTURE_REGISTER_FILE, GEN_ARF_STATE, 2,
- GEN_TYPE_UD, GEN_VERTICAL_STRIDE_0, GEN_WIDTH_1, GEN_HORIZONTAL_STRIDE_1);
+ stateReg = GenRegister::sr(0, 2);
p->MOV(tmp, stateReg);
p->AND(tmp, tmp, GenRegister::immud(0x0000ffff));
p->SHL(tmp, tmp, GenRegister::immud(16));
@@ -2851,7 +2848,284 @@ namespace gbe
} p->pop();
}
+ static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+ uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ uint32_t cond;
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+ cond = GEN_CONDITIONAL_LE;
+ else
+ cond = GEN_CONDITIONAL_GE;
+
+ p->SEL_CMP(cond, msgData, threadData, msgData);
+ }
+ p->pop();
+ }
+
+ static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) {
+ if (dataReg.type == GEN_TYPE_UD) {
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+ p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF));
+ } else {
+ GBE_ASSERT(wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX);
+ p->MOV(dataReg, GenRegister::immud(0));
+ }
+ } else if (dataReg.type == GEN_TYPE_F) {
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+ p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000)); // inf
+ } else if (wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) {
+ p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000)); // -inf
+ }
+ } else {
+ GBE_ASSERT(0);
+ }
+ }
+
+ static void workgroupOpInThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+ GenRegister tmp, uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+
+ /* Setting the init value here. */
+ threadData = GenRegister::retype(threadData, theVal.type);
+ initValue(p, threadData, wg_op);
+
+ if (theVal.hstride != GEN_HORIZONTAL_STRIDE_0) {
+ /* We need to set the value out of dispatch mask to MAX. */
+ tmp = GenRegister::retype(tmp, theVal.type);
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = simd;
+ initValue(p, tmp, wg_op);
+ p->curr.noMask = 0;
+ p->MOV(tmp, theVal);
+ p->pop();
+ }
+
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ uint32_t cond;
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+ cond = GEN_CONDITIONAL_LE;
+ else
+ cond = GEN_CONDITIONAL_GE;
+
+ if (theVal.hstride == GEN_HORIZONTAL_STRIDE_0) { // an uniform value.
+ p->SEL_CMP(cond, threadData, threadData, theVal);
+ } else {
+ GBE_ASSERT(tmp.type == theVal.type);
+ GenRegister v = GenRegister::toUniform(tmp, theVal.type);
+ for (uint32_t i = 0; i < simd; i++) {
+ p->SEL_CMP(cond, threadData, threadData, v);
+ v.subnr += typeSize(theVal.type);
+ if (v.subnr == 32) {
+ v.subnr = 0;
+ v.nr++;
+ }
+ }
+ }
+ }
+
+ p->pop();
+ }
+
+#define SEND_RESULT_MSG() \
+do { \
+ p->push(); { /* then send msg. */ \
+ p->curr.noMask = 1; \
+ p->curr.predicate = GEN_PREDICATE_NONE; \
+ p->curr.execWidth = 1; \
+ GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \
+ offLen.vstride = GEN_VERTICAL_STRIDE_0; \
+ offLen.width = GEN_WIDTH_1; \
+ offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \
+ uint32_t szEnc = typeSize(theVal.type) >> 1; \
+ if (szEnc == 4) { \
+ szEnc = 3; \
+ } \
+ p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \
+ \
+ GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \
+ tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \
+ tidEuid.width = GEN_WIDTH_1; \
+ tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \
+ p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \
+ \
+ p->curr.execWidth = 8; \
+ p->FWD_GATEWAY_MSG(nextThreadID, 2); \
+ } p->pop(); \
+} while(0)
+
+
+ /* The basic idea is like this:
+ 1. All the threads firstly calculate the max/min/add value within their own thread, that is finding
+ the max/min/add value within their 16 work items when SIMD == 16.
+ 2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated
+ result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding.
+ 3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it
+ compares the result in the message and the result within its thread, then forward the correct result to
+ the next thread by sending a message again. If it is the last thread, send it to thread 0.
+ 4. Thread 0 finally get the message from the last one and broadcast the final result. */
void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) {
+ const GenRegister dst = ra->genReg(insn.dst(0));
+ const GenRegister tmp = ra->genReg(insn.dst(2));
+ GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag);
+ GenRegister nextThreadID = ra->genReg(insn.src(1));
+ const GenRegister theVal = ra->genReg(insn.src(0));
+ GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid));
+ GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn));
+ GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward.
+ msgData.vstride = GEN_VERTICAL_STRIDE_0;
+ msgData.width = GEN_WIDTH_1;
+ msgData.hstride = GEN_HORIZONTAL_STRIDE_0;
+ GenRegister threadData =
+ GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread.
+ threadData.vstride = GEN_VERTICAL_STRIDE_0;
+ threadData.width = GEN_WIDTH_1;
+ threadData.hstride = GEN_HORIZONTAL_STRIDE_0;
+ uint32_t wg_op = insn.extra.workgroupOp;
+ uint32_t simd = p->curr.execWidth;
+ GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW);
+ flag_save.vstride = GEN_VERTICAL_STRIDE_0;
+ flag_save.width = GEN_WIDTH_1;
+ flag_save.hstride = GEN_HORIZONTAL_STRIDE_0;
+ int32_t jip;
+ int32_t oneThreadJip = -1;
+
+ p->push(); { /* First, so something within thread. */
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ /* Do some calculation within each thread. */
+ workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
+ } p->pop();
+
+ /* If we are the only one thread, no need to send msg, just broadcast the result.*/
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+
+ /* Broadcast result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 1;
+ p->MOV(flag_save, GenRegister::immuw(0x0));
+ p->curr.inversePredicate = 0;
+ p->MOV(flag_save, GenRegister::immuw(0xffff));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.execWidth = simd;
+ p->MOV(dst, threadData);
+ }
+
+ /* Bail out. */
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 0;
+ p->curr.execWidth = 1;
+ oneThreadJip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ } p->pop();
+
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
+
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 1;
+ p->MOV(flag_save, GenRegister::immuw(0x0));
+ p->curr.inversePredicate = 0;
+ p->MOV(flag_save, GenRegister::immuw(0xffff));
+
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ } p->pop();
+
+ p->push(); {
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+
+ /* threadid 0, send the msg and wait */
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(msgData, threadData);
+ SEND_RESULT_MSG();
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+ /* Others wait and send msg, and do something when we get the msg. */
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 0;
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
+ SEND_RESULT_MSG();
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+ /* Restore the flag. */
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ } p->pop();
+
+ /* Broadcast the result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 0;
+
+ /* Not the first thread, wait for msg first. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+ /* Do something when get the msg. */
+ p->curr.execWidth = simd;
+ p->MOV(dst, msgData);
+
+ p->curr.execWidth = 8;
+ p->FWD_GATEWAY_MSG(nextThreadID, 2);
+
+ p->curr.execWidth = 1;
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+
+ /* The first thread, the last one will notify us. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ } p->pop();
+ }
+
+ if (oneThreadJip >=0)
+ p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0);
}
void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) {