// Copyright (c) 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/scalar_replacement_pass.h" #include #include #include #include #include "source/enum_string_mapping.h" #include "source/extensions.h" #include "source/opt/make_unique.h" #include "source/opt/reflect.h" #include "source/opt/types.h" namespace spvtools { namespace opt { Pass::Status ScalarReplacementPass::Process() { Status status = Status::SuccessWithoutChange; for (auto& f : *get_module()) { Status functionStatus = ProcessFunction(&f); if (functionStatus == Status::Failure) return functionStatus; else if (functionStatus == Status::SuccessWithChange) status = functionStatus; } return status; } Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { std::queue worklist; BasicBlock& entry = *function->begin(); for (auto iter = entry.begin(); iter != entry.end(); ++iter) { // Function storage class OpVariables must appear as the first instructions // of the entry block. if (iter->opcode() != SpvOpVariable) break; Instruction* varInst = &*iter; if (CanReplaceVariable(varInst)) { worklist.push(varInst); } } Status status = Status::SuccessWithoutChange; while (!worklist.empty()) { Instruction* varInst = worklist.front(); worklist.pop(); if (!ReplaceVariable(varInst, &worklist)) return Status::Failure; else status = Status::SuccessWithChange; } return status; } bool ScalarReplacementPass::ReplaceVariable( Instruction* inst, std::queue* worklist) { std::vector replacements; CreateReplacementVariables(inst, &replacements); std::vector dead; dead.push_back(inst); if (!get_def_use_mgr()->WhileEachUser( inst, [this, &replacements, &dead](Instruction* user) { if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpLoad: ReplaceWholeLoad(user, replacements); dead.push_back(user); break; case SpvOpStore: ReplaceWholeStore(user, replacements); dead.push_back(user); break; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (!ReplaceAccessChain(user, replacements)) return false; dead.push_back(user); break; case SpvOpName: case SpvOpMemberName: break; default: assert(false && "Unexpected opcode"); break; } } return true; })) return false; // Clean up some dead code. while (!dead.empty()) { Instruction* toKill = dead.back(); dead.pop_back(); context()->KillInst(toKill); } // Attempt to further scalarize. for (auto var : replacements) { if (var->opcode() == SpvOpVariable) { if (get_def_use_mgr()->NumUsers(var) == 0) { context()->KillInst(var); } else if (CanReplaceVariable(var)) { worklist->push(var); } } } return true; } void ScalarReplacementPass::ReplaceWholeLoad( Instruction* load, const std::vector& replacements) { // Replaces the load of the entire composite with a load from each replacement // variable followed by a composite construction. BasicBlock* block = context()->get_instr_block(load); std::vector loads; loads.reserve(replacements.size()); BasicBlock::iterator where(load); for (auto var : replacements) { // Create a load of each replacement variable. if (var->opcode() != SpvOpVariable) { loads.push_back(var); continue; } Instruction* type = GetStorageType(var); uint32_t loadId = TakeNextId(); std::unique_ptr newLoad( new Instruction(context(), SpvOpLoad, type->result_id(), loadId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Copy memory access attributes which start at index 1. Index 0 is the // pointer to load. for (uint32_t i = 1; i < load->NumInOperands(); ++i) { Operand copy(load->GetInOperand(i)); newLoad->AddOperand(std::move(copy)); } where = where.InsertBefore(std::move(newLoad)); get_def_use_mgr()->AnalyzeInstDefUse(&*where); context()->set_instr_block(&*where, block); loads.push_back(&*where); } // Construct a new composite. uint32_t compositeId = TakeNextId(); where = load; std::unique_ptr compositeConstruct(new Instruction( context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); for (auto l : loads) { Operand op(SPV_OPERAND_TYPE_ID, std::initializer_list{l->result_id()}); compositeConstruct->AddOperand(std::move(op)); } where = where.InsertBefore(std::move(compositeConstruct)); get_def_use_mgr()->AnalyzeInstDefUse(&*where); context()->set_instr_block(&*where, block); context()->ReplaceAllUsesWith(load->result_id(), compositeId); } void ScalarReplacementPass::ReplaceWholeStore( Instruction* store, const std::vector& replacements) { // Replaces a store to the whole composite with a series of extract and stores // to each element. uint32_t storeInput = store->GetSingleWordInOperand(1u); BasicBlock* block = context()->get_instr_block(store); BasicBlock::iterator where(store); uint32_t elementIndex = 0; for (auto var : replacements) { // Create the extract. if (var->opcode() != SpvOpVariable) { elementIndex++; continue; } Instruction* type = GetStorageType(var); uint32_t extractId = TakeNextId(); std::unique_ptr extract(new Instruction( context(), SpvOpCompositeExtract, type->result_id(), extractId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {storeInput}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); auto iter = where.InsertBefore(std::move(extract)); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, block); // Create the store. std::unique_ptr newStore( new Instruction(context(), SpvOpStore, 0, 0, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}, {SPV_OPERAND_TYPE_ID, {extractId}}})); // Copy memory access attributes which start at index 2. Index 0 is the // pointer and index 1 is the data. for (uint32_t i = 2; i < store->NumInOperands(); ++i) { Operand copy(store->GetInOperand(i)); newStore->AddOperand(std::move(copy)); } iter = where.InsertBefore(std::move(newStore)); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, block); } } bool ScalarReplacementPass::ReplaceAccessChain( Instruction* chain, const std::vector& replacements) { // Replaces the access chain with either another access chain (with one fewer // indexes) or a direct use of the replacement variable. uint32_t indexId = chain->GetSingleWordInOperand(1u); const Instruction* index = get_def_use_mgr()->GetDef(indexId); size_t indexValue = GetConstantInteger(index); if (indexValue > replacements.size()) { // Out of bounds access, this is illegal IR. return false; } else { const Instruction* var = replacements[indexValue]; if (chain->NumInOperands() > 2) { // Replace input access chain with another access chain. BasicBlock::iterator chainIter(chain); uint32_t replacementId = TakeNextId(); std::unique_ptr replacementChain(new Instruction( context(), chain->opcode(), chain->type_id(), replacementId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Add the remaining indexes. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { Operand copy(chain->GetInOperand(i)); replacementChain->AddOperand(std::move(copy)); } auto iter = chainIter.InsertBefore(std::move(replacementChain)); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, context()->get_instr_block(chain)); context()->ReplaceAllUsesWith(chain->result_id(), replacementId); } else { // Replace with a use of the variable. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); } } return true; } void ScalarReplacementPass::CreateReplacementVariables( Instruction* inst, std::vector* replacements) { Instruction* type = GetStorageType(inst); std::unique_ptr> components_used = GetUsedComponents(inst); uint32_t elem = 0; switch (type->opcode()) { case SpvOpTypeStruct: type->ForEachInOperand( [this, inst, &elem, replacements, &components_used](uint32_t* id) { if (!components_used || components_used->count(elem)) { CreateVariable(*id, inst, elem, replacements); } else { replacements->push_back(CreateNullConstant(*id)); } elem++; }); break; case SpvOpTypeArray: for (uint32_t i = 0; i != GetArrayLength(type); ++i) { if (!components_used || components_used->count(i)) { CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); } else { replacements->push_back( CreateNullConstant(type->GetSingleWordInOperand(0u))); } } break; case SpvOpTypeMatrix: case SpvOpTypeVector: for (uint32_t i = 0; i != GetNumElements(type); ++i) { CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); } break; default: assert(false && "Unexpected type."); break; } TransferAnnotations(inst, replacements); } void ScalarReplacementPass::TransferAnnotations( const Instruction* source, std::vector* replacements) { // Only transfer invariant and restrict decorations on the variable. There are // no type or member decorations that are necessary to transfer. for (auto inst : get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { assert(inst->opcode() == SpvOpDecorate); uint32_t decoration = inst->GetSingleWordInOperand(1u); if (decoration == SpvDecorationInvariant || decoration == SpvDecorationRestrict) { for (auto var : *replacements) { std::unique_ptr annotation( new Instruction(context(), SpvOpDecorate, 0, 0, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}, {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { Operand copy(inst->GetInOperand(i)); annotation->AddOperand(std::move(copy)); } context()->AddAnnotationInst(std::move(annotation)); get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); } } } } void ScalarReplacementPass::CreateVariable( uint32_t typeId, Instruction* varInst, uint32_t index, std::vector* replacements) { uint32_t ptrId = GetOrCreatePointerType(typeId); uint32_t id = TakeNextId(); std::unique_ptr variable(new Instruction( context(), SpvOpVariable, ptrId, id, std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); BasicBlock* block = context()->get_instr_block(varInst); block->begin().InsertBefore(std::move(variable)); Instruction* inst = &*block->begin(); // If varInst was initialized, make sure to initialize its replacement. GetOrCreateInitialValue(varInst, index, inst); get_def_use_mgr()->AnalyzeInstDefUse(inst); context()->set_instr_block(inst, block); replacements->push_back(inst); } uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { auto iter = pointee_to_pointer_.find(id); if (iter != pointee_to_pointer_.end()) return iter->second; analysis::Type* pointeeTy; std::unique_ptr pointerTy; std::tie(pointeeTy, pointerTy) = context()->get_type_mgr()->GetTypeAndPointerType(id, SpvStorageClassFunction); uint32_t ptrId = 0; if (id == context()->get_type_mgr()->GetId(pointeeTy)) { // Non-ambiguous type, just ask the type manager for an id. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get()); pointee_to_pointer_[id] = ptrId; return ptrId; } // Ambiguous type. We must perform a linear search to try and find the right // type. for (auto global : context()->types_values()) { if (global.opcode() == SpvOpTypePointer && global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && global.GetSingleWordInOperand(1u) == id) { if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { // Only reuse a decoration-less pointer of the correct type. ptrId = global.result_id(); break; } } } if (ptrId != 0) { pointee_to_pointer_[id] = ptrId; return ptrId; } ptrId = TakeNextId(); context()->AddType(MakeUnique( context(), SpvOpTypePointer, 0, ptrId, std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {id}}})); Instruction* ptr = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(ptr); pointee_to_pointer_[id] = ptrId; // Register with the type manager if necessary. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy); return ptrId; } void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, uint32_t index, Instruction* newVar) { assert(source->opcode() == SpvOpVariable); if (source->NumInOperands() < 2) return; uint32_t initId = source->GetSingleWordInOperand(1u); uint32_t storageId = GetStorageType(newVar)->result_id(); Instruction* init = get_def_use_mgr()->GetDef(initId); uint32_t newInitId = 0; // TODO(dnovillo): Refactor this with constant propagation. if (init->opcode() == SpvOpConstantNull) { // Initialize to appropriate NULL. auto iter = type_to_null_.find(storageId); if (iter == type_to_null_.end()) { newInitId = TakeNextId(); type_to_null_[storageId] = newInitId; context()->AddGlobalValue( MakeUnique(context(), SpvOpConstantNull, storageId, newInitId, std::initializer_list{})); Instruction* newNull = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newNull); } else { newInitId = iter->second; } } else if (IsSpecConstantInst(init->opcode())) { // Create a new constant extract. newInitId = TakeNextId(); context()->AddGlobalValue(MakeUnique( context(), SpvOpSpecConstantOp, storageId, newInitId, std::initializer_list{ {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, {SPV_OPERAND_TYPE_ID, {init->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); Instruction* newSpecConst = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); } else if (init->opcode() == SpvOpConstantComposite) { // Get the appropriate index constant. newInitId = init->GetSingleWordInOperand(index); Instruction* element = get_def_use_mgr()->GetDef(newInitId); if (element->opcode() == SpvOpUndef) { // Undef is not a valid initializer for a variable. newInitId = 0; } } else { assert(false); } if (newInitId != 0) { newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); } } size_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const { assert(op.words.size() <= 2); size_t len = 0; for (uint32_t i = 0; i != op.words.size(); ++i) { len |= (op.words[i] << (32 * i)); } return len; } size_t ScalarReplacementPass::GetConstantInteger( const Instruction* constant) const { assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() == SpvOpTypeInt); assert(constant->opcode() == SpvOpConstant || constant->opcode() == SpvOpConstantNull); if (constant->opcode() == SpvOpConstantNull) { return 0; } const Operand& op = constant->GetInOperand(0u); return GetIntegerLiteral(op); } size_t ScalarReplacementPass::GetArrayLength( const Instruction* arrayType) const { assert(arrayType->opcode() == SpvOpTypeArray); const Instruction* length = get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); return GetConstantInteger(length); } size_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { assert(type->opcode() == SpvOpTypeVector || type->opcode() == SpvOpTypeMatrix); const Operand& op = type->GetInOperand(1u); assert(op.words.size() <= 2); size_t len = 0; for (uint32_t i = 0; i != op.words.size(); ++i) { len |= (op.words[i] << (32 * i)); } return len; } Instruction* ScalarReplacementPass::GetStorageType( const Instruction* inst) const { assert(inst->opcode() == SpvOpVariable); uint32_t ptrTypeId = inst->type_id(); uint32_t typeId = get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); return get_def_use_mgr()->GetDef(typeId); } bool ScalarReplacementPass::CanReplaceVariable( const Instruction* varInst) const { assert(varInst->opcode() == SpvOpVariable); // Can only replace function scope variables. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) return false; if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) return false; const Instruction* typeInst = GetStorageType(varInst); return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst); } bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { if (!CheckTypeAnnotations(typeInst)) return false; switch (typeInst->opcode()) { case SpvOpTypeStruct: // Don't bother with empty structs or very large structs. if (typeInst->NumInOperands() == 0 || IsLargerThanSizeLimit(typeInst->NumInOperands())) return false; return true; case SpvOpTypeArray: if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) return false; return true; // TODO(alanbaker): Develop some heuristics for when this should be // re-enabled. //// Specifically including matrix and vector in an attempt to reduce the //// number of vector registers required. // case SpvOpTypeMatrix: // case SpvOpTypeVector: // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; // return true; case SpvOpTypeRuntimeArray: default: return false; } } bool ScalarReplacementPass::CheckTypeAnnotations( const Instruction* typeInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { uint32_t decoration; if (inst->opcode() == SpvOpDecorate) { decoration = inst->GetSingleWordInOperand(1u); } else { assert(inst->opcode() == SpvOpMemberDecorate); decoration = inst->GetSingleWordInOperand(2u); } switch (decoration) { case SpvDecorationRowMajor: case SpvDecorationColMajor: case SpvDecorationArrayStride: case SpvDecorationMatrixStride: case SpvDecorationCPacked: case SpvDecorationInvariant: case SpvDecorationRestrict: case SpvDecorationOffset: case SpvDecorationAlignment: case SpvDecorationAlignmentId: case SpvDecorationMaxByteOffset: break; default: return false; } } return true; } bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { assert(inst->opcode() == SpvOpDecorate); uint32_t decoration = inst->GetSingleWordInOperand(1u); switch (decoration) { case SpvDecorationInvariant: case SpvDecorationRestrict: case SpvDecorationAlignment: case SpvDecorationAlignmentId: case SpvDecorationMaxByteOffset: break; default: return false; } } return true; } bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { VariableStats stats = {0, 0}; bool ok = CheckUses(inst, &stats); // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when // SRoA is costly, such as when the structure has many (unaccessed?) // members. return ok; } bool ScalarReplacementPass::CheckUses(const Instruction* inst, VariableStats* stats) const { bool ok = true; get_def_use_mgr()->ForEachUse( inst, [this, stats, &ok](const Instruction* user, uint32_t index) { // Annotations are check as a group separately. if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (index == 2u) { uint32_t id = user->GetSingleWordOperand(3u); const Instruction* opInst = get_def_use_mgr()->GetDef(id); if (!IsCompileTimeConstantInst(opInst->opcode())) { ok = false; } else { if (!CheckUsesRelaxed(user)) ok = false; } stats->num_partial_accesses++; } else { ok = false; } break; case SpvOpLoad: if (!CheckLoad(user, index)) ok = false; stats->num_full_accesses++; break; case SpvOpStore: if (!CheckStore(user, index)) ok = false; stats->num_full_accesses++; break; case SpvOpName: case SpvOpMemberName: break; default: ok = false; break; } } }); return ok; } bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { bool ok = true; get_def_use_mgr()->ForEachUse( inst, [this, &ok](const Instruction* user, uint32_t index) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (index != 2u) { ok = false; } else { if (!CheckUsesRelaxed(user)) ok = false; } break; case SpvOpLoad: if (!CheckLoad(user, index)) ok = false; break; case SpvOpStore: if (!CheckStore(user, index)) ok = false; break; default: ok = false; break; } }); return ok; } bool ScalarReplacementPass::CheckLoad(const Instruction* inst, uint32_t index) const { if (index != 2u) return false; if (inst->NumInOperands() >= 2 && inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask) return false; return true; } bool ScalarReplacementPass::CheckStore(const Instruction* inst, uint32_t index) const { if (index != 0u) return false; if (inst->NumInOperands() >= 3 && inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask) return false; return true; } bool ScalarReplacementPass::IsLargerThanSizeLimit(size_t length) const { if (max_num_elements_ == 0) { return false; } return length > max_num_elements_; } std::unique_ptr> ScalarReplacementPass::GetUsedComponents(Instruction* inst) { std::unique_ptr> result( new std::unordered_set()); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, this](Instruction* use) { switch (use->opcode()) { case SpvOpLoad: { // Look for extract from the load. std::vector t; if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { if (use2->opcode() != SpvOpCompositeExtract) { return false; } t.push_back(use2->GetSingleWordInOperand(1)); return true; })) { result->insert(t.begin(), t.end()); return true; } else { result.reset(nullptr); return false; } } case SpvOpStore: // No components are used. Things are just stored to. return true; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: { // Add the first index it if is a constant. // TODO: Could be improved by checking if the address is used in a load. analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); uint32_t index_id = use->GetSingleWordInOperand(1); const analysis::Constant* index_const = const_mgr->FindDeclaredConstant(index_id); if (index_const) { const analysis::Integer* index_type = index_const->type()->AsInteger(); assert(index_type); if (index_type->width() == 32) { result->insert(index_const->GetU32()); return true; } else if (index_type->width() == 64) { result->insert(index_const->GetU64()); return true; } result.reset(nullptr); return false; } else { // Could be any element. Assuming all are used. result.reset(nullptr); return false; } } case SpvOpCopyObject: { // Follow the copy to see which components are used. auto t = GetUsedComponents(use); if (!t) { result.reset(nullptr); return false; } result->insert(t->begin(), t->end()); return true; } default: // We do not know what is happening. Have to assume the worst. result.reset(nullptr); return false; } }); return result; } Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); const analysis::Type* type = type_mgr->GetType(type_id); const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); Instruction* null_inst = const_mgr->GetDefiningInstruction(null_const, type_id); context()->UpdateDefUse(null_inst); return null_inst; } } // namespace opt } // namespace spvtools