blob: 352e7ea339d3c2543d19c862ac59e7d3cd960096 [file] [log] [blame] [edit]
// Copyright 2025 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fnvar.h"
#include <initializer_list>
#include <memory>
#include <sstream>
#include "source/opt/instruction.h"
namespace spvtools {
using opt::Function;
using opt::Instruction;
using opt::analysis::Type;
namespace {
// Helper functions
// Parses a CSV source string for the purpose of this extension.
//
// Required columns must be known in advance and supplied as the required_cols
// argument -- this is used for error checking. Values are assumed to be
// separated by CSV_SEP. The input source string is assumed to be the output of
// io::ReadTextFile and no other validation, apart from the CSV parsing, is
// performed.
//
// Returns true on success, false on error (with error message stored in
// err_msg).
bool ParseCsv(const std::string& source,
const std::vector<std::string>& required_cols,
std::stringstream& err_msg,
std::vector<std::vector<std::string>>& result) {
std::stringstream fn_variants_csv_stream(source);
std::string line;
std::vector<std::string> columns;
constexpr char CSV_SEP = ',';
bool first_line = true;
while (std::getline(fn_variants_csv_stream, line, '\n')) {
if (line.empty()) {
continue;
}
std::vector<std::string> vals;
std::string val;
std::stringstream line_stream(line);
auto* vec = first_line ? &columns : &vals;
while (std::getline(line_stream, val, CSV_SEP)) {
vec->push_back(val);
}
if (!line_stream && val.empty()) {
vec->push_back("");
}
if (!first_line) {
if (vals.size() != columns.size()) {
err_msg << "Number of values does not match the number of columns. "
"Offending line:\n"
<< line;
return false;
}
result.push_back(vals);
}
first_line = false;
}
// check if required columns match actual columns (ordering matters)
if (columns.size() != required_cols.size()) {
err_msg << "Invalid number of CSV columns: " << columns.size()
<< ", expected " << required_cols.size() << ".";
return false;
}
for (size_t i = 0; i < columns.size(); ++i) {
if (columns[i] != required_cols[i]) {
err_msg << "Invalid name of column " << i + 1 << ". Expected '"
<< required_cols[i] << "', got '" << columns[i] << "'.";
return false;
}
}
return true;
}
// Annotate ID with ConditionalINTEL decoration
void DecorateConditional(IRContext* context, uint32_t id_to_decorate,
uint32_t spec_const_id) {
auto decor_instr =
std::make_unique<Instruction>(context, spv::Op::OpDecorate);
decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {id_to_decorate}});
decor_instr->AddOperand({SPV_OPERAND_TYPE_DECORATION,
{uint32_t(spv::Decoration::ConditionalINTEL)}});
decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
context->module()->AddAnnotationInst(std::move(decor_instr));
}
// Finds entry point corresponding to a function
//
// Returns null if not found, otherwise returns pointer to the EP Instruction.
Instruction* FindEntryPoint(const Instruction& fn_inst) {
auto* mod = fn_inst.context()->module();
for (auto& entry_point : mod->entry_points()) {
const int ep_i =
entry_point.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1;
if (entry_point.GetOperand(ep_i).AsId() == fn_inst.result_id()) {
return &entry_point;
}
}
return nullptr;
}
// If the function has an entry point, converts it to a conditional one
void ConvertEPToConditional(Module* module, const Function& fn,
uint32_t spec_const_id) {
for (const auto& ep_inst : module->entry_points()) {
if (ep_inst.opcode() == spv::Op::OpEntryPoint) {
auto* entry_point = FindEntryPoint(fn.DefInst());
if (entry_point != nullptr) {
std::vector<opt::Operand> old_operands;
for (auto operand : *entry_point) {
old_operands.push_back(operand);
}
entry_point->ToNop();
entry_point->SetOpcode(spv::Op::OpConditionalEntryPointINTEL);
entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
for (auto old_operand : old_operands) {
entry_point->AddOperand(old_operand);
}
}
}
}
}
// Finds ID of a bool type (returns 0 if not found)
uint32_t FindIdOfBoolType(const Module* const mod) {
return mod->context()->get_type_mgr()->GetBoolTypeId();
}
// Combines IDs using OpSpecConstantOp with the operation defined by cmp_op.
//
// Returns the ID of the final result. If there are no IDs, returns 0. If there
// is one ID, does not generate any instructions and returns the ID.
uint32_t CombineIds(IRContext* const context, const std::vector<uint32_t>& ids,
spv::Op cmp_op) {
if (ids.empty()) {
return 0;
} else if (ids.size() == 1) {
return ids[0];
} else {
uint32_t bool_id = FindIdOfBoolType(context->module());
assert(bool_id != 0);
uint32_t prev_spec_const_id = ids[0];
for (size_t i = 1; i < ids.size(); ++i) {
const uint32_t id = ids[i];
const uint32_t spec_const_op_id = context->TakeNextId();
auto inst = std::make_unique<Instruction>(
context, spv::Op::OpSpecConstantOp, bool_id, spec_const_op_id,
std::initializer_list<opt::Operand>{
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {(uint32_t)(cmp_op)}},
{SPV_OPERAND_TYPE_ID, {prev_spec_const_id}},
{SPV_OPERAND_TYPE_ID, {id}}});
context->module()->AddType(std::move(inst));
prev_spec_const_id = spec_const_op_id;
}
return prev_spec_const_id;
}
}
// Returns whether instruction can be shared between variant modules and
// combined using spec constants (such as conditional capabilities).
bool CanBeFnVarCombined(const Instruction* inst) {
const spv::Op opcode = inst->opcode();
if ((opcode != spv::Op::OpExtInstImport) &&
(opcode != spv::Op::OpCapability) && (opcode != spv::Op::OpExtension) &&
!spvOpcodeGeneratesType(opcode)) {
return false;
}
if ((opcode == spv::Op::OpCapability) &&
((inst->GetSingleWordOperand(0) ==
static_cast<uint32_t>(spv::Capability::FunctionVariantsINTEL)) ||
(inst->GetSingleWordOperand(0) ==
static_cast<uint32_t>(spv::Capability::SpecConditionalINTEL)))) {
// Always enabled
return false;
}
if ((opcode == spv::Op::OpExtension) &&
(inst->GetOperand(0).AsString() == FNVAR_EXT_NAME)) {
// Always enabled
return false;
}
return true;
}
// Calculates hash of an instruction.
//
// Applicable only to instructions that can be combined (ie. with
// CanBeFnVarCombined being true) and from those, hash can be only computed for
// selected instructions. Computing hash from other instruction is unsupported.
size_t HashInst(const Instruction* inst) {
if (CanBeFnVarCombined(inst)) {
if (spvOpcodeGeneratesType(inst->opcode())) {
const Type* t =
inst->context()->get_type_mgr()->GetType(inst->result_id());
assert(t != nullptr);
return t->HashValue();
}
if (inst->opcode() == spv::Op::OpExtension) {
const auto name = inst->GetOperand(0).AsString();
return std::hash<std::string>()(name);
}
if (inst->opcode() == spv::Op::OpCapability) {
const auto cap = inst->GetSingleWordOperand(0);
return std::hash<uint32_t>()(cap);
}
if (inst->opcode() == spv::Op::OpExtInstImport) {
const auto name = inst->GetOperand(1).AsString();
return std::hash<std::string>()(name);
}
}
assert(false && "Unsupported instruction hash");
return std::hash<const Instruction*>()(inst);
}
std::string GetFnName(const Instruction& fn_inst) {
// Check entry point
const auto* ep_inst = FindEntryPoint(fn_inst);
if (ep_inst != nullptr) {
const int name_i =
ep_inst->opcode() == spv::Op::OpConditionalEntryPointINTEL ? 3 : 2;
return ep_inst->GetOperand(name_i).AsString();
}
// Check name of export linkage attribute decoration
const auto* decor_mgr = fn_inst.context()->get_decoration_mgr();
for (const auto* inst :
decor_mgr->GetDecorationsFor(fn_inst.result_id(), true)) {
const auto decoration = inst->GetOperand(1);
if ((decoration.type == SPV_OPERAND_TYPE_DECORATION) &&
(decoration.words.size() == 1) &&
(decoration.words[0] ==
static_cast<uint32_t>(spv::Decoration::LinkageAttributes))) {
const auto linkage = inst->GetOperand(3);
if ((linkage.type == SPV_OPERAND_TYPE_LINKAGE_TYPE) &&
(linkage.words.size() == 1) &&
(linkage.words[0] ==
static_cast<uint32_t>(spv::LinkageType::Export))) {
// decorates fn with LinkageAttribute and Export linkage type -> get the
// name
return inst->GetOperand(2).AsString();
}
}
}
return "";
}
uint32_t FindSpecConstByName(const Module* mod, std::string name) {
for (const auto* const_inst : mod->context()->GetConstants()) {
if (opt::IsSpecConstantInst(const_inst->opcode())) {
const auto id = const_inst->result_id();
for (const auto& name_inst : mod->debugs2()) {
if ((name_inst.opcode() == spv::Op::OpName) &&
(name_inst.GetOperand(0).AsId() == id) &&
(name_inst.GetOperand(1).AsString() == name)) {
return id;
}
}
}
}
return 0;
}
uint32_t CombineVariantDefs(const std::vector<VariantDef>& variant_defs,
const std::vector<size_t> var_ids,
IRContext* context,
std::map<std::vector<size_t>, uint32_t>& cache) {
assert(var_ids.size() <= variant_defs.size());
uint32_t spec_const_comb_id = 0;
if (var_ids.size() != variant_defs.size()) {
// if not used by all variants
if (cache.find(var_ids) == cache.end()) {
// cache variant combinations
std::vector<uint32_t> spec_const_ids;
for (const auto& var_id : var_ids) {
const auto var_name = variant_defs[var_id].GetName();
const auto var_spec_id =
FindSpecConstByName(context->module(), var_name);
spec_const_ids.push_back(var_spec_id);
}
spec_const_comb_id =
CombineIds(context, spec_const_ids, spv::Op::OpLogicalOr);
assert(spec_const_comb_id != 0);
cache.insert({var_ids, spec_const_comb_id});
} else {
spec_const_comb_id = cache[var_ids];
}
}
return spec_const_comb_id;
}
bool strToInt(std::string s, uint32_t* x) {
for (const char& c : s) {
if (c < '0' || c > '9') {
return false;
}
}
if (!(std::stringstream(s) >> *x)) {
return false;
}
return true;
}
} // anonymous namespace
bool VariantDefs::ProcessFnVar(const LinkerOptions& options,
const std::vector<Module*>& modules) {
assert(variant_defs_.empty());
assert(modules.size() == options.GetInFiles().size());
for (size_t i = 0; i < modules.size(); ++i) {
const auto* feat_mgr = modules[i]->context()->get_feature_mgr();
if ((feat_mgr->HasCapability(spv::Capability::FunctionVariantsINTEL)) ||
(feat_mgr->HasCapability(spv::Capability::SpecConditionalINTEL)) ||
(feat_mgr->HasExtension(kSPV_INTEL_function_variants))) {
// In principle, it can be done but it's complicated due to having to
// combine the existing conditionals with the new ones. For example,
// conditional capabilities would need to become "doubly-conditional".
err_ << "Creating multitarget modules from multitarget modules is not "
"supported. Offending file: "
<< options.GetInFiles()[i];
return false;
}
}
std::vector<std::vector<std::string>> target_rows;
std::vector<std::vector<std::string>> architecture_rows;
if (!options.GetFnVarTargetsCsv().empty()) {
const std::vector<std::string> tgt_cols = {"module", "target", "features"};
if (!ParseCsv(options.GetFnVarTargetsCsv(), tgt_cols, err_, target_rows)) {
return false;
}
}
if (!options.GetFnVarArchitecturesCsv().empty()) {
const std::vector<std::string> arch_cols = {"module", "category", "family",
"op", "architecture"};
if (!ParseCsv(options.GetFnVarArchitecturesCsv(), arch_cols, err_,
architecture_rows)) {
return false;
}
}
// check that all modules defined in the CSV exist
for (const auto& tgt_vals : target_rows) {
bool found = false;
for (const auto& in_file : options.GetInFiles()) {
if (tgt_vals[0] == in_file) {
found = true;
}
}
if (!found) {
err_ << "Module '" << tgt_vals[0]
<< "' found in targets CSV not passed to the CLI.";
return false;
}
}
for (const auto& arch_vals : architecture_rows) {
bool found = false;
for (const auto& in_file : options.GetInFiles()) {
if (arch_vals[0] == in_file) {
found = true;
}
}
if (!found) {
err_ << "Module '" << arch_vals[0]
<< "' found in architectures CSV not passed to the CLI.";
return false;
}
}
// create per-module variant defs
for (size_t i = 0; i < modules.size(); ++i) {
// first module passed to the CLI is considered the base module
bool is_base = i == 0;
const auto name = options.GetInFiles()[i];
auto variant_def = VariantDef(is_base, name, modules[i]);
for (const auto& arch_row : architecture_rows) {
const auto row_name = arch_row[0];
if (row_name == name) {
uint32_t category, family, op, architecture;
if (!strToInt(arch_row[1], &category)) {
err_ << "Error converting " << arch_row[1]
<< " to architecture category.";
return false;
}
if (!strToInt(arch_row[2], &family)) {
err_ << "Error converting " << arch_row[2]
<< " to architecture family.";
return false;
}
if (!strToInt(arch_row[3], &op)) {
err_ << "Error converting " << arch_row[3] << " to architecture op.";
return false;
}
if (!strToInt(arch_row[4], &architecture)) {
err_ << "Error converting " << arch_row[4] << " to architecture.";
return false;
}
variant_def.AddArchDef(category, family, op, architecture);
}
}
for (const auto& tgt_row : target_rows) {
const auto row_name = tgt_row[0];
if (row_name == name) {
uint32_t target;
std::vector<uint32_t> features;
if (!strToInt(tgt_row[1], &target)) {
err_ << "Error converting " << tgt_row[1] << " to target.";
return false;
}
// get features as FEAT_SEP-delimited integers
std::stringstream feat_stream(tgt_row[2]);
std::string feat;
while (std::getline(feat_stream, feat, FEAT_SEP)) {
uint32_t ufeat;
// if (!(std::stringstream(feat) >> ufeat)) {
if (!strToInt(feat, &ufeat)) {
err_ << "Error converting " << feat << " in " << tgt_row[2]
<< " to target feature.";
return false;
}
features.push_back(ufeat);
}
variant_def.AddTgtDef(target, features);
}
}
if (options.GetHasFnVarCapabilities()) {
variant_def.InferCapabilities();
}
variant_defs_.push_back(variant_def);
}
return true;
}
bool VariantDefs::ProcessVariantDefs() {
EnsureBoolType();
CollectVarInsts();
if (!GenerateFnVarConstants()) {
return false;
}
CollectBaseFnCalls();
return true;
}
void VariantDefs::GenerateHeader(IRContext* linked_context) {
linked_context->AddCapability(spv::Capability::SpecConditionalINTEL);
linked_context->AddCapability(spv::Capability::FunctionVariantsINTEL);
linked_context->AddExtension(std::string(FNVAR_EXT_NAME));
// Specifies used registry version
auto inst =
std::make_unique<Instruction>(linked_context, spv::Op::OpModuleProcessed);
std::stringstream line;
line << "SPV_INTEL_function_variants registry version "
<< FNVAR_REGISTRY_VERSION;
inst->AddOperand(
{SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(line.str())});
linked_context->AddDebug3Inst(std::move(inst));
}
void VariantDefs::CombineVariantInstructions(IRContext* linked_context) {
CombineBaseFnCalls(linked_context);
CombineInstructions(linked_context);
}
void VariantDefs::EnsureBoolType() {
for (auto& variant_def : variant_defs_) {
Module* module = variant_def.GetModule();
IRContext* context = module->context();
uint32_t bool_id = FindIdOfBoolType(module);
if (bool_id == 0) {
bool_id = context->TakeNextId();
auto variant_bool = std::make_unique<Instruction>(
context, spv::Op::OpTypeBool, 0, bool_id,
std::initializer_list<opt::Operand>{});
module->AddType(std::move(variant_bool));
}
}
}
void VariantDefs::CollectVarInsts() {
for (size_t i = 0; i < variant_defs_.size(); ++i) {
const auto variant_def = variant_defs_[i];
const auto* var_mod = variant_def.GetModule();
var_mod->ForEachInst([this, &i](const Instruction* inst) {
if (CanBeFnVarCombined(inst)) {
const size_t inst_hash = HashInst(inst);
if (fnvar_usage_.find(inst_hash) == fnvar_usage_.end()) {
fnvar_usage_.insert({inst_hash, {i}});
} else {
assert(fnvar_usage_[inst_hash].size() < variant_defs_.size());
fnvar_usage_[inst_hash].push_back(i);
}
}
});
}
}
bool VariantDefs::GenerateFnVarConstants() {
assert(variant_defs_.size() > 0);
assert(variant_defs_[0].IsBase());
if (variant_defs_.size() == 1) {
return true;
}
for (auto& variant_def : variant_defs_) {
Module* module = variant_def.GetModule();
IRContext* context = module->context();
uint32_t bool_id = FindIdOfBoolType(module);
if (bool_id == 0) {
// add a bool type if not present already
bool_id = context->TakeNextId();
auto variant_bool = std::make_unique<Instruction>(
context, spv::Op::OpTypeBool, 0, bool_id,
std::initializer_list<opt::Operand>{});
module->AddType(std::move(variant_bool));
}
// Spec constant architecture and target
std::vector<uint32_t> spec_const_arch_ids;
for (const auto& arch_def : variant_def.GetArchDefs()) {
const uint32_t spec_const_arch_id = context->TakeNextId();
spec_const_arch_ids.push_back(spec_const_arch_id);
auto inst = std::make_unique<Instruction>(
context, spv::Op::OpSpecConstantArchitectureINTEL, bool_id,
spec_const_arch_id,
std::initializer_list<opt::Operand>{
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.category}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.family}},
// Using spec op opcode here expects then next operand to be
// a type:
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.op}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.architecture}},
});
module->AddType(std::move(inst));
}
std::vector<uint32_t> spec_const_tgt_ids;
for (const auto& tgt_def : variant_def.GetTgtDefs()) {
const uint32_t spec_const_tgt_id = context->TakeNextId();
spec_const_tgt_ids.push_back(spec_const_tgt_id);
auto inst = std::make_unique<Instruction>(
context, spv::Op::OpSpecConstantTargetINTEL, bool_id,
spec_const_tgt_id,
std::initializer_list<opt::Operand>{
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {tgt_def.target}},
});
for (const auto& feat : tgt_def.features) {
inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {feat}});
}
module->AddType(std::move(inst));
}
std::vector<uint32_t> spec_const_ids;
// Spec constant capabilities
const auto variant_capabilities = variant_def.GetCapabilities();
if (!variant_capabilities.empty()) {
const uint32_t spec_const_cap_id = context->TakeNextId();
auto inst = std::make_unique<Instruction>(
context, spv::Op::OpSpecConstantCapabilitiesINTEL, bool_id,
spec_const_cap_id, std::initializer_list<opt::Operand>{});
for (const auto& cap : variant_capabilities) {
inst->AddOperand({SPV_OPERAND_TYPE_CAPABILITY, {uint32_t(cap)}});
}
module->AddType(std::move(inst));
spec_const_ids.push_back(spec_const_cap_id);
}
// Combine architectures such that, for the same module, those with the same
// category and family are combined with AND and different cat/fam are
// combined with OR.
// This lets you create combinations like "architecture between X and Y".
// map (category, family) -> IDs
std::map<std::pair<uint32_t, uint32_t>, std::vector<uint32_t>> arch_map_and;
for (size_t i = 0; i < spec_const_arch_ids.size(); ++i) {
const auto& arch_def = variant_def.GetArchDefs()[i];
const auto id = spec_const_arch_ids[i];
const auto key = std::make_pair(arch_def.category, arch_def.family);
if (arch_map_and.find(key) == arch_map_and.end()) {
arch_map_and[key] = {id};
} else {
arch_map_and[key].push_back(id);
}
}
std::vector<uint32_t> arch_ids_or;
for (const auto& it : arch_map_and) {
const auto id = CombineIds(context, it.second, spv::Op::OpLogicalAnd);
if (id > 0) {
arch_ids_or.push_back(id);
}
}
const uint32_t spec_const_arch_id =
CombineIds(context, arch_ids_or, spv::Op::OpLogicalOr);
if (spec_const_arch_id > 0) {
spec_const_ids.push_back(spec_const_arch_id);
}
const uint32_t spec_const_tgt_id =
CombineIds(context, spec_const_tgt_ids, spv::Op::OpLogicalOr);
if (spec_const_tgt_id > 0) {
spec_const_ids.push_back(spec_const_tgt_id);
}
uint32_t combined_spec_const_id =
CombineIds(context, spec_const_ids, spv::Op::OpLogicalAnd);
if (combined_spec_const_id == 0) {
// If the variant module has no constraints, use SpecConstantTrue
combined_spec_const_id = context->TakeNextId();
auto inst = std::make_unique<Instruction>(
context, spv::Op::OpSpecConstantTrue, bool_id, combined_spec_const_id,
std::initializer_list<opt::Operand>{});
context->module()->AddType(std::move(inst));
}
assert(combined_spec_const_id != 0);
// Add a name the combined boolean ID so we can look it up after the IDs are
// shifted
auto inst = std::make_unique<Instruction>(context, spv::Op::OpName);
inst->AddOperand({SPV_OPERAND_TYPE_ID, {combined_spec_const_id}});
std::vector<uint32_t> str_words;
utils::AppendToVector(variant_def.GetName(), &str_words);
inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_STRING, {str_words}});
module->AddDebug2Inst(std::move(inst));
// Annotate all instructions in the types section (eg. constants) with
// ConditionalINTEL, unless they can be shared between variant_defs_ (eg.
// types). Spec constants are excluded because they might have been
// generated by this extension.
for (const auto& type_inst : module->types_values()) {
if (!CanBeFnVarCombined(&type_inst) &&
!spvOpcodeIsSpecConstant(type_inst.opcode())) {
DecorateConditional(context, type_inst.result_id(),
combined_spec_const_id);
}
}
}
// Annotate functions with ConditionalINTEL
for (const auto& base_fn : *variant_defs_[0].GetModule()) {
// For each function of the base module, find matching variant functions in
// other modules
auto base_fn_name = GetFnName(base_fn.DefInst());
if (base_fn_name.empty()) {
err_ << "Could not find name of a function " << base_fn.result_id()
<< " in a base module " << variant_defs_[0].GetName()
<< ". To be usable by SPV_INTEL_function_variants, a function "
"must either have an entry point or an export "
"LinkAttribute decoration.";
return false;
}
bool base_fn_needs_conditional = false;
for (size_t i = 1; i < variant_defs_.size(); ++i) {
const auto& variant_def = variant_defs_[i];
auto* variant_module = variant_def.GetModule();
auto* variant_context = variant_module->context();
for (const auto& var_fn : *variant_module) {
auto var_fn_name = GetFnName(var_fn.DefInst());
if (var_fn_name.empty()) {
err_ << "Could not find name of a function " << var_fn.result_id()
<< " in a base module " << variant_def.GetName()
<< ". To be usable by SPV_INTEL_function_variants, a function "
"must either have an entry point or an export "
"LinkAttribute decoration.";
return false;
}
if (base_fn_name == var_fn_name) {
base_fn_needs_conditional = true;
}
// each function in a variant module gets a ConditionalINTEL decoration
uint32_t spec_const_id =
FindSpecConstByName(variant_module, variant_def.GetName());
assert(spec_const_id != 0);
DecorateConditional(variant_context, var_fn.result_id(), spec_const_id);
ConvertEPToConditional(variant_module, var_fn, spec_const_id);
}
}
if (base_fn_needs_conditional) {
// only a base function that has a variant in another module gets a
// ConditionalINTEL decoration, the others are common for all
// variant_defs_
auto* base_module = variant_defs_[0].GetModule();
auto* base_context = base_module->context();
uint32_t spec_const_id =
FindSpecConstByName(base_module, variant_defs_[0].GetName());
assert(spec_const_id != 0);
DecorateConditional(base_context, base_fn.result_id(), spec_const_id);
ConvertEPToConditional(base_module, base_fn, spec_const_id);
}
}
return true;
}
void VariantDefs::CollectBaseFnCalls() {
auto* base_mod = variant_defs_[0].GetModule();
assert(variant_defs_[0].IsBase());
const auto* base_def_use_mgr = base_mod->context()->get_def_use_mgr();
base_mod->ForEachInst([this, &base_def_use_mgr](const Instruction* inst) {
if (inst->opcode() == spv::Op::OpFunctionCall) {
// For each function call in base module, get the function name
const auto fn_id = inst->GetOperand(2).AsId();
const auto* called_fn_inst = base_def_use_mgr->GetDef(fn_id);
assert(called_fn_inst != nullptr);
const auto called_fn_name = GetFnName(*called_fn_inst);
assert(!called_fn_name.empty());
std::vector<std::pair<std::string, const opt::Function*>> called_fns;
for (size_t i = 1; i < variant_defs_.size(); ++i) {
// ... then see in which variant the called function was defined
const auto& variant_def = variant_defs_[i];
assert(!variant_def.IsBase());
for (const auto& fn : *variant_def.GetModule()) {
const auto fn_name = GetFnName(fn.DefInst());
if (fn_name == called_fn_name) {
called_fns.push_back(std::make_pair(variant_def.GetName(), &fn));
}
}
}
if (!called_fns.empty()) {
base_fn_calls_[inst->result_id()] = called_fns;
}
}
});
}
void VariantDefs::CombineBaseFnCalls(IRContext* linked_context) {
for (auto kv : base_fn_calls_) {
const uint32_t call_id = kv.first;
const auto called_fns = kv.second;
if (called_fns.empty()) {
return;
}
opt::BasicBlock* fn_call_bb = linked_context->get_instr_block(call_id);
Instruction* found_call_inst = nullptr;
auto bb_iter = fn_call_bb->begin();
while (bb_iter != fn_call_bb->end() && found_call_inst == nullptr) {
if (bb_iter->HasResultId() && bb_iter->result_id() == call_id) {
found_call_inst = &*bb_iter;
}
++bb_iter;
}
if (found_call_inst == nullptr) {
return;
}
const auto base_spec_const_id = FindSpecConstByName(
variant_defs_[0].GetModule(), variant_defs_[0].GetName());
const auto base_type_op = found_call_inst->context()
->get_def_use_mgr()
->GetDef(found_call_inst->type_id())
->opcode();
const auto base_call_id = found_call_inst->result_id();
// decorate the base call with ConditionalINTEL
DecorateConditional(linked_context, base_call_id, base_spec_const_id);
// Add OpFunctionCall for each variant
Instruction* last_inst = found_call_inst;
std::vector<std::pair<uint32_t, uint32_t>> var_call_ids;
for (const auto& kv2 : called_fns) {
const std::string var_name = kv2.first;
const opt::Function* fn = kv2.second;
const uint32_t spec_const_id =
FindSpecConstByName(linked_context->module(), var_name);
assert(spec_const_id != 0);
const uint32_t var_call_id = linked_context->TakeNextId();
var_call_ids.push_back(std::make_pair(spec_const_id, var_call_id));
auto* var_call_inst = found_call_inst->Clone(linked_context);
var_call_inst->SetResultId(var_call_id);
var_call_inst->SetOperand(2, {fn->result_id()});
var_call_inst->InsertAfter(last_inst);
linked_context->set_instr_block(var_call_inst, fn_call_bb);
last_inst = var_call_inst;
// decorate the variant call with ConditionalINTEL
DecorateConditional(linked_context, var_call_id, spec_const_id);
}
if (base_type_op != spv::Op::OpTypeVoid) {
// Add OpConditionalCopyObjectINTEL combining the function calls
const uint32_t result_id = linked_context->TakeNextId();
auto conditional_copy_inst = new Instruction(
linked_context, spv::Op::OpConditionalCopyObjectINTEL,
found_call_inst->type_id(), result_id,
{{SPV_OPERAND_TYPE_ID, {base_spec_const_id}},
{SPV_OPERAND_TYPE_ID, {found_call_inst->result_id()}}});
for (const auto& kv3 : var_call_ids) {
const auto spec_const_id = kv3.first;
const auto var_call_id = kv3.second;
conditional_copy_inst->AddOperand(
{SPV_OPERAND_TYPE_ID, {spec_const_id}});
conditional_copy_inst->AddOperand({SPV_OPERAND_TYPE_ID, {var_call_id}});
}
conditional_copy_inst->InsertAfter(last_inst);
linked_context->set_instr_block(conditional_copy_inst, fn_call_bb);
last_inst = conditional_copy_inst;
// In all remaining instructions within the basic block, replace all
// usages of the base call ID with the result of
// OpConditionalCopyObjectINTEL
do {
last_inst = last_inst->NextNode();
last_inst->ForEachInId([base_call_id, result_id](uint32_t* id) {
if (*id == base_call_id) {
*id = result_id;
}
});
} while (last_inst != nullptr && *last_inst != *fn_call_bb->tail());
}
}
// Combine spec consts for the base module (base module is activated if all
// variant defs are inactive AND the base module constraints are satisfied)
std::vector<uint32_t> var_spec_const_ids;
for (const auto& variant_def : variant_defs_) {
if (variant_def.IsBase()) {
continue;
}
const auto id =
FindSpecConstByName(linked_context->module(), variant_def.GetName());
assert(id != 0);
var_spec_const_ids.push_back(id);
}
const uint32_t base_or_id =
CombineIds(linked_context, var_spec_const_ids, spv::Op::OpLogicalOr);
if (base_or_id != 0) {
const uint32_t bool_id = FindIdOfBoolType(linked_context->module());
assert(bool_id != 0);
const uint32_t base_not_id = linked_context->TakeNextId();
auto spec_const_op_inst = std::make_unique<Instruction>(
linked_context, spv::Op::OpSpecConstantOp, bool_id, base_not_id,
std::initializer_list<opt::Operand>{
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
{(uint32_t)(spv::Op::OpLogicalNot)}},
{SPV_OPERAND_TYPE_ID, {base_or_id}}});
linked_context->module()->AddType(std::move(spec_const_op_inst));
// Update any ConditionalINTEL annotations, names and entry points
// referencing the old spec const ID to use the new one
const uint32_t old_base_spec_const_id = FindSpecConstByName(
linked_context->module(), variant_defs_[0].GetName());
assert(old_base_spec_const_id != 0);
const uint32_t base_spec_const_id =
CombineIds(linked_context, {old_base_spec_const_id, base_not_id},
spv::Op::OpLogicalAnd);
for (auto& annot_inst : linked_context->module()->annotations()) {
if ((annot_inst.GetSingleWordOperand(1) ==
uint32_t(spv::Decoration::ConditionalINTEL)) &&
(annot_inst.GetOperand(2).AsId() == old_base_spec_const_id)) {
annot_inst.SetOperand(2, {base_spec_const_id});
}
}
for (auto& name_inst : linked_context->module()->debugs2()) {
if ((name_inst.opcode() == spv::Op::OpName) &&
(name_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
name_inst.SetOperand(0, {base_spec_const_id});
}
}
for (auto& ep_inst : linked_context->module()->entry_points()) {
if ((ep_inst.opcode() == spv::Op::OpConditionalEntryPointINTEL) &&
(ep_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
ep_inst.SetOperand(0, {base_spec_const_id});
}
}
linked_context->module()->ForEachInst(
[old_base_spec_const_id, base_spec_const_id](Instruction* inst) {
if (inst->opcode() == spv::Op::OpConditionalCopyObjectINTEL) {
inst->ForEachInId(
[old_base_spec_const_id, base_spec_const_id](uint32_t* id) {
if (*id == old_base_spec_const_id) {
*id = base_spec_const_id;
}
});
}
});
}
}
void VariantDefs::CombineInstructions(IRContext* linked_context) {
// cache for existing variant ID combinations
std::map<std::vector<size_t>, uint32_t> spec_const_comb_ids;
linked_context->module()->ForEachInst(
[this, &linked_context, &spec_const_comb_ids](Instruction* inst) {
if (!CanBeFnVarCombined(inst)) {
return;
}
const size_t inst_hash = HashInst(inst);
if (fnvar_usage_.find(inst_hash) != fnvar_usage_.end()) {
const std::vector<size_t> var_ids = fnvar_usage_[inst_hash];
const uint32_t spec_const_comb_id = CombineVariantDefs(
variant_defs_, var_ids, linked_context, spec_const_comb_ids);
if (spec_const_comb_id != 0) {
if (inst->HasResultId()) {
DecorateConditional(linked_context, inst->result_id(),
spec_const_comb_id);
} else if (inst->opcode() == spv::Op::OpCapability) {
const uint32_t cap = inst->GetSingleWordOperand(0);
inst->SetOpcode(spv::Op::OpConditionalCapabilityINTEL);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
{SPV_OPERAND_TYPE_CAPABILITY, {cap}}});
} else if (inst->opcode() == spv::Op::OpExtension) {
const std::string ext_name = inst->GetOperand(0).AsString();
inst->SetOpcode(spv::Op::OpConditionalExtensionINTEL);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
{SPV_OPERAND_TYPE_LITERAL_STRING,
{utils::MakeVector(ext_name)}}});
} else {
assert(false && "Unsupported");
}
}
}
});
}
} // namespace spvtools