diff options
Diffstat (limited to 'source/opt/constants.cpp')
-rw-r--r-- | source/opt/constants.cpp | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index 238dab38..ba7fc6b8 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -131,6 +131,24 @@ std::vector<const Constant*> ConstantManager::GetOperandConstants( return constants; } +uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, + uint32_t type_id) const { + c = FindConstant(c); + if (c == nullptr) { + return 0; + } + + for (auto range = const_val_to_id_.equal_range(c); + range.first != range.second; ++range.first) { + Instruction* const_def = + context()->get_def_use_mgr()->GetDef(range.first->second); + if (type_id == 0 || const_def->type_id() == type_id) { + return range.first->second; + } + } + return 0; +} + std::vector<const Constant*> ConstantManager::GetConstantsFromIds( const std::vector<uint32_t>& ids) const { std::vector<const Constant*> constants; @@ -163,7 +181,7 @@ Instruction* ConstantManager::GetDefiningInstruction( const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { assert(type_id == 0 || context()->get_type_mgr()->GetType(type_id) == c->type()); - uint32_t decl_id = FindDeclaredConstant(c); + uint32_t decl_id = FindDeclaredConstant(c, type_id); if (decl_id == 0) { auto iter = context()->types_values_end(); if (pos == nullptr) pos = &iter; @@ -295,8 +313,17 @@ std::unique_ptr<Instruction> ConstantManager::CreateInstruction( std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction( uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { std::vector<Operand> operands; + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); + uint32_t component_index = 0; for (const Constant* component_const : cc->GetComponents()) { - uint32_t id = FindDeclaredConstant(component_const); + uint32_t component_type_id = 0; + if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { + component_type_id = type_inst->GetSingleWordInOperand(component_index); + } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { + component_type_id = type_inst->GetSingleWordInOperand(0); + } + uint32_t id = FindDeclaredConstant(component_const, component_type_id); + if (id == 0) { // Cannot get the id of the component constant, while all components // should have been added to the module prior to the composite constant. @@ -305,6 +332,7 @@ std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction( } operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list<uint32_t>{id}); + component_index++; } uint32_t type = (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; |