diff options
author | dan sinclair <dj2@everburning.com> | 2018-07-12 09:12:23 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-12 09:12:23 -0400 |
commit | 4cc6cd184ad74d35cb3e995f206f85ad23c2febc (patch) | |
tree | e3d0057926f103922aa01185c0dd168dbb051247 /source | |
parent | f96b7f1cb9f6a5f06a16882b10e4b1e528e3aaee (diff) |
Pass the IRContext into the folding rules. (#1709)
This CL updates the folding rules to receive the IRContext as a paramter
instead of retrieving off of the Instruction.
Issue #1703
Diffstat (limited to 'source')
-rw-r--r-- | source/opt/const_folding_rules.cpp | 24 | ||||
-rw-r--r-- | source/opt/const_folding_rules.h | 2 | ||||
-rw-r--r-- | source/opt/fold.cpp | 18 | ||||
-rw-r--r-- | source/opt/fold.h | 4 | ||||
-rw-r--r-- | source/opt/folding_rules.cpp | 102 | ||||
-rw-r--r-- | source/opt/folding_rules.h | 2 | ||||
-rw-r--r-- | source/opt/ir_context.h | 2 |
7 files changed, 63 insertions, 91 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 42714b32..f0e413a1 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -35,7 +35,7 @@ bool HasFloatingPoint(const analysis::Type* type) { // Folds an OpcompositeExtract where input is a composite constant. ConstantFoldingRule FoldExtractWithConstants() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; @@ -47,7 +47,6 @@ ConstantFoldingRule FoldExtractWithConstants() { uint32_t element_index = inst->GetSingleWordInOperand(i); if (c->AsNullConstant()) { // Return Null for the return type. - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); @@ -63,7 +62,7 @@ ConstantFoldingRule FoldExtractWithConstants() { } ConstantFoldingRule FoldVectorShuffleWithConstants() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { assert(inst->opcode() == SpvOpVectorShuffle); @@ -73,7 +72,6 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() { return nullptr; } - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* element_type = c1->type()->AsVector()->element_type(); @@ -116,11 +114,10 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() { } ConstantFoldingRule FoldVectorTimesScalar() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { assert(inst->opcode() == SpvOpVectorTimesScalar); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); @@ -194,10 +191,9 @@ ConstantFoldingRule FoldVectorTimesScalar() { ConstantFoldingRule FoldCompositeWithConstants() { // Folds an OpCompositeConstruct where all of the inputs are constants to a // constant. A new constant is created if necessary. - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); @@ -238,10 +234,9 @@ using BinaryScalarFoldingRule = std::function<const analysis::Constant*( // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| // whose element type is |Float| or |Integer|. ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { - return [scalar_rule](opt::Instruction* inst, + return [scalar_rule](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -288,10 +283,9 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { // that |constants| contains 2 entries. If they are not |nullptr|, then their // type is either |Float| or a |Vector| whose element type is |Float|. ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { - return [scalar_rule](opt::Instruction* inst, + return [scalar_rule](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -518,10 +512,9 @@ ConstantFoldingRule FoldFUnordGreaterThanEqual() { // Folds an OpDot where all of the inputs are constants to a // constant. A new constant is created if necessary. ConstantFoldingRule FoldOpDotWithConstants() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); @@ -614,10 +607,9 @@ UnaryScalarFoldingRule FoldFNegateOp() { ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { - return [cmp_opcode](opt::Instruction* inst, + return [cmp_opcode](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); diff --git a/source/opt/const_folding_rules.h b/source/opt/const_folding_rules.h index 354ec6b9..543df1c6 100644 --- a/source/opt/const_folding_rules.h +++ b/source/opt/const_folding_rules.h @@ -48,7 +48,7 @@ namespace opt { // fold an instruction, the later rules will not be attempted. using ConstantFoldingRule = std::function<const analysis::Constant*( - opt::Instruction* inst, + IRContext* ctx, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants)>; class ConstantFoldingRules { diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp index e7ff1814..fcef7509 100644 --- a/source/opt/fold.cpp +++ b/source/opt/fold.cpp @@ -178,7 +178,6 @@ uint32_t InstructionFolder::OperateWords( } bool InstructionFolder::FoldInstructionInternal(opt::Instruction* inst) const { - opt::IRContext* context = inst->context(); auto identity_map = [](uint32_t id) { return id; }; opt::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); if (folded_inst != nullptr) { @@ -188,13 +187,13 @@ bool InstructionFolder::FoldInstructionInternal(opt::Instruction* inst) const { } SpvOp opcode = inst->opcode(); - analysis::ConstantManager* const_manager = context->get_constant_mgr(); + analysis::ConstantManager* const_manager = context_->get_constant_mgr(); std::vector<const analysis::Constant*> constants = const_manager->GetOperandConstants(inst); for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) { - if (rule(inst, constants)) { + if (rule(context_, inst, constants)) { return true; } } @@ -233,8 +232,7 @@ bool InstructionFolder::FoldBinaryIntegerOpToConstant( opt::Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, uint32_t* result) const { SpvOp opcode = inst->opcode(); - opt::IRContext* context = inst->context(); - analysis::ConstantManager* const_manger = context->get_constant_mgr(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); uint32_t ids[2]; const analysis::IntConstant* constants[2]; @@ -417,8 +415,7 @@ bool InstructionFolder::FoldBinaryBooleanOpToConstant( opt::Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, uint32_t* result) const { SpvOp opcode = inst->opcode(); - opt::IRContext* context = inst->context(); - analysis::ConstantManager* const_manger = context->get_constant_mgr(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); uint32_t ids[2]; const analysis::BoolConstant* constants[2]; @@ -574,8 +571,7 @@ bool InstructionFolder::IsFoldableConstant( opt::Instruction* InstructionFolder::FoldInstructionToConstant( opt::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const { - opt::IRContext* context = inst->context(); - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); if (!inst->IsFoldableByFoldScalar() && !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) { @@ -600,13 +596,13 @@ opt::Instruction* InstructionFolder::FoldInstructionToConstant( const analysis::Constant* folded_const = nullptr; for (auto rule : GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) { - folded_const = rule(inst, constants); + folded_const = rule(context_, inst, constants); if (folded_const != nullptr) { opt::Instruction* const_inst = const_mgr->GetDefiningInstruction(folded_const, inst->type_id()); assert(const_inst->type_id() == inst->type_id()); // May be a new instruction that needs to be analysed. - context->UpdateDefUse(const_inst); + context_->UpdateDefUse(const_inst); return const_inst; } } diff --git a/source/opt/fold.h b/source/opt/fold.h index c4e0dbc2..0e027b1d 100644 --- a/source/opt/fold.h +++ b/source/opt/fold.h @@ -28,6 +28,8 @@ namespace opt { class InstructionFolder { public: + explicit InstructionFolder(IRContext* context) : context_(context) {} + // Returns the result of folding a scalar instruction with the given |opcode| // and |operands|. Each entry in |operands| is a pointer to an // analysis::Constant instance, which should've been created with the constant @@ -154,6 +156,8 @@ class InstructionFolder { const std::function<uint32_t(uint32_t)>& id_map, uint32_t* result) const; + IRContext* context_; + // Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|. ConstantFoldingRules const_folding_rules; diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 66829860..edc73488 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -199,10 +199,9 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr, // Replaces fdiv where second operand is constant with fmul. FoldingRule ReciprocalFDiv() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFDiv); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -244,11 +243,10 @@ FoldingRule ReciprocalFDiv() { // Elides consecutive negate instructions. FoldingRule MergeNegateArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) @@ -279,11 +277,10 @@ FoldingRule MergeNegateArithmetic() { // -(x / 2) = x / -2 // -(2 / x) = -2 / x FoldingRule MergeNegateMulDivArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -338,11 +335,10 @@ FoldingRule MergeNegateMulDivArithmetic() { // -(x - 2) = 2 - x // -(2 - x) = x - 2 FoldingRule MergeNegateAddSubArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -571,10 +567,9 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, // (x * 2) * 2 = x * 4 // (2 * x) * 2 = x * 4 FoldingRule MergeMulMulArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -624,10 +619,9 @@ FoldingRule MergeMulMulArithmetic() { // (y / x) * x = y // x * (y / x) = y FoldingRule MergeMulDivArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFMul); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); @@ -699,10 +693,9 @@ FoldingRule MergeMulDivArithmetic() { // (-x) * 2 = x * -2 // 2 * (-x) = x * -2 FoldingRule MergeMulNegateArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -740,10 +733,9 @@ FoldingRule MergeMulNegateArithmetic() { // (4 / x) / 2 = 2 / x // (x / 2) / 2 = x / 4 FoldingRule MergeDivDivArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFDiv); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -812,10 +804,9 @@ FoldingRule MergeDivDivArithmetic() { // (x * y) / x = y // (y * x) / x = y FoldingRule MergeDivMulArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFDiv); - opt::IRContext* context = inst->context(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -885,11 +876,10 @@ FoldingRule MergeDivMulArithmetic() { // (-x) / 2 = x / -2 // 2 / (-x) = 2 / -x FoldingRule MergeDivNegateArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv || inst->opcode() == SpvOpUDiv); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -931,10 +921,9 @@ FoldingRule MergeDivNegateArithmetic() { // (-x) + 2 = 2 - x // 2 + (-x) = 2 - x FoldingRule MergeAddNegateArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); @@ -965,10 +954,9 @@ FoldingRule MergeAddNegateArithmetic() { // (-x) - 2 = -2 - x // 2 - (-x) = x + 2 FoldingRule MergeSubNegateArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -1014,10 +1002,9 @@ FoldingRule MergeSubNegateArithmetic() { // 2 + (x + 2) = x + 4 // 2 + (2 + x) = x + 4 FoldingRule MergeAddAddArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1062,10 +1049,9 @@ FoldingRule MergeAddAddArithmetic() { // 2 + (x - 2) = x + 0 // 2 + (2 - x) = 4 - x FoldingRule MergeAddSubArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1122,10 +1108,9 @@ FoldingRule MergeAddSubArithmetic() { // 2 - (x + 2) = 0 - x // 2 - (2 + x) = 0 - x FoldingRule MergeSubAddArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1188,10 +1173,9 @@ FoldingRule MergeSubAddArithmetic() { // 2 - (x - 2) = 4 - x // 2 - (2 - x) = x + 0 FoldingRule MergeSubSubArithmetic() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - opt::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1255,7 +1239,7 @@ FoldingRule MergeSubSubArithmetic() { } FoldingRule IntMultipleBy1() { - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul."); for (uint32_t i = 0; i < 2; i++) { @@ -1281,14 +1265,14 @@ FoldingRule IntMultipleBy1() { } FoldingRule CompositeConstructFeedingExtract() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { // If the input to an OpCompositeExtract is an OpCompositeConstruct, // then we can simply use the appropriate element in the construction. assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); - analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); opt::Instruction* cinst = def_use_mgr->GetDef(cid); @@ -1366,11 +1350,11 @@ FoldingRule CompositeExtractFeedingConstruct() { // // This is a common code pattern because of the way that scalar replacement // works. - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpCompositeConstruct && "Wrong opcode. Should be OpCompositeConstruct."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); uint32_t original_id = 0; // Check each element to make sure they are: @@ -1417,11 +1401,11 @@ FoldingRule CompositeExtractFeedingConstruct() { } FoldingRule InsertFeedingExtract() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); opt::Instruction* cinst = def_use_mgr->GetDef(cid); @@ -1492,12 +1476,12 @@ FoldingRule InsertFeedingExtract() { // operands of the VectorShuffle. We just need to adjust the index in the // extract instruction. FoldingRule VectorShuffleFeedingExtract() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); - analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); opt::Instruction* cinst = def_use_mgr->GetDef(cid); @@ -1540,11 +1524,10 @@ FoldingRule VectorShuffleFeedingExtract() { // corresponding |a| in the FMix is 0 or 1, we can extract from one of the // operands of the FMix. FoldingRule FMixFeedingExtract() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - opt::IRContext* context = inst->context(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1557,7 +1540,7 @@ FoldingRule FMixFeedingExtract() { } uint32_t inst_set_id = - inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) != inst_set_id || @@ -1568,7 +1551,7 @@ FoldingRule FMixFeedingExtract() { // Get the |a| for the FMix instruction. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx); - std::unique_ptr<opt::Instruction> a(inst->Clone(inst->context())); + std::unique_ptr<opt::Instruction> a(inst->Clone(context)); a->SetInOperand(kExtractCompositeIdInIdx, {a_id}); context->get_instruction_folder().FoldInstruction(a.get()); @@ -1612,7 +1595,7 @@ FoldingRule FMixFeedingExtract() { FoldingRule RedundantPhi() { // An OpPhi instruction where all values are the same or the result of the phi // itself, can be replaced by the value itself. - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); @@ -1647,7 +1630,7 @@ FoldingRule RedundantPhi() { FoldingRule RedundantSelect() { // An OpSelect instruction where both values are the same or the condition is // constant can be replaced by one of the values - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpSelect && "Wrong opcode. Should be OpSelect."); @@ -1763,7 +1746,7 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { } FoldingRule RedundantFAdd() { - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd."); assert(constants.size() == 2); @@ -1788,7 +1771,7 @@ FoldingRule RedundantFAdd() { } FoldingRule RedundantFSub() { - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub."); assert(constants.size() == 2); @@ -1819,7 +1802,7 @@ FoldingRule RedundantFSub() { } FoldingRule RedundantFMul() { - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul."); assert(constants.size() == 2); @@ -1852,7 +1835,7 @@ FoldingRule RedundantFMul() { } FoldingRule RedundantFDiv() { - return [](opt::Instruction* inst, + return [](IRContext*, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv."); assert(constants.size() == 2); @@ -1883,7 +1866,7 @@ FoldingRule RedundantFDiv() { } FoldingRule RedundantFMix() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpExtInst && "Wrong opcode. Should be OpExtInst."); @@ -1893,7 +1876,7 @@ FoldingRule RedundantFMix() { } uint32_t instSetId = - inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == @@ -1920,11 +1903,10 @@ FoldingRule RedundantFMix() { // This rule look for a dot with a constant vector containing a single 1 and // the rest 0s. This is the same as doing an extract. FoldingRule DotProductDoingExtract() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants) { assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot."); - opt::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); if (!inst->IsFloatingPointFoldingAllowed()) { @@ -1995,11 +1977,10 @@ FoldingRule DotProductDoingExtract() { // TODO: We can do something similar for OpImageWrite, but checking for volatile // is complicated. Waiting to see if it is needed. FoldingRule StoringUndef() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore."); - opt::IRContext* context = inst->context(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); // If this is a volatile store, the store cannot be removed. @@ -2020,12 +2001,11 @@ FoldingRule StoringUndef() { } FoldingRule VectorShuffleFeedingShuffle() { - return [](opt::Instruction* inst, + return [](IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>&) { assert(inst->opcode() == SpvOpVectorShuffle && "Wrong opcode. Should be OpVectorShuffle."); - IRContext* context = inst->context(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); diff --git a/source/opt/folding_rules.h b/source/opt/folding_rules.h index 19fe2174..d807ec1e 100644 --- a/source/opt/folding_rules.h +++ b/source/opt/folding_rules.h @@ -52,7 +52,7 @@ namespace opt { // the later rules will not be attempted. using FoldingRule = std::function<bool( - opt::Instruction* inst, + IRContext* context, opt::Instruction* inst, const std::vector<const analysis::Constant*>& constants)>; class FoldingRules { diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 5d37fcbc..ac4e4578 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -445,7 +445,7 @@ class IRContext { const opt::InstructionFolder& get_instruction_folder() { if (!inst_folder_) { - inst_folder_.reset(new opt::InstructionFolder()); + inst_folder_.reset(new opt::InstructionFolder(this)); } return *inst_folder_; } |