diff options
39 files changed, 818 insertions, 420 deletions
diff --git a/include/spirv-tools/linker.hpp b/include/spirv-tools/linker.hpp index 43c725da..a36aa75f 100644 --- a/include/spirv-tools/linker.hpp +++ b/include/spirv-tools/linker.hpp @@ -26,7 +26,9 @@ namespace spvtools { class LinkerOptions { public: - LinkerOptions() : createLibrary_(false) {} + LinkerOptions() + : createLibrary_(false), + verifyIds_(false) {} // Returns whether a library or an executable should be produced by the // linking phase. @@ -36,13 +38,25 @@ class LinkerOptions { // The returned value will be true if creating a library, and false if // creating an executable. bool GetCreateLibrary() const { return createLibrary_; } + // Sets whether a library or an executable should be produced. void SetCreateLibrary(bool create_library) { createLibrary_ = create_library; } + // Returns whether to verify the uniqueness of the unique ids in the merged + // context. + bool GetVerifyIds() const { return verifyIds_; } + + // Sets whether to verify the uniqueness of the unique ids in the merged + // context. + void SetVerifyIds(bool verifyIds) { + verifyIds_ = verifyIds; + } + private: bool createLibrary_; + bool verifyIds_; }; class Linker { diff --git a/source/link/linker.cpp b/source/link/linker.cpp index 59ea36ce..7f1b5cd1 100644 --- a/source/link/linker.cpp +++ b/source/link/linker.cpp @@ -38,6 +38,7 @@ namespace spvtools { using ir::Instruction; +using ir::IRContext; using ir::Module; using ir::Operand; using opt::PassManager; @@ -69,31 +70,34 @@ using LinkageTable = std::vector<LinkageEntry>; // is returned in |max_id_bound|. // // Both |modules| and |max_id_bound| should not be null, and |modules| should -// not be empty either. +// not be empty either. Furthermore |modules| should not contain any null +// pointers. static spv_result_t ShiftIdsInModules( const MessageConsumer& consumer, - std::vector<std::unique_ptr<ir::Module>>* modules, uint32_t* max_id_bound); + std::vector<ir::Module*>* modules, uint32_t* max_id_bound); // Generates the header for the linked module and returns it in |header|. // -// |header| should not be null, |modules| should not be empty and -// |max_id_bound| should be strictly greater than 0. +// |header| should not be null, |modules| should not be empty and pointers +// should be non-null. |max_id_bound| should be strictly greater than 0. // // TODO(pierremoreau): What to do when binaries use different versions of // SPIR-V? For now, use the max of all versions found in // the input modules. static spv_result_t GenerateHeader( const MessageConsumer& consumer, - const std::vector<std::unique_ptr<ir::Module>>& modules, + const std::vector<ir::Module*>& modules, uint32_t max_id_bound, ir::ModuleHeader* header); -// Merge all the modules from |inModules| into |linked_module|. +// Merge all the modules from |inModules| into a single module owned by +// |linked_context|. // -// |linked_module| should not be null. +// |linked_context| should not be null. static spv_result_t MergeModules( const MessageConsumer& consumer, - const std::vector<std::unique_ptr<Module>>& inModules, - const libspirv::AssemblyGrammar& grammar, Module* linked_module); + const std::vector<Module*>& inModules, + const libspirv::AssemblyGrammar& grammar, + IRContext* linked_context); // Compute all pairs of import and export and return it in |linkings_to_do|. // @@ -123,7 +127,7 @@ static spv_result_t CheckImportExportCompatibility( // functions, declarations of imported variables, import (and export if // necessary) linkage attribtes. // -// |linked_module| and |decoration_manager| should not be null, and the +// |linked_context| and |decoration_manager| should not be null, and the // 'RemoveDuplicatePass' should be run first. // // TODO(pierremoreau): Linkage attributes applied by a group decoration are @@ -136,6 +140,11 @@ static spv_result_t RemoveLinkageSpecificInstructions( const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, ir::IRContext* linked_context); +// Verify that the unique ids of each instruction in |linked_context| (i.e. the +// merged module) are truly unique. Does not check the validity of other ids +static spv_result_t VerifyIds(const MessageConsumer& consumer, + ir::IRContext* linked_context); + // Structs for holding the data members for SpvLinker. struct Linker::Impl { explicit Impl(spv_target_env env) : context(spvContextCreate(env)) { @@ -186,7 +195,8 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, SPV_ERROR_INVALID_BINARY) << "No modules were given."; - std::vector<std::unique_ptr<Module>> modules; + std::vector<std::unique_ptr<IRContext>> contexts; + std::vector<Module*> modules; modules.reserve(num_binaries); for (size_t i = 0u; i < num_binaries; ++i) { const uint32_t schema = binaries[i][4u]; @@ -197,13 +207,14 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, << "Schema is non-zero for module " << i << "."; } - std::unique_ptr<Module> module = BuildModule( + std::unique_ptr<IRContext> context = BuildModule( impl_->context->target_env, consumer, binaries[i], binary_sizes[i]); - if (module == nullptr) + if (context == nullptr) return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INVALID_BINARY) - << "Failed to build a module out of " << modules.size() << "."; - modules.push_back(std::move(module)); + << "Failed to build a module out of " << contexts.size() << "."; + modules.push_back(context->module()); + contexts.push_back(std::move(context)); } // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint @@ -216,14 +227,18 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, ir::ModuleHeader header; res = GenerateHeader(consumer, modules, max_id_bound, &header); if (res != SPV_SUCCESS) return res; - auto linked_module = MakeUnique<Module>(); - linked_module->SetHeader(header); + IRContext linked_context(consumer); + linked_context.module()->SetHeader(header); // Phase 3: Merge all the binaries into a single one. libspirv::AssemblyGrammar grammar(impl_->context); - res = MergeModules(consumer, modules, grammar, linked_module.get()); + res = MergeModules(consumer, modules, grammar, &linked_context); if (res != SPV_SUCCESS) return res; - ir::IRContext linked_context(std::move(linked_module), consumer); + + if (options.GetVerifyIds()) { + res = VerifyIds(consumer, &linked_context); + if (res != SPV_SUCCESS) return res; + } // Phase 4: Find the import/export pairs LinkageTable linkings_to_do; @@ -270,7 +285,7 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, static spv_result_t ShiftIdsInModules( const MessageConsumer& consumer, - std::vector<std::unique_ptr<ir::Module>>* modules, uint32_t* max_id_bound) { + std::vector<ir::Module*>* modules, uint32_t* max_id_bound) { spv_position_t position = {}; if (modules == nullptr) @@ -289,7 +304,7 @@ static spv_result_t ShiftIdsInModules( uint32_t id_bound = modules->front()->IdBound() - 1u; for (auto module_iter = modules->begin() + 1; module_iter != modules->end(); ++module_iter) { - Module* module = module_iter->get(); + Module* module = *module_iter; module->ForEachInst([&id_bound](Instruction* insn) { insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; }); }); @@ -313,7 +328,7 @@ static spv_result_t ShiftIdsInModules( static spv_result_t GenerateHeader( const MessageConsumer& consumer, - const std::vector<std::unique_ptr<ir::Module>>& modules, + const std::vector<ir::Module*>& modules, uint32_t max_id_bound, ir::ModuleHeader* header) { spv_position_t position = {}; @@ -341,28 +356,32 @@ static spv_result_t GenerateHeader( static spv_result_t MergeModules( const MessageConsumer& consumer, - const std::vector<std::unique_ptr<Module>>& input_modules, - const libspirv::AssemblyGrammar& grammar, Module* linked_module) { + const std::vector<Module*>& input_modules, + const libspirv::AssemblyGrammar& grammar, IRContext* linked_context) { spv_position_t position = {}; - if (linked_module == nullptr) + if (linked_context == nullptr) return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INVALID_DATA) << "|linked_module| of MergeModules should not be null."; + Module* linked_module = linked_context->module(); if (input_modules.empty()) return SPV_SUCCESS; for (const auto& module : input_modules) for (const auto& inst : module->capabilities()) - linked_module->AddCapability(MakeUnique<Instruction>(inst)); + linked_module->AddCapability( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->extensions()) - linked_module->AddExtension(MakeUnique<Instruction>(inst)); + linked_module->AddExtension( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->ext_inst_imports()) - linked_module->AddExtInstImport(MakeUnique<Instruction>(inst)); + linked_module->AddExtInstImport( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); do { const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel(); @@ -402,7 +421,7 @@ static spv_result_t MergeModules( if (memory_model_inst != nullptr) linked_module->SetMemoryModel( - MakeUnique<Instruction>(*memory_model_inst)); + std::unique_ptr<Instruction>(memory_model_inst->Clone(linked_context))); } while (false); std::vector<std::pair<uint32_t, const char*>> entry_points; @@ -424,25 +443,30 @@ static spv_result_t MergeModules( << "The entry point \"" << name << "\", with execution model " << desc->name << ", was already defined."; } - linked_module->AddEntryPoint(MakeUnique<Instruction>(inst)); + linked_module->AddEntryPoint( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); entry_points.emplace_back(model, name); } for (const auto& module : input_modules) for (const auto& inst : module->execution_modes()) - linked_module->AddExecutionMode(MakeUnique<Instruction>(inst)); + linked_module->AddExecutionMode( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->debugs1()) - linked_module->AddDebug1Inst(MakeUnique<Instruction>(inst)); + linked_module->AddDebug1Inst( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->debugs2()) - linked_module->AddDebug2Inst(MakeUnique<Instruction>(inst)); + linked_module->AddDebug2Inst( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->annotations()) - linked_module->AddAnnotationInst(MakeUnique<Instruction>(inst)); + linked_module->AddAnnotationInst( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); // TODO(pierremoreau): Since the modules have not been validate, should we // expect SpvStorageClassFunction variables outside @@ -450,7 +474,8 @@ static spv_result_t MergeModules( uint32_t num_global_values = 0u; for (const auto& module : input_modules) { for (const auto& inst : module->types_values()) { - linked_module->AddType(MakeUnique<Instruction>(inst)); + linked_module->AddType( + std::unique_ptr<Instruction>(inst.Clone(linked_context))); num_global_values += inst.opcode() == SpvOpVariable; } } @@ -462,8 +487,7 @@ static spv_result_t MergeModules( // Process functions and their basic blocks for (const auto& module : input_modules) { for (const auto& func : *module) { - std::unique_ptr<ir::Function> cloned_func = - MakeUnique<ir::Function>(func); + std::unique_ptr<ir::Function> cloned_func(func.Clone(linked_context)); cloned_func->SetParent(linked_module); linked_module->AddFunction(std::move(cloned_func)); } @@ -711,4 +735,19 @@ static spv_result_t RemoveLinkageSpecificInstructions( return SPV_SUCCESS; } +spv_result_t VerifyIds(const MessageConsumer& consumer, ir::IRContext* linked_context) { + std::unordered_set<uint32_t> ids; + bool ok = true; + linked_context->module()->ForEachInst([&ids,&ok](const ir::Instruction* inst) { + ok &= ids.insert(inst->unique_id()).second; + }); + + if (!ok) { + consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module"); + return SPV_ERROR_INVALID_ID; + } + + return SPV_SUCCESS; +} + } // namespace spvtools diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 7be3f9ae..6071b942 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -176,7 +176,7 @@ void AggressiveDCEPass::ComputeInst2BlockMap(ir::Function* func) { void AggressiveDCEPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); bp->AddInstruction(std::move(newBranch)); diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp index 7e0f421f..fccd396d 100644 --- a/source/opt/basic_block.cpp +++ b/source/opt/basic_block.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include "basic_block.h" +#include "function.h" +#include "module.h" #include "make_unique.h" @@ -27,12 +29,13 @@ const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; } // namespace -BasicBlock::BasicBlock(const BasicBlock& bb) - : function_(nullptr), - label_(MakeUnique<Instruction>(bb.GetLabelInst())), - insts_() { - for (auto& inst : bb.insts_) - AddInstruction(std::unique_ptr<Instruction>(inst.Clone())); +BasicBlock* BasicBlock::Clone(IRContext* context) const { + BasicBlock* clone = + new BasicBlock(std::unique_ptr<Instruction>(GetLabelInst().Clone(context))); + for (const auto& inst : insts_) + // Use the incoming context + clone->AddInstruction(std::unique_ptr<Instruction>(inst.Clone(context))); + return clone; } const Instruction* BasicBlock::GetMergeInst() const { diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index 32550e73..f4405f2a 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -31,6 +31,7 @@ namespace spvtools { namespace ir { class Function; +class IRContext; // A SPIR-V basic block. class BasicBlock { @@ -41,15 +42,20 @@ class BasicBlock { // Creates a basic block with the given starting |label|. inline explicit BasicBlock(std::unique_ptr<Instruction> label); - // Creates a basic block from the given basic block |bb|. + explicit BasicBlock(const BasicBlock& bb) = delete; + + // Creates a clone of the basic block in the given |context| // // The parent function will default to null and needs to be explicitly set by // the user. - explicit BasicBlock(const BasicBlock& bb); + BasicBlock* Clone(IRContext*) const; // Sets the enclosing function for this basic block. void SetParent(Function* function) { function_ = function; } + // Return the enclosing function + inline Function* GetParent() const { return function_; } + // Appends an instruction to this basic block. inline void AddInstruction(std::unique_ptr<Instruction> i); diff --git a/source/opt/build_module.cpp b/source/opt/build_module.cpp index e3439f3d..42dbdd7d 100644 --- a/source/opt/build_module.cpp +++ b/source/opt/build_module.cpp @@ -14,6 +14,7 @@ #include "build_module.h" +#include"ir_context.h" #include "ir_loader.h" #include "make_unique.h" #include "table.h" @@ -43,15 +44,15 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { } // annoymous namespace -std::unique_ptr<ir::Module> BuildModule(spv_target_env env, +std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env, MessageConsumer consumer, const uint32_t* binary, const size_t size) { auto context = spvContextCreate(env); SetContextMessageConsumer(context, consumer); - auto module = MakeUnique<ir::Module>(); - ir::IrLoader loader(context->consumer, module.get()); + auto irContext = MakeUnique<ir::IRContext>(consumer); + ir::IrLoader loader(consumer, irContext->module()); spv_result_t status = spvBinaryParse(context, &loader, binary, size, SetSpvHeader, SetSpvInst, nullptr); @@ -59,13 +60,13 @@ std::unique_ptr<ir::Module> BuildModule(spv_target_env env, spvContextDestroy(context); - return status == SPV_SUCCESS ? std::move(module) : nullptr; + return status == SPV_SUCCESS ? std::move(irContext) : nullptr; } -std::unique_ptr<ir::Module> BuildModule(spv_target_env env, - MessageConsumer consumer, - const std::string& text, - uint32_t assemble_options) { +std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env, + MessageConsumer consumer, + const std::string& text, + uint32_t assemble_options) { SpirvTools t(env); t.SetMessageConsumer(consumer); std::vector<uint32_t> binary; diff --git a/source/opt/build_module.h b/source/opt/build_module.h index 36ea74f2..3ee66072 100644 --- a/source/opt/build_module.h +++ b/source/opt/build_module.h @@ -18,23 +18,25 @@ #include <memory> #include <string> +#include "ir_context.h" #include "module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { -// Builds and returns an ir::Module from the given SPIR-V |binary|. |size| -// specifies number of words in |binary|. The |binary| will be decoded -// according to the given target |env|. Returns nullptr if erors occur and -// sends the errors to |consumer|. -std::unique_ptr<ir::Module> BuildModule(spv_target_env env, +// Builds an ir::Module returns the owning ir::IRContext from the given SPIR-V +// |binary|. |size| specifies number of words in |binary|. The |binary| will be +// decoded according to the given target |env|. Returns nullptr if errors occur +// and sends the errors to |consumer|. +std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env, MessageConsumer consumer, const uint32_t* binary, size_t size); -// Builds and returns an ir::Module from the given SPIR-V assembly |text|. -// The |text| will be encoded according to the given target |env|. Returns -// nullptr if erors occur and sends the errors to |consumer|. -std::unique_ptr<ir::Module> BuildModule( +// Builds an ir::Module and returns the owning ir::IRContext from the given +// SPIR-V assembly |text|. The |text| will be encoded according to the given +// target |env|. Returns nullptr if errors occur and sends the errors to +// |consumer|. +std::unique_ptr<ir::IRContext> BuildModule( spv_target_env env, MessageConsumer consumer, const std::string& text, uint32_t assemble_options = SpirvTools::kDefaultAssembleOption); diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp index 6adc1103..a0b78c78 100644 --- a/source/opt/cfg.cpp +++ b/source/opt/cfg.cpp @@ -29,9 +29,9 @@ const int kInvalidId = 0x400000; CFG::CFG(ir::Module* module) : module_(module), pseudo_entry_block_(std::unique_ptr<ir::Instruction>( - new ir::Instruction(SpvOpLabel, 0, 0, {}))), + new ir::Instruction(module->context(), SpvOpLabel, 0, 0, {}))), pseudo_exit_block_(std::unique_ptr<ir::Instruction>( - new ir::Instruction(SpvOpLabel, 0, kInvalidId, {}))) { + new ir::Instruction(module->context(), SpvOpLabel, 0, kInvalidId, {}))) { for (auto& fn : *module) { for (auto& blk : fn) { uint32_t blkId = blk.id(); diff --git a/source/opt/common_uniform_elim_pass.cpp b/source/opt/common_uniform_elim_pass.cpp index d68ed71a..3339c4ab 100644 --- a/source/opt/common_uniform_elim_pass.cpp +++ b/source/opt/common_uniform_elim_pass.cpp @@ -231,7 +231,7 @@ void CommonUniformElimPass::GenACLoadRepl( ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list<uint32_t>{varId})); std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction( - SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); + context(), SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); newInsts->emplace_back(std::move(newLoad)); @@ -254,7 +254,7 @@ void CommonUniformElimPass::GenACLoadRepl( ++iidIdx; }); std::unique_ptr<ir::Instruction> newExt(new ir::Instruction( - SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds)); + context(), SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newExt); newInsts->emplace_back(std::move(newExt)); *resultId = extResultId; @@ -388,7 +388,7 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { // Copy load into most recent dominating block and remember it replId = TakeNextId(); std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction( - SpvOpLoad, ii->type_id(), replId, + context(), SpvOpLoad, ii->type_id(), replId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); insertItr = insertItr.InsertBefore(std::move(newLoad)); @@ -460,7 +460,7 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { if (idxItr.second.size() < 2) continue; uint32_t replId = TakeNextId(); std::unique_ptr<ir::Instruction> newExtract( - new ir::Instruction(*idxItr.second.front())); + idxItr.second.front()->Clone(context())); newExtract->SetResultId(replId); get_def_use_mgr()->AnalyzeInstDefUse(&*newExtract); ++ii; diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index e3bf25fd..f1c9bf17 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -74,7 +74,7 @@ bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) { void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); bp->AddInstruction(std::move(newBranch)); @@ -83,7 +83,7 @@ void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { void DeadBranchElimPass::AddSelectionMerge(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr<ir::Instruction> newMerge(new ir::Instruction( - SpvOpSelectionMerge, 0, 0, + context(), SpvOpSelectionMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newMerge); @@ -95,7 +95,7 @@ void DeadBranchElimPass::AddBranchConditional(uint32_t condId, uint32_t falseLabId, ir::BasicBlock* bp) { std::unique_ptr<ir::Instruction> newBranchCond(new ir::Instruction( - SpvOpBranchConditional, 0, 0, + context(), SpvOpBranchConditional, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {condId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {trueLabId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {falseLabId}}})); @@ -302,7 +302,7 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { ++icnt; }); std::unique_ptr<ir::Instruction> newPhi(new ir::Instruction( - SpvOpPhi, pii->type_id(), replId, phi_in_opnds)); + context(), SpvOpPhi, pii->type_id(), replId, phi_in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi); pii = pii.InsertBefore(std::move(newPhi)); ++pii; diff --git a/source/opt/decoration_manager.cpp b/source/opt/decoration_manager.cpp index aa926dbf..b25c20ff 100644 --- a/source/opt/decoration_manager.cpp +++ b/source/opt/decoration_manager.cpp @@ -70,11 +70,11 @@ bool DecorationManager::AreDecorationsTheSame( // for (uint32_t i = 2u; i < inst.NumInOperands(); ++i) { // const auto& j = constants.find(inst.GetSingleWordInOperand(i)); // if (j == constants.end()) - // return Instruction(); + // return Instruction(inst.context()); // const auto operand = j->second->GetOperand(0u); // operands.emplace_back(operand.type, operand.words); // } - // return Instruction(SpvOpDecorate, 0u, 0u, operands); + // return Instruction(inst.context(), SpvOpDecorate, 0u, 0u, operands); // }; // Instruction tmpA = (deco1.opcode() == SpvOpDecorateId) ? // decorateIdToDecorate(deco1) : deco1; @@ -261,7 +261,7 @@ void DecorationManager::CloneDecorations( case SpvOpMemberDecorate: case SpvOpDecorateId: { // simply clone decoration and change |target-id| to |to| - std::unique_ptr<ir::Instruction> new_inst(inst->Clone()); + std::unique_ptr<ir::Instruction> new_inst(inst->Clone(module_->context())); new_inst->SetInOperand(0, {to}); id_to_decoration_insts_[to].push_back(new_inst.get()); f(*new_inst, true); diff --git a/source/opt/flatten_decoration_pass.cpp b/source/opt/flatten_decoration_pass.cpp index e92935d8..eac82973 100644 --- a/source/opt/flatten_decoration_pass.cpp +++ b/source/opt/flatten_decoration_pass.cpp @@ -91,7 +91,7 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) { const auto normal_uses_iter = normal_uses.find(group); if (normal_uses_iter != normal_uses.end()) { for (auto target : normal_uses[group]) { - std::unique_ptr<Instruction> new_inst(new Instruction(*inst_iter)); + std::unique_ptr<Instruction> new_inst(inst_iter->Clone(context())); new_inst->SetInOperand(0, Words{target}); inst_iter = inst_iter.InsertBefore(std::move(new_inst)); ++inst_iter; @@ -116,8 +116,8 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) { decoration_operands_iter++; // Skip the group target. operands.insert(operands.end(), decoration_operands_iter, inst_iter->end()); - std::unique_ptr<Instruction> new_inst( - new Instruction(SpvOp::SpvOpMemberDecorate, 0, 0, operands)); + std::unique_ptr<Instruction> new_inst(new Instruction( + context(), SpvOp::SpvOpMemberDecorate, 0, 0, operands)); inst_iter = inst_iter.InsertBefore(std::move(new_inst)); ++inst_iter; replace = true; diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index a630d8a3..e91d1fb1 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -724,22 +724,23 @@ std::unique_ptr<ir::Instruction> FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id, analysis::Constant* c) { if (c->AsNullConstant()) { - return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull, + return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantNull, type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{}); } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) { return MakeUnique<ir::Instruction>( + context(), bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{}); } else if (analysis::IntConstant* ic = c->AsIntConstant()) { return MakeUnique<ir::Instruction>( - SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, + context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{ir::Operand( spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, ic->words())}); } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) { return MakeUnique<ir::Instruction>( - SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, + context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{ir::Operand( spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, fc->words())}); @@ -765,7 +766,8 @@ FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction( operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list<uint32_t>{id}); } - return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite, + return MakeUnique<ir::Instruction>(context(), + SpvOp::SpvOpConstantComposite, type_mgr_->GetId(cc->type()), result_id, std::move(operands)); } diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 4ad2dce9..dc5320f6 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp @@ -19,27 +19,25 @@ namespace spvtools { namespace ir { -Function::Function(const Function& f) - : module_(nullptr), - def_inst_(MakeUnique<Instruction>(f.DefInst())), - params_(), - blocks_(), - end_inst_() { - params_.reserve(f.params_.size()); - f.ForEachParam( - [this](const Instruction* insn) { - AddParameter(MakeUnique<Instruction>(*insn)); +Function* Function::Clone(IRContext* context) const { + Function* clone = + new Function(std::unique_ptr<Instruction>(DefInst().Clone(context))); + clone->params_.reserve(params_.size()); + ForEachParam( + [clone,context](const Instruction* inst) { + clone->AddParameter(std::unique_ptr<Instruction>(inst->Clone(context))); }, true); - blocks_.reserve(f.blocks_.size()); - for (const auto& b : f.blocks_) { - std::unique_ptr<BasicBlock> bb = MakeUnique<BasicBlock>(*b); - bb->SetParent(this); - AddBasicBlock(std::move(bb)); + clone->blocks_.reserve(blocks_.size()); + for (const auto& b : blocks_) { + std::unique_ptr<BasicBlock> bb(b->Clone(context)); + bb->SetParent(clone); + clone->AddBasicBlock(std::move(bb)); } - SetFunctionEnd(MakeUnique<Instruction>(f.function_end())); + clone->SetFunctionEnd(std::unique_ptr<Instruction>(function_end().Clone(context))); + return clone; } void Function::ForEachInst(const std::function<void(Instruction*)>& f, diff --git a/source/opt/function.h b/source/opt/function.h index 618eb7d2..9cd72095 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -27,6 +27,7 @@ namespace spvtools { namespace ir { +class IRContext; class Module; // A SPIR-V function. @@ -38,17 +39,22 @@ class Function { // Creates a function instance declared by the given OpFunction instruction // |def_inst|. inline explicit Function(std::unique_ptr<Instruction> def_inst); - // Creates a function instance based on the given function |f|. + + explicit Function(const Function& f) = delete; + + // Creates a clone of the instruction in the given |context| // // The parent module will default to null and needs to be explicitly set by // the user. - explicit Function(const Function& f); + Function* Clone(IRContext*) const; // The OpFunction instruction that begins the definition of this function. Instruction& DefInst() { return *def_inst_; } const Instruction& DefInst() const { return *def_inst_; } // Sets the enclosing module for this function. void SetParent(Module* module) { module_ = module; } + // Gets the enclosing module for this function + Module* GetParent() const { return module_; } // Appends a parameter to this function. inline void AddParameter(std::unique_ptr<Instruction> p); // Appends a basic block to this function. diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index f52277bd..5c6e3fb4 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -49,7 +49,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id, SpvStorageClass storage_class) { uint32_t resultId = TakeNextId(); std::unique_ptr<ir::Instruction> type_inst(new ir::Instruction( - SpvOpTypePointer, 0, resultId, + context(), SpvOpTypePointer, 0, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, {uint32_t(storage_class)}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); @@ -60,7 +60,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id, void InlinePass::AddBranch(uint32_t label_id, std::unique_ptr<ir::BasicBlock>* block_ptr) { std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); (*block_ptr)->AddInstruction(std::move(newBranch)); } @@ -69,7 +69,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, uint32_t false_id, std::unique_ptr<ir::BasicBlock>* block_ptr) { std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction( - SpvOpBranchConditional, 0, 0, + context(), SpvOpBranchConditional, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); @@ -79,7 +79,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, std::unique_ptr<ir::BasicBlock>* block_ptr) { std::unique_ptr<ir::Instruction> newLoopMerge(new ir::Instruction( - SpvOpLoopMerge, 0, 0, + context(), SpvOpLoopMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}})); @@ -89,7 +89,7 @@ void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, std::unique_ptr<ir::BasicBlock>* block_ptr) { std::unique_ptr<ir::Instruction> newStore(new ir::Instruction( - SpvOpStore, 0, 0, + context(), SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); (*block_ptr)->AddInstruction(std::move(newStore)); @@ -98,14 +98,14 @@ void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, std::unique_ptr<ir::BasicBlock>* block_ptr) { std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction( - SpvOpLoad, type_id, resultId, + context(), SpvOpLoad, type_id, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); (*block_ptr)->AddInstruction(std::move(newLoad)); } std::unique_ptr<ir::Instruction> InlinePass::NewLabel(uint32_t label_id) { std::unique_ptr<ir::Instruction> newLabel( - new ir::Instruction(SpvOpLabel, 0, label_id, {})); + new ir::Instruction(context(), SpvOpLabel, 0, label_id, {})); return newLabel; } @@ -143,7 +143,8 @@ void InlinePass::CloneAndMapLocals( auto callee_block_itr = calleeFn->begin(); auto callee_var_itr = callee_block_itr->begin(); while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { - std::unique_ptr<ir::Instruction> var_inst(callee_var_itr->Clone()); + std::unique_ptr<ir::Instruction> var_inst( + callee_var_itr->Clone(callee_var_itr->context())); uint32_t newId = TakeNextId(); get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId, update_def_use_mgr_); var_inst->SetResultId(newId); @@ -169,7 +170,7 @@ uint32_t InlinePass::CreateReturnVar( // Add return var to new function scope variables. returnVarId = TakeNextId(); std::unique_ptr<ir::Instruction> var_inst(new ir::Instruction( - SpvOpVariable, returnVarTypeId, returnVarId, + context(), SpvOpVariable, returnVarTypeId, returnVarId, {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); new_vars->push_back(std::move(var_inst)); @@ -195,7 +196,8 @@ void InlinePass::CloneSameBlockOps( if (mapItr2 != (*preCallSB).end()) { // Clone pre-call same-block ops, map result id. const ir::Instruction* inInst = mapItr2->second; - std::unique_ptr<ir::Instruction> sb_inst(inInst->Clone()); + std::unique_ptr<ir::Instruction> sb_inst( + inInst->Clone(inInst->context())); CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr); const uint32_t rid = sb_inst->result_id(); const uint32_t nid = this->TakeNextId(); @@ -325,7 +327,7 @@ void InlinePass::GenInlineCode( // Copy contents of original caller block up to call instruction. for (auto cii = call_block_itr->begin(); cii != call_inst_itr; ++cii) { - std::unique_ptr<ir::Instruction> cp_inst(cii->Clone()); + std::unique_ptr<ir::Instruction> cp_inst(cii->Clone(context())); // Remember same-block ops for possible regeneration. if (IsSameBlockOp(&*cp_inst)) { auto* sb_inst_ptr = cp_inst.get(); @@ -434,7 +436,7 @@ void InlinePass::GenInlineCode( // Copy remaining instructions from caller block. auto cii = call_inst_itr; for (++cii; cii != call_block_itr->end(); ++cii) { - std::unique_ptr<ir::Instruction> cp_inst(cii->Clone()); + std::unique_ptr<ir::Instruction> cp_inst(cii->Clone(context())); // If multiple blocks generated, regenerate any same-block // instruction that has not been seen in this last block. if (multiBlocks) { @@ -452,7 +454,7 @@ void InlinePass::GenInlineCode( } break; default: { // Copy callee instruction and remap all input Ids. - std::unique_ptr<ir::Instruction> cp_inst(cpi->Clone()); + std::unique_ptr<ir::Instruction> cp_inst(cpi->Clone(context())); cp_inst->ForEachInId([&callee2caller, &callee_result_ids, this](uint32_t* iid) { const auto mapItr = callee2caller.find(*iid); @@ -497,7 +499,7 @@ void InlinePass::GenInlineCode( auto loop_merge_itr = last->tail(); --loop_merge_itr; assert(loop_merge_itr->opcode() == SpvOpLoopMerge); - std::unique_ptr<ir::Instruction> cp_inst(loop_merge_itr->Clone()); + std::unique_ptr<ir::Instruction> cp_inst(loop_merge_itr->Clone(context())); if (caller_is_single_block_loop) { // Also, update its continue target to point to the last block. cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()}); diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp index f26fb1d3..df2dcb73 100644 --- a/source/opt/instruction.cpp +++ b/source/opt/instruction.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "instruction.h" +#include "ir_context.h" #include <initializer_list> @@ -21,11 +22,29 @@ namespace spvtools { namespace ir { -Instruction::Instruction(const spv_parsed_instruction_t& inst, +Instruction::Instruction(IRContext* c) + : utils::IntrusiveNodeBase<Instruction>(), + context_(c), + opcode_(SpvOpNop), + type_id_(0), + result_id_(0), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, SpvOp op) + : utils::IntrusiveNodeBase<Instruction>(), + context_(c), + opcode_(op), + type_id_(0), + result_id_(0), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, std::vector<Instruction>&& dbg_line) - : opcode_(static_cast<SpvOp>(inst.opcode)), + : context_(c), + opcode_(static_cast<SpvOp>(inst.opcode)), type_id_(inst.type_id), result_id_(inst.result_id), + unique_id_(c->TakeNextUniqueId()), dbg_line_insts_(std::move(dbg_line)) { assert((!IsDebugLineInst(opcode_) || dbg_line.empty()) && "Op(No)Line attaching to Op(No)Line found"); @@ -38,12 +57,14 @@ Instruction::Instruction(const spv_parsed_instruction_t& inst, } } -Instruction::Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id, +Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, const std::vector<Operand>& in_operands) : utils::IntrusiveNodeBase<Instruction>(), + context_(c), opcode_(op), type_id_(ty_id), result_id_(res_id), + unique_id_(c->TakeNextUniqueId()), operands_() { if (type_id_ != 0) { operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_TYPE_ID, @@ -61,6 +82,7 @@ Instruction::Instruction(Instruction&& that) opcode_(that.opcode_), type_id_(that.type_id_), result_id_(that.result_id_), + unique_id_(that.unique_id_), operands_(std::move(that.operands_)), dbg_line_insts_(std::move(that.dbg_line_insts_)) {} @@ -68,16 +90,18 @@ Instruction& Instruction::operator=(Instruction&& that) { opcode_ = that.opcode_; type_id_ = that.type_id_; result_id_ = that.result_id_; + unique_id_ = that.unique_id_; operands_ = std::move(that.operands_); dbg_line_insts_ = std::move(that.dbg_line_insts_); return *this; } -Instruction* Instruction::Clone() const { - Instruction* clone = new Instruction(); +Instruction* Instruction::Clone(IRContext *c) const { + Instruction* clone = new Instruction(c); clone->opcode_ = opcode_; clone->type_id_ = type_id_; clone->result_id_ = result_id_; + clone->unique_id_ = c->TakeNextUniqueId(); clone->operands_ = operands_; clone->dbg_line_insts_ = dbg_line_insts_; return clone; diff --git a/source/opt/instruction.h b/source/opt/instruction.h index ff0acdb0..4c964740 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h @@ -31,6 +31,7 @@ namespace spvtools { namespace ir { class Function; +class IRContext; class Module; class InstructionList; @@ -84,28 +85,30 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { using const_iterator = std::vector<Operand>::const_iterator; // Creates a default OpNop instruction. + // This exists solely for containers that can't do without. Should be removed. Instruction() : utils::IntrusiveNodeBase<Instruction>(), + context_(nullptr), opcode_(SpvOpNop), type_id_(0), - result_id_(0) {} + result_id_(0), + unique_id_(0) {} + + // Creates a default OpNop instruction. + Instruction(IRContext*); // Creates an instruction with the given opcode |op| and no additional logical // operands. - Instruction(SpvOp op) - : utils::IntrusiveNodeBase<Instruction>(), - opcode_(op), - type_id_(0), - result_id_(0) {} + Instruction(IRContext*, SpvOp); // Creates an instruction using the given spv_parsed_instruction_t |inst|. All // the data inside |inst| will be copied and owned in this instance. And keep // record of line-related debug instructions |dbg_line| ahead of this // instruction, if any. - Instruction(const spv_parsed_instruction_t& inst, + Instruction(IRContext* c, const spv_parsed_instruction_t& inst, std::vector<Instruction>&& dbg_line = {}); // Creates an instruction with the given opcode |op|, type id: |ty_id|, // result id: |res_id| and input operands: |in_operands|. - Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id, + Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, const std::vector<Operand>& in_operands); // TODO: I will want to remove these, but will first have to remove the use of @@ -123,7 +126,9 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { // It is the responsibility of the caller to make sure that the storage is // removed. It is the caller's responsibility to make sure that there is only // one instruction for each result id. - Instruction* Clone() const; + Instruction* Clone(IRContext *c) const; + + IRContext* context() const { return context_; } SpvOp opcode() const { return opcode_; } // Sets the opcode of this instruction to a specific opcode. Note this may @@ -133,6 +138,7 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { void SetOpcode(SpvOp op) { opcode_ = op; } uint32_t type_id() const { return type_id_; } uint32_t result_id() const { return result_id_; } + uint32_t unique_id() const { assert(unique_id_ != 0); return unique_id_; } // Returns the vector of line-related debug instructions attached to this // instruction and the caller can directly modify them. std::vector<Instruction>& dbg_line_insts() { return dbg_line_insts_; } @@ -241,15 +247,21 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { // Returns true if the instruction annotates an id with a decoration. inline bool IsDecoration(); + inline bool operator==(const Instruction&) const; + inline bool operator!=(const Instruction&) const; + inline bool operator<(const Instruction&) const; + private: // Returns the total count of result type id and result id. uint32_t TypeResultIdCount() const { return (type_id_ != 0) + (result_id_ != 0); } + IRContext* context_; // IR Context SpvOp opcode_; // Opcode uint32_t type_id_; // Result type id. A value of 0 means no result type id. uint32_t result_id_; // Result id. A value of 0 means no result id. + uint32_t unique_id_; // Unique instruction id // All logical operands, including result type id and result id. std::vector<Operand> operands_; // Opline and OpNoLine instructions preceding this instruction. Note that for @@ -260,6 +272,18 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { friend InstructionList; }; +inline bool Instruction::operator==(const Instruction& other) const { + return unique_id() == other.unique_id(); +} + +inline bool Instruction::operator!=(const Instruction& other) const { + return !(*this == other); +} + +inline bool Instruction::operator<(const Instruction& other) const { + return unique_id() < other.unique_id(); +} + inline const Operand& Instruction::GetOperand(uint32_t index) const { assert(index < operands_.size() && "operand index out of bound"); return operands_[index]; diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 23f59ed9..eedee1ef 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -21,6 +21,7 @@ #include <algorithm> #include <iostream> +#include <limits> namespace spvtools { namespace ir { @@ -53,11 +54,26 @@ class IRContext { friend inline Analysis operator<<(Analysis a, int shift); friend inline Analysis& operator<<=(Analysis& a, int shift); + // Create an |IRContext| that contains an owned |Module| + IRContext(spvtools::MessageConsumer c) + : unique_id_(0), + module_(new Module()), + consumer_(std::move(c)), + def_use_mgr_(nullptr), + valid_analyses_(kAnalysisNone) + { + module_->SetContext(this); + } + IRContext(std::unique_ptr<Module>&& m, spvtools::MessageConsumer c) - : module_(std::move(m)), + : unique_id_(0), + module_(std::move(m)), consumer_(std::move(c)), def_use_mgr_(nullptr), - valid_analyses_(kAnalysisNone) {} + valid_analyses_(kAnalysisNone) + { + module_->SetContext(this); + } Module* module() const { return module_.get(); } inline void SetIdBound(uint32_t i); @@ -239,6 +255,14 @@ class IRContext { // Kill all name and decorate ops targeting the result id of |inst|. void KillNamesAndDecorates(ir::Instruction* inst); + // Returns the next unique id for use by an instruction. + inline uint32_t TakeNextUniqueId() { + assert(unique_id_ != std::numeric_limits<uint32_t>::max()); + + // Skip zero. + return ++unique_id_; + } + private: // Builds the def-use manager from scratch, even if it was already valid. void BuildDefUseManager() { @@ -264,6 +288,13 @@ class IRContext { valid_analyses_ = valid_analyses_ | kAnalysisDecorations; } + // An unique identifier for this instruction. Can be used to order + // instructions in a container. + // + // This member is initialized to 0, but always issues this value plus one. + // Therefore, 0 is not a valid unique id for an instruction. + uint32_t unique_id_; + std::unique_ptr<Module> module_; spvtools::MessageConsumer consumer_; std::unique_ptr<opt::analysis::DefUseManager> def_use_mgr_; diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp index e3d84842..b705343e 100644 --- a/source/opt/ir_loader.cpp +++ b/source/opt/ir_loader.cpp @@ -20,9 +20,9 @@ namespace spvtools { namespace ir { -IrLoader::IrLoader(const MessageConsumer& consumer, Module* module) +IrLoader::IrLoader(const MessageConsumer& consumer, Module* m) : consumer_(consumer), - module_(module), + module_(m), source_("<instruction>"), inst_index_(0) {} @@ -30,12 +30,12 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { ++inst_index_; const auto opcode = static_cast<SpvOp>(inst->opcode); if (IsDebugLineInst(opcode)) { - dbg_line_info_.push_back(Instruction(*inst)); + dbg_line_info_.push_back(Instruction(module()->context(), *inst)); return true; } std::unique_ptr<Instruction> spv_inst( - new Instruction(*inst, std::move(dbg_line_info_))); + new Instruction(module()->context(), *inst, std::move(dbg_line_info_))); dbg_line_info_.clear(); const char* src = source_.c_str(); diff --git a/source/opt/ir_loader.h b/source/opt/ir_loader.h index bcb55f1e..2f0ca8b0 100644 --- a/source/opt/ir_loader.h +++ b/source/opt/ir_loader.h @@ -39,11 +39,13 @@ class IrLoader { // All internal messages will be communicated to the outside via the given // message |consumer|. This instance only keeps a reference to the |consumer|, // so the |consumer| should outlive this instance. - IrLoader(const MessageConsumer& consumer, Module* module); + IrLoader(const MessageConsumer& consumer, Module* m); // Sets the source name of the module. void SetSource(const std::string& src) { source_ = src; } + Module* module() const { return module_; } + // Sets the fields in the module's header to the given parameters. void SetModuleHeader(uint32_t magic, uint32_t version, uint32_t generator, uint32_t bound, uint32_t reserved) { diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 9663e880..ff5f9124 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -44,7 +44,7 @@ void LocalAccessChainConvertPass::BuildAndAppendInst( const std::vector<ir::Operand>& in_opnds, std::vector<std::unique_ptr<ir::Instruction>>* newInsts) { std::unique_ptr<ir::Instruction> newInst( - new ir::Instruction(opcode, typeId, resultId, in_opnds)); + new ir::Instruction(context(), opcode, typeId, resultId, in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newInst); newInsts->emplace_back(std::move(newInst)); } diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 72e4f73f..ae5baa81 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -287,7 +287,7 @@ uint32_t MemPass::Type2Undef(uint32_t type_id) { if (uitr != type2undefs_.end()) return uitr->second; const uint32_t undefId = TakeNextId(); std::unique_ptr<ir::Instruction> undef_inst( - new ir::Instruction(SpvOpUndef, type_id, undefId, {})); + new ir::Instruction(context(), SpvOpUndef, type_id, undefId, {})); get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst); get_module()->AddGlobalValue(std::move(undef_inst)); type2undefs_[type_id] = undefId; @@ -402,7 +402,7 @@ void MemPass::SSABlockInitLoopHeader( } const uint32_t phiId = TakeNextId(); std::unique_ptr<ir::Instruction> newPhi( - new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands)); + new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands)); // The only phis requiring patching are the ones we create. phis_to_patch_.insert(phiId); // Only analyze the phi define now; analyze the phi uses after the @@ -470,7 +470,7 @@ void MemPass::SSABlockInitMultiPred(ir::BasicBlock* block_ptr) { } const uint32_t phiId = TakeNextId(); std::unique_ptr<ir::Instruction> newPhi( - new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands)); + new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands)); get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi); insertItr = insertItr.InsertBefore(std::move(newPhi)); ++insertItr; diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index 9374a915..e8228850 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp @@ -60,7 +60,7 @@ bool MergeReturnPass::MergeReturnBlocks( // Create a label for the new return block std::unique_ptr<ir::Instruction> returnLabel( - new ir::Instruction(SpvOpLabel, 0u, TakeNextId(), {})); + new ir::Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); uint32_t returnId = returnLabel->result_id(); // Create the new basic block @@ -84,13 +84,14 @@ bool MergeReturnPass::MergeReturnBlocks( // Need a PHI node to select the correct return value. uint32_t phiResultId = TakeNextId(); uint32_t phiTypeId = function->type_id(); - std::unique_ptr<ir::Instruction> phiInst( - new ir::Instruction(SpvOpPhi, phiTypeId, phiResultId, phiOps)); + std::unique_ptr<ir::Instruction> phiInst(new ir::Instruction( + context(), SpvOpPhi, phiTypeId, phiResultId, phiOps)); retBlockIter->AddInstruction(std::move(phiInst)); ir::BasicBlock::iterator phiIter = retBlockIter->tail(); - std::unique_ptr<ir::Instruction> returnInst(new ir::Instruction( - SpvOpReturnValue, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {phiResultId}}})); + std::unique_ptr<ir::Instruction> returnInst( + new ir::Instruction(context(), SpvOpReturnValue, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {phiResultId}}})); retBlockIter->AddInstruction(std::move(returnInst)); ir::BasicBlock::iterator ret = retBlockIter->tail(); @@ -98,7 +99,7 @@ bool MergeReturnPass::MergeReturnBlocks( get_def_use_mgr()->AnalyzeInstDef(&*ret); } else { std::unique_ptr<ir::Instruction> returnInst( - new ir::Instruction(SpvOpReturn)); + new ir::Instruction(context(), SpvOpReturn)); retBlockIter->AddInstruction(std::move(returnInst)); } diff --git a/source/opt/module.cpp b/source/opt/module.cpp index e329b3cb..9d46a1b6 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -65,7 +65,7 @@ uint32_t Module::GetGlobalValue(SpvOp opcode) const { void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id, uint32_t type_id) { std::unique_ptr<ir::Instruction> newGlobal( - new ir::Instruction(opcode, type_id, result_id, {})); + new ir::Instruction(context(), opcode, type_id, result_id, {})); AddGlobalValue(std::move(newGlobal)); } diff --git a/source/opt/module.h b/source/opt/module.h index e4c03e27..d3fe2b5e 100644 --- a/source/opt/module.h +++ b/source/opt/module.h @@ -27,6 +27,8 @@ namespace spvtools { namespace ir { +class IRContext; + // A struct for containing the module header information. struct ModuleHeader { uint32_t magic_number; @@ -223,11 +225,18 @@ class Module { // Returns 0 if not found. uint32_t GetExtInstImportId(const char* extstr); + // Sets the associated context for this module + void SetContext(IRContext* c) { context_ = c; } + + // Gets the associated context for this module + IRContext* context() const { return context_; } + private: ModuleHeader header_; // Module header // The following fields respect the "Logical Layout of a Module" in // Section 2.4 of the SPIR-V specification. + IRContext* context_; InstructionList capabilities_; InstructionList extensions_; InstructionList ext_inst_imports_; diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 3527c1f7..ac913dd1 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -105,19 +105,18 @@ Optimizer& Optimizer::RegisterSizePasses() { bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector<uint32_t>* optimized_binary) const { - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(impl_->target_env, impl_->pass_manager.consumer(), original_binary, original_binary_size); - if (module == nullptr) return false; - ir::IRContext context(std::move(module), impl_->pass_manager.consumer()); + if (context == nullptr) return false; - auto status = impl_->pass_manager.Run(&context); + auto status = impl_->pass_manager.Run(context.get()); if (status == opt::Pass::Status::SuccessWithChange || (status == opt::Pass::Status::SuccessWithoutChange && (optimized_binary->data() != original_binary || optimized_binary->size() != original_binary_size))) { optimized_binary->clear(); - context.module()->ToBinary(optimized_binary, /* skip_nop = */ true); + context->module()->ToBinary(optimized_binary, /* skip_nop = */ true); } return status != opt::Pass::Status::Failure; diff --git a/source/opt/strength_reduction_pass.cpp b/source/opt/strength_reduction_pass.cpp index 5c08f5e3..f2aee915 100644 --- a/source/opt/strength_reduction_pass.cpp +++ b/source/opt/strength_reduction_pass.cpp @@ -100,7 +100,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( {shiftConstResultId}); newOperands.push_back(shiftOperand); std::unique_ptr<ir::Instruction> newInstruction( - new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(), + new ir::Instruction(context(), SpvOp::SpvOpShiftLeftLogical, inst->type_id(), newResultId, newOperands)); // Insert the new instruction and update the data structures. @@ -161,7 +161,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) { ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}); std::unique_ptr<ir::Instruction> newConstant(new ir::Instruction( - SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant})); + context(), SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant})); get_module()->AddGlobalValue(std::move(newConstant)); // Store the result id for next time. @@ -199,7 +199,7 @@ uint32_t StrengthReductionPass::CreateUint32Type() { ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}); std::unique_ptr<ir::Instruction> newType(new ir::Instruction( - SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand})); + context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand})); context()->AddType(std::move(newType)); return type_id; } diff --git a/test/link/CMakeLists.txt b/test/link/CMakeLists.txt index 9768ab39..f2ced245 100644 --- a/test/link/CMakeLists.txt +++ b/test/link/CMakeLists.txt @@ -41,3 +41,8 @@ add_spvtools_unittest(TARGET link_matching_imports_to_exports SRCS matching_imports_to_exports_test.cpp LIBS SPIRV-Tools-opt SPIRV-Tools-link ) + +add_spvtools_unittest(TARGET link_unique_ids + SRCS unique_ids_test.cpp + LIBS SPIRV-Tools-opt SPIRV-Tools-link +) diff --git a/test/link/unique_ids_test.cpp b/test/link/unique_ids_test.cpp new file mode 100644 index 00000000..8b67d34a --- /dev/null +++ b/test/link/unique_ids_test.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "linker_fixture.h" + +namespace { + +using UniqueIds = spvtest::LinkerTest; + +TEST_F(UniqueIds, UniquelyMerged) { + std::vector<std::string> bodies(2); + bodies[0] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + bodies[1] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpSource ESSL 310\n" + "OpName %main \"main2\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv12\"\n" + "OpName %gv2 \"gv22\"\n" + "OpName %lv1 \"lv12\"\n" + "OpName %lv2 \"lv22\"\n" + "OpName %lv1_0 \"lv12\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + spvtest::Binary linked_binary; + spvtools::LinkerOptions options; + options.SetVerifyIds(true); + spv_result_t res = AssembleAndLink(bodies, &linked_binary, options); + EXPECT_EQ(SPV_SUCCESS, res); +} + +} // anonymous namespace diff --git a/test/opt/def_use_test.cpp b/test/opt/def_use_test.cpp index aa889786..bd1ac7ed 100644 --- a/test/opt/def_use_test.cpp +++ b/test/opt/def_use_test.cpp @@ -21,6 +21,7 @@ #include "opt/build_module.h" #include "opt/def_use_manager.h" #include "opt/ir_context.h" +#include "opt/module.h" #include "pass_utils.h" #include "spirv-tools/libspirv.hpp" @@ -131,12 +132,12 @@ TEST_P(ParseDefUseTest, Case) { // Build module. const std::vector<const char*> text = {tc.text}; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); CheckDef(tc.du, manager.id_to_defs()); CheckUse(tc.du, manager.id_to_uses()); @@ -512,23 +513,22 @@ TEST_P(ReplaceUseTest, Case) { // Build module. const std::vector<const char*> text = {tc.before}; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // Do the substitution. for (const auto& candidate : tc.candidates) { - context.ReplaceAllUsesWith(candidate.first, candidate.second); + context->ReplaceAllUsesWith(candidate.first, candidate.second); } - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -816,20 +816,19 @@ TEST_P(KillDefTest, Case) { // Build module. const std::vector<const char*> text = {tc.before}; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); // Do the substitution. - for (const auto id : tc.ids_to_kill) context.KillDef(id); + for (const auto id : tc.ids_to_kill) context->KillDef(id); - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1067,19 +1066,18 @@ TEST(DefUseTest, OpSwitch) { " OpReturnValue %6 " " OpFunctionEnd"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // Do a bunch replacements. - context.ReplaceAllUsesWith(9, 900); // to unused id - context.ReplaceAllUsesWith(10, 1000); // to unused id - context.ReplaceAllUsesWith(11, 7); // to existing id + context->ReplaceAllUsesWith(9, 900); // to unused id + context->ReplaceAllUsesWith(10, 1000); // to unused id + context->ReplaceAllUsesWith(11, 7); // to existing id // clang-format off const char modified_text[] = @@ -1103,7 +1101,7 @@ TEST(DefUseTest, OpSwitch) { "OpFunctionEnd"; // clang-format on - EXPECT_EQ(modified_text, DisassembleModule(context.module())); + EXPECT_EQ(modified_text, DisassembleModule(context->module())); InstDefUse def_uses = {}; def_uses.defs = { @@ -1118,10 +1116,10 @@ TEST(DefUseTest, OpSwitch) { {10, "%10 = OpLabel"}, {11, "%11 = OpLabel"}, }; - CheckDef(def_uses, context.get_def_use_mgr()->id_to_defs()); + CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs()); { - auto* use_list = context.get_def_use_mgr()->GetUses(6); + auto* use_list = context->get_def_use_mgr()->GetUses(6); ASSERT_NE(nullptr, use_list); EXPECT_EQ(2u, use_list->size()); std::vector<SpvOp> opcodes = {use_list->front().inst->opcode(), @@ -1129,7 +1127,7 @@ TEST(DefUseTest, OpSwitch) { EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue)); } { - auto* use_list = context.get_def_use_mgr()->GetUses(7); + auto* use_list = context->get_def_use_mgr()->GetUses(7); ASSERT_NE(nullptr, use_list); EXPECT_EQ(6u, use_list->size()); std::vector<SpvOp> opcodes; @@ -1143,44 +1141,15 @@ TEST(DefUseTest, OpSwitch) { } // Check all ids only used by OpSwitch after replacement. for (const auto id : {8, 900, 1000}) { - auto* use_list = context.get_def_use_mgr()->GetUses(id); + auto* use_list = context->get_def_use_mgr()->GetUses(id); ASSERT_NE(nullptr, use_list); EXPECT_EQ(1u, use_list->size()); EXPECT_EQ(SpvOpSwitch, use_list->front().inst->opcode()); } } -// Creates an |result_id| = OpTypeInt 32 1 instruction. -ir::Instruction Int32TypeInstruction(uint32_t result_id) { - return ir::Instruction(SpvOp::SpvOpTypeInt, 0, result_id, - {ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {32}), - ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {1})}); -} - -// Creates an |result_id| = OpConstantTrue/Flase |type_id| instruction. -ir::Instruction ConstantBoolInstruction(bool value, uint32_t type_id, - uint32_t result_id) { - return ir::Instruction( - value ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, type_id, - result_id, {}); -} - -// Creates an |result_id| = OpLabel instruction. -ir::Instruction LabelInstruction(uint32_t result_id) { - return ir::Instruction(SpvOp::SpvOpLabel, 0, result_id, {}); -} - -// Creates an OpBranch |target_id| instruction. -ir::Instruction BranchInstruction(uint32_t target_id) { - return ir::Instruction(SpvOp::SpvOpBranch, 0, 0, - { - ir::Operand(SPV_OPERAND_TYPE_ID, {target_id}), - }); -} - // Test case for analyzing individual instructions. struct AnalyzeInstDefUseTestCase { - std::vector<ir::Instruction> insts; // instrutions to be analyzed in order. const char* module_text; InstDefUse expected_define_use; }; @@ -1193,15 +1162,12 @@ TEST_P(AnalyzeInstDefUseTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Analyze the instructions. - opt::analysis::DefUseManager manager(module.get()); - for (ir::Instruction& inst : tc.insts) { - manager.AnalyzeInstDefUse(&inst); - } + opt::analysis::DefUseManager manager(context->module()); CheckDef(tc.expected_define_use, manager.id_to_defs()); CheckUse(tc.expected_define_use, manager.id_to_uses()); @@ -1212,8 +1178,7 @@ INSTANTIATE_TEST_CASE_P( TestCase, AnalyzeInstDefUseTest, ::testing::ValuesIn(std::vector<AnalyzeInstDefUseTestCase>{ { // A type declaring instruction. - {Int32TypeInstruction(1)}, - "", + "%1 = OpTypeInt 32 1", { // defs {{1, "%1 = OpTypeInt 32 1"}}, @@ -1221,88 +1186,79 @@ INSTANTIATE_TEST_CASE_P( }, }, { // A type declaring instruction and a constant value. - { - Int32TypeInstruction(1), - ConstantBoolInstruction(true, 1, 2), - }, - "", - { - { // defs - {1, "%1 = OpTypeInt 32 1"}, - {2, "%2 = OpConstantTrue %1"}, // It is fine the SPIR-V code here is invalid. - }, - { // uses - {1, {"%2 = OpConstantTrue %1"}}, - }, - }, - }, - { // Analyze two instrutions that have same result id. The def use info - // of the result id from the first instruction should be overwritten by - // the second instruction. - { - ConstantBoolInstruction(true, 1, 2), - // The def-use info of the following instruction should overwrite the - // records of the above one. - ConstantBoolInstruction(false, 3, 2), - }, - "", - { - // defs - {{2, "%2 = OpConstantFalse %3"}}, - // uses - {{3, {"%2 = OpConstantFalse %3"}}} - } - }, - { // Analyze forward reference instruction, also instruction that does - // not have result id. - { - BranchInstruction(2), - LabelInstruction(2), - }, - "", - { - // defs - {{2, "%2 = OpLabel"}}, - // uses - {{2, {"OpBranch %2"}}}, - } - }, - { // Analyzing an additional instruction with new result id to an - // existing module. - { - ConstantBoolInstruction(true, 1, 2), - }, - "%1 = OpTypeInt 32 1 ", + "%1 = OpTypeBool " + "%2 = OpConstantTrue %1", { { // defs - {1, "%1 = OpTypeInt 32 1"}, + {1, "%1 = OpTypeBool"}, {2, "%2 = OpConstantTrue %1"}, }, { // uses {1, {"%2 = OpConstantTrue %1"}}, }, - } - }, - { // Analyzing an additional instruction with existing result id to an - // existing module. - { - ConstantBoolInstruction(true, 1, 2), }, - "%1 = OpTypeInt 32 1 " - "%2 = OpTypeBool ", - { - { // defs - {1, "%1 = OpTypeInt 32 1"}, - {2, "%2 = OpConstantTrue %1"}, - }, - { // uses - {1, {"%2 = OpConstantTrue %1"}}, - }, - } }, })); // clang-format on +using AnalyzeInstDefUse = ::testing::Test; + +TEST(AnalyzeInstDefUse, UseWithNoResultId) { + ir::IRContext context(nullptr); + + // Analyze the instructions. + opt::analysis::DefUseManager manager(context.module()); + + ir::Instruction label(&context, SpvOpLabel, 0, 2, {}); + manager.AnalyzeInstDefUse(&label); + + ir::Instruction branch(&context, SpvOpBranch, 0, 0, + {{SPV_OPERAND_TYPE_ID, {2}}}); + manager.AnalyzeInstDefUse(&branch); + + InstDefUse expected = + { + // defs + { + {2, "%2 = OpLabel"}, + }, + // uses + {{2, {"OpBranch %2"}}}, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, manager.id_to_uses()); +} + +TEST(AnalyzeInstDefUse, AddNewInstruction) { + const std::string input = "%1 = OpTypeBool"; + + // Build module. + std::unique_ptr<ir::IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); + ASSERT_NE(nullptr, context); + + // Analyze the instructions. + opt::analysis::DefUseManager manager(context->module()); + + ir::Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {}); + manager.AnalyzeInstDefUse(&newInst); + + InstDefUse expected = + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + }, + { // uses + {1, {"%2 = OpConstantTrue %1"}}, + }, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, manager.id_to_uses()); +} + struct KillInstTestCase { const char* before; std::unordered_set<uint32_t> indices_for_inst_to_kill; @@ -1316,27 +1272,26 @@ TEST_P(KillInstTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of the def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // KillInst uint32_t index = 0; - context.module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) { + context->module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) { if (tc.indices_for_inst_to_kill.count(index) != 0) { - context.KillInst(inst); + context->KillInst(inst); } index++; }); - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.expected_define_use, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.expected_define_use, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.expected_define_use, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1428,12 +1383,12 @@ TEST_P(GetAnnotationsTest, Case) { const GetAnnotationsTestCase& tc = GetParam(); // Build module. - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Get annotations - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); auto insts = manager.GetAnnotations(tc.id); // Check diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp index 2db4ed2c..f930d464 100644 --- a/test/opt/instruction_test.cpp +++ b/test/opt/instruction_test.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "opt/instruction.h" +#include "opt/ir_context.h" #include "gmock/gmock.h" @@ -23,6 +24,7 @@ namespace { using spvtest::MakeInstruction; using spvtools::ir::Instruction; +using spvtools::ir::IRContext; using spvtools::ir::Operand; using ::testing::Eq; @@ -39,7 +41,8 @@ TEST(InstructionTest, CreateTrivial) { } TEST(InstructionTest, CreateWithOpcodeAndNoOperands) { - Instruction inst(SpvOpReturn); + IRContext context(nullptr); + Instruction inst(&context, SpvOpReturn); EXPECT_EQ(SpvOpReturn, inst.opcode()); EXPECT_EQ(0u, inst.type_id()); EXPECT_EQ(0u, inst.result_id()); @@ -119,7 +122,8 @@ spv_parsed_instruction_t kSampleControlBarrierInstruction = { 3}; TEST(InstructionTest, CreateWithOpcodeAndOperands) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_EQ(SpvOpTypeInt, inst.opcode()); EXPECT_EQ(0u, inst.type_id()); EXPECT_EQ(44u, inst.result_id()); @@ -129,20 +133,23 @@ TEST(InstructionTest, CreateWithOpcodeAndOperands) { } TEST(InstructionTest, GetOperand) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_THAT(inst.GetOperand(0).words, Eq(std::vector<uint32_t>{44})); EXPECT_THAT(inst.GetOperand(1).words, Eq(std::vector<uint32_t>{32})); EXPECT_THAT(inst.GetOperand(2).words, Eq(std::vector<uint32_t>{1})); } TEST(InstructionTest, GetInOperand) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_THAT(inst.GetInOperand(0).words, Eq(std::vector<uint32_t>{32})); EXPECT_THAT(inst.GetInOperand(1).words, Eq(std::vector<uint32_t>{1})); } TEST(InstructionTest, OperandConstIterators) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); // Spot check iteration across operands. auto cbegin = inst.cbegin(); auto cend = inst.cend(); @@ -168,7 +175,8 @@ TEST(InstructionTest, OperandConstIterators) { } TEST(InstructionTest, OperandIterators) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); // Spot check iteration across operands, with mutable iterators. auto begin = inst.begin(); auto end = inst.end(); @@ -198,7 +206,8 @@ TEST(InstructionTest, OperandIterators) { } TEST(InstructionTest, ForInIdStandardIdTypes) { - Instruction inst(kSampleAccessChainInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleAccessChainInstruction); std::vector<uint32_t> ids; inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); @@ -210,7 +219,8 @@ TEST(InstructionTest, ForInIdStandardIdTypes) { } TEST(InstructionTest, ForInIdNonstandardIdTypes) { - Instruction inst(kSampleControlBarrierInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleControlBarrierInstruction); std::vector<uint32_t> ids; inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); @@ -221,4 +231,60 @@ TEST(InstructionTest, ForInIdNonstandardIdTypes) { EXPECT_THAT(ids, Eq(std::vector<uint32_t>{100, 101, 102})); } +TEST(InstructionTest, UniqueIds) { + IRContext context(nullptr); + Instruction inst1(&context); + Instruction inst2(&context); + EXPECT_NE(inst1.unique_id(), inst2.unique_id()); +} + +TEST(InstructionTest, CloneUniqueIdDifferent) { + IRContext context(nullptr); + Instruction inst(&context); + std::unique_ptr<Instruction> clone(inst.Clone(&context)); + EXPECT_EQ(inst.context(), clone->context()); + EXPECT_NE(inst.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, CloneDifferentContext) { + IRContext c1(nullptr); + IRContext c2(nullptr); + Instruction inst(&c1); + std::unique_ptr<Instruction> clone(inst.Clone(&c2)); + EXPECT_EQ(&c1, inst.context()); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(&c1, &c2); +} + +TEST(InstructionTest, CloneDifferentContextDifferentUniqueId) { + IRContext c1(nullptr); + IRContext c2(nullptr); + Instruction inst(&c1); + Instruction other(&c2); + std::unique_ptr<Instruction> clone(inst.Clone(&c2)); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(other.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, EqualsEqualsOperator) { + IRContext context(nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr<Instruction> clone(i1.Clone(&context)); + EXPECT_TRUE(i1 == i1); + EXPECT_FALSE(i1 == i2); + EXPECT_FALSE(i1 == *clone); + EXPECT_FALSE(i2 == *clone); +} + +TEST(InstructionTest, LessThanOperator) { + IRContext context(nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr<Instruction> clone(i1.Clone(&context)); + EXPECT_TRUE(i1 < i2); + EXPECT_TRUE(i1 < *clone); + EXPECT_TRUE(i2 < *clone); +} + } // anonymous namespace diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp index e2ace777..770c3063 100644 --- a/test/opt/ir_context_test.cpp +++ b/test/opt/ir_context_test.cpp @@ -62,97 +62,97 @@ using IRContextTest = PassTest<::testing::Test>; TEST_F(IRContextTest, IndividualValidAfterBuild) { std::unique_ptr<ir::Module> module(new ir::Module()); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); - EXPECT_TRUE(context.AreAnalysesValid(i)); + localContext.BuildInvalidAnalyses(i); + EXPECT_TRUE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, AllValidAfterBuild) { std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); Analysis built_analyses = IRContext::kAnalysisNone; for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); built_analyses |= i; } - EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); } TEST_F(IRContextTest, AllValidAfterPassNoChange) { std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); Analysis built_analyses = IRContext::kAnalysisNone; for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); built_analyses |= i; } DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithoutChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithoutChange); - EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); } TEST_F(IRContextTest, NoneValidAfterPassWithChange) { std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_FALSE(context.AreAnalysesValid(i)); + EXPECT_FALSE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_TRUE(context.AreAnalysesValid(i)); + EXPECT_TRUE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) { std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesFirst pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); - EXPECT_TRUE(context.AreAnalysesValid(IRContext::kAnalysisBegin)); + EXPECT_TRUE(localContext.AreAnalysesValid(IRContext::kAnalysisBegin)); for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_FALSE(context.AreAnalysesValid(i)); + EXPECT_FALSE(localContext.AreAnalysesValid(i)); } } @@ -178,25 +178,31 @@ TEST_F(IRContextTest, KillMemberName) { OpFunctionEnd )"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); // Build the decoration manager. - context.get_decoration_mgr(); + context->get_decoration_mgr(); // Delete the OpTypeStruct. Should delete the OpName, OpMemberName, and // OpMemberDecorate associated with it. - context.KillDef(3); + context->KillDef(3); // Make sure all of the name are removed. - for (auto& inst : context.debugs2()) { + for (auto& inst : context->debugs2()) { EXPECT_EQ(inst.opcode(), SpvOpNop); } // Make sure all of the decorations are removed. - for (auto& inst : context.annotations()) { + for (auto& inst : context->annotations()) { EXPECT_EQ(inst.opcode(), SpvOpNop); } } + +TEST_F(IRContextTest, TakeNextUniqueIdIncrementing) { + const uint32_t NUM_TESTS = 1000; + IRContext localContext(nullptr); + for (uint32_t i = 1; i < NUM_TESTS; ++i) + EXPECT_EQ(i, localContext.TakeNextUniqueId()); +} } // anonymous namespace diff --git a/test/opt/ir_loader_test.cpp b/test/opt/ir_loader_test.cpp index b61f7cbc..ae46df90 100644 --- a/test/opt/ir_loader_test.cpp +++ b/test/opt/ir_loader_test.cpp @@ -14,9 +14,11 @@ #include <gtest/gtest.h> #include <algorithm> +#include <unordered_set> #include "message.h" #include "opt/build_module.h" +#include "opt/ir_context.h" #include "spirv-tools/libspirv.hpp" namespace { @@ -25,12 +27,12 @@ using namespace spvtools; void DoRoundTripCheck(const std::string& text) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - ASSERT_NE(nullptr, module) << "Failed to assemble\n" << text; + ASSERT_NE(nullptr, context) << "Failed to assemble\n" << text; std::vector<uint32_t> binary; - module->ToBinary(&binary, /* skip_nop = */ false); + context->module()->ToBinary(&binary, /* skip_nop = */ false); std::string disassembled_text; EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); @@ -212,17 +214,17 @@ TEST(IrBuilder, OpUndefOutsideFunction) { // clang-format on SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); const auto opundef_count = std::count_if( - module->types_values_begin(), module->types_values_end(), + context->module()->types_values_begin(), context->module()->types_values_end(), [](const ir::Instruction& inst) { return inst.opcode() == SpvOpUndef; }); EXPECT_EQ(3, opundef_count); std::vector<uint32_t> binary; - module->ToBinary(&binary, /* skip_nop = */ false); + context->module()->ToBinary(&binary, /* skip_nop = */ false); std::string disassembled_text; EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); @@ -322,9 +324,9 @@ void DoErrorMessageCheck(const std::string& assembly, }; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, std::move(consumer), assembly); - EXPECT_EQ(nullptr, module); + EXPECT_EQ(nullptr, context); } TEST(IrBuilder, FunctionInsideFunction) { @@ -378,4 +380,69 @@ TEST(IrBuilder, NotAllowedInstAppearingInFunction) { "block"); } +TEST(IrBuilder, UniqueIds) { + const std::string text = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + std::unique_ptr<ir::IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + ASSERT_NE(nullptr, context); + + std::unordered_set<uint32_t> ids; + context->module()->ForEachInst([&ids](const ir::Instruction* inst) { + EXPECT_TRUE(ids.insert(inst->unique_id()).second); + }); +} + } // anonymous namespace diff --git a/test/opt/module_test.cpp b/test/opt/module_test.cpp index 622d920f..4a434edb 100644 --- a/test/opt/module_test.cpp +++ b/test/opt/module_test.cpp @@ -26,6 +26,7 @@ namespace { +using spvtools::ir::IRContext; using spvtools::ir::Module; using spvtest::GetIdBound; using ::testing::Eq; @@ -42,31 +43,31 @@ TEST(ModuleTest, SetIdBound) { EXPECT_EQ(102u, GetIdBound(m)); } -// Returns a module formed by assembling the given text, +// Returns an IRContext owning the module formed by assembling the given text, // then loading the result. -inline std::unique_ptr<Module> BuildModule(std::string text) { +inline std::unique_ptr<IRContext> BuildModule(std::string text) { return spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); } TEST(ModuleTest, ComputeIdBound) { // Emtpy module case. - EXPECT_EQ(1u, BuildModule("")->ComputeIdBound()); + EXPECT_EQ(1u, BuildModule("")->module()->ComputeIdBound()); // Sensitive to result id - EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->ComputeIdBound()); + EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->module()->ComputeIdBound()); // Sensitive to type id - EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->ComputeIdBound()); + EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->module()->ComputeIdBound()); // Sensitive to a regular Id parameter - EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->ComputeIdBound()); + EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->module()->ComputeIdBound()); // Sensitive to a scope Id parameter. EXPECT_EQ(3000u, BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " "OpMemoryBarrier !2999 %b\n") - ->ComputeIdBound()); + ->module()->ComputeIdBound()); // Sensitive to a semantics Id parameter EXPECT_EQ(4000u, BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " "OpMemoryBarrier %b !3999\n") - ->ComputeIdBound()); + ->module()->ComputeIdBound()); } } // anonymous namespace diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h index 7ad71817..fdc4398f 100644 --- a/test/opt/pass_fixture.h +++ b/test/opt/pass_fixture.h @@ -46,36 +46,35 @@ class PassTest : public TestT { public: PassTest() : consumer_(nullptr), + context_(nullptr), tools_(SPV_ENV_UNIVERSAL_1_1), manager_(new opt::PassManager()), assemble_options_(SpirvTools::kDefaultAssembleOption), disassemble_options_(SpirvTools::kDefaultDisassembleOption) {} // Runs the given |pass| on the binary assembled from the |original|. - // Returns a tuple of the optimized binary and the boolean value returned + // Returns a tuple of the optimized binary and the boolean value returned // from pass Process() function. std::tuple<std::vector<uint32_t>, opt::Pass::Status> OptimizeToBinary( opt::Pass* pass, const std::string& original, bool skip_nop) { - std::unique_ptr<ir::Module> module = BuildModule( - SPV_ENV_UNIVERSAL_1_1, consumer_, original, assemble_options_); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << original << std::endl; - if (!module) { - return std::make_tuple(std::vector<uint32_t>(), + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original, + assemble_options_)); + EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n" + << original << std::endl; + if (!context()) { + return std::make_tuple(std::vector<uint32_t>(), opt::Pass::Status::Failure); } - ir::IRContext context(std::move(module), consumer()); - - const auto status = pass->Run(&context); + const auto status = pass->Run(context()); std::vector<uint32_t> binary; - context.module()->ToBinary(&binary, skip_nop); + context()->module()->ToBinary(&binary, skip_nop); return std::make_tuple(binary, status); } // Runs a single pass of class |PassT| on the binary assembled from the - // |assembly|. Returns a tuple of the optimized binary and the boolean value + // |assembly|. Returns a tuple of the optimized binary and the boolean value // from the pass Process() function. template <typename PassT, typename... Args> std::tuple<std::vector<uint32_t>, opt::Pass::Status> SinglePassRunToBinary( @@ -106,7 +105,7 @@ class PassTest : public TestT { // Runs a single pass of class |PassT| on the binary assembled from the // |original| assembly, and checks whether the optimized binary can be // disassembled to the |expected| assembly. Optionally will also validate - // the optimized binary. This does *not* involve pass manager. Callers + // the optimized binary. This does *not* involve pass manager. Callers // are suggested to use SCOPED_TRACE() for better messages. template <typename PassT, typename... Args> void SinglePassRunAndCheck(const std::string& original, @@ -122,16 +121,16 @@ class PassTest : public TestT { status == opt::Pass::Status::SuccessWithoutChange); if (do_validation) { spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1; - spv_context context = spvContextCreate(target_env); + spv_context spvContext = spvContextCreate(target_env); spv_diagnostic diagnostic = nullptr; spv_const_binary_t binary = {optimized_bin.data(), optimized_bin.size()}; - spv_result_t error = spvValidate(context, &binary, &diagnostic); + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); EXPECT_EQ(error, 0); if (error != 0) spvDiagnosticPrint(diagnostic); spvDiagnosticDestroy(diagnostic); - spvContextDestroy(context); + spvContextDestroy(spvContext); } std::string optimized_asm; EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm, @@ -191,15 +190,14 @@ class PassTest : public TestT { void RunAndCheck(const std::string& original, const std::string& expected) { assert(manager_->NumPasses()); - std::unique_ptr<ir::Module> module = BuildModule( - SPV_ENV_UNIVERSAL_1_1, nullptr, original, assemble_options_); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), consumer()); + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original, + assemble_options_)); + ASSERT_NE(nullptr, context()); - manager_->Run(&context); + manager_->Run(context()); std::vector<uint32_t> binary; - context.module()->ToBinary(&binary, /* skip_nop = */ false); + context()->module()->ToBinary(&binary, /* skip_nop = */ false); std::string optimized; EXPECT_TRUE(tools_.Disassemble(binary, &optimized, @@ -216,8 +214,10 @@ class PassTest : public TestT { } MessageConsumer consumer() { return consumer_;} + ir::IRContext* context() { return context_.get(); } private: MessageConsumer consumer_; // Message consumer. + std::unique_ptr<ir::IRContext> context_; // IR context SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities. std::unique_ptr<opt::PassManager> manager_; // The pass manager. uint32_t assemble_options_; diff --git a/test/opt/pass_manager_test.cpp b/test/opt/pass_manager_test.cpp index 43d70055..77ed38b5 100644 --- a/test/opt/pass_manager_test.cpp +++ b/test/opt/pass_manager_test.cpp @@ -75,7 +75,7 @@ class AppendOpNopPass : public opt::Pass { public: const char* name() const override { return "AppendOpNop"; } Status Process(ir::IRContext* irContext) override { - irContext->AddDebug1Inst(MakeUnique<ir::Instruction>()); + irContext->AddDebug1Inst(MakeUnique<ir::Instruction>(irContext)); return Status::SuccessWithChange; } }; @@ -89,7 +89,7 @@ class AppendMultipleOpNopPass : public opt::Pass { const char* name() const override { return "AppendOpNop"; } Status Process(ir::IRContext* irContext) override { for (uint32_t i = 0; i < num_nop_; i++) { - irContext->AddDebug1Inst(MakeUnique<ir::Instruction>()); + irContext->AddDebug1Inst(MakeUnique<ir::Instruction>(irContext)); } return Status::SuccessWithChange; } @@ -103,7 +103,8 @@ class DuplicateInstPass : public opt::Pass { public: const char* name() const override { return "DuplicateInst"; } Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique<ir::Instruction>(*(--irContext->debug1_end())); + auto inst = MakeUnique<ir::Instruction>( + *(--irContext->debug1_end())->Clone(irContext)); irContext->AddDebug1Inst(std::move(inst)); return Status::SuccessWithChange; } @@ -140,7 +141,7 @@ class AppendTypeVoidInstPass : public opt::Pass { const char* name() const override { return "AppendTypeVoidInstPass"; } Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique<ir::Instruction>(SpvOpTypeVoid, 0, result_id_, + auto inst = MakeUnique<ir::Instruction>(irContext, SpvOpTypeVoid, 0, result_id_, std::vector<ir::Operand>{}); irContext->AddType(std::move(inst)); return Status::SuccessWithChange; diff --git a/test/opt/pass_test.cpp b/test/opt/pass_test.cpp index 8c62b289..5ff1a121 100644 --- a/test/opt/pass_test.cpp +++ b/test/opt/pass_test.cpp @@ -76,18 +76,18 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { )"; // clang-format on - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector<uint32_t> processed; opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessEntryPointCallTree(mark_visited, module.get()); + testPass.ProcessEntryPointCallTree(mark_visited, localContext->module()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11)); } @@ -132,12 +132,11 @@ TEST_F(PassClassTest, BasicVisitReachable) { )"; // clang-format on - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector<uint32_t> processed; @@ -145,7 +144,7 @@ TEST_F(PassClassTest, BasicVisitReachable) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13)); } @@ -185,12 +184,11 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { )"; // clang-format on - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector<uint32_t> processed; @@ -198,7 +196,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12)); } @@ -228,12 +226,11 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { )"; // clang-format on - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector<uint32_t> processed; @@ -241,7 +238,7 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10)); } } // namespace diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index 17eb2a49..1c8f2db1 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -88,9 +88,9 @@ TEST(TypeManager, TypeStrings) { {28, "named_barrier"}, }; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); EXPECT_EQ(type_id_strs.size(), manager.NumTypes()); EXPECT_EQ(2u, manager.NumForwardPointers()); @@ -118,9 +118,9 @@ TEST(TypeManager, DecorationOnStruct) { %struct4 = OpTypeStruct %u32 %f32 ; the same %struct7 = OpTypeStruct %f32 ; no decoration )"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(7u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -168,9 +168,9 @@ TEST(TypeManager, DecorationOnMember) { %struct7 = OpTypeStruct %u32 %f32 ; extra decoration on the struct %struct10 = OpTypeStruct %u32 %f32 ; no member decoration )"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(10u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -206,9 +206,9 @@ TEST(TypeManager, DecorationEmpty) { %struct2 = OpTypeStruct %f32 %u32 %struct5 = OpTypeStruct %f32 )"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(5u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -228,9 +228,9 @@ TEST(TypeManager, DecorationEmpty) { TEST(TypeManager, BeginEndForEmptyModule) { const std::string text = ""; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(0u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -245,9 +245,9 @@ TEST(TypeManager, BeginEnd) { %u32 = OpTypeInt 32 0 %f64 = OpTypeFloat 64 )"; - std::unique_ptr<ir::Module> module = + std::unique_ptr<ir::IRContext> context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(5u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); |