spirv-fuzz: Refactor boilerplate in TransformationAddParameter (#3625)
Part of #3534. I forgot to implement this functionality in the original PR.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index b2d5681..15d1057 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -826,6 +826,12 @@
operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
parameter_type_ids.end());
+ // A trivial case - we change nothing.
+ if (FindFunctionType(ir_context, operand_ids) ==
+ old_function_type->result_id()) {
+ return old_function_type->result_id();
+ }
+
if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
FindFunctionType(ir_context, operand_ids) == 0) {
// We can change |old_function_type| only if it's used once in the module
@@ -849,17 +855,16 @@
// existing one or create a new one.
auto type_id = FindOrCreateFunctionType(
ir_context, new_function_type_result_id, operand_ids);
+ assert(type_id != old_function_type->result_id() &&
+ "We should've handled this case above");
- if (type_id != old_function_type->result_id()) {
- function->DefInst().SetInOperand(1, {type_id});
+ function->DefInst().SetInOperand(1, {type_id});
- // DefUseManager hasn't been updated yet, so if the following condition is
- // true, then |old_function_type| will have no users when this function
- // returns. We might as well remove it.
- if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
- old_function_type->RemoveFromList();
- delete old_function_type;
- }
+ // DefUseManager hasn't been updated yet, so if the following condition is
+ // true, then |old_function_type| will have no users when this function
+ // returns. We might as well remove it.
+ if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
+ ir_context->KillInst(old_function_type);
}
return type_id;
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 343c94e..6058022 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -316,6 +316,10 @@
// more users, it is removed from the module. Returns the result id of the
// OpTypeFunction instruction that is used as a type of the function with
// |function_id|.
+//
+// CAUTION: When the old type of the function is removed from the module, its
+// memory is deallocated. Be sure not to use any pointers to the old
+// type when this function returns.
uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
uint32_t new_function_type_result_id,
uint32_t return_type_id,
diff --git a/source/fuzz/transformation_add_parameter.cpp b/source/fuzz/transformation_add_parameter.cpp
index cc32362..6c0ab28 100644
--- a/source/fuzz/transformation_add_parameter.cpp
+++ b/source/fuzz/transformation_add_parameter.cpp
@@ -14,8 +14,6 @@
#include "source/fuzz/transformation_add_parameter.h"
-#include <source/spirv_constant.h>
-
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
@@ -72,13 +70,13 @@
auto* function = fuzzerutil::FindFunction(ir_context, message_.function_id());
assert(function && "Can't find the function");
- auto parameter_type_id =
+ const auto new_parameter_type_id =
fuzzerutil::GetTypeId(ir_context, message_.initializer_id());
- assert(parameter_type_id != 0 && "Initializer has invalid type");
+ assert(new_parameter_type_id != 0 && "Initializer has invalid type");
// Add new parameters to the function.
function->AddParameter(MakeUnique<opt::Instruction>(
- ir_context, SpvOpFunctionParameter, parameter_type_id,
+ ir_context, SpvOpFunctionParameter, new_parameter_type_id,
message_.parameter_fresh_id(), opt::Instruction::OperandList()));
fuzzerutil::UpdateModuleIdBound(ir_context, message_.parameter_fresh_id());
@@ -92,43 +90,30 @@
message_.parameter_fresh_id());
// Fix all OpFunctionCall instructions.
- ir_context->get_def_use_mgr()->ForEachUser(
- &function->DefInst(), [this](opt::Instruction* call) {
- if (call->opcode() != SpvOpFunctionCall ||
- call->GetSingleWordInOperand(0) != message_.function_id()) {
- return;
- }
+ for (auto* inst : fuzzerutil::GetCallers(ir_context, function->result_id())) {
+ inst->AddOperand({SPV_OPERAND_TYPE_ID, {message_.initializer_id()}});
+ }
- call->AddOperand({SPV_OPERAND_TYPE_ID, {message_.initializer_id()}});
- });
+ // Update function's type.
+ {
+ // We use a separate scope here since |old_function_type| might become a
+ // dangling pointer after the call to the fuzzerutil::UpdateFunctionType.
- auto* old_function_type = fuzzerutil::GetFunctionType(ir_context, function);
- assert(old_function_type && "Function must have a valid type");
+ const auto* old_function_type =
+ fuzzerutil::GetFunctionType(ir_context, function);
+ assert(old_function_type && "Function must have a valid type");
- if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
- // Adjust existing function type if it is used only by this function.
- old_function_type->AddOperand({SPV_OPERAND_TYPE_ID, {parameter_type_id}});
-
- // We must make sure that all dependencies of |old_function_type| are
- // evaluated before |old_function_type| (i.e. the domination rules are not
- // broken). Thus, we move |old_function_type| to the end of the list of all
- // types in the module.
- old_function_type->RemoveFromList();
- ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
- } else {
- // Otherwise, either create a new type or use an existing one.
- std::vector<uint32_t> type_ids;
- type_ids.reserve(old_function_type->NumInOperands() + 1);
-
- for (uint32_t i = 0, n = old_function_type->NumInOperands(); i < n; ++i) {
- type_ids.push_back(old_function_type->GetSingleWordInOperand(i));
+ std::vector<uint32_t> parameter_type_ids;
+ for (uint32_t i = 1; i < old_function_type->NumInOperands(); ++i) {
+ parameter_type_ids.push_back(
+ old_function_type->GetSingleWordInOperand(i));
}
- type_ids.push_back(parameter_type_id);
+ parameter_type_ids.push_back(new_parameter_type_id);
- function->DefInst().SetInOperand(
- 1, {fuzzerutil::FindOrCreateFunctionType(
- ir_context, message_.function_type_fresh_id(), type_ids)});
+ fuzzerutil::UpdateFunctionType(
+ ir_context, function->result_id(), message_.function_type_fresh_id(),
+ old_function_type->GetSingleWordInOperand(0), parameter_type_ids);
}
// Make sure our changes are analyzed.
diff --git a/test/fuzz/transformation_add_parameter_test.cpp b/test/fuzz/transformation_add_parameter_test.cpp
index 6593d00..c57c738 100644
--- a/test/fuzz/transformation_add_parameter_test.cpp
+++ b/test/fuzz/transformation_add_parameter_test.cpp
@@ -32,20 +32,69 @@
%2 = OpTypeVoid
%7 = OpTypeBool
%11 = OpTypeInt 32 1
+ %16 = OpTypeFloat 32
%3 = OpTypeFunction %2
%6 = OpTypeFunction %7 %7
%8 = OpConstant %11 23
%12 = OpConstantTrue %7
+ %15 = OpTypeFunction %2 %16
+ %24 = OpTypeFunction %2 %16 %7
+ %31 = OpTypeStruct %7 %11
+ %32 = OpConstant %16 23
+ %33 = OpConstantComposite %31 %12 %8
+ %41 = OpTypeStruct %11 %16
+ %42 = OpConstantComposite %41 %8 %32
+ %43 = OpTypeFunction %2 %41
+ %44 = OpTypeFunction %2 %41 %7
%4 = OpFunction %2 None %3
%5 = OpLabel
%13 = OpFunctionCall %7 %9 %12
OpReturn
OpFunctionEnd
+
+ ; adjust type of the function in-place
%9 = OpFunction %7 None %6
%14 = OpFunctionParameter %7
%10 = OpLabel
OpReturnValue %12
OpFunctionEnd
+
+ ; reuse an existing function type
+ %17 = OpFunction %2 None %15
+ %18 = OpFunctionParameter %16
+ %19 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %20 = OpFunction %2 None %15
+ %21 = OpFunctionParameter %16
+ %22 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %25 = OpFunction %2 None %24
+ %26 = OpFunctionParameter %16
+ %27 = OpFunctionParameter %7
+ %28 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; create a new function type
+ %29 = OpFunction %2 None %3
+ %30 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; don't adjust the type of the function if it creates a duplicate
+ %34 = OpFunction %2 None %43
+ %35 = OpFunctionParameter %41
+ %36 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %37 = OpFunction %2 None %44
+ %38 = OpFunctionParameter %41
+ %39 = OpFunctionParameter %7
+ %40 = OpLabel
+ OpReturn
+ OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_3;
@@ -59,38 +108,58 @@
validator_options);
// Can't modify entry point function.
- ASSERT_FALSE(TransformationAddParameter(4, 15, 12, 16)
+ ASSERT_FALSE(TransformationAddParameter(4, 60, 12, 61)
.IsApplicable(context.get(), transformation_context));
// There is no function with result id 29.
- ASSERT_FALSE(TransformationAddParameter(29, 15, 8, 16)
+ ASSERT_FALSE(TransformationAddParameter(60, 60, 8, 61)
.IsApplicable(context.get(), transformation_context));
// Parameter id is not fresh.
- ASSERT_FALSE(TransformationAddParameter(9, 14, 8, 16)
+ ASSERT_FALSE(TransformationAddParameter(9, 14, 8, 61)
.IsApplicable(context.get(), transformation_context));
// Function type id is not fresh.
- ASSERT_FALSE(TransformationAddParameter(9, 15, 8, 14)
+ ASSERT_FALSE(TransformationAddParameter(9, 60, 8, 14)
.IsApplicable(context.get(), transformation_context));
// Function type id and parameter type id are equal.
- ASSERT_FALSE(TransformationAddParameter(9, 15, 8, 15)
+ ASSERT_FALSE(TransformationAddParameter(9, 60, 8, 60)
.IsApplicable(context.get(), transformation_context));
// Parameter's initializer doesn't exist.
- ASSERT_FALSE(TransformationAddParameter(9, 15, 15, 16)
+ ASSERT_FALSE(TransformationAddParameter(9, 60, 60, 61)
.IsApplicable(context.get(), transformation_context));
- // Correct transformation.
- TransformationAddParameter correct(9, 15, 8, 16);
- ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
- correct.Apply(context.get(), &transformation_context);
-
- // The module remains valid.
- ASSERT_TRUE(IsValid(env, context.get()));
-
- ASSERT_TRUE(fact_manager.IdIsIrrelevant(15));
+ // Correct transformations.
+ {
+ TransformationAddParameter correct(9, 60, 8, 61);
+ ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+ correct.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(fact_manager.IdIsIrrelevant(60));
+ }
+ {
+ TransformationAddParameter correct(17, 62, 12, 63);
+ ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+ correct.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(fact_manager.IdIsIrrelevant(62));
+ }
+ {
+ TransformationAddParameter correct(29, 64, 33, 65);
+ ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+ correct.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(fact_manager.IdIsIrrelevant(64));
+ }
+ {
+ TransformationAddParameter correct(34, 66, 12, 67);
+ ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+ correct.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(fact_manager.IdIsIrrelevant(66));
+ }
std::string expected_shader = R"(
OpCapability Shader
@@ -103,21 +172,73 @@
%2 = OpTypeVoid
%7 = OpTypeBool
%11 = OpTypeInt 32 1
+ %16 = OpTypeFloat 32
%3 = OpTypeFunction %2
%8 = OpConstant %11 23
%12 = OpConstantTrue %7
+ %15 = OpTypeFunction %2 %16
+ %24 = OpTypeFunction %2 %16 %7
+ %31 = OpTypeStruct %7 %11
+ %32 = OpConstant %16 23
+ %33 = OpConstantComposite %31 %12 %8
+ %41 = OpTypeStruct %11 %16
+ %42 = OpConstantComposite %41 %8 %32
+ %44 = OpTypeFunction %2 %41 %7
%6 = OpTypeFunction %7 %7 %11
+ %65 = OpTypeFunction %2 %31
%4 = OpFunction %2 None %3
%5 = OpLabel
%13 = OpFunctionCall %7 %9 %12 %8
OpReturn
OpFunctionEnd
+
+ ; adjust type of the function in-place
%9 = OpFunction %7 None %6
%14 = OpFunctionParameter %7
- %15 = OpFunctionParameter %11
+ %60 = OpFunctionParameter %11
%10 = OpLabel
OpReturnValue %12
OpFunctionEnd
+
+ ; reuse an existing function type
+ %17 = OpFunction %2 None %24
+ %18 = OpFunctionParameter %16
+ %62 = OpFunctionParameter %7
+ %19 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %20 = OpFunction %2 None %15
+ %21 = OpFunctionParameter %16
+ %22 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %25 = OpFunction %2 None %24
+ %26 = OpFunctionParameter %16
+ %27 = OpFunctionParameter %7
+ %28 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; create a new function type
+ %29 = OpFunction %2 None %65
+ %64 = OpFunctionParameter %31
+ %30 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; don't adjust the type of the function if it creates a duplicate
+ %34 = OpFunction %2 None %44
+ %35 = OpFunctionParameter %41
+ %66 = OpFunctionParameter %7
+ %36 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %37 = OpFunction %2 None %44
+ %38 = OpFunctionParameter %41
+ %39 = OpFunctionParameter %7
+ %40 = OpLabel
+ OpReturn
+ OpFunctionEnd
)";
ASSERT_TRUE(IsEqual(env, expected_shader, context.get()));