spirv-fuzz: Handle invalid ids in fact manager (#3742)
Fixes #3741.
diff --git a/source/fuzz/fact_manager/data_synonym_and_id_equation_facts.cpp b/source/fuzz/fact_manager/data_synonym_and_id_equation_facts.cpp
index a046a33..d5d9f3f 100644
--- a/source/fuzz/fact_manager/data_synonym_and_id_equation_facts.cpp
+++ b/source/fuzz/fact_manager/data_synonym_and_id_equation_facts.cpp
@@ -281,14 +281,9 @@
assert(synonymous_.Exists(dd) &&
"|dd| should've been registered in the equivalence relation");
- const auto* representative = synonymous_.Find(&dd);
- assert(representative &&
- "Representative can't be null for a registered descriptor");
-
const auto* type =
context->get_type_mgr()->GetType(fuzzerutil::WalkCompositeTypeIndices(
- context, fuzzerutil::GetTypeId(context, representative->object()),
- representative->index()));
+ context, fuzzerutil::GetTypeId(context, dd.object()), dd.index()));
assert(type && "Data descriptor has invalid type");
if ((type->AsVector() && type->AsVector()->element_type()->AsInteger()) ||
@@ -300,24 +295,36 @@
std::vector<const protobufs::DataDescriptor*> convert_u_to_f_lhs;
for (const auto& fact : id_equations_) {
+ auto equivalence_class = synonymous_.GetEquivalenceClass(*fact.first);
+ auto dd_it = std::find_if(
+ equivalence_class.begin(), equivalence_class.end(),
+ [context](const protobufs::DataDescriptor* a) {
+ return context->get_def_use_mgr()->GetDef(a->object()) != nullptr;
+ });
+ if (dd_it == equivalence_class.end()) {
+ // Skip |equivalence_class| if it has no valid ids.
+ continue;
+ }
+
for (const auto& equation : fact.second) {
- if (synonymous_.IsEquivalent(*equation.operands[0], *representative)) {
+ if (synonymous_.IsEquivalent(*equation.operands[0], dd)) {
if (equation.opcode == SpvOpConvertSToF) {
- convert_s_to_f_lhs.push_back(fact.first);
+ convert_s_to_f_lhs.push_back(*dd_it);
} else if (equation.opcode == SpvOpConvertUToF) {
- convert_u_to_f_lhs.push_back(fact.first);
+ convert_u_to_f_lhs.push_back(*dd_it);
}
}
}
}
- for (const auto& synonyms :
- {std::move(convert_s_to_f_lhs), std::move(convert_u_to_f_lhs)}) {
- for (const auto* synonym_a : synonyms) {
- for (const auto* synonym_b : synonyms) {
- if (!synonymous_.IsEquivalent(*synonym_a, *synonym_b) &&
- DataDescriptorsAreWellFormedAndComparable(context, *synonym_a,
- *synonym_b)) {
+ // We use pointers in the initializer list here since otherwise we would
+ // copy memory from these vectors.
+ for (const auto* synonyms : {&convert_s_to_f_lhs, &convert_u_to_f_lhs}) {
+ for (const auto* synonym_a : *synonyms) {
+ for (const auto* synonym_b : *synonyms) {
+ // DataDescriptorsAreWellFormedAndComparable will be called in the
+ // AddDataSynonymFactRecursive method.
+ if (!synonymous_.IsEquivalent(*synonym_a, *synonym_b)) {
// |synonym_a| and |synonym_b| have compatible types - they are
// synonymous.
AddDataSynonymFactRecursive(*synonym_a, *synonym_b, context);
@@ -765,12 +772,14 @@
bool DataSynonymAndIdEquationFacts::DataDescriptorsAreWellFormedAndComparable(
opt::IRContext* context, const protobufs::DataDescriptor& dd1,
const protobufs::DataDescriptor& dd2) {
+ assert(context->get_def_use_mgr()->GetDef(dd1.object()) &&
+ context->get_def_use_mgr()->GetDef(dd2.object()) &&
+ "Both descriptors must exist in the module");
+
auto end_type_id_1 = fuzzerutil::WalkCompositeTypeIndices(
- context, context->get_def_use_mgr()->GetDef(dd1.object())->type_id(),
- dd1.index());
+ context, fuzzerutil::GetTypeId(context, dd1.object()), dd1.index());
auto end_type_id_2 = fuzzerutil::WalkCompositeTypeIndices(
- context, context->get_def_use_mgr()->GetDef(dd2.object())->type_id(),
- dd2.index());
+ context, fuzzerutil::GetTypeId(context, dd2.object()), dd2.index());
// The end types of the data descriptors must exist.
if (end_type_id_1 == 0 || end_type_id_2 == 0) {
return false;
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index e765fe9..001cd86 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -572,7 +572,9 @@
}
uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
- return context->get_def_use_mgr()->GetDef(result_id)->type_id();
+ const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
+ assert(inst && "|result_id| is invalid");
+ return inst->type_id();
}
uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
diff --git a/test/fuzz/fact_manager_test.cpp b/test/fuzz/fact_manager_test.cpp
index 26b9ecc..64104df 100644
--- a/test/fuzz/fact_manager_test.cpp
+++ b/test/fuzz/fact_manager_test.cpp
@@ -16,6 +16,7 @@
#include <limits>
+#include "source/fuzz/transformation_merge_blocks.h"
#include "source/fuzz/uniform_buffer_element_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
@@ -871,6 +872,78 @@
MakeDataDescriptor(29, {})));
}
+TEST(FactManagerTest, HandlesCorollariesWithInvalidIds) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %12 "main"
+ OpExecutionMode %12 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %8 = OpTypeInt 32 1
+ %9 = OpConstant %8 3
+ %12 = OpFunction %2 None %3
+ %13 = OpLabel
+ %14 = OpConvertSToF %6 %9
+ OpBranch %16
+ %16 = OpLabel
+ %17 = OpPhi %6 %14 %13
+ %15 = OpConvertSToF %6 %9
+ %18 = OpConvertSToF %6 %9
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+
+ // Add required facts.
+ fact_manager.AddFactIdEquation(14, SpvOpConvertSToF, {9}, context.get());
+ fact_manager.AddFactDataSynonym(MakeDataDescriptor(14, {}),
+ MakeDataDescriptor(17, {}), context.get());
+
+ // Apply TransformationMergeBlocks which will remove %17 from the module.
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+ TransformationMergeBlocks transformation(16);
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ transformation.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ ASSERT_EQ(context->get_def_use_mgr()->GetDef(17), nullptr);
+
+ // Add another equation.
+ fact_manager.AddFactIdEquation(15, SpvOpConvertSToF, {9}, context.get());
+
+ // Check that two ids are synonymous even though one of them doesn't exist in
+ // the module (%17).
+ ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(15, {}),
+ MakeDataDescriptor(17, {})));
+ ASSERT_TRUE(fact_manager.IsSynonymous(MakeDataDescriptor(15, {}),
+ MakeDataDescriptor(14, {})));
+
+ // Remove some instructions from the module. At this point, the equivalence
+ // class of %14 has no valid members.
+ ASSERT_TRUE(context->KillDef(14));
+ ASSERT_TRUE(context->KillDef(15));
+
+ fact_manager.AddFactIdEquation(18, SpvOpConvertSToF, {9}, context.get());
+
+ // We don't create synonyms if at least one of the equivalence classes has no
+ // valid members.
+ ASSERT_FALSE(fact_manager.IsSynonymous(MakeDataDescriptor(14, {}),
+ MakeDataDescriptor(18, {})));
+}
+
TEST(FactManagerTest, LogicalNotEquationFacts) {
std::string shader = R"(
OpCapability Shader