summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlan Baker <alanbaker@google.com>2018-07-23 11:23:11 -0400
committerAlan Baker <alanbaker@google.com>2018-07-31 13:42:47 -0400
commit755e5c94207ede680cf5f1b84626f20e3a24524f (patch)
treeea7cfad040a3e60d6cc284f0922fc72cfc63896f
parent8a0ec22f1303c158a10b0b6604dd8473d63f6131 (diff)
Transform to combine consecutive access chains
* Combines OpAccessChain, OpInBoundsAccessChain, OpPtrAccessChain and OpInBoundsPtrAccessChain * New folding rule to fold add with 0 for integers * Converts to a bitcast if the result type does not match the operand type V
-rw-r--r--Android.mk1
-rw-r--r--include/spirv-tools/optimizer.hpp5
-rw-r--r--source/opt/CMakeLists.txt2
-rw-r--r--source/opt/combine_access_chains.cpp288
-rw-r--r--source/opt/combine_access_chains.h80
-rw-r--r--source/opt/folding_rules.cpp32
-rw-r--r--source/opt/optimizer.cpp8
-rw-r--r--source/opt/passes.h1
-rw-r--r--test/opt/CMakeLists.txt5
-rw-r--r--test/opt/combine_access_chains_test.cpp752
-rw-r--r--test/opt/fold_test.cpp126
-rw-r--r--tools/opt/opt.cpp3
12 files changed, 1299 insertions, 4 deletions
diff --git a/Android.mk b/Android.mk
index ffd3064b..a1cf282b 100644
--- a/Android.mk
+++ b/Android.mk
@@ -68,6 +68,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/cfg.cpp \
source/opt/cfg_cleanup_pass.cpp \
source/opt/ccp_pass.cpp \
+ source/opt/combine_access_chains.cpp \
source/opt/common_uniform_elim_pass.cpp \
source/opt/compact_ids_pass.cpp \
source/opt/composite.cpp \
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 6e193bdb..2849e28c 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -636,6 +636,11 @@ Optimizer::PassToken CreateVectorDCEPass();
// a load of the specific elements.
Optimizer::PassToken CreateReduceLoadSizePass();
+// Create a pass to combine chained access chains.
+// This pass looks for access chains fed by other access chains and combines
+// them into a single instruction where possible.
+Optimizer::PassToken CreateCombineAccessChainsPass();
+
} // namespace spvtools
#endif // SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 484dd5f3..d9fc8c9f 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -19,6 +19,7 @@ add_library(SPIRV-Tools-opt
ccp_pass.h
cfg_cleanup_pass.h
cfg.h
+ combine_access_chains.h
common_uniform_elim_pass.h
compact_ids_pass.h
composite.h
@@ -106,6 +107,7 @@ add_library(SPIRV-Tools-opt
ccp_pass.cpp
cfg_cleanup_pass.cpp
cfg.cpp
+ combine_access_chains.cpp
common_uniform_elim_pass.cpp
compact_ids_pass.cpp
composite.cpp
diff --git a/source/opt/combine_access_chains.cpp b/source/opt/combine_access_chains.cpp
new file mode 100644
index 00000000..ccefb258
--- /dev/null
+++ b/source/opt/combine_access_chains.cpp
@@ -0,0 +1,288 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "combine_access_chains.h"
+
+#include "constants.h"
+#include "ir_builder.h"
+#include "ir_context.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status CombineAccessChains::Process() {
+ bool modified = false;
+
+ for (auto& function : *get_module()) {
+ modified |= ProcessFunction(function);
+ }
+
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+}
+
+bool CombineAccessChains::ProcessFunction(Function& function) {
+ bool modified = false;
+
+ cfg()->ForEachBlockInReversePostOrder(
+ function.entry().get(), [&modified, this](BasicBlock* block) {
+ block->ForEachInst([&modified, this](Instruction* inst) {
+ switch (inst->opcode()) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ case SpvOpPtrAccessChain:
+ case SpvOpInBoundsPtrAccessChain:
+ modified |= CombineAccessChain(inst);
+ break;
+ default:
+ break;
+ }
+ });
+ });
+
+ return modified;
+}
+
+uint32_t CombineAccessChains::GetConstantValue(
+ const analysis::Constant* constant_inst) {
+ if (constant_inst->type()->AsInteger()->width() <= 32) {
+ if (constant_inst->type()->AsInteger()->IsSigned()) {
+ return static_cast<uint32_t>(constant_inst->GetS32());
+ } else {
+ return constant_inst->GetU32();
+ }
+ } else {
+ assert(false);
+ return 0u;
+ }
+}
+
+uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
+ uint32_t array_stride = 0;
+ context()->get_decoration_mgr()->WhileEachDecoration(
+ inst->type_id(), SpvDecorationArrayStride,
+ [&array_stride](const Instruction& decoration) {
+ assert(decoration.opcode() != SpvOpDecorateId);
+ if (decoration.opcode() == SpvOpDecorate) {
+ array_stride = decoration.GetSingleWordInOperand(1);
+ } else {
+ array_stride = decoration.GetSingleWordInOperand(2);
+ }
+ return false;
+ });
+ return array_stride;
+}
+
+const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+
+ Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+ const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
+ assert(type->AsPointer());
+ type = type->AsPointer()->pointee_type();
+ std::vector<uint32_t> element_indices;
+ uint32_t starting_index = 1;
+ if (IsPtrAccessChain(inst->opcode())) {
+ // Skip the first index of OpPtrAccessChain as it does not affect type
+ // resolution.
+ starting_index = 2;
+ }
+ for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
+ Instruction* index_inst =
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
+ const analysis::Constant* index_constant =
+ context()->get_constant_mgr()->GetConstantFromInst(index_inst);
+ if (index_constant) {
+ uint32_t index_value = GetConstantValue(index_constant);
+ element_indices.push_back(index_value);
+ } else {
+ // This index must not matter to resolve the type in valid SPIR-V.
+ element_indices.push_back(0);
+ }
+ }
+ type = type_mgr->GetMemberType(type, element_indices);
+ return type;
+}
+
+bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
+ Instruction* inst,
+ std::vector<Operand>* new_operands) {
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
+
+ Instruction* last_index_inst = def_use_mgr->GetDef(
+ ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
+ const analysis::Constant* last_index_constant =
+ constant_mgr->GetConstantFromInst(last_index_inst);
+
+ Instruction* element_inst =
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
+ const analysis::Constant* element_constant =
+ constant_mgr->GetConstantFromInst(element_inst);
+
+ // Combine the last index of the AccessChain (|ptr_inst|) with the element
+ // operand of the PtrAccessChain (|inst|).
+ const bool combining_element_operands =
+ IsPtrAccessChain(inst->opcode()) &&
+ IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
+ uint32_t new_value_id = 0;
+ const analysis::Type* type = GetIndexedType(ptr_input);
+ if (last_index_constant && element_constant) {
+ // Combine the constants.
+ uint32_t new_value = GetConstantValue(last_index_constant) +
+ GetConstantValue(element_constant);
+ const analysis::Constant* new_value_constant =
+ constant_mgr->GetConstant(last_index_constant->type(), {new_value});
+ Instruction* new_value_inst =
+ constant_mgr->GetDefiningInstruction(new_value_constant);
+ new_value_id = new_value_inst->result_id();
+ } else if (!type->AsStruct() || combining_element_operands) {
+ // Generate an addition of the two indices.
+ InstructionBuilder builder(
+ context(), inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
+ last_index_inst->result_id(),
+ element_inst->result_id());
+ new_value_id = addition->result_id();
+ } else {
+ // Indexing into structs must be constant, so bail out here.
+ return false;
+ }
+ new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
+ return true;
+}
+
+bool CombineAccessChains::CreateNewInputOperands(
+ Instruction* ptr_input, Instruction* inst,
+ std::vector<Operand>* new_operands) {
+ // Start by copying all the input operands of the feeder access chain.
+ for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
+ new_operands->push_back(ptr_input->GetInOperand(i));
+ }
+
+ // Deal with the last index of the feeder access chain.
+ if (IsPtrAccessChain(inst->opcode())) {
+ // The last index of the feeder should be combined with the element operand
+ // of |inst|.
+ if (!CombineIndices(ptr_input, inst, new_operands)) return false;
+ } else {
+ // The indices aren't being combined so now add the last index operand of
+ // |ptr_input|.
+ new_operands->push_back(
+ ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
+ }
+
+ // Copy the remaining index operands.
+ uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
+ for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
+ new_operands->push_back(inst->GetInOperand(i));
+ }
+
+ return true;
+}
+
+bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
+ assert((inst->opcode() == SpvOpPtrAccessChain ||
+ inst->opcode() == SpvOpAccessChain ||
+ inst->opcode() == SpvOpInBoundsAccessChain ||
+ inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
+ "Wrong opcode. Expected an access chain.");
+
+ Instruction* ptr_input =
+ context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
+ if (ptr_input->opcode() != SpvOpAccessChain &&
+ ptr_input->opcode() != SpvOpInBoundsAccessChain &&
+ ptr_input->opcode() != SpvOpPtrAccessChain &&
+ ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
+ return false;
+ }
+
+ if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
+
+ // Handles the following cases:
+ // 1. |ptr_input| is an index-less access chain. Replace the pointer
+ // in |inst| with |ptr_input|'s pointer.
+ // 2. |inst| is a index-less access chain. Change |inst| to an
+ // OpCopyObject.
+ // 3. |inst| is not a pointer access chain.
+ // |inst|'s indices are appended to |ptr_input|'s indices.
+ // 4. |ptr_input| is not pointer access chain.
+ // |inst| is a pointer access chain.
+ // |inst|'s element operand is combined with the last index in
+ // |ptr_input| to form a new operand.
+ // 5. |ptr_input| is a pointer access chain.
+ // Like the above scenario, |inst|'s element operand is combined
+ // with |ptr_input|'s last index. This results is either a
+ // combined element operand or combined regular index.
+
+ // TODO(alan-baker): Support this properly. Requires analyzing the
+ // size/alignment of the type and converting the stride into an element
+ // index.
+ uint32_t array_stride = GetArrayStride(ptr_input);
+ if (array_stride != 0) return false;
+
+ if (ptr_input->NumInOperands() == 1) {
+ // The input is effectively a no-op.
+ inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
+ context()->AnalyzeUses(inst);
+ } else if (inst->NumInOperands() == 1) {
+ // |inst| is a no-op, change it to a copy. Instruction simplification will
+ // clean it up.
+ inst->SetOpcode(SpvOpCopyObject);
+ } else {
+ std::vector<Operand> new_operands;
+ if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
+
+ // Update the instruction.
+ inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
+ inst->SetInOperands(std::move(new_operands));
+ context()->AnalyzeUses(inst);
+ }
+ return true;
+}
+
+SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
+ auto IsInBounds = [](SpvOp opcode) {
+ return opcode == SpvOpInBoundsPtrAccessChain ||
+ opcode == SpvOpInBoundsAccessChain;
+ };
+
+ if (input_opcode == SpvOpInBoundsPtrAccessChain) {
+ if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
+ } else if (input_opcode == SpvOpInBoundsAccessChain) {
+ if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
+ }
+
+ return input_opcode;
+}
+
+bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
+ return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
+}
+
+bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
+ for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
+ Instruction* index_inst =
+ context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
+ const analysis::Type* index_type =
+ context()->get_type_mgr()->GetType(index_inst->type_id());
+ if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
+ return true;
+ }
+ return false;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/combine_access_chains.h b/source/opt/combine_access_chains.h
new file mode 100644
index 00000000..63027f12
--- /dev/null
+++ b/source/opt/combine_access_chains.h
@@ -0,0 +1,80 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_
+#define LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_
+
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class CombineAccessChains : public Pass {
+ public:
+ const char* name() const override { return "combine-access-chains"; }
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+ IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
+ IRContext::kAnalysisNameMap;
+ }
+
+ private:
+ // Combine access chains in |function|. Blocks are processed in reverse
+ // post-order. Returns true if the function is modified.
+ bool ProcessFunction(Function& function);
+
+ // Combines an access chain (normal, in bounds or pointer) |inst| if its base
+ // pointer is another access chain. Returns true if the access chain was
+ // modified.
+ bool CombineAccessChain(Instruction* inst);
+
+ // Returns the value of |constant_inst| as a uint32_t.
+ uint32_t GetConstantValue(const analysis::Constant* constant_inst);
+
+ // Returns the array stride of |inst|'s type.
+ uint32_t GetArrayStride(const Instruction* inst);
+
+ // Returns the type by resolving the index operands |inst|. |inst| must be an
+ // access chain instruction.
+ const analysis::Type* GetIndexedType(Instruction* inst);
+
+ // Populates |new_operands| with the operands for the combined access chain.
+ // Returns false if the access chains cannot be combined.
+ bool CreateNewInputOperands(Instruction* ptr_input, Instruction* inst,
+ std::vector<Operand>* new_operands);
+
+ // Combines the last index of |ptr_input| with the element operand of |inst|.
+ // Adds the combined operand to |new_operands|.
+ bool CombineIndices(Instruction* ptr_input, Instruction* inst,
+ std::vector<Operand>* new_operands);
+
+ // Returns the opcode to use for the combined access chain.
+ SpvOp UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode);
+
+ // Returns true if |opcode| is a pointer access chain.
+ bool IsPtrAccessChain(SpvOp opcode);
+
+ // Returns true if |inst| (an access chain) has 64-bit indices.
+ bool Has64BitIndices(Instruction* inst);
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index c3daa6fc..7faa3610 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1907,6 +1907,37 @@ FoldingRule RedundantFMix() {
};
}
+// This rule handles addition of zero for integers.
+FoldingRule RedundantIAdd() {
+ return [](IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants) {
+ assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
+
+ uint32_t operand = std::numeric_limits<uint32_t>::max();
+ const analysis::Type* operand_type = nullptr;
+ if (constants[0] && constants[0]->IsZero()) {
+ operand = inst->GetSingleWordInOperand(1);
+ operand_type = constants[0]->type();
+ } else if (constants[1] && constants[1]->IsZero()) {
+ operand = inst->GetSingleWordInOperand(0);
+ operand_type = constants[1]->type();
+ }
+
+ if (operand != std::numeric_limits<uint32_t>::max()) {
+ const analysis::Type* inst_type =
+ context->get_type_mgr()->GetType(inst->type_id());
+ if (inst_type->IsSame(operand_type)) {
+ inst->SetOpcode(SpvOpCopyObject);
+ } else {
+ inst->SetOpcode(SpvOpBitcast);
+ }
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
+ return true;
+ }
+ return false;
+ };
+}
+
// This rule look for a dot with a constant vector containing a single 1 and
// the rest 0s. This is the same as doing an extract.
FoldingRule DotProductDoingExtract() {
@@ -2177,6 +2208,7 @@ FoldingRules::FoldingRules() {
rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
+ rules_[SpvOpIAdd].push_back(RedundantIAdd());
rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index f91e80bd..60791ca0 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -159,6 +159,7 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
.RegisterPass(CreateCCPPass())
.RegisterPass(CreateAggressiveDCEPass())
.RegisterPass(CreateRedundancyEliminationPass())
+ .RegisterPass(CreateCombineAccessChainsPass())
.RegisterPass(CreateSimplificationPass())
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
@@ -303,6 +304,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateInlineExhaustivePass());
} else if (pass_name == "inline-entry-points-opaque") {
RegisterPass(CreateInlineOpaquePass());
+ } else if (pass_name == "combine-access-chains") {
+ RegisterPass(CreateCombineAccessChainsPass());
} else if (pass_name == "convert-local-access-chains") {
RegisterPass(CreateLocalAccessChainConvertPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
@@ -710,4 +713,9 @@ Optimizer::PassToken CreateReduceLoadSizePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::ReduceLoadSize>());
}
+
+Optimizer::PassToken CreateCombineAccessChainsPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::CombineAccessChains>());
+}
} // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 07419de4..8931b47a 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -21,6 +21,7 @@
#include "block_merge_pass.h"
#include "ccp_pass.h"
#include "cfg_cleanup_pass.h"
+#include "combine_access_chains.h"
#include "common_uniform_elim_pass.h"
#include "compact_ids_pass.h"
#include "copy_prop_arrays.h"
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index a0b37901..f2741f67 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -332,3 +332,8 @@ add_spvtools_unittest(TARGET constant_manager
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET combine_access_chains
+ SRCS combine_access_chains_test.cpp
+ LIBS SPIRV-Tools-opt
+)
+
diff --git a/test/opt/combine_access_chains_test.cpp b/test/opt/combine_access_chains_test.cpp
new file mode 100644
index 00000000..a9ed9104
--- /dev/null
+++ b/test/opt/combine_access_chains_test.cpp
@@ -0,0 +1,752 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "assembly_builder.h"
+#include "gmock/gmock.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using CombineAccessChainsTest = PassTest<::testing::Test>;
+
+#ifdef SPIRV_EFFCEE
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainConstant) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_0
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsAccessChainConstant) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpInBoundsAccessChain %ptr_Workgroup_uint %var %uint_0
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainCombineConstant) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int2]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_1
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainNonConstant) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[ld1:%\w+]] = OpLoad
+; CHECK: [[ld2:%\w+]] = OpLoad
+; CHECK: [[add:%\w+]] = OpIAdd [[int]] [[ld1]] [[ld2]]
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[add]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%local_var = OpVariable %ptr_Function_uint Function
+%ld1 = OpLoad %uint %local_var
+%gep = OpAccessChain %ptr_Workgroup_uint %var %ld1
+%ld2 = OpLoad %uint %local_var
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld2
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainExtraIndices) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1
+; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4
+%uint_array_4_array_4_array_4 = OpTypeArray %uint_array_4_array_4 %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4
+%ptr_Workgroup_uint_array_4_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4_array_4_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_0
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_2 %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest,
+ PtrAccessChainFromPtrAccessChainCombineElementOperand) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6
+; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int6]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest,
+ PtrAccessChainFromPtrAccessChainOnlyElementOperand) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4
+; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]]
+; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6
+; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest,
+ PtrAccessChainFromPtrAccessCombineNonElementIndex) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 %uint_0
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest,
+ AccessChainFromPtrAccessChainOnlyElementOperand) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, AccessChainFromPtrAccessChainAppend) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1
+; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2
+; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_2
+%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, AccessChainFromAccessChainAppend) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1
+; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2
+; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%ptr_gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1
+%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_2
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, NonConstantStructSlide) {
+ const std::string text = R"(
+; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[ld:%\w+]] = OpLoad
+; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[ld]] [[int0]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%struct = OpTypeStruct %uint %uint
+%ptr_Workgroup_struct = OpTypePointer Workgroup %struct
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%wg_var = OpVariable %ptr_Workgroup_struct Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%1 = OpLabel
+%func_var = OpVariable %ptr_Function_uint Function
+%ld = OpLoad %uint %func_var
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld
+%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_0
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, DontCombineNonConstantStructSlide) {
+ const std::string text = R"(
+; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0
+; CHECK: [[ld:%\w+]] = OpLoad
+; CHECK: [[gep:%\w+]] = OpAccessChain
+; CHECK: OpPtrAccessChain {{%\w+}} [[gep]] [[ld]] [[int0]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_4 = OpConstant %uint 4
+%struct = OpTypeStruct %uint %uint
+%struct_array_4 = OpTypeArray %struct %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%ptr_Workgroup_struct = OpTypePointer Workgroup %struct
+%ptr_Workgroup_struct_array_4 = OpTypePointer Workgroup %struct_array_4
+%wg_var = OpVariable %ptr_Workgroup_struct_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%1 = OpLabel
+%func_var = OpVariable %ptr_Function_uint Function
+%ld = OpLoad %uint %func_var
+%gep = OpAccessChain %ptr_Workgroup_struct %wg_var %uint_0
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, CombineNonConstantStructSlideElement) {
+ const std::string text = R"(
+; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[ld:%\w+]] = OpLoad
+; CHECK: [[add:%\w+]] = OpIAdd {{%\w+}} [[ld]] [[ld]]
+; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[add]] [[int0]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_4 = OpConstant %uint 4
+%struct = OpTypeStruct %uint %uint
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Function_uint = OpTypePointer Function %uint
+%ptr_Workgroup_struct = OpTypePointer Workgroup %struct
+%wg_var = OpVariable %ptr_Workgroup_struct Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%1 = OpLabel
+%func_var = OpVariable %ptr_Function_uint Function
+%ld = OpLoad %uint %func_var
+%gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsPtrAccessChain) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4
+; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]]
+; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6
+; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]]
+OpCapability Shader
+OpCapability VariablePointers
+OpCapability Addresses
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, InBoundsPtrAccessChainFromPtrAccessChain) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4
+; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]]
+; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6
+; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]]
+OpCapability Shader
+OpCapability VariablePointers
+OpCapability Addresses
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest,
+ InBoundsPtrAccessChainFromInBoundsPtrAccessChain) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4
+; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]]
+; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]]
+; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup
+; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6
+; CHECK: OpInBoundsPtrAccessChain [[ptr_array]] [[var]] [[int6]]
+OpCapability Shader
+OpCapability VariablePointers
+OpCapability Addresses
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%uint_3 = OpConstant %uint 3
+%uint_4 = OpConstant %uint 4
+%uint_array_4 = OpTypeArray %uint %uint_4
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4
+%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%main_lab = OpLabel
+%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3
+%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, NoIndexAccessChains) {
+ const std::string text = R"(
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK-NOT: OpConstant
+; CHECK: [[gep:%\w+]] = OpAccessChain {{%\w+}} [[var]]
+; CHECK: OpAccessChain {{%\w+}} [[var]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%var = OpVariable %ptr_Workgroup_uint Workgroup
+%void_func = OpTypeFunction %void
+%func = OpFunction %void None %void_func
+%1 = OpLabel
+%gep1 = OpAccessChain %ptr_Workgroup_uint %var
+%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains) {
+ const std::string text = R"(
+; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK: [[gep:%\w+]] = OpPtrAccessChain {{%\w+}} [[var]] [[int0]]
+; CHECK: OpCopyObject {{%\w+}} [[gep]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%var = OpVariable %ptr_Workgroup_uint Workgroup
+%void_func = OpTypeFunction %void
+%func = OpFunction %void None %void_func
+%1 = OpLabel
+%gep1 = OpPtrAccessChain %ptr_Workgroup_uint %var %uint_0
+%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains2) {
+ const std::string text = R"(
+; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[int0]]
+OpCapability Shader
+OpCapability VariablePointers
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%var = OpVariable %ptr_Workgroup_uint Workgroup
+%void_func = OpTypeFunction %void
+%func = OpFunction %void None %void_func
+%1 = OpLabel
+%gep1 = OpAccessChain %ptr_Workgroup_uint %var
+%gep2 = OpPtrAccessChain %ptr_Workgroup_uint %gep1 %uint_0
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+
+TEST_F(CombineAccessChainsTest, CombineMixedSign) {
+ const std::string text = R"(
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
+; CHECK: OpInBoundsPtrAccessChain {{%\w+}} [[var]] [[uint2]]
+OpCapability Shader
+OpCapability VariablePointers
+OpCapability Addresses
+OpExtension "SPV_KHR_variable_pointers"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%int = OpTypeInt 32 1
+%uint_1 = OpConstant %uint 1
+%int_1 = OpConstant %int 1
+%ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%var = OpVariable %ptr_Workgroup_uint Workgroup
+%void_func = OpTypeFunction %void
+%func = OpFunction %void None %void_func
+%1 = OpLabel
+%gep1 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %var %uint_1
+%gep2 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %gep1 %int_1
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<CombineAccessChains>(text, true);
+}
+#endif // SPIRV_EFFCEE
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 9a870a6e..e4a52c71 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -162,6 +162,7 @@ OpName %main "main"
%_ptr_v2float = OpTypePointer Function %v2float
%_ptr_v2double = OpTypePointer Function %v2double
%short_0 = OpConstant %short 0
+%short_2 = OpConstant %short 2
%short_3 = OpConstant %short 3
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
@@ -176,12 +177,15 @@ OpName %main "main"
%long_2 = OpConstant %long 2
%long_3 = OpConstant %long 3
%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_4 = OpConstant %uint 4
%uint_32 = OpConstant %uint 32
%uint_max = OpConstant %uint 4294967295
%v2int_undef = OpUndef %v2int
+%v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0
+%v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0
%v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
@@ -2589,19 +2593,19 @@ INSTANTIATE_TEST_CASE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTes
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 38: Don't fold 0 + 3 (long), bad length
+ // Test case 38: Don't fold 2 + 3 (long), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
- "%2 = OpIAdd %long %long_0 %long_3\n" +
+ "%2 = OpIAdd %long %long_2 %long_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 39: Don't fold 0 + 3 (short), bad length
+ // Test case 39: Don't fold 2 + 3 (short), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
- "%2 = OpIAdd %short %short_0 %short_3\n" +
+ "%2 = OpIAdd %short %short_2 %short_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
@@ -3326,6 +3330,90 @@ INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFold
2, 3)
));
+INSTANTIATE_TEST_CASE_P(IntegerRedundantFoldingTest, GeneralInstructionFoldingTest,
+ ::testing::Values(
+ // Test case 0: Don't fold n + 1
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_uint Function\n" +
+ "%3 = OpLoad %uint %n\n" +
+ "%2 = OpIAdd %uint %3 %uint_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 1: Don't fold 1 + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_uint Function\n" +
+ "%3 = OpLoad %uint %n\n" +
+ "%2 = OpIAdd %uint %uint_1 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 2: Fold n + 0
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_uint Function\n" +
+ "%3 = OpLoad %uint %n\n" +
+ "%2 = OpIAdd %uint %3 %uint_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 3: Fold 0 + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_uint Function\n" +
+ "%3 = OpLoad %uint %n\n" +
+ "%2 = OpIAdd %uint %uint_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 4: Don't fold n + (1,0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%2 = OpIAdd %v2int %3 %v2int_1_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 5: Don't fold (1,0) + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%2 = OpIAdd %v2int %v2int_1_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 6: Fold n + (0,0)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%2 = OpIAdd %v2int %3 %v2int_0_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3),
+ // Test case 7: Fold (0,0) + n
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%2 = OpIAdd %v2int %v2int_0_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 3)
+));
+
INSTANTIATE_TEST_CASE_P(ClampAndCmpLHS, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't Fold 0.0 < clamp(-1, 1)
@@ -3785,6 +3873,36 @@ TEST_P(MatchingInstructionFoldingTest, Case) {
}
}
+INSTANTIATE_TEST_CASE_P(RedundantIntegerMatching, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: Fold 0 + n (change sign)
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
+ "; CHECK: %2 = OpBitcast [[uint]] %3\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%3 = OpLoad %uint %n\n" +
+ "%2 = OpIAdd %uint %int_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 2, true),
+ // Test case 0: Fold 0 + n (change sign)
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: %2 = OpBitcast [[int]] %3\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%3 = OpLoad %int %n\n" +
+ "%2 = OpIAdd %int %uint_0 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 2, true)
+));
+
INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: fold consecutive fnegate
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index 44047449..2fde5e8c 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -106,6 +106,9 @@ Options (in lexicographical order):
Cleanup the control flow graph. This will remove any unnecessary
code from the CFG like unreachable code. Performed on entry
point call tree functions and exported functions.
+ --combine-access-chains
+ Combines chained access chains to produce a single instruction
+ where possible.
--compact-ids
Remap result ids to a compact range starting from %%1 and without
any gaps.