spirv-fuzz: Refactoring and type-related fixes (#3144)
This change refactors some code for walking access chain indexes to
make it mirror the structure of other code (to improve readability in
the first instance and potentially enable a future refactoring to
extract common code), and fixes a problem related to module donation
and function types.
diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp
index 33813d2..0587a50 100644
--- a/source/fuzz/fuzzer_pass_donate_modules.cpp
+++ b/source/fuzz/fuzzer_pass_donate_modules.cpp
@@ -299,43 +299,41 @@
} break;
case SpvOpTypeFunction: {
// It is not OK to have multiple function types that use identical ids
- // for their return and parameter types. We thus first look for a
- // matching function type in the recipient module and use the id of this
- // type if a match is found. Otherwise we add a remapped version of the
- // function type.
+ // for their return and parameter types. We thus go through all
+ // existing function types to look for a match. We do not use the
+ // type manager here because we want to regard two function types that
+ // are structurally identical but that differ with respect to the
+ // actual ids used for pointer types as different.
+ //
+ // Example:
+ //
+ // %1 = OpTypeVoid
+ // %2 = OpTypeInt 32 0
+ // %3 = OpTypePointer Function %2
+ // %4 = OpTypePointer Function %2
+ // %5 = OpTypeFunction %1 %3
+ // %6 = OpTypeFunction %1 %4
+ //
+ // We regard %5 and %6 as distinct function types here, even though
+ // they both have the form "uint32* -> void"
- // Build a sequence of types used as parameters for the function type.
- std::vector<const opt::analysis::Type*> parameter_types;
- // We start iterating at 1 because 0 is the function's return type.
- for (uint32_t index = 1; index < type_or_value.NumInOperands();
- index++) {
- parameter_types.push_back(GetIRContext()->get_type_mgr()->GetType(
- original_id_to_donated_id->at(
- type_or_value.GetSingleWordInOperand(index))));
+ std::vector<uint32_t> return_and_parameter_types;
+ for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
+ return_and_parameter_types.push_back(original_id_to_donated_id->at(
+ type_or_value.GetSingleWordInOperand(i)));
}
- // Make a type object corresponding to the function type.
- opt::analysis::Function function_type(
- GetIRContext()->get_type_mgr()->GetType(
- original_id_to_donated_id->at(
- type_or_value.GetSingleWordInOperand(0))),
- parameter_types);
-
- // Check whether a function type corresponding to this this type object
- // is already declared by the module.
- auto function_type_id =
- GetIRContext()->get_type_mgr()->GetId(&function_type);
- if (function_type_id) {
- // A suitable existing function was found - use its id.
- new_result_id = function_type_id;
+ uint32_t existing_function_id = fuzzerutil::FindFunctionType(
+ GetIRContext(), return_and_parameter_types);
+ if (existing_function_id) {
+ new_result_id = existing_function_id;
} else {
// No match was found, so add a remapped version of the function type
// to the module, with a fresh id.
new_result_id = GetFuzzerContext()->GetFreshId();
std::vector<uint32_t> argument_type_ids;
- for (uint32_t index = 1; index < type_or_value.NumInOperands();
- index++) {
+ for (uint32_t i = 1; i < type_or_value.NumInOperands(); i++) {
argument_type_ids.push_back(original_id_to_donated_id->at(
- type_or_value.GetSingleWordInOperand(index)));
+ type_or_value.GetSingleWordInOperand(i)));
}
ApplyTransformation(TransformationAddTypeFunction(
new_result_id,
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 82d761c..085246e 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -258,33 +258,36 @@
auto should_be_composite_type =
context->get_def_use_mgr()->GetDef(sub_object_type_id);
assert(should_be_composite_type && "The type should exist.");
- if (SpvOpTypeArray == should_be_composite_type->opcode()) {
- auto array_length = GetArraySize(*should_be_composite_type, context);
- if (array_length == 0 || index >= array_length) {
- return 0;
+ switch (should_be_composite_type->opcode()) {
+ case SpvOpTypeArray: {
+ auto array_length = GetArraySize(*should_be_composite_type, context);
+ if (array_length == 0 || index >= array_length) {
+ return 0;
+ }
+ sub_object_type_id =
+ should_be_composite_type->GetSingleWordInOperand(0);
+ break;
}
- sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
- } else if (SpvOpTypeMatrix == should_be_composite_type->opcode()) {
- auto matrix_column_count =
- should_be_composite_type->GetSingleWordInOperand(1);
- if (index >= matrix_column_count) {
- return 0;
+ case SpvOpTypeMatrix:
+ case SpvOpTypeVector: {
+ auto count = should_be_composite_type->GetSingleWordInOperand(1);
+ if (index >= count) {
+ return 0;
+ }
+ sub_object_type_id =
+ should_be_composite_type->GetSingleWordInOperand(0);
+ break;
}
- sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
- } else if (SpvOpTypeStruct == should_be_composite_type->opcode()) {
- if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
- return 0;
+ case SpvOpTypeStruct: {
+ if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
+ return 0;
+ }
+ sub_object_type_id =
+ should_be_composite_type->GetSingleWordInOperand(index);
+ break;
}
- sub_object_type_id =
- should_be_composite_type->GetSingleWordInOperand(index);
- } else if (SpvOpTypeVector == should_be_composite_type->opcode()) {
- auto vector_length = should_be_composite_type->GetSingleWordInOperand(1);
- if (index >= vector_length) {
+ default:
return 0;
- }
- sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
- } else {
- return 0;
}
}
return sub_object_type_id;
@@ -347,6 +350,35 @@
return result;
}
+uint32_t FindFunctionType(opt::IRContext* ir_context,
+ const std::vector<uint32_t>& type_ids) {
+ // Look through the existing types for a match.
+ for (auto& type_or_value : ir_context->types_values()) {
+ if (type_or_value.opcode() != SpvOpTypeFunction) {
+ // We are only interested in function types.
+ continue;
+ }
+ if (type_or_value.NumInOperands() != type_ids.size()) {
+ // Not a match: different numbers of arguments.
+ continue;
+ }
+ // Check whether the return type and argument types match.
+ bool input_operands_match = true;
+ for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
+ if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
+ input_operands_match = false;
+ break;
+ }
+ }
+ if (input_operands_match) {
+ // Everything matches.
+ return type_or_value.result_id();
+ }
+ }
+ // No match was found.
+ return 0;
+}
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index cbd81cd..f0a2953 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -131,6 +131,12 @@
// Returns true if and only if |block_id| is a merge block or continue target
bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id);
+// Returns the result id of an instruction of the form:
+// %id = OpTypeFunction |type_ids|
+// or 0 if no such instruction exists.
+uint32_t FindFunctionType(opt::IRContext* ir_context,
+ const std::vector<uint32_t>& type_ids);
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index c7aae58..8037af1 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -16,6 +16,7 @@
#include <cassert>
+#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/transformation_add_constant_boolean.h"
#include "source/fuzz/transformation_add_constant_composite.h"
#include "source/fuzz/transformation_add_constant_scalar.h"
@@ -159,5 +160,18 @@
return nullptr;
}
+bool Transformation::CheckIdIsFreshAndNotUsedByThisTransformation(
+ uint32_t id, opt::IRContext* context,
+ std::set<uint32_t>* ids_used_by_this_transformation) {
+ if (!fuzzerutil::IsFreshId(context, id)) {
+ return false;
+ }
+ if (ids_used_by_this_transformation->count(id) != 0) {
+ return false;
+ }
+ ids_used_by_this_transformation->insert(id);
+ return true;
+}
+
} // namespace fuzz
} // namespace spvtools
diff --git a/source/fuzz/transformation.h b/source/fuzz/transformation.h
index c6b852f..dbe803f 100644
--- a/source/fuzz/transformation.h
+++ b/source/fuzz/transformation.h
@@ -83,6 +83,15 @@
// representation of a transformation given by |message|.
static std::unique_ptr<Transformation> FromMessage(
const protobufs::Transformation& message);
+
+ // Helper that returns true if and only if (a) |id| is a fresh id for the
+ // module, and (b) |id| is not in |ids_used_by_this_transformation|, a set of
+ // ids already known to be in use by a transformation. This is useful when
+ // checking id freshness for a transformation that uses many ids, all of which
+ // must be distinct.
+ static bool CheckIdIsFreshAndNotUsedByThisTransformation(
+ uint32_t id, opt::IRContext* context,
+ std::set<uint32_t>* ids_used_by_this_transformation);
};
} // namespace fuzz
diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp
index b50b9c5..1b308c4 100644
--- a/source/fuzz/transformation_outline_function.cpp
+++ b/source/fuzz/transformation_outline_function.cpp
@@ -368,20 +368,6 @@
return result;
}
-bool TransformationOutlineFunction::
- CheckIdIsFreshAndNotUsedByThisTransformation(
- uint32_t id, opt::IRContext* context,
- std::set<uint32_t>* ids_used_by_this_transformation) const {
- if (!fuzzerutil::IsFreshId(context, id)) {
- return false;
- }
- if (ids_used_by_this_transformation->count(id) != 0) {
- return false;
- }
- ids_used_by_this_transformation->insert(id);
- return true;
-}
-
std::vector<uint32_t> TransformationOutlineFunction::GetRegionInputIds(
opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set,
opt::BasicBlock* region_exit_block) {
@@ -540,15 +526,16 @@
// not exist there cannot already be a function type with this struct as its
// return type.
if (region_output_ids.empty()) {
+ std::vector<uint32_t> return_and_parameter_types;
opt::analysis::Void void_type;
return_type_id = context->get_type_mgr()->GetId(&void_type);
- std::vector<const opt::analysis::Type*> argument_types;
+ return_and_parameter_types.push_back(return_type_id);
for (auto id : region_input_ids) {
- argument_types.push_back(context->get_type_mgr()->GetType(
- context->get_def_use_mgr()->GetDef(id)->type_id()));
+ return_and_parameter_types.push_back(
+ context->get_def_use_mgr()->GetDef(id)->type_id());
}
- opt::analysis::Function function_type(&void_type, argument_types);
- function_type_id = context->get_type_mgr()->GetId(&function_type);
+ function_type_id =
+ fuzzerutil::FindFunctionType(context, return_and_parameter_types);
}
// If no existing function type was found, we need to create one.
diff --git a/source/fuzz/transformation_outline_function.h b/source/fuzz/transformation_outline_function.h
index b59e660..43bdf3b 100644
--- a/source/fuzz/transformation_outline_function.h
+++ b/source/fuzz/transformation_outline_function.h
@@ -128,15 +128,6 @@
opt::BasicBlock* region_exit_block);
private:
- // A helper method for the applicability check. Returns true if and only if
- // |id| is (a) a fresh id for the module, and (b) an id that has not
- // previously been subject to this check. We use this to check whether the
- // ids given for the transformation are not only fresh but also different from
- // one another.
- bool CheckIdIsFreshAndNotUsedByThisTransformation(
- uint32_t id, opt::IRContext* context,
- std::set<uint32_t>* ids_used_by_this_transformation) const;
-
// Ensures that the module's id bound is at least the maximum of any fresh id
// associated with the transformation.
void UpdateModuleIdBoundForFreshIds(
diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
index 988b675..7342dd4 100644
--- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp
+++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
@@ -438,6 +438,58 @@
ASSERT_TRUE(IsEqual(env, after_transformation, recipient_context.get()));
}
+TEST(FuzzerPassDonateModulesTest, DonateFunctionTypeWithDifferentPointers) {
+ std::string recipient_and_donor_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 0
+ %7 = OpTypePointer Function %6
+ %8 = OpTypeFunction %2 %7
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %9 = OpVariable %7 Function
+ %10 = OpFunctionCall %2 %11 %9
+ OpReturn
+ OpFunctionEnd
+ %11 = OpFunction %2 None %8
+ %12 = OpFunctionParameter %7
+ %13 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto recipient_context = BuildModule(
+ env, consumer, recipient_and_donor_shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+ const auto donor_context = BuildModule(
+ env, consumer, recipient_and_donor_shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+ FactManager fact_manager;
+
+ FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+ protobufs::TransformationSequence transformation_sequence;
+
+ FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager,
+ &fuzzer_context, &transformation_sequence,
+ {});
+
+ fuzzer_pass.DonateSingleModule(donor_context.get());
+
+ // We just check that the result is valid. Checking to what it should be
+ // exactly equal to would be very fragile.
+ ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
TEST(FuzzerPassDonateModulesTest, Miscellaneous1) {
std::string recipient_shader = R"(
OpCapability Shader
diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp
index 4f828b6..5cd1437 100644
--- a/test/fuzz/transformation_outline_function_test.cpp
+++ b/test/fuzz/transformation_outline_function_test.cpp
@@ -2040,6 +2040,91 @@
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationOutlineFunctionTest, Miscellaneous4) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %6 "main"
+ OpExecutionMode %6 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %21 = OpTypeBool
+ %100 = OpTypeInt 32 0
+ %101 = OpTypePointer Function %100
+ %102 = OpTypePointer Function %100
+ %103 = OpTypeFunction %2 %101
+ %6 = OpFunction %2 None %3
+ %7 = OpLabel
+ %104 = OpVariable %102 Function
+ OpBranch %80
+ %80 = OpLabel
+ %105 = OpLoad %100 %104
+ OpBranch %106
+ %106 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+
+ TransformationOutlineFunction transformation(
+ /*entry_block*/ 80,
+ /*exit_block*/ 106,
+ /*new_function_struct_return_type_id*/ 300,
+ /*new_function_type_id*/ 301,
+ /*new_function_id*/ 302,
+ /*new_function_region_entry_block*/ 304,
+ /*new_caller_result_id*/ 305,
+ /*new_callee_result_id*/ 306,
+ /*input_id_to_fresh_id*/ {{104, 307}},
+ /*output_id_to_fresh_id*/ {});
+
+ ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager));
+ transformation.Apply(context.get(), &fact_manager);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %6 "main"
+ OpExecutionMode %6 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %21 = OpTypeBool
+ %100 = OpTypeInt 32 0
+ %101 = OpTypePointer Function %100
+ %102 = OpTypePointer Function %100
+ %103 = OpTypeFunction %2 %101
+ %301 = OpTypeFunction %2 %102
+ %6 = OpFunction %2 None %3
+ %7 = OpLabel
+ %104 = OpVariable %102 Function
+ OpBranch %80
+ %80 = OpLabel
+ %305 = OpFunctionCall %2 %302 %104
+ OpReturn
+ OpFunctionEnd
+ %302 = OpFunction %2 None %301
+ %307 = OpFunctionParameter %102
+ %304 = OpLabel
+ %105 = OpLoad %100 %307
+ OpBranch %106
+ %106 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools