| // Copyright 2025 The Khronos Group Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "fnvar.h" |
| |
| #include <initializer_list> |
| #include <memory> |
| #include <sstream> |
| |
| #include "source/opt/instruction.h" |
| |
| namespace spvtools { |
| |
| using opt::Function; |
| using opt::Instruction; |
| using opt::analysis::Type; |
| |
| namespace { |
| // Helper functions |
| |
| // Parses a CSV source string for the purpose of this extension. |
| // |
| // Required columns must be known in advance and supplied as the required_cols |
| // argument -- this is used for error checking. Values are assumed to be |
| // separated by CSV_SEP. The input source string is assumed to be the output of |
| // io::ReadTextFile and no other validation, apart from the CSV parsing, is |
| // performed. |
| // |
| // Returns true on success, false on error (with error message stored in |
| // err_msg). |
| bool ParseCsv(const std::string& source, |
| const std::vector<std::string>& required_cols, |
| std::stringstream& err_msg, |
| std::vector<std::vector<std::string>>& result) { |
| std::stringstream fn_variants_csv_stream(source); |
| std::string line; |
| std::vector<std::string> columns; |
| constexpr char CSV_SEP = ','; |
| bool first_line = true; |
| |
| while (std::getline(fn_variants_csv_stream, line, '\n')) { |
| if (line.empty()) { |
| continue; |
| } |
| |
| std::vector<std::string> vals; |
| std::string val; |
| std::stringstream line_stream(line); |
| auto* vec = first_line ? &columns : &vals; |
| |
| while (std::getline(line_stream, val, CSV_SEP)) { |
| vec->push_back(val); |
| } |
| |
| if (!line_stream && val.empty()) { |
| vec->push_back(""); |
| } |
| |
| if (!first_line) { |
| if (vals.size() != columns.size()) { |
| err_msg << "Number of values does not match the number of columns. " |
| "Offending line:\n" |
| << line; |
| return false; |
| } |
| result.push_back(vals); |
| } |
| |
| first_line = false; |
| } |
| |
| // check if required columns match actual columns (ordering matters) |
| |
| if (columns.size() != required_cols.size()) { |
| err_msg << "Invalid number of CSV columns: " << columns.size() |
| << ", expected " << required_cols.size() << "."; |
| return false; |
| } |
| |
| for (size_t i = 0; i < columns.size(); ++i) { |
| if (columns[i] != required_cols[i]) { |
| err_msg << "Invalid name of column " << i + 1 << ". Expected '" |
| << required_cols[i] << "', got '" << columns[i] << "'."; |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| // Annotate ID with ConditionalINTEL decoration |
| void DecorateConditional(IRContext* context, uint32_t id_to_decorate, |
| uint32_t spec_const_id) { |
| auto decor_instr = |
| std::make_unique<Instruction>(context, spv::Op::OpDecorate); |
| decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {id_to_decorate}}); |
| decor_instr->AddOperand({SPV_OPERAND_TYPE_DECORATION, |
| {uint32_t(spv::Decoration::ConditionalINTEL)}}); |
| decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}}); |
| context->module()->AddAnnotationInst(std::move(decor_instr)); |
| } |
| |
| // Finds entry point corresponding to a function |
| // |
| // Returns null if not found, otherwise returns pointer to the EP Instruction. |
| Instruction* FindEntryPoint(const Instruction& fn_inst) { |
| auto* mod = fn_inst.context()->module(); |
| for (auto& entry_point : mod->entry_points()) { |
| const int ep_i = |
| entry_point.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1; |
| if (entry_point.GetOperand(ep_i).AsId() == fn_inst.result_id()) { |
| return &entry_point; |
| } |
| } |
| return nullptr; |
| } |
| |
| // If the function has an entry point, converts it to a conditional one |
| void ConvertEPToConditional(Module* module, const Function& fn, |
| uint32_t spec_const_id) { |
| for (const auto& ep_inst : module->entry_points()) { |
| if (ep_inst.opcode() == spv::Op::OpEntryPoint) { |
| auto* entry_point = FindEntryPoint(fn.DefInst()); |
| if (entry_point != nullptr) { |
| std::vector<opt::Operand> old_operands; |
| for (auto operand : *entry_point) { |
| old_operands.push_back(operand); |
| } |
| entry_point->ToNop(); |
| entry_point->SetOpcode(spv::Op::OpConditionalEntryPointINTEL); |
| entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}}); |
| for (auto old_operand : old_operands) { |
| entry_point->AddOperand(old_operand); |
| } |
| } |
| } |
| } |
| } |
| |
| // Finds ID of a bool type (returns 0 if not found) |
| uint32_t FindIdOfBoolType(const Module* const mod) { |
| return mod->context()->get_type_mgr()->GetBoolTypeId(); |
| } |
| |
| // Combines IDs using OpSpecConstantOp with the operation defined by cmp_op. |
| // |
| // Returns the ID of the final result. If there are no IDs, returns 0. If there |
| // is one ID, does not generate any instructions and returns the ID. |
| uint32_t CombineIds(IRContext* const context, const std::vector<uint32_t>& ids, |
| spv::Op cmp_op) { |
| if (ids.empty()) { |
| return 0; |
| } else if (ids.size() == 1) { |
| return ids[0]; |
| } else { |
| uint32_t bool_id = FindIdOfBoolType(context->module()); |
| assert(bool_id != 0); |
| |
| uint32_t prev_spec_const_id = ids[0]; |
| |
| for (size_t i = 1; i < ids.size(); ++i) { |
| const uint32_t id = ids[i]; |
| const uint32_t spec_const_op_id = context->TakeNextId(); |
| |
| auto inst = std::make_unique<Instruction>( |
| context, spv::Op::OpSpecConstantOp, bool_id, spec_const_op_id, |
| std::initializer_list<opt::Operand>{ |
| {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {(uint32_t)(cmp_op)}}, |
| {SPV_OPERAND_TYPE_ID, {prev_spec_const_id}}, |
| {SPV_OPERAND_TYPE_ID, {id}}}); |
| context->module()->AddType(std::move(inst)); |
| |
| prev_spec_const_id = spec_const_op_id; |
| } |
| |
| return prev_spec_const_id; |
| } |
| } |
| |
| // Returns whether instruction can be shared between variant modules and |
| // combined using spec constants (such as conditional capabilities). |
| bool CanBeFnVarCombined(const Instruction* inst) { |
| const spv::Op opcode = inst->opcode(); |
| |
| if ((opcode != spv::Op::OpExtInstImport) && |
| (opcode != spv::Op::OpCapability) && (opcode != spv::Op::OpExtension) && |
| !spvOpcodeGeneratesType(opcode)) { |
| return false; |
| } |
| |
| if ((opcode == spv::Op::OpCapability) && |
| ((inst->GetSingleWordOperand(0) == |
| static_cast<uint32_t>(spv::Capability::FunctionVariantsINTEL)) || |
| (inst->GetSingleWordOperand(0) == |
| static_cast<uint32_t>(spv::Capability::SpecConditionalINTEL)))) { |
| // Always enabled |
| return false; |
| } |
| |
| if ((opcode == spv::Op::OpExtension) && |
| (inst->GetOperand(0).AsString() == FNVAR_EXT_NAME)) { |
| // Always enabled |
| return false; |
| } |
| |
| return true; |
| } |
| |
| // Calculates hash of an instruction. |
| // |
| // Applicable only to instructions that can be combined (ie. with |
| // CanBeFnVarCombined being true) and from those, hash can be only computed for |
| // selected instructions. Computing hash from other instruction is unsupported. |
| size_t HashInst(const Instruction* inst) { |
| if (CanBeFnVarCombined(inst)) { |
| if (spvOpcodeGeneratesType(inst->opcode())) { |
| const Type* t = |
| inst->context()->get_type_mgr()->GetType(inst->result_id()); |
| assert(t != nullptr); |
| return t->HashValue(); |
| } |
| |
| if (inst->opcode() == spv::Op::OpExtension) { |
| const auto name = inst->GetOperand(0).AsString(); |
| return std::hash<std::string>()(name); |
| } |
| |
| if (inst->opcode() == spv::Op::OpCapability) { |
| const auto cap = inst->GetSingleWordOperand(0); |
| return std::hash<uint32_t>()(cap); |
| } |
| |
| if (inst->opcode() == spv::Op::OpExtInstImport) { |
| const auto name = inst->GetOperand(1).AsString(); |
| return std::hash<std::string>()(name); |
| } |
| } |
| |
| assert(false && "Unsupported instruction hash"); |
| return std::hash<const Instruction*>()(inst); |
| } |
| |
| std::string GetFnName(const Instruction& fn_inst) { |
| // Check entry point |
| const auto* ep_inst = FindEntryPoint(fn_inst); |
| if (ep_inst != nullptr) { |
| const int name_i = |
| ep_inst->opcode() == spv::Op::OpConditionalEntryPointINTEL ? 3 : 2; |
| return ep_inst->GetOperand(name_i).AsString(); |
| } |
| |
| // Check name of export linkage attribute decoration |
| const auto* decor_mgr = fn_inst.context()->get_decoration_mgr(); |
| for (const auto* inst : |
| decor_mgr->GetDecorationsFor(fn_inst.result_id(), true)) { |
| const auto decoration = inst->GetOperand(1); |
| if ((decoration.type == SPV_OPERAND_TYPE_DECORATION) && |
| (decoration.words.size() == 1) && |
| (decoration.words[0] == |
| static_cast<uint32_t>(spv::Decoration::LinkageAttributes))) { |
| const auto linkage = inst->GetOperand(3); |
| if ((linkage.type == SPV_OPERAND_TYPE_LINKAGE_TYPE) && |
| (linkage.words.size() == 1) && |
| (linkage.words[0] == |
| static_cast<uint32_t>(spv::LinkageType::Export))) { |
| // decorates fn with LinkageAttribute and Export linkage type -> get the |
| // name |
| return inst->GetOperand(2).AsString(); |
| } |
| } |
| } |
| |
| return ""; |
| } |
| |
| uint32_t FindSpecConstByName(const Module* mod, std::string name) { |
| for (const auto* const_inst : mod->context()->GetConstants()) { |
| if (opt::IsSpecConstantInst(const_inst->opcode())) { |
| const auto id = const_inst->result_id(); |
| for (const auto& name_inst : mod->debugs2()) { |
| if ((name_inst.opcode() == spv::Op::OpName) && |
| (name_inst.GetOperand(0).AsId() == id) && |
| (name_inst.GetOperand(1).AsString() == name)) { |
| return id; |
| } |
| } |
| } |
| } |
| return 0; |
| } |
| |
| uint32_t CombineVariantDefs(const std::vector<VariantDef>& variant_defs, |
| const std::vector<size_t> var_ids, |
| IRContext* context, |
| std::map<std::vector<size_t>, uint32_t>& cache) { |
| assert(var_ids.size() <= variant_defs.size()); |
| uint32_t spec_const_comb_id = 0; |
| if (var_ids.size() != variant_defs.size()) { |
| // if not used by all variants |
| if (cache.find(var_ids) == cache.end()) { |
| // cache variant combinations |
| std::vector<uint32_t> spec_const_ids; |
| for (const auto& var_id : var_ids) { |
| const auto var_name = variant_defs[var_id].GetName(); |
| const auto var_spec_id = |
| FindSpecConstByName(context->module(), var_name); |
| spec_const_ids.push_back(var_spec_id); |
| } |
| spec_const_comb_id = |
| CombineIds(context, spec_const_ids, spv::Op::OpLogicalOr); |
| assert(spec_const_comb_id != 0); |
| cache.insert({var_ids, spec_const_comb_id}); |
| } else { |
| spec_const_comb_id = cache[var_ids]; |
| } |
| } |
| return spec_const_comb_id; |
| } |
| |
| bool strToInt(std::string s, uint32_t* x) { |
| for (const char& c : s) { |
| if (c < '0' || c > '9') { |
| return false; |
| } |
| } |
| if (!(std::stringstream(s) >> *x)) { |
| return false; |
| } |
| return true; |
| } |
| |
| } // anonymous namespace |
| |
| bool VariantDefs::ProcessFnVar(const LinkerOptions& options, |
| const std::vector<Module*>& modules) { |
| assert(variant_defs_.empty()); |
| assert(modules.size() == options.GetInFiles().size()); |
| |
| for (size_t i = 0; i < modules.size(); ++i) { |
| const auto* feat_mgr = modules[i]->context()->get_feature_mgr(); |
| if ((feat_mgr->HasCapability(spv::Capability::FunctionVariantsINTEL)) || |
| (feat_mgr->HasCapability(spv::Capability::SpecConditionalINTEL)) || |
| (feat_mgr->HasExtension(kSPV_INTEL_function_variants))) { |
| // In principle, it can be done but it's complicated due to having to |
| // combine the existing conditionals with the new ones. For example, |
| // conditional capabilities would need to become "doubly-conditional". |
| err_ << "Creating multitarget modules from multitarget modules is not " |
| "supported. Offending file: " |
| << options.GetInFiles()[i]; |
| return false; |
| } |
| } |
| |
| std::vector<std::vector<std::string>> target_rows; |
| std::vector<std::vector<std::string>> architecture_rows; |
| |
| if (!options.GetFnVarTargetsCsv().empty()) { |
| const std::vector<std::string> tgt_cols = {"module", "target", "features"}; |
| if (!ParseCsv(options.GetFnVarTargetsCsv(), tgt_cols, err_, target_rows)) { |
| return false; |
| } |
| } |
| |
| if (!options.GetFnVarArchitecturesCsv().empty()) { |
| const std::vector<std::string> arch_cols = {"module", "category", "family", |
| "op", "architecture"}; |
| if (!ParseCsv(options.GetFnVarArchitecturesCsv(), arch_cols, err_, |
| architecture_rows)) { |
| return false; |
| } |
| } |
| |
| // check that all modules defined in the CSV exist |
| |
| for (const auto& tgt_vals : target_rows) { |
| bool found = false; |
| for (const auto& in_file : options.GetInFiles()) { |
| if (tgt_vals[0] == in_file) { |
| found = true; |
| } |
| } |
| if (!found) { |
| err_ << "Module '" << tgt_vals[0] |
| << "' found in targets CSV not passed to the CLI."; |
| return false; |
| } |
| } |
| |
| for (const auto& arch_vals : architecture_rows) { |
| bool found = false; |
| for (const auto& in_file : options.GetInFiles()) { |
| if (arch_vals[0] == in_file) { |
| found = true; |
| } |
| } |
| if (!found) { |
| err_ << "Module '" << arch_vals[0] |
| << "' found in architectures CSV not passed to the CLI."; |
| return false; |
| } |
| } |
| |
| // create per-module variant defs |
| |
| for (size_t i = 0; i < modules.size(); ++i) { |
| // first module passed to the CLI is considered the base module |
| bool is_base = i == 0; |
| const auto name = options.GetInFiles()[i]; |
| auto variant_def = VariantDef(is_base, name, modules[i]); |
| |
| for (const auto& arch_row : architecture_rows) { |
| const auto row_name = arch_row[0]; |
| if (row_name == name) { |
| uint32_t category, family, op, architecture; |
| |
| if (!strToInt(arch_row[1], &category)) { |
| err_ << "Error converting " << arch_row[1] |
| << " to architecture category."; |
| return false; |
| } |
| if (!strToInt(arch_row[2], &family)) { |
| err_ << "Error converting " << arch_row[2] |
| << " to architecture family."; |
| return false; |
| } |
| if (!strToInt(arch_row[3], &op)) { |
| err_ << "Error converting " << arch_row[3] << " to architecture op."; |
| return false; |
| } |
| if (!strToInt(arch_row[4], &architecture)) { |
| err_ << "Error converting " << arch_row[4] << " to architecture."; |
| return false; |
| } |
| |
| variant_def.AddArchDef(category, family, op, architecture); |
| } |
| } |
| |
| for (const auto& tgt_row : target_rows) { |
| const auto row_name = tgt_row[0]; |
| if (row_name == name) { |
| uint32_t target; |
| std::vector<uint32_t> features; |
| |
| if (!strToInt(tgt_row[1], &target)) { |
| err_ << "Error converting " << tgt_row[1] << " to target."; |
| return false; |
| } |
| |
| // get features as FEAT_SEP-delimited integers |
| |
| std::stringstream feat_stream(tgt_row[2]); |
| std::string feat; |
| while (std::getline(feat_stream, feat, FEAT_SEP)) { |
| uint32_t ufeat; |
| // if (!(std::stringstream(feat) >> ufeat)) { |
| if (!strToInt(feat, &ufeat)) { |
| err_ << "Error converting " << feat << " in " << tgt_row[2] |
| << " to target feature."; |
| return false; |
| } |
| features.push_back(ufeat); |
| } |
| |
| variant_def.AddTgtDef(target, features); |
| } |
| } |
| |
| if (options.GetHasFnVarCapabilities()) { |
| variant_def.InferCapabilities(); |
| } |
| |
| variant_defs_.push_back(variant_def); |
| } |
| |
| return true; |
| } |
| |
| bool VariantDefs::ProcessVariantDefs() { |
| EnsureBoolType(); |
| CollectVarInsts(); |
| if (!GenerateFnVarConstants()) { |
| return false; |
| } |
| CollectBaseFnCalls(); |
| return true; |
| } |
| |
| void VariantDefs::GenerateHeader(IRContext* linked_context) { |
| linked_context->AddCapability(spv::Capability::SpecConditionalINTEL); |
| linked_context->AddCapability(spv::Capability::FunctionVariantsINTEL); |
| linked_context->AddExtension(std::string(FNVAR_EXT_NAME)); |
| |
| // Specifies used registry version |
| auto inst = |
| std::make_unique<Instruction>(linked_context, spv::Op::OpModuleProcessed); |
| std::stringstream line; |
| line << "SPV_INTEL_function_variants registry version " |
| << FNVAR_REGISTRY_VERSION; |
| inst->AddOperand( |
| {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(line.str())}); |
| linked_context->AddDebug3Inst(std::move(inst)); |
| } |
| |
| void VariantDefs::CombineVariantInstructions(IRContext* linked_context) { |
| CombineBaseFnCalls(linked_context); |
| CombineInstructions(linked_context); |
| } |
| |
| void VariantDefs::EnsureBoolType() { |
| for (auto& variant_def : variant_defs_) { |
| Module* module = variant_def.GetModule(); |
| IRContext* context = module->context(); |
| |
| uint32_t bool_id = FindIdOfBoolType(module); |
| if (bool_id == 0) { |
| bool_id = context->TakeNextId(); |
| auto variant_bool = std::make_unique<Instruction>( |
| context, spv::Op::OpTypeBool, 0, bool_id, |
| std::initializer_list<opt::Operand>{}); |
| module->AddType(std::move(variant_bool)); |
| } |
| } |
| } |
| |
| void VariantDefs::CollectVarInsts() { |
| for (size_t i = 0; i < variant_defs_.size(); ++i) { |
| const auto variant_def = variant_defs_[i]; |
| const auto* var_mod = variant_def.GetModule(); |
| |
| var_mod->ForEachInst([this, &i](const Instruction* inst) { |
| if (CanBeFnVarCombined(inst)) { |
| const size_t inst_hash = HashInst(inst); |
| if (fnvar_usage_.find(inst_hash) == fnvar_usage_.end()) { |
| fnvar_usage_.insert({inst_hash, {i}}); |
| } else { |
| assert(fnvar_usage_[inst_hash].size() < variant_defs_.size()); |
| fnvar_usage_[inst_hash].push_back(i); |
| } |
| } |
| }); |
| } |
| } |
| |
| bool VariantDefs::GenerateFnVarConstants() { |
| assert(variant_defs_.size() > 0); |
| assert(variant_defs_[0].IsBase()); |
| |
| if (variant_defs_.size() == 1) { |
| return true; |
| } |
| |
| for (auto& variant_def : variant_defs_) { |
| Module* module = variant_def.GetModule(); |
| IRContext* context = module->context(); |
| |
| uint32_t bool_id = FindIdOfBoolType(module); |
| if (bool_id == 0) { |
| // add a bool type if not present already |
| bool_id = context->TakeNextId(); |
| auto variant_bool = std::make_unique<Instruction>( |
| context, spv::Op::OpTypeBool, 0, bool_id, |
| std::initializer_list<opt::Operand>{}); |
| module->AddType(std::move(variant_bool)); |
| } |
| |
| // Spec constant architecture and target |
| |
| std::vector<uint32_t> spec_const_arch_ids; |
| for (const auto& arch_def : variant_def.GetArchDefs()) { |
| const uint32_t spec_const_arch_id = context->TakeNextId(); |
| spec_const_arch_ids.push_back(spec_const_arch_id); |
| |
| auto inst = std::make_unique<Instruction>( |
| context, spv::Op::OpSpecConstantArchitectureINTEL, bool_id, |
| spec_const_arch_id, |
| std::initializer_list<opt::Operand>{ |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.category}}, |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.family}}, |
| // Using spec op opcode here expects then next operand to be |
| // a type: |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.op}}, |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.architecture}}, |
| }); |
| module->AddType(std::move(inst)); |
| } |
| |
| std::vector<uint32_t> spec_const_tgt_ids; |
| for (const auto& tgt_def : variant_def.GetTgtDefs()) { |
| const uint32_t spec_const_tgt_id = context->TakeNextId(); |
| spec_const_tgt_ids.push_back(spec_const_tgt_id); |
| |
| auto inst = std::make_unique<Instruction>( |
| context, spv::Op::OpSpecConstantTargetINTEL, bool_id, |
| spec_const_tgt_id, |
| std::initializer_list<opt::Operand>{ |
| {SPV_OPERAND_TYPE_LITERAL_INTEGER, {tgt_def.target}}, |
| }); |
| for (const auto& feat : tgt_def.features) { |
| inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {feat}}); |
| } |
| module->AddType(std::move(inst)); |
| } |
| |
| std::vector<uint32_t> spec_const_ids; |
| |
| // Spec constant capabilities |
| |
| const auto variant_capabilities = variant_def.GetCapabilities(); |
| if (!variant_capabilities.empty()) { |
| const uint32_t spec_const_cap_id = context->TakeNextId(); |
| auto inst = std::make_unique<Instruction>( |
| context, spv::Op::OpSpecConstantCapabilitiesINTEL, bool_id, |
| spec_const_cap_id, std::initializer_list<opt::Operand>{}); |
| for (const auto& cap : variant_capabilities) { |
| inst->AddOperand({SPV_OPERAND_TYPE_CAPABILITY, {uint32_t(cap)}}); |
| } |
| module->AddType(std::move(inst)); |
| spec_const_ids.push_back(spec_const_cap_id); |
| } |
| |
| // Combine architectures such that, for the same module, those with the same |
| // category and family are combined with AND and different cat/fam are |
| // combined with OR. |
| // This lets you create combinations like "architecture between X and Y". |
| |
| // map (category, family) -> IDs |
| std::map<std::pair<uint32_t, uint32_t>, std::vector<uint32_t>> arch_map_and; |
| |
| for (size_t i = 0; i < spec_const_arch_ids.size(); ++i) { |
| const auto& arch_def = variant_def.GetArchDefs()[i]; |
| const auto id = spec_const_arch_ids[i]; |
| const auto key = std::make_pair(arch_def.category, arch_def.family); |
| if (arch_map_and.find(key) == arch_map_and.end()) { |
| arch_map_and[key] = {id}; |
| } else { |
| arch_map_and[key].push_back(id); |
| } |
| } |
| |
| std::vector<uint32_t> arch_ids_or; |
| for (const auto& it : arch_map_and) { |
| const auto id = CombineIds(context, it.second, spv::Op::OpLogicalAnd); |
| if (id > 0) { |
| arch_ids_or.push_back(id); |
| } |
| } |
| |
| const uint32_t spec_const_arch_id = |
| CombineIds(context, arch_ids_or, spv::Op::OpLogicalOr); |
| if (spec_const_arch_id > 0) { |
| spec_const_ids.push_back(spec_const_arch_id); |
| } |
| |
| const uint32_t spec_const_tgt_id = |
| CombineIds(context, spec_const_tgt_ids, spv::Op::OpLogicalOr); |
| if (spec_const_tgt_id > 0) { |
| spec_const_ids.push_back(spec_const_tgt_id); |
| } |
| |
| uint32_t combined_spec_const_id = |
| CombineIds(context, spec_const_ids, spv::Op::OpLogicalAnd); |
| if (combined_spec_const_id == 0) { |
| // If the variant module has no constraints, use SpecConstantTrue |
| combined_spec_const_id = context->TakeNextId(); |
| auto inst = std::make_unique<Instruction>( |
| context, spv::Op::OpSpecConstantTrue, bool_id, combined_spec_const_id, |
| std::initializer_list<opt::Operand>{}); |
| context->module()->AddType(std::move(inst)); |
| } |
| assert(combined_spec_const_id != 0); |
| |
| // Add a name the combined boolean ID so we can look it up after the IDs are |
| // shifted |
| auto inst = std::make_unique<Instruction>(context, spv::Op::OpName); |
| inst->AddOperand({SPV_OPERAND_TYPE_ID, {combined_spec_const_id}}); |
| std::vector<uint32_t> str_words; |
| utils::AppendToVector(variant_def.GetName(), &str_words); |
| inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_STRING, {str_words}}); |
| module->AddDebug2Inst(std::move(inst)); |
| |
| // Annotate all instructions in the types section (eg. constants) with |
| // ConditionalINTEL, unless they can be shared between variant_defs_ (eg. |
| // types). Spec constants are excluded because they might have been |
| // generated by this extension. |
| for (const auto& type_inst : module->types_values()) { |
| if (!CanBeFnVarCombined(&type_inst) && |
| !spvOpcodeIsSpecConstant(type_inst.opcode())) { |
| DecorateConditional(context, type_inst.result_id(), |
| combined_spec_const_id); |
| } |
| } |
| } |
| |
| // Annotate functions with ConditionalINTEL |
| |
| for (const auto& base_fn : *variant_defs_[0].GetModule()) { |
| // For each function of the base module, find matching variant functions in |
| // other modules |
| |
| auto base_fn_name = GetFnName(base_fn.DefInst()); |
| if (base_fn_name.empty()) { |
| err_ << "Could not find name of a function " << base_fn.result_id() |
| << " in a base module " << variant_defs_[0].GetName() |
| << ". To be usable by SPV_INTEL_function_variants, a function " |
| "must either have an entry point or an export " |
| "LinkAttribute decoration."; |
| return false; |
| } |
| |
| bool base_fn_needs_conditional = false; |
| for (size_t i = 1; i < variant_defs_.size(); ++i) { |
| const auto& variant_def = variant_defs_[i]; |
| auto* variant_module = variant_def.GetModule(); |
| auto* variant_context = variant_module->context(); |
| |
| for (const auto& var_fn : *variant_module) { |
| auto var_fn_name = GetFnName(var_fn.DefInst()); |
| if (var_fn_name.empty()) { |
| err_ << "Could not find name of a function " << var_fn.result_id() |
| << " in a base module " << variant_def.GetName() |
| << ". To be usable by SPV_INTEL_function_variants, a function " |
| "must either have an entry point or an export " |
| "LinkAttribute decoration."; |
| return false; |
| } |
| |
| if (base_fn_name == var_fn_name) { |
| base_fn_needs_conditional = true; |
| } |
| |
| // each function in a variant module gets a ConditionalINTEL decoration |
| |
| uint32_t spec_const_id = |
| FindSpecConstByName(variant_module, variant_def.GetName()); |
| assert(spec_const_id != 0); |
| DecorateConditional(variant_context, var_fn.result_id(), spec_const_id); |
| ConvertEPToConditional(variant_module, var_fn, spec_const_id); |
| } |
| } |
| |
| if (base_fn_needs_conditional) { |
| // only a base function that has a variant in another module gets a |
| // ConditionalINTEL decoration, the others are common for all |
| // variant_defs_ |
| auto* base_module = variant_defs_[0].GetModule(); |
| auto* base_context = base_module->context(); |
| uint32_t spec_const_id = |
| FindSpecConstByName(base_module, variant_defs_[0].GetName()); |
| assert(spec_const_id != 0); |
| DecorateConditional(base_context, base_fn.result_id(), spec_const_id); |
| ConvertEPToConditional(base_module, base_fn, spec_const_id); |
| } |
| } |
| |
| return true; |
| } |
| |
| void VariantDefs::CollectBaseFnCalls() { |
| auto* base_mod = variant_defs_[0].GetModule(); |
| assert(variant_defs_[0].IsBase()); |
| const auto* base_def_use_mgr = base_mod->context()->get_def_use_mgr(); |
| |
| base_mod->ForEachInst([this, &base_def_use_mgr](const Instruction* inst) { |
| if (inst->opcode() == spv::Op::OpFunctionCall) { |
| // For each function call in base module, get the function name |
| const auto fn_id = inst->GetOperand(2).AsId(); |
| const auto* called_fn_inst = base_def_use_mgr->GetDef(fn_id); |
| assert(called_fn_inst != nullptr); |
| const auto called_fn_name = GetFnName(*called_fn_inst); |
| assert(!called_fn_name.empty()); |
| |
| std::vector<std::pair<std::string, const opt::Function*>> called_fns; |
| for (size_t i = 1; i < variant_defs_.size(); ++i) { |
| // ... then see in which variant the called function was defined |
| const auto& variant_def = variant_defs_[i]; |
| assert(!variant_def.IsBase()); |
| |
| for (const auto& fn : *variant_def.GetModule()) { |
| const auto fn_name = GetFnName(fn.DefInst()); |
| if (fn_name == called_fn_name) { |
| called_fns.push_back(std::make_pair(variant_def.GetName(), &fn)); |
| } |
| } |
| } |
| |
| if (!called_fns.empty()) { |
| base_fn_calls_[inst->result_id()] = called_fns; |
| } |
| } |
| }); |
| } |
| |
| void VariantDefs::CombineBaseFnCalls(IRContext* linked_context) { |
| for (auto kv : base_fn_calls_) { |
| const uint32_t call_id = kv.first; |
| const auto called_fns = kv.second; |
| |
| if (called_fns.empty()) { |
| return; |
| } |
| |
| opt::BasicBlock* fn_call_bb = linked_context->get_instr_block(call_id); |
| |
| Instruction* found_call_inst = nullptr; |
| auto bb_iter = fn_call_bb->begin(); |
| while (bb_iter != fn_call_bb->end() && found_call_inst == nullptr) { |
| if (bb_iter->HasResultId() && bb_iter->result_id() == call_id) { |
| found_call_inst = &*bb_iter; |
| } |
| ++bb_iter; |
| } |
| |
| if (found_call_inst == nullptr) { |
| return; |
| } |
| |
| const auto base_spec_const_id = FindSpecConstByName( |
| variant_defs_[0].GetModule(), variant_defs_[0].GetName()); |
| const auto base_type_op = found_call_inst->context() |
| ->get_def_use_mgr() |
| ->GetDef(found_call_inst->type_id()) |
| ->opcode(); |
| const auto base_call_id = found_call_inst->result_id(); |
| |
| // decorate the base call with ConditionalINTEL |
| DecorateConditional(linked_context, base_call_id, base_spec_const_id); |
| |
| // Add OpFunctionCall for each variant |
| Instruction* last_inst = found_call_inst; |
| std::vector<std::pair<uint32_t, uint32_t>> var_call_ids; |
| for (const auto& kv2 : called_fns) { |
| const std::string var_name = kv2.first; |
| const opt::Function* fn = kv2.second; |
| const uint32_t spec_const_id = |
| FindSpecConstByName(linked_context->module(), var_name); |
| assert(spec_const_id != 0); |
| const uint32_t var_call_id = linked_context->TakeNextId(); |
| var_call_ids.push_back(std::make_pair(spec_const_id, var_call_id)); |
| |
| auto* var_call_inst = found_call_inst->Clone(linked_context); |
| var_call_inst->SetResultId(var_call_id); |
| var_call_inst->SetOperand(2, {fn->result_id()}); |
| var_call_inst->InsertAfter(last_inst); |
| linked_context->set_instr_block(var_call_inst, fn_call_bb); |
| last_inst = var_call_inst; |
| |
| // decorate the variant call with ConditionalINTEL |
| DecorateConditional(linked_context, var_call_id, spec_const_id); |
| } |
| |
| if (base_type_op != spv::Op::OpTypeVoid) { |
| // Add OpConditionalCopyObjectINTEL combining the function calls |
| const uint32_t result_id = linked_context->TakeNextId(); |
| auto conditional_copy_inst = new Instruction( |
| linked_context, spv::Op::OpConditionalCopyObjectINTEL, |
| found_call_inst->type_id(), result_id, |
| {{SPV_OPERAND_TYPE_ID, {base_spec_const_id}}, |
| {SPV_OPERAND_TYPE_ID, {found_call_inst->result_id()}}}); |
| |
| for (const auto& kv3 : var_call_ids) { |
| const auto spec_const_id = kv3.first; |
| const auto var_call_id = kv3.second; |
| conditional_copy_inst->AddOperand( |
| {SPV_OPERAND_TYPE_ID, {spec_const_id}}); |
| conditional_copy_inst->AddOperand({SPV_OPERAND_TYPE_ID, {var_call_id}}); |
| } |
| conditional_copy_inst->InsertAfter(last_inst); |
| linked_context->set_instr_block(conditional_copy_inst, fn_call_bb); |
| last_inst = conditional_copy_inst; |
| |
| // In all remaining instructions within the basic block, replace all |
| // usages of the base call ID with the result of |
| // OpConditionalCopyObjectINTEL |
| do { |
| last_inst = last_inst->NextNode(); |
| last_inst->ForEachInId([base_call_id, result_id](uint32_t* id) { |
| if (*id == base_call_id) { |
| *id = result_id; |
| } |
| }); |
| } while (last_inst != nullptr && *last_inst != *fn_call_bb->tail()); |
| } |
| } |
| |
| // Combine spec consts for the base module (base module is activated if all |
| // variant defs are inactive AND the base module constraints are satisfied) |
| |
| std::vector<uint32_t> var_spec_const_ids; |
| for (const auto& variant_def : variant_defs_) { |
| if (variant_def.IsBase()) { |
| continue; |
| } |
| |
| const auto id = |
| FindSpecConstByName(linked_context->module(), variant_def.GetName()); |
| assert(id != 0); |
| var_spec_const_ids.push_back(id); |
| } |
| const uint32_t base_or_id = |
| CombineIds(linked_context, var_spec_const_ids, spv::Op::OpLogicalOr); |
| |
| if (base_or_id != 0) { |
| const uint32_t bool_id = FindIdOfBoolType(linked_context->module()); |
| assert(bool_id != 0); |
| |
| const uint32_t base_not_id = linked_context->TakeNextId(); |
| auto spec_const_op_inst = std::make_unique<Instruction>( |
| linked_context, spv::Op::OpSpecConstantOp, bool_id, base_not_id, |
| std::initializer_list<opt::Operand>{ |
| {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, |
| {(uint32_t)(spv::Op::OpLogicalNot)}}, |
| {SPV_OPERAND_TYPE_ID, {base_or_id}}}); |
| linked_context->module()->AddType(std::move(spec_const_op_inst)); |
| |
| // Update any ConditionalINTEL annotations, names and entry points |
| // referencing the old spec const ID to use the new one |
| |
| const uint32_t old_base_spec_const_id = FindSpecConstByName( |
| linked_context->module(), variant_defs_[0].GetName()); |
| assert(old_base_spec_const_id != 0); |
| const uint32_t base_spec_const_id = |
| CombineIds(linked_context, {old_base_spec_const_id, base_not_id}, |
| spv::Op::OpLogicalAnd); |
| |
| for (auto& annot_inst : linked_context->module()->annotations()) { |
| if ((annot_inst.GetSingleWordOperand(1) == |
| uint32_t(spv::Decoration::ConditionalINTEL)) && |
| (annot_inst.GetOperand(2).AsId() == old_base_spec_const_id)) { |
| annot_inst.SetOperand(2, {base_spec_const_id}); |
| } |
| } |
| |
| for (auto& name_inst : linked_context->module()->debugs2()) { |
| if ((name_inst.opcode() == spv::Op::OpName) && |
| (name_inst.GetOperand(0).AsId() == old_base_spec_const_id)) { |
| name_inst.SetOperand(0, {base_spec_const_id}); |
| } |
| } |
| |
| for (auto& ep_inst : linked_context->module()->entry_points()) { |
| if ((ep_inst.opcode() == spv::Op::OpConditionalEntryPointINTEL) && |
| (ep_inst.GetOperand(0).AsId() == old_base_spec_const_id)) { |
| ep_inst.SetOperand(0, {base_spec_const_id}); |
| } |
| } |
| |
| linked_context->module()->ForEachInst( |
| [old_base_spec_const_id, base_spec_const_id](Instruction* inst) { |
| if (inst->opcode() == spv::Op::OpConditionalCopyObjectINTEL) { |
| inst->ForEachInId( |
| [old_base_spec_const_id, base_spec_const_id](uint32_t* id) { |
| if (*id == old_base_spec_const_id) { |
| *id = base_spec_const_id; |
| } |
| }); |
| } |
| }); |
| } |
| } |
| |
| void VariantDefs::CombineInstructions(IRContext* linked_context) { |
| // cache for existing variant ID combinations |
| std::map<std::vector<size_t>, uint32_t> spec_const_comb_ids; |
| |
| linked_context->module()->ForEachInst( |
| [this, &linked_context, &spec_const_comb_ids](Instruction* inst) { |
| if (!CanBeFnVarCombined(inst)) { |
| return; |
| } |
| |
| const size_t inst_hash = HashInst(inst); |
| if (fnvar_usage_.find(inst_hash) != fnvar_usage_.end()) { |
| const std::vector<size_t> var_ids = fnvar_usage_[inst_hash]; |
| const uint32_t spec_const_comb_id = CombineVariantDefs( |
| variant_defs_, var_ids, linked_context, spec_const_comb_ids); |
| if (spec_const_comb_id != 0) { |
| if (inst->HasResultId()) { |
| DecorateConditional(linked_context, inst->result_id(), |
| spec_const_comb_id); |
| } else if (inst->opcode() == spv::Op::OpCapability) { |
| const uint32_t cap = inst->GetSingleWordOperand(0); |
| inst->SetOpcode(spv::Op::OpConditionalCapabilityINTEL); |
| inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}}, |
| {SPV_OPERAND_TYPE_CAPABILITY, {cap}}}); |
| } else if (inst->opcode() == spv::Op::OpExtension) { |
| const std::string ext_name = inst->GetOperand(0).AsString(); |
| inst->SetOpcode(spv::Op::OpConditionalExtensionINTEL); |
| inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}}, |
| {SPV_OPERAND_TYPE_LITERAL_STRING, |
| {utils::MakeVector(ext_name)}}}); |
| } else { |
| assert(false && "Unsupported"); |
| } |
| } |
| } |
| }); |
| } |
| |
| } // namespace spvtools |