blob: 49db030b1bea7847c58f572325aecb0cc26d52a7 [file] [log] [blame] [edit]
// Copyright 2025 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implementation of generating multitarget modules according to the
// *SPV_INTEL_function_variants* extension
//
// Multitarget module is generated by linking separate modules: a base module
// and variant modules containing device-specific variants of the functions in
// the base module. The behavior is controlled by Comma-Separated Values (CSV)
// files passed to the following flags:
// --fnvar-targets: Required columns:
// module - module file name
// target - device target ISA value
// features - feature values for the target separated by '/' (FEAT_SEP)
// --fnvar-architectures: Required columns:
// module - module file name
// category - device category value
// family - device family value
// op - opcode of the comparison instruction
// architecture - device architecture
// The values (except module) are decimal strings with their meaning defined in
// the 'targets registry' as described in the extension spec. The decimal
// strings may only encode unsigned 32-bit integers (characters 0-9), possibly
// with leading zeros.
//
// In addition, --fnvar-capabilities generates OpSpecConstantCapabilitiesINTEL
// for each module with operands corresponding to the module's capabilities.
//
// Each line in the targets/architectures CSV file defines one
// OpSpecConstant<Target/Architecture>INTEL instruction, the columns correspond
// to the operands of these instructions. One module can have multiple lines, in
// which case they are combined into a single boolean spec constant using
// OpSpecConstantOp and OpLogicalOr (except when category and family in the
// architectures CSV are the same, then the lines are combined with
// OpLogicalAnd). For example, the following architectures CSV
//
// module,category,family,op,architecture
// foo.spv,1,7,174,1
// foo.spv,1,7,178,3
// foo.spv,1,8,170,1
//
// is combined as follows:
//
// %53 = OpSpecConstantArchitectureINTEL %bool 1 7 174 1
// %54 = OpSpecConstantArchitectureINTEL %bool 1 7 178 3
// %55 = OpSpecConstantArchitectureINTEL %bool 1 8 170 1
// %56 = OpSpecConstantOp %bool LogicalAnd %53 %54
// %foo_spv = OpSpecConstantOp %bool LogicalOr %55 %56
//
// The %foo_spv is annotated with OpName "foo.spv" (the module's name) which
// serves as an identifier to find the constant later. We cannot use IDs for it
// because the IDs get shifted during linking.
//
// The first module passed to `spirv-link` is considered the 'base' module. For
// example, if base module defines functions 'foo' and 'bar' and the other
// modules define only 'foo', only the 'foo' is treated as a function variant
// guarded by spec constants. The 'bar' function will be untouched and therefore
// present for all variants. The function variants are matched by name, and
// therefore they must either have an entry point, or an Export linkage
// attribute.
#ifndef FNVAR_H
#define FNVAR_H
#include <map>
#include <set>
#include <string>
#include <vector>
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "spirv-tools/linker.hpp"
namespace spvtools {
using opt::IRContext;
using opt::Module;
// Map of instruction hash -> which variants are using the instruction (denoted
// by the index to the variants vector)
using FnVarUsage = std::unordered_map<size_t, std::vector<size_t>>;
// Map of base function call ID -> variant functions corresponding to the
// called function (along with the variant name)
using BaseFnCalls =
std::map<uint32_t,
std::vector<std::pair<std::string, const opt::Function*>>>;
constexpr char FNVAR_EXT_NAME[] = "SPV_INTEL_function_variants";
constexpr uint32_t FNVAR_REGISTRY_VERSION = 0;
constexpr char FEAT_SEP = '/';
struct FnVarArchDef {
uint32_t category;
uint32_t family;
uint32_t op;
uint32_t architecture;
};
struct FnVarTargetDef {
uint32_t target;
std::vector<uint32_t> features;
};
// Definition of a variant
//
// Stores architecture and target definitions inferred from lines in the CSV
// files for a single module (as well as a pointer to the Module).
class VariantDef {
public:
VariantDef(bool isbase, std::string nm, Module* mod)
: is_base(isbase), name(nm), module(mod) {}
bool IsBase() const { return this->is_base; }
std::string GetName() const { return this->name; }
Module* GetModule() const { return this->module; }
void AddArchDef(uint32_t category, uint32_t family, uint32_t op,
uint32_t architecture) {
FnVarArchDef arch_def;
arch_def.category = category;
arch_def.family = family;
arch_def.op = op;
arch_def.architecture = architecture;
this->arch_defs.push_back(arch_def);
}
const std::vector<FnVarArchDef>& GetArchDefs() const {
return this->arch_defs;
}
void AddTgtDef(uint32_t target, std::vector<uint32_t> features) {
FnVarTargetDef tgt_def;
tgt_def.target = target;
tgt_def.features = features;
this->tgt_defs.push_back(tgt_def);
}
const std::vector<FnVarTargetDef>& GetTgtDefs() const {
return this->tgt_defs;
}
void InferCapabilities() {
for (const auto& cap_inst : module->capabilities()) {
capabilities.insert(spv::Capability(cap_inst.GetOperand(0).words[0]));
}
}
const std::set<spv::Capability>& GetCapabilities() const {
return this->capabilities;
}
private:
bool is_base;
std::string name;
Module* module;
std::vector<FnVarTargetDef> tgt_defs;
std::vector<FnVarArchDef> arch_defs;
std::set<spv::Capability> capabilities;
};
// Collection of VariantDef instances
//
// Apart from being a wrapper around a vector of VariantDef instances, it
// defines the main API for generating SPV_INTEL_function_variants instructions
// based on the CSV files.
class VariantDefs {
public:
// Returns last error message.
std::string GetErr() { return err_.str(); }
// Processes CSV files passed to the CLI and populate _variants.
//
// Returns true on success, false on error.
bool ProcessFnVar(const LinkerOptions& options,
const std::vector<Module*>& modules);
// Analyses each variant def module and generates those instructions that are
// module-specific, ie., not requiring knowledge from other modules.
//
// Returns true on success, false on error.
bool ProcessVariantDefs();
// Generates basic instructions required for this extension to work.
void GenerateHeader(IRContext* linked_context);
// Generates instructions from this extension that result from combining
// several variant def modules.
void CombineVariantInstructions(IRContext* linked_context);
private:
// Adds a boolean type to every module if there is none.
//
// These are necessary for spec constants.
void EnsureBoolType();
// Collects which combinable instructions are defined in which modules
void CollectVarInsts();
// Generates OpSpecConstant<Target/Architecture/Capabilities>INTEL and
// combines them as necessary. Also converts entry points to conditional ones
// and decorates module-specific instructions with ConditionalINTEL.
//
// Returns true on success, false on error.
bool GenerateFnVarConstants();
// Determines which functions in the base module are called by which function
// variants.
void CollectBaseFnCalls();
// Combines OpFunctionCall instructions collected with CollectBaseFnCalls()
// using conditional copy.
void CombineBaseFnCalls(IRContext* linked_context);
// Decorates instructions shared between modules with ConditionalINTEL or
// generates conditional capabilities and extensions, depending on which
// variants are used by each.
void CombineInstructions(IRContext* linked_context);
// Accumulates all errors encountered during processing.
std::stringstream err_;
// Collection of VariantDef instances
std::vector<VariantDef> variant_defs_;
// Used for combining OpFunctionCall instructions
BaseFnCalls base_fn_calls_;
// Used for determining which function variant uses which (applicable)
// instruction
FnVarUsage fnvar_usage_;
};
} // namespace spvtools
#endif // FNVAR_H