diff options
author | Alan Baker <alanbaker@google.com> | 2018-07-23 11:23:11 -0400 |
---|---|---|
committer | Alan Baker <alanbaker@google.com> | 2018-07-31 13:42:47 -0400 |
commit | 755e5c94207ede680cf5f1b84626f20e3a24524f (patch) | |
tree | ea7cfad040a3e60d6cc284f0922fc72cfc63896f | |
parent | 8a0ec22f1303c158a10b0b6604dd8473d63f6131 (diff) |
Transform to combine consecutive access chains
* Combines OpAccessChain, OpInBoundsAccessChain, OpPtrAccessChain and
OpInBoundsPtrAccessChain
* New folding rule to fold add with 0 for integers
* Converts to a bitcast if the result type does not match the operand
type
V
-rw-r--r-- | Android.mk | 1 | ||||
-rw-r--r-- | include/spirv-tools/optimizer.hpp | 5 | ||||
-rw-r--r-- | source/opt/CMakeLists.txt | 2 | ||||
-rw-r--r-- | source/opt/combine_access_chains.cpp | 288 | ||||
-rw-r--r-- | source/opt/combine_access_chains.h | 80 | ||||
-rw-r--r-- | source/opt/folding_rules.cpp | 32 | ||||
-rw-r--r-- | source/opt/optimizer.cpp | 8 | ||||
-rw-r--r-- | source/opt/passes.h | 1 | ||||
-rw-r--r-- | test/opt/CMakeLists.txt | 5 | ||||
-rw-r--r-- | test/opt/combine_access_chains_test.cpp | 752 | ||||
-rw-r--r-- | test/opt/fold_test.cpp | 126 | ||||
-rw-r--r-- | tools/opt/opt.cpp | 3 |
12 files changed, 1299 insertions, 4 deletions
@@ -68,6 +68,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/cfg.cpp \ source/opt/cfg_cleanup_pass.cpp \ source/opt/ccp_pass.cpp \ + source/opt/combine_access_chains.cpp \ source/opt/common_uniform_elim_pass.cpp \ source/opt/compact_ids_pass.cpp \ source/opt/composite.cpp \ diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index 6e193bdb..2849e28c 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -636,6 +636,11 @@ Optimizer::PassToken CreateVectorDCEPass(); // a load of the specific elements. Optimizer::PassToken CreateReduceLoadSizePass(); +// Create a pass to combine chained access chains. +// This pass looks for access chains fed by other access chains and combines +// them into a single instruction where possible. +Optimizer::PassToken CreateCombineAccessChainsPass(); + } // namespace spvtools #endif // SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 484dd5f3..d9fc8c9f 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(SPIRV-Tools-opt ccp_pass.h cfg_cleanup_pass.h cfg.h + combine_access_chains.h common_uniform_elim_pass.h compact_ids_pass.h composite.h @@ -106,6 +107,7 @@ add_library(SPIRV-Tools-opt ccp_pass.cpp cfg_cleanup_pass.cpp cfg.cpp + combine_access_chains.cpp common_uniform_elim_pass.cpp compact_ids_pass.cpp composite.cpp diff --git a/source/opt/combine_access_chains.cpp b/source/opt/combine_access_chains.cpp new file mode 100644 index 00000000..ccefb258 --- /dev/null +++ b/source/opt/combine_access_chains.cpp @@ -0,0 +1,288 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "combine_access_chains.h" + +#include "constants.h" +#include "ir_builder.h" +#include "ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status CombineAccessChains::Process() { + bool modified = false; + + for (auto& function : *get_module()) { + modified |= ProcessFunction(function); + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool CombineAccessChains::ProcessFunction(Function& function) { + bool modified = false; + + cfg()->ForEachBlockInReversePostOrder( + function.entry().get(), [&modified, this](BasicBlock* block) { + block->ForEachInst([&modified, this](Instruction* inst) { + switch (inst->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + modified |= CombineAccessChain(inst); + break; + default: + break; + } + }); + }); + + return modified; +} + +uint32_t CombineAccessChains::GetConstantValue( + const analysis::Constant* constant_inst) { + if (constant_inst->type()->AsInteger()->width() <= 32) { + if (constant_inst->type()->AsInteger()->IsSigned()) { + return static_cast<uint32_t>(constant_inst->GetS32()); + } else { + return constant_inst->GetU32(); + } + } else { + assert(false); + return 0u; + } +} + +uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { + uint32_t array_stride = 0; + context()->get_decoration_mgr()->WhileEachDecoration( + inst->type_id(), SpvDecorationArrayStride, + [&array_stride](const Instruction& decoration) { + assert(decoration.opcode() != SpvOpDecorateId); + if (decoration.opcode() == SpvOpDecorate) { + array_stride = decoration.GetSingleWordInOperand(1); + } else { + array_stride = decoration.GetSingleWordInOperand(2); + } + return false; + }); + return array_stride; +} + +const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); + assert(type->AsPointer()); + type = type->AsPointer()->pointee_type(); + std::vector<uint32_t> element_indices; + uint32_t starting_index = 1; + if (IsPtrAccessChain(inst->opcode())) { + // Skip the first index of OpPtrAccessChain as it does not affect type + // resolution. + starting_index = 2; + } + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Constant* index_constant = + context()->get_constant_mgr()->GetConstantFromInst(index_inst); + if (index_constant) { + uint32_t index_value = GetConstantValue(index_constant); + element_indices.push_back(index_value); + } else { + // This index must not matter to resolve the type in valid SPIR-V. + element_indices.push_back(0); + } + } + type = type_mgr->GetMemberType(type, element_indices); + return type; +} + +bool CombineAccessChains::CombineIndices(Instruction* ptr_input, + Instruction* inst, + std::vector<Operand>* new_operands) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); + + Instruction* last_index_inst = def_use_mgr->GetDef( + ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); + const analysis::Constant* last_index_constant = + constant_mgr->GetConstantFromInst(last_index_inst); + + Instruction* element_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + const analysis::Constant* element_constant = + constant_mgr->GetConstantFromInst(element_inst); + + // Combine the last index of the AccessChain (|ptr_inst|) with the element + // operand of the PtrAccessChain (|inst|). + const bool combining_element_operands = + IsPtrAccessChain(inst->opcode()) && + IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; + uint32_t new_value_id = 0; + const analysis::Type* type = GetIndexedType(ptr_input); + if (last_index_constant && element_constant) { + // Combine the constants. + uint32_t new_value = GetConstantValue(last_index_constant) + + GetConstantValue(element_constant); + const analysis::Constant* new_value_constant = + constant_mgr->GetConstant(last_index_constant->type(), {new_value}); + Instruction* new_value_inst = + constant_mgr->GetDefiningInstruction(new_value_constant); + new_value_id = new_value_inst->result_id(); + } else if (!type->AsStruct() || combining_element_operands) { + // Generate an addition of the two indices. + InstructionBuilder builder( + context(), inst, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), + last_index_inst->result_id(), + element_inst->result_id()); + new_value_id = addition->result_id(); + } else { + // Indexing into structs must be constant, so bail out here. + return false; + } + new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); + return true; +} + +bool CombineAccessChains::CreateNewInputOperands( + Instruction* ptr_input, Instruction* inst, + std::vector<Operand>* new_operands) { + // Start by copying all the input operands of the feeder access chain. + for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { + new_operands->push_back(ptr_input->GetInOperand(i)); + } + + // Deal with the last index of the feeder access chain. + if (IsPtrAccessChain(inst->opcode())) { + // The last index of the feeder should be combined with the element operand + // of |inst|. + if (!CombineIndices(ptr_input, inst, new_operands)) return false; + } else { + // The indices aren't being combined so now add the last index operand of + // |ptr_input|. + new_operands->push_back( + ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); + } + + // Copy the remaining index operands. + uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + new_operands->push_back(inst->GetInOperand(i)); + } + + return true; +} + +bool CombineAccessChains::CombineAccessChain(Instruction* inst) { + assert((inst->opcode() == SpvOpPtrAccessChain || + inst->opcode() == SpvOpAccessChain || + inst->opcode() == SpvOpInBoundsAccessChain || + inst->opcode() == SpvOpInBoundsPtrAccessChain) && + "Wrong opcode. Expected an access chain."); + + Instruction* ptr_input = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); + if (ptr_input->opcode() != SpvOpAccessChain && + ptr_input->opcode() != SpvOpInBoundsAccessChain && + ptr_input->opcode() != SpvOpPtrAccessChain && + ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { + return false; + } + + if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; + + // Handles the following cases: + // 1. |ptr_input| is an index-less access chain. Replace the pointer + // in |inst| with |ptr_input|'s pointer. + // 2. |inst| is a index-less access chain. Change |inst| to an + // OpCopyObject. + // 3. |inst| is not a pointer access chain. + // |inst|'s indices are appended to |ptr_input|'s indices. + // 4. |ptr_input| is not pointer access chain. + // |inst| is a pointer access chain. + // |inst|'s element operand is combined with the last index in + // |ptr_input| to form a new operand. + // 5. |ptr_input| is a pointer access chain. + // Like the above scenario, |inst|'s element operand is combined + // with |ptr_input|'s last index. This results is either a + // combined element operand or combined regular index. + + // TODO(alan-baker): Support this properly. Requires analyzing the + // size/alignment of the type and converting the stride into an element + // index. + uint32_t array_stride = GetArrayStride(ptr_input); + if (array_stride != 0) return false; + + if (ptr_input->NumInOperands() == 1) { + // The input is effectively a no-op. + inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); + context()->AnalyzeUses(inst); + } else if (inst->NumInOperands() == 1) { + // |inst| is a no-op, change it to a copy. Instruction simplification will + // clean it up. + inst->SetOpcode(SpvOpCopyObject); + } else { + std::vector<Operand> new_operands; + if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; + + // Update the instruction. + inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); + inst->SetInOperands(std::move(new_operands)); + context()->AnalyzeUses(inst); + } + return true; +} + +SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { + auto IsInBounds = [](SpvOp opcode) { + return opcode == SpvOpInBoundsPtrAccessChain || + opcode == SpvOpInBoundsAccessChain; + }; + + if (input_opcode == SpvOpInBoundsPtrAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; + } else if (input_opcode == SpvOpInBoundsAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpAccessChain; + } + + return input_opcode; +} + +bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { + return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; +} + +bool CombineAccessChains::Has64BitIndices(Instruction* inst) { + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Type* index_type = + context()->get_type_mgr()->GetType(index_inst->type_id()); + if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) + return true; + } + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/combine_access_chains.h b/source/opt/combine_access_chains.h new file mode 100644 index 00000000..63027f12 --- /dev/null +++ b/source/opt/combine_access_chains.h @@ -0,0 +1,80 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_ +#define LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_ + +#include "pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class CombineAccessChains : public Pass { + public: + const char* name() const override { return "combine-access-chains"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; + } + + private: + // Combine access chains in |function|. Blocks are processed in reverse + // post-order. Returns true if the function is modified. + bool ProcessFunction(Function& function); + + // Combines an access chain (normal, in bounds or pointer) |inst| if its base + // pointer is another access chain. Returns true if the access chain was + // modified. + bool CombineAccessChain(Instruction* inst); + + // Returns the value of |constant_inst| as a uint32_t. + uint32_t GetConstantValue(const analysis::Constant* constant_inst); + + // Returns the array stride of |inst|'s type. + uint32_t GetArrayStride(const Instruction* inst); + + // Returns the type by resolving the index operands |inst|. |inst| must be an + // access chain instruction. + const analysis::Type* GetIndexedType(Instruction* inst); + + // Populates |new_operands| with the operands for the combined access chain. + // Returns false if the access chains cannot be combined. + bool CreateNewInputOperands(Instruction* ptr_input, Instruction* inst, + std::vector<Operand>* new_operands); + + // Combines the last index of |ptr_input| with the element operand of |inst|. + // Adds the combined operand to |new_operands|. + bool CombineIndices(Instruction* ptr_input, Instruction* inst, + std::vector<Operand>* new_operands); + + // Returns the opcode to use for the combined access chain. + SpvOp UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode); + + // Returns true if |opcode| is a pointer access chain. + bool IsPtrAccessChain(SpvOp opcode); + + // Returns true if |inst| (an access chain) has 64-bit indices. + bool Has64BitIndices(Instruction* inst); +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_COMBINE_ACCESS_CHAINS_H_ diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index c3daa6fc..7faa3610 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1907,6 +1907,37 @@ FoldingRule RedundantFMix() { }; } +// This rule handles addition of zero for integers. +FoldingRule RedundantIAdd() { + return [](IRContext* context, Instruction* inst, + const std::vector<const analysis::Constant*>& constants) { + assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd."); + + uint32_t operand = std::numeric_limits<uint32_t>::max(); + const analysis::Type* operand_type = nullptr; + if (constants[0] && constants[0]->IsZero()) { + operand = inst->GetSingleWordInOperand(1); + operand_type = constants[0]->type(); + } else if (constants[1] && constants[1]->IsZero()) { + operand = inst->GetSingleWordInOperand(0); + operand_type = constants[1]->type(); + } + + if (operand != std::numeric_limits<uint32_t>::max()) { + const analysis::Type* inst_type = + context->get_type_mgr()->GetType(inst->type_id()); + if (inst_type->IsSame(operand_type)) { + inst->SetOpcode(SpvOpCopyObject); + } else { + inst->SetOpcode(SpvOpBitcast); + } + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); + return true; + } + return false; + }; +} + // This rule look for a dot with a constant vector containing a single 1 and // the rest 0s. This is the same as doing an extract. FoldingRule DotProductDoingExtract() { @@ -2177,6 +2208,7 @@ FoldingRules::FoldingRules() { rules_[SpvOpFSub].push_back(MergeSubAddArithmetic()); rules_[SpvOpFSub].push_back(MergeSubSubArithmetic()); + rules_[SpvOpIAdd].push_back(RedundantIAdd()); rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic()); diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index f91e80bd..60791ca0 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -159,6 +159,7 @@ Optimizer& Optimizer::RegisterPerformancePasses() { .RegisterPass(CreateCCPPass()) .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateRedundancyEliminationPass()) + .RegisterPass(CreateCombineAccessChainsPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) @@ -303,6 +304,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateInlineExhaustivePass()); } else if (pass_name == "inline-entry-points-opaque") { RegisterPass(CreateInlineOpaquePass()); + } else if (pass_name == "combine-access-chains") { + RegisterPass(CreateCombineAccessChainsPass()); } else if (pass_name == "convert-local-access-chains") { RegisterPass(CreateLocalAccessChainConvertPass()); } else if (pass_name == "eliminate-dead-code-aggressive") { @@ -710,4 +713,9 @@ Optimizer::PassToken CreateReduceLoadSizePass() { return MakeUnique<Optimizer::PassToken::Impl>( MakeUnique<opt::ReduceLoadSize>()); } + +Optimizer::PassToken CreateCombineAccessChainsPass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::CombineAccessChains>()); +} } // namespace spvtools diff --git a/source/opt/passes.h b/source/opt/passes.h index 07419de4..8931b47a 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -21,6 +21,7 @@ #include "block_merge_pass.h" #include "ccp_pass.h" #include "cfg_cleanup_pass.h" +#include "combine_access_chains.h" #include "common_uniform_elim_pass.h" #include "compact_ids_pass.h" #include "copy_prop_arrays.h" diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index a0b37901..f2741f67 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -332,3 +332,8 @@ add_spvtools_unittest(TARGET constant_manager LIBS SPIRV-Tools-opt ) +add_spvtools_unittest(TARGET combine_access_chains + SRCS combine_access_chains_test.cpp + LIBS SPIRV-Tools-opt +) + diff --git a/test/opt/combine_access_chains_test.cpp b/test/opt/combine_access_chains_test.cpp new file mode 100644 index 00000000..a9ed9104 --- /dev/null +++ b/test/opt/combine_access_chains_test.cpp @@ -0,0 +1,752 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "assembly_builder.h" +#include "gmock/gmock.h" +#include "pass_fixture.h" +#include "pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CombineAccessChainsTest = PassTest<::testing::Test>; + +#ifdef SPIRV_EFFCEE +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainCombineConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_1 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainNonConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld1:%\w+]] = OpLoad +; CHECK: [[ld2:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd [[int]] [[ld1]] [[ld2]] +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[add]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%local_var = OpVariable %ptr_Function_uint Function +%ld1 = OpLoad %uint %local_var +%gep = OpAccessChain %ptr_Workgroup_uint %var %ld1 +%ld2 = OpLoad %uint %local_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainExtraIndices) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%uint_array_4_array_4_array_4 = OpTypeArray %uint_array_4_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%ptr_Workgroup_uint_array_4_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_2 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainCombineElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int6]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessCombineNonElementIndex) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, + AccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromPtrAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_2 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, NonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, DontCombineNonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[gep:%\w+]] = OpAccessChain +; CHECK: OpPtrAccessChain {{%\w+}} [[gep]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%struct_array_4 = OpTypeArray %struct %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_struct_array_4 = OpTypePointer Workgroup %struct_array_4 +%wg_var = OpVariable %ptr_Workgroup_struct_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpAccessChain %ptr_Workgroup_struct %wg_var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineNonConstantStructSlideElement) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd {{%\w+}} [[ld]] [[ld]] +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[add]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, InBoundsPtrAccessChainFromPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, + InBoundsPtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpInBoundsPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexAccessChains) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpConstant +; CHECK: [[gep:%\w+]] = OpAccessChain {{%\w+}} [[var]] +; CHECK: OpAccessChain {{%\w+}} [[var]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[gep:%\w+]] = OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +; CHECK: OpCopyObject {{%\w+}} [[gep]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpPtrAccessChain %ptr_Workgroup_uint %var %uint_0 +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains2) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpPtrAccessChain %ptr_Workgroup_uint %gep1 %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineMixedSign) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2 +; CHECK: OpInBoundsPtrAccessChain {{%\w+}} [[var]] [[uint2]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%uint_1 = OpConstant %uint 1 +%int_1 = OpConstant %int 1 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %var %uint_1 +%gep2 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %gep1 %int_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<CombineAccessChains>(text, true); +} +#endif // SPIRV_EFFCEE + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 9a870a6e..e4a52c71 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -162,6 +162,7 @@ OpName %main "main" %_ptr_v2float = OpTypePointer Function %v2float %_ptr_v2double = OpTypePointer Function %v2double %short_0 = OpConstant %short 0 +%short_2 = OpConstant %short 2 %short_3 = OpConstant %short 3 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. @@ -176,12 +177,15 @@ OpName %main "main" %long_2 = OpConstant %long 2 %long_3 = OpConstant %long 3 %uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 %uint_2 = OpConstant %uint 2 %uint_3 = OpConstant %uint 3 %uint_4 = OpConstant %uint 4 %uint_32 = OpConstant %uint 32 %uint_max = OpConstant %uint 4294967295 %v2int_undef = OpUndef %v2int +%v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 +%v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0 %v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2 %v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3 %v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 @@ -2589,19 +2593,19 @@ INSTANTIATE_TEST_CASE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 38: Don't fold 0 + 3 (long), bad length + // Test case 38: Don't fold 2 + 3 (long), bad length InstructionFoldingCase<uint32_t>( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + - "%2 = OpIAdd %long %long_0 %long_3\n" + + "%2 = OpIAdd %long %long_2 %long_3\n" + "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 39: Don't fold 0 + 3 (short), bad length + // Test case 39: Don't fold 2 + 3 (short), bad length InstructionFoldingCase<uint32_t>( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + - "%2 = OpIAdd %short %short_0 %short_3\n" + + "%2 = OpIAdd %short %short_2 %short_3\n" + "OpReturn\n" + "OpFunctionEnd", 2, 0), @@ -3326,6 +3330,90 @@ INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFold 2, 3) )); +INSTANTIATE_TEST_CASE_P(IntegerRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n + 1 + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold 1 + n + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Fold n + 0 + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 3: Fold 0 + n + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 4: Don't fold n + (1,0) + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_1_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't fold (1,0) + n + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_1_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Fold n + (0,0) + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 7: Fold (0,0) + n + InstructionFoldingCase<uint32_t>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_0_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + INSTANTIATE_TEST_CASE_P(ClampAndCmpLHS, GeneralInstructionFoldingTest, ::testing::Values( // Test case 0: Don't Fold 0.0 < clamp(-1, 1) @@ -3785,6 +3873,36 @@ TEST_P(MatchingInstructionFoldingTest, Case) { } } +INSTANTIATE_TEST_CASE_P(RedundantIntegerMatching, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase<bool>( + Header() + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: %2 = OpBitcast [[uint]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %int_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true), + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase<bool>( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %2 = OpBitcast [[int]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %int %n\n" + + "%2 = OpIAdd %int %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true) +)); + INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest, ::testing::Values( // Test case 0: fold consecutive fnegate diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 44047449..2fde5e8c 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -106,6 +106,9 @@ Options (in lexicographical order): Cleanup the control flow graph. This will remove any unnecessary code from the CFG like unreachable code. Performed on entry point call tree functions and exported functions. + --combine-access-chains + Combines chained access chains to produce a single instruction + where possible. --compact-ids Remap result ids to a compact range starting from %%1 and without any gaps. |