summaryrefslogtreecommitdiff
path: root/source/opt/constants.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/constants.cpp')
-rw-r--r--source/opt/constants.cpp32
1 files changed, 30 insertions, 2 deletions
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index 238dab38..ba7fc6b8 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -131,6 +131,24 @@ std::vector<const Constant*> ConstantManager::GetOperandConstants(
return constants;
}
+uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
+ uint32_t type_id) const {
+ c = FindConstant(c);
+ if (c == nullptr) {
+ return 0;
+ }
+
+ for (auto range = const_val_to_id_.equal_range(c);
+ range.first != range.second; ++range.first) {
+ Instruction* const_def =
+ context()->get_def_use_mgr()->GetDef(range.first->second);
+ if (type_id == 0 || const_def->type_id() == type_id) {
+ return range.first->second;
+ }
+ }
+ return 0;
+}
+
std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
const std::vector<uint32_t>& ids) const {
std::vector<const Constant*> constants;
@@ -163,7 +181,7 @@ Instruction* ConstantManager::GetDefiningInstruction(
const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
assert(type_id == 0 ||
context()->get_type_mgr()->GetType(type_id) == c->type());
- uint32_t decl_id = FindDeclaredConstant(c);
+ uint32_t decl_id = FindDeclaredConstant(c, type_id);
if (decl_id == 0) {
auto iter = context()->types_values_end();
if (pos == nullptr) pos = &iter;
@@ -295,8 +313,17 @@ std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
std::vector<Operand> operands;
+ Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
+ uint32_t component_index = 0;
for (const Constant* component_const : cc->GetComponents()) {
- uint32_t id = FindDeclaredConstant(component_const);
+ uint32_t component_type_id = 0;
+ if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
+ component_type_id = type_inst->GetSingleWordInOperand(component_index);
+ } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
+ component_type_id = type_inst->GetSingleWordInOperand(0);
+ }
+ uint32_t id = FindDeclaredConstant(component_const, component_type_id);
+
if (id == 0) {
// Cannot get the id of the component constant, while all components
// should have been added to the module prior to the composite constant.
@@ -305,6 +332,7 @@ std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
}
operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
+ component_index++;
}
uint32_t type =
(type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;