blob: 843a8c032caf35a772bb29cd20801a22d91051fb [file] [log] [blame]
/*
* Copyright 2024 Google LLC
*
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#include "src/sksl/analysis/SkSLSpecialization.h"
#include "include/private/base/SkAssert.h"
#include "include/private/base/SkSpan_impl.h"
#include "src/sksl/SkSLAnalysis.h"
#include "src/sksl/SkSLDefines.h"
#include "src/sksl/analysis/SkSLProgramVisitor.h"
#include "src/sksl/ir/SkSLExpression.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
#include "src/sksl/ir/SkSLFunctionCall.h"
#include "src/sksl/ir/SkSLFunctionDeclaration.h"
#include "src/sksl/ir/SkSLFunctionDefinition.h"
#include "src/sksl/ir/SkSLProgram.h"
#include "src/sksl/ir/SkSLProgramElement.h"
#include "src/sksl/ir/SkSLVariable.h"
#include "src/sksl/ir/SkSLVariableReference.h"
#include <algorithm>
#include <memory>
using namespace skia_private;
namespace SkSL::Analysis {
static bool parameter_mappings_are_equal(const SpecializedParameters& left,
const SpecializedParameters& right) {
if (left.count() != right.count()) {
return false;
}
for (const auto& [key, leftExpr] : left) {
const Expression** rightExpr = right.find(key);
if (!rightExpr) {
return false;
}
if (!Analysis::IsSameExpressionTree(*leftExpr, **rightExpr)) {
return false;
}
}
return true;
}
void FindFunctionsToSpecialize(const Program& program,
SpecializationInfo* info,
const ParameterMatchesFn& parameterMatchesFn) {
class Searcher : public ProgramVisitor {
public:
using ProgramVisitor::visitProgramElement;
using INHERITED = ProgramVisitor;
Searcher(SpecializationInfo& info, const ParameterMatchesFn& parameterMatchesFn)
: fSpecializationMap(info.fSpecializationMap)
, fSpecializedCallMap(info.fSpecializedCallMap)
, fParameterMatchesFn(parameterMatchesFn) {}
bool visitExpression(const Expression& expr) override {
if (expr.is<FunctionCall>()) {
const FunctionCall& call = expr.as<FunctionCall>();
const FunctionDeclaration& decl = call.function();
if (!decl.isIntrinsic()) {
SpecializedParameters specialization;
const int numParams = decl.parameters().size();
SkASSERT(call.arguments().size() == numParams);
for (int i = 0; i < numParams; i++) {
const Expression& arg = *call.arguments()[i];
// Specializations can only be made on arguments that are not complex
// expressions but only a variable reference or field access since these
// references will be inlined in the generated specialized functions.
const Variable* argBase = nullptr;
if (arg.is<VariableReference>()) {
argBase = arg.as<VariableReference>().variable();
} else if (arg.is<FieldAccess>() &&
arg.as<FieldAccess>().base()->is<VariableReference>()) {
argBase =
arg.as<FieldAccess>().base()->as<VariableReference>().variable();
} else {
continue;
}
SkASSERT(argBase);
const Variable* param = decl.parameters()[i];
// Check that this parameter fits the criteria to create a specialization.
if (!fParameterMatchesFn(*param)) {
continue;
}
if (argBase->storage() == Variable::Storage::kGlobal) {
specialization[param] = &arg;
} else if (argBase->storage() == Variable::Storage::kParameter) {
const Expression** uniformExpr =
fInheritedSpecializations.find(argBase);
SkASSERT(uniformExpr);
specialization[param] = *uniformExpr;
} else {
// TODO(b/353532475): Report an error instead of aborting.
SK_ABORT("Specialization requires a uniform or parameter variable");
}
}
// Only create a specialization for this function if there are
// variables to specialize on.
if (specialization.count() > 0) {
Specializations& specializations = fSpecializationMap[&decl];
SpecializedCallKey callKey{call.stablePointer(),
fInheritedSpecializationIndex};
for (int i = 0; i < specializations.size(); i++) {
const SpecializedParameters& entry = specializations[i];
if (parameter_mappings_are_equal(specialization, entry)) {
// This specialization has already been tracked.
fSpecializedCallMap[callKey] = i;
return INHERITED::visitExpression(expr);
}
}
// Set the index to the corresponding specialization this function call
// requires, also tracking the inherited specialization this function
// call is in so the right specialized function can be called.
SpecializationIndex specializationIndex = specializations.size();
fSpecializedCallMap[callKey] = specializationIndex;
specializations.push_back(specialization);
// We swap so we don't lose when our last inherited specializations were
// once we are done traversing this specific specialization.
fInheritedSpecializations.swap(specialization);
std::swap(fInheritedSpecializationIndex, specializationIndex);
this->visitProgramElement(*decl.definition());
std::swap(fInheritedSpecializationIndex, specializationIndex);
fInheritedSpecializations.swap(specialization);
} else {
// The function being called isn't specialized, but we need to walk the
// entire call graph or we may miss a specialized call entirely. Since
// nothing is specialized, it is safe to skip over repeated traversals.
if (!fVisitedFunctions.find(&decl)) {
fVisitedFunctions.add(&decl);
this->visitProgramElement(*decl.definition());
}
}
}
}
return INHERITED::visitExpression(expr);
}
private:
SpecializationMap& fSpecializationMap;
SpecializedCallMap& fSpecializedCallMap;
const ParameterMatchesFn& fParameterMatchesFn;
THashSet<const FunctionDeclaration*> fVisitedFunctions;
SpecializedParameters fInheritedSpecializations;
SpecializationIndex fInheritedSpecializationIndex = kUnspecialized;
};
for (const ProgramElement* elem : program.elements()) {
if (elem->is<FunctionDefinition>()) {
const FunctionDeclaration& decl = elem->as<FunctionDefinition>().declaration();
if (decl.isMain()) {
// Visit through the program call stack and aggregates any necessary
// function specializations.
Searcher(*info, parameterMatchesFn).visitProgramElement(*elem);
continue;
}
// Look for any function parameter which needs specialization.
for (const Variable* param : decl.parameters()) {
if (parameterMatchesFn(*param)) {
// We found a function that requires specialization. Ensure that this function
// ends up in the specialization map, whether or not it is reachable from
// main().
//
// Doing this here allows unreachable specialized functions to be discarded,
// because it will be in the specialization map with an array of zero necessary
// specializations to emit. If we didn't add this function to the specialization
// map at all, the code generator would try to emit it without applying
// specializations, and generally this would lead to invalid code.
info->fSpecializationMap[&decl];
break;
}
}
}
}
}
SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
const SpecializationInfo& info,
SpecializationIndex parentSpecializationIndex) {
SpecializedCallKey callKey{call.stablePointer(), parentSpecializationIndex};
SpecializationIndex* foundIndex = info.fSpecializedCallMap.find(callKey);
return foundIndex ? *foundIndex : kUnspecialized;
}
SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
const SpecializationInfo& info) {
SkBitSet result(func.parameters().size());
if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
const Analysis::SpecializedParameters& specializedParams = specializations->front();
const SkSpan<Variable* const> funcParams = func.parameters();
for (size_t index = 0; index < funcParams.size(); ++index) {
if (specializedParams.find(funcParams[index])) {
result.set(index);
}
}
}
return result;
}
void GetParameterMappingsForFunction(const FunctionDeclaration& func,
const SpecializationInfo& info,
SpecializationIndex specializationIndex,
const ParameterMappingCallback& callback) {
if (specializationIndex != Analysis::kUnspecialized) {
if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
const Analysis::SpecializedParameters& specializedParams =
specializations->at(specializationIndex);
const SkSpan<Variable* const> funcParams = func.parameters();
for (size_t index = 0; index < funcParams.size(); ++index) {
const Variable* param = funcParams[index];
if (const Expression** expr = specializedParams.find(param)) {
callback(index, param, *expr);
}
}
}
}
}
} // namespace SkSL::Analysis