summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authordan sinclair <dj2@everburning.com>2018-07-12 09:12:23 -0400
committerGitHub <noreply@github.com>2018-07-12 09:12:23 -0400
commit4cc6cd184ad74d35cb3e995f206f85ad23c2febc (patch)
treee3d0057926f103922aa01185c0dd168dbb051247 /source
parentf96b7f1cb9f6a5f06a16882b10e4b1e528e3aaee (diff)
Pass the IRContext into the folding rules. (#1709)
This CL updates the folding rules to receive the IRContext as a paramter instead of retrieving off of the Instruction. Issue #1703
Diffstat (limited to 'source')
-rw-r--r--source/opt/const_folding_rules.cpp24
-rw-r--r--source/opt/const_folding_rules.h2
-rw-r--r--source/opt/fold.cpp18
-rw-r--r--source/opt/fold.h4
-rw-r--r--source/opt/folding_rules.cpp102
-rw-r--r--source/opt/folding_rules.h2
-rw-r--r--source/opt/ir_context.h2
7 files changed, 63 insertions, 91 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 42714b32..f0e413a1 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -35,7 +35,7 @@ bool HasFloatingPoint(const analysis::Type* type) {
// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
@@ -47,7 +47,6 @@ ConstantFoldingRule FoldExtractWithConstants() {
uint32_t element_index = inst->GetSingleWordInOperand(i);
if (c->AsNullConstant()) {
// Return Null for the return type.
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
@@ -63,7 +62,7 @@ ConstantFoldingRule FoldExtractWithConstants() {
}
ConstantFoldingRule FoldVectorShuffleWithConstants() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(inst->opcode() == SpvOpVectorShuffle);
@@ -73,7 +72,6 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() {
return nullptr;
}
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* element_type = c1->type()->AsVector()->element_type();
@@ -116,11 +114,10 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() {
}
ConstantFoldingRule FoldVectorTimesScalar() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(inst->opcode() == SpvOpVectorTimesScalar);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
@@ -194,10 +191,9 @@ ConstantFoldingRule FoldVectorTimesScalar() {
ConstantFoldingRule FoldCompositeWithConstants() {
// Folds an OpCompositeConstruct where all of the inputs are constants to a
// constant. A new constant is created if necessary.
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
@@ -238,10 +234,9 @@ using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
// whose element type is |Float| or |Integer|.
ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
- return [scalar_rule](opt::Instruction* inst,
+ return [scalar_rule](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
@@ -288,10 +283,9 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
// that |constants| contains 2 entries. If they are not |nullptr|, then their
// type is either |Float| or a |Vector| whose element type is |Float|.
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
- return [scalar_rule](opt::Instruction* inst,
+ return [scalar_rule](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
@@ -518,10 +512,9 @@ ConstantFoldingRule FoldFUnordGreaterThanEqual() {
// Folds an OpDot where all of the inputs are constants to a
// constant. A new constant is created if necessary.
ConstantFoldingRule FoldOpDotWithConstants() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
@@ -614,10 +607,9 @@ UnaryScalarFoldingRule FoldFNegateOp() {
ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
- return [cmp_opcode](opt::Instruction* inst,
+ return [cmp_opcode](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
diff --git a/source/opt/const_folding_rules.h b/source/opt/const_folding_rules.h
index 354ec6b9..543df1c6 100644
--- a/source/opt/const_folding_rules.h
+++ b/source/opt/const_folding_rules.h
@@ -48,7 +48,7 @@ namespace opt {
// fold an instruction, the later rules will not be attempted.
using ConstantFoldingRule = std::function<const analysis::Constant*(
- opt::Instruction* inst,
+ IRContext* ctx, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)>;
class ConstantFoldingRules {
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index e7ff1814..fcef7509 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -178,7 +178,6 @@ uint32_t InstructionFolder::OperateWords(
}
bool InstructionFolder::FoldInstructionInternal(opt::Instruction* inst) const {
- opt::IRContext* context = inst->context();
auto identity_map = [](uint32_t id) { return id; };
opt::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
if (folded_inst != nullptr) {
@@ -188,13 +187,13 @@ bool InstructionFolder::FoldInstructionInternal(opt::Instruction* inst) const {
}
SpvOp opcode = inst->opcode();
- analysis::ConstantManager* const_manager = context->get_constant_mgr();
+ analysis::ConstantManager* const_manager = context_->get_constant_mgr();
std::vector<const analysis::Constant*> constants =
const_manager->GetOperandConstants(inst);
for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
- if (rule(inst, constants)) {
+ if (rule(context_, inst, constants)) {
return true;
}
}
@@ -233,8 +232,7 @@ bool InstructionFolder::FoldBinaryIntegerOpToConstant(
opt::Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
uint32_t* result) const {
SpvOp opcode = inst->opcode();
- opt::IRContext* context = inst->context();
- analysis::ConstantManager* const_manger = context->get_constant_mgr();
+ analysis::ConstantManager* const_manger = context_->get_constant_mgr();
uint32_t ids[2];
const analysis::IntConstant* constants[2];
@@ -417,8 +415,7 @@ bool InstructionFolder::FoldBinaryBooleanOpToConstant(
opt::Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
uint32_t* result) const {
SpvOp opcode = inst->opcode();
- opt::IRContext* context = inst->context();
- analysis::ConstantManager* const_manger = context->get_constant_mgr();
+ analysis::ConstantManager* const_manger = context_->get_constant_mgr();
uint32_t ids[2];
const analysis::BoolConstant* constants[2];
@@ -574,8 +571,7 @@ bool InstructionFolder::IsFoldableConstant(
opt::Instruction* InstructionFolder::FoldInstructionToConstant(
opt::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
- opt::IRContext* context = inst->context();
- analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
if (!inst->IsFoldableByFoldScalar() &&
!GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
@@ -600,13 +596,13 @@ opt::Instruction* InstructionFolder::FoldInstructionToConstant(
const analysis::Constant* folded_const = nullptr;
for (auto rule :
GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
- folded_const = rule(inst, constants);
+ folded_const = rule(context_, inst, constants);
if (folded_const != nullptr) {
opt::Instruction* const_inst =
const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
assert(const_inst->type_id() == inst->type_id());
// May be a new instruction that needs to be analysed.
- context->UpdateDefUse(const_inst);
+ context_->UpdateDefUse(const_inst);
return const_inst;
}
}
diff --git a/source/opt/fold.h b/source/opt/fold.h
index c4e0dbc2..0e027b1d 100644
--- a/source/opt/fold.h
+++ b/source/opt/fold.h
@@ -28,6 +28,8 @@ namespace opt {
class InstructionFolder {
public:
+ explicit InstructionFolder(IRContext* context) : context_(context) {}
+
// Returns the result of folding a scalar instruction with the given |opcode|
// and |operands|. Each entry in |operands| is a pointer to an
// analysis::Constant instance, which should've been created with the constant
@@ -154,6 +156,8 @@ class InstructionFolder {
const std::function<uint32_t(uint32_t)>& id_map,
uint32_t* result) const;
+ IRContext* context_;
+
// Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|.
ConstantFoldingRules const_folding_rules;
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 66829860..edc73488 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -199,10 +199,9 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
// Replaces fdiv where second operand is constant with fmul.
FoldingRule ReciprocalFDiv() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -244,11 +243,10 @@ FoldingRule ReciprocalFDiv() {
// Elides consecutive negate instructions.
FoldingRule MergeNegateArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
@@ -279,11 +277,10 @@ FoldingRule MergeNegateArithmetic() {
// -(x / 2) = x / -2
// -(2 / x) = -2 / x
FoldingRule MergeNegateMulDivArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -338,11 +335,10 @@ FoldingRule MergeNegateMulDivArithmetic() {
// -(x - 2) = 2 - x
// -(2 - x) = x - 2
FoldingRule MergeNegateAddSubArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -571,10 +567,9 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
// (x * 2) * 2 = x * 4
// (2 * x) * 2 = x * 4
FoldingRule MergeMulMulArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -624,10 +619,9 @@ FoldingRule MergeMulMulArithmetic() {
// (y / x) * x = y
// x * (y / x) = y
FoldingRule MergeMulDivArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
@@ -699,10 +693,9 @@ FoldingRule MergeMulDivArithmetic() {
// (-x) * 2 = x * -2
// 2 * (-x) = x * -2
FoldingRule MergeMulNegateArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -740,10 +733,9 @@ FoldingRule MergeMulNegateArithmetic() {
// (4 / x) / 2 = 2 / x
// (x / 2) / 2 = x / 4
FoldingRule MergeDivDivArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -812,10 +804,9 @@ FoldingRule MergeDivDivArithmetic() {
// (x * y) / x = y
// (y * x) / x = y
FoldingRule MergeDivMulArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
- opt::IRContext* context = inst->context();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -885,11 +876,10 @@ FoldingRule MergeDivMulArithmetic() {
// (-x) / 2 = x / -2
// 2 / (-x) = 2 / -x
FoldingRule MergeDivNegateArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
inst->opcode() == SpvOpUDiv);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -931,10 +921,9 @@ FoldingRule MergeDivNegateArithmetic() {
// (-x) + 2 = 2 - x
// 2 + (-x) = 2 - x
FoldingRule MergeAddNegateArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
@@ -965,10 +954,9 @@ FoldingRule MergeAddNegateArithmetic() {
// (-x) - 2 = -2 - x
// 2 - (-x) = x + 2
FoldingRule MergeSubNegateArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
@@ -1014,10 +1002,9 @@ FoldingRule MergeSubNegateArithmetic() {
// 2 + (x + 2) = x + 4
// 2 + (2 + x) = x + 4
FoldingRule MergeAddAddArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -1062,10 +1049,9 @@ FoldingRule MergeAddAddArithmetic() {
// 2 + (x - 2) = x + 0
// 2 + (2 - x) = 4 - x
FoldingRule MergeAddSubArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -1122,10 +1108,9 @@ FoldingRule MergeAddSubArithmetic() {
// 2 - (x + 2) = 0 - x
// 2 - (2 + x) = 0 - x
FoldingRule MergeSubAddArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -1188,10 +1173,9 @@ FoldingRule MergeSubAddArithmetic() {
// 2 - (x - 2) = 4 - x
// 2 - (2 - x) = x + 0
FoldingRule MergeSubSubArithmetic() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
- opt::IRContext* context = inst->context();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -1255,7 +1239,7 @@ FoldingRule MergeSubSubArithmetic() {
}
FoldingRule IntMultipleBy1() {
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
for (uint32_t i = 0; i < 2; i++) {
@@ -1281,14 +1265,14 @@ FoldingRule IntMultipleBy1() {
}
FoldingRule CompositeConstructFeedingExtract() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
// then we can simply use the appropriate element in the construction.
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
- analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
- analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
opt::Instruction* cinst = def_use_mgr->GetDef(cid);
@@ -1366,11 +1350,11 @@ FoldingRule CompositeExtractFeedingConstruct() {
//
// This is a common code pattern because of the way that scalar replacement
// works.
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeConstruct &&
"Wrong opcode. Should be OpCompositeConstruct.");
- analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
uint32_t original_id = 0;
// Check each element to make sure they are:
@@ -1417,11 +1401,11 @@ FoldingRule CompositeExtractFeedingConstruct() {
}
FoldingRule InsertFeedingExtract() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
- analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
opt::Instruction* cinst = def_use_mgr->GetDef(cid);
@@ -1492,12 +1476,12 @@ FoldingRule InsertFeedingExtract() {
// operands of the VectorShuffle. We just need to adjust the index in the
// extract instruction.
FoldingRule VectorShuffleFeedingExtract() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
- analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
- analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
opt::Instruction* cinst = def_use_mgr->GetDef(cid);
@@ -1540,11 +1524,10 @@ FoldingRule VectorShuffleFeedingExtract() {
// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
// operands of the FMix.
FoldingRule FMixFeedingExtract() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
- opt::IRContext* context = inst->context();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@@ -1557,7 +1540,7 @@ FoldingRule FMixFeedingExtract() {
}
uint32_t inst_set_id =
- inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
inst_set_id ||
@@ -1568,7 +1551,7 @@ FoldingRule FMixFeedingExtract() {
// Get the |a| for the FMix instruction.
uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
- std::unique_ptr<opt::Instruction> a(inst->Clone(inst->context()));
+ std::unique_ptr<opt::Instruction> a(inst->Clone(context));
a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
context->get_instruction_folder().FoldInstruction(a.get());
@@ -1612,7 +1595,7 @@ FoldingRule FMixFeedingExtract() {
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
@@ -1647,7 +1630,7 @@ FoldingRule RedundantPhi() {
FoldingRule RedundantSelect() {
// An OpSelect instruction where both values are the same or the condition is
// constant can be replaced by one of the values
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpSelect &&
"Wrong opcode. Should be OpSelect.");
@@ -1763,7 +1746,7 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
}
FoldingRule RedundantFAdd() {
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
assert(constants.size() == 2);
@@ -1788,7 +1771,7 @@ FoldingRule RedundantFAdd() {
}
FoldingRule RedundantFSub() {
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
assert(constants.size() == 2);
@@ -1819,7 +1802,7 @@ FoldingRule RedundantFSub() {
}
FoldingRule RedundantFMul() {
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
assert(constants.size() == 2);
@@ -1852,7 +1835,7 @@ FoldingRule RedundantFMul() {
}
FoldingRule RedundantFDiv() {
- return [](opt::Instruction* inst,
+ return [](IRContext*, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
assert(constants.size() == 2);
@@ -1883,7 +1866,7 @@ FoldingRule RedundantFDiv() {
}
FoldingRule RedundantFMix() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Wrong opcode. Should be OpExtInst.");
@@ -1893,7 +1876,7 @@ FoldingRule RedundantFMix() {
}
uint32_t instSetId =
- inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
@@ -1920,11 +1903,10 @@ FoldingRule RedundantFMix() {
// This rule look for a dot with a constant vector containing a single 1 and
// the rest 0s. This is the same as doing an extract.
FoldingRule DotProductDoingExtract() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
- opt::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
@@ -1995,11 +1977,10 @@ FoldingRule DotProductDoingExtract() {
// TODO: We can do something similar for OpImageWrite, but checking for volatile
// is complicated. Waiting to see if it is needed.
FoldingRule StoringUndef() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
- opt::IRContext* context = inst->context();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
// If this is a volatile store, the store cannot be removed.
@@ -2020,12 +2001,11 @@ FoldingRule StoringUndef() {
}
FoldingRule VectorShuffleFeedingShuffle() {
- return [](opt::Instruction* inst,
+ return [](IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpVectorShuffle &&
"Wrong opcode. Should be OpVectorShuffle.");
- IRContext* context = inst->context();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
diff --git a/source/opt/folding_rules.h b/source/opt/folding_rules.h
index 19fe2174..d807ec1e 100644
--- a/source/opt/folding_rules.h
+++ b/source/opt/folding_rules.h
@@ -52,7 +52,7 @@ namespace opt {
// the later rules will not be attempted.
using FoldingRule = std::function<bool(
- opt::Instruction* inst,
+ IRContext* context, opt::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)>;
class FoldingRules {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 5d37fcbc..ac4e4578 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -445,7 +445,7 @@ class IRContext {
const opt::InstructionFolder& get_instruction_folder() {
if (!inst_folder_) {
- inst_folder_.reset(new opt::InstructionFolder());
+ inst_folder_.reset(new opt::InstructionFolder(this));
}
return *inst_folder_;
}