spirv-fuzz: Remove OpFunctionCall operands in correct order (#3630)
Fixes #3629.
diff --git a/source/fuzz/transformation_replace_params_with_struct.cpp b/source/fuzz/transformation_replace_params_with_struct.cpp
index 7e76415..8c683a3 100644
--- a/source/fuzz/transformation_replace_params_with_struct.cpp
+++ b/source/fuzz/transformation_replace_params_with_struct.cpp
@@ -101,8 +101,9 @@
return false;
}
- auto caller_id_to_fresh_composite_id = fuzzerutil::RepeatedUInt32PairToMap(
- message_.caller_id_to_fresh_composite_id());
+ const auto caller_id_to_fresh_composite_id =
+ fuzzerutil::RepeatedUInt32PairToMap(
+ message_.caller_id_to_fresh_composite_id());
// Check that |callee_id_to_fresh_composite_id| is valid.
for (const auto* inst :
@@ -151,24 +152,12 @@
// Compute indices of replaced parameters. This will be used to adjust
// OpFunctionCall instructions and create OpCompositeConstruct instructions at
// every call site.
- std::vector<uint32_t> indices_of_replaced_params;
- {
- // We want to destroy |params| after the loop because it will contain
- // dangling pointers when we remove parameters from the function.
- auto params = fuzzerutil::GetParameters(ir_context, function->result_id());
- for (auto id : message_.parameter_id()) {
- auto it = std::find_if(params.begin(), params.end(),
- [id](const opt::Instruction* param) {
- return param->result_id() == id;
- });
- assert(it != params.end() && "Parameter's id is invalid");
- indices_of_replaced_params.push_back(
- static_cast<uint32_t>(it - params.begin()));
- }
- }
+ const auto indices_of_replaced_params =
+ ComputeIndicesOfReplacedParameters(ir_context);
- auto caller_id_to_fresh_composite_id = fuzzerutil::RepeatedUInt32PairToMap(
- message_.caller_id_to_fresh_composite_id());
+ const auto caller_id_to_fresh_composite_id =
+ fuzzerutil::RepeatedUInt32PairToMap(
+ message_.caller_id_to_fresh_composite_id());
// Update all function calls.
for (auto* inst : fuzzerutil::GetCallers(ir_context, function->result_id())) {
@@ -182,12 +171,13 @@
}
// Remove arguments from the function call. We do it in a separate loop
- // and in reverse order to make sure we have removed correct operands.
- for (auto it = indices_of_replaced_params.rbegin();
- it != indices_of_replaced_params.rend(); ++it) {
+ // and in decreasing order to make sure we have removed correct operands.
+ for (auto index : std::set<uint32_t, std::greater<uint32_t>>(
+ indices_of_replaced_params.begin(),
+ indices_of_replaced_params.end())) {
// +1 since the first in operand to OpFunctionCall is the result id of
// the function.
- inst->RemoveInOperand(*it + 1);
+ inst->RemoveInOperand(index + 1);
}
// Insert OpCompositeConstruct before the function call.
@@ -305,5 +295,30 @@
return fuzzerutil::MaybeGetStructType(ir_context, component_type_ids);
}
+std::vector<uint32_t>
+TransformationReplaceParamsWithStruct::ComputeIndicesOfReplacedParameters(
+ opt::IRContext* ir_context) const {
+ assert(!message_.parameter_id().empty() &&
+ "There must be at least one parameter to replace");
+
+ const auto* function = fuzzerutil::GetFunctionFromParameterId(
+ ir_context, message_.parameter_id(0));
+ assert(function && "|parameter_id|s are invalid");
+
+ std::vector<uint32_t> result;
+
+ auto params = fuzzerutil::GetParameters(ir_context, function->result_id());
+ for (auto id : message_.parameter_id()) {
+ auto it = std::find_if(params.begin(), params.end(),
+ [id](const opt::Instruction* param) {
+ return param->result_id() == id;
+ });
+ assert(it != params.end() && "Parameter's id is invalid");
+ result.push_back(static_cast<uint32_t>(it - params.begin()));
+ }
+
+ return result;
+}
+
} // namespace fuzz
} // namespace spvtools
diff --git a/source/fuzz/transformation_replace_params_with_struct.h b/source/fuzz/transformation_replace_params_with_struct.h
index d2ce204..7e40de8 100644
--- a/source/fuzz/transformation_replace_params_with_struct.h
+++ b/source/fuzz/transformation_replace_params_with_struct.h
@@ -73,6 +73,12 @@
// transformation (see docs on the IsApplicable method to learn more).
uint32_t MaybeGetRequiredStructType(opt::IRContext* ir_context) const;
+ // Returns a vector of indices of parameters to replace. Concretely, i'th
+ // element is the index of the parameter with result id |parameter_id[i]| in
+ // its function.
+ std::vector<uint32_t> ComputeIndicesOfReplacedParameters(
+ opt::IRContext* ir_context) const;
+
protobufs::TransformationReplaceParamsWithStruct message_;
};
diff --git a/test/fuzz/transformation_replace_params_with_struct_test.cpp b/test/fuzz/transformation_replace_params_with_struct_test.cpp
index e59f6ea..a198b42 100644
--- a/test/fuzz/transformation_replace_params_with_struct_test.cpp
+++ b/test/fuzz/transformation_replace_params_with_struct_test.cpp
@@ -333,6 +333,147 @@
ASSERT_TRUE(IsEqual(env, expected_shader, context.get()));
}
+TEST(TransformationReplaceParamsWithStructTest, ParametersRemainValid) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %6 = OpTypeInt 32 1
+ %3 = OpTypeFunction %2
+ %8 = OpTypeFloat 32
+ %10 = OpTypeVector %8 2
+ %12 = OpTypeBool
+ %40 = OpTypePointer Function %12
+ %13 = OpTypeStruct %6 %8
+ %45 = OpTypeStruct %6 %8 %13
+ %47 = OpTypeStruct %45 %12 %10
+ %15 = OpTypeFunction %2 %6 %8 %10 %13 %40 %12
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %20 = OpFunction %2 None %15
+ %16 = OpFunctionParameter %6
+ %17 = OpFunctionParameter %8
+ %18 = OpFunctionParameter %10
+ %19 = OpFunctionParameter %13
+ %42 = OpFunctionParameter %40
+ %43 = OpFunctionParameter %12
+ %21 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+
+ {
+ // Try to replace parameters in "increasing" order of their declaration.
+ TransformationReplaceParamsWithStruct transformation({16, 17, 19}, 70, 71,
+ {{}});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ transformation.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ }
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %6 = OpTypeInt 32 1
+ %3 = OpTypeFunction %2
+ %8 = OpTypeFloat 32
+ %10 = OpTypeVector %8 2
+ %12 = OpTypeBool
+ %40 = OpTypePointer Function %12
+ %13 = OpTypeStruct %6 %8
+ %45 = OpTypeStruct %6 %8 %13
+ %47 = OpTypeStruct %45 %12 %10
+ %15 = OpTypeFunction %2 %10 %40 %12 %45
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %20 = OpFunction %2 None %15
+ %18 = OpFunctionParameter %10
+ %42 = OpFunctionParameter %40
+ %43 = OpFunctionParameter %12
+ %71 = OpFunctionParameter %45
+ %21 = OpLabel
+ %19 = OpCompositeExtract %13 %71 2
+ %17 = OpCompositeExtract %8 %71 1
+ %16 = OpCompositeExtract %6 %71 0
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+
+ {
+ // Try to replace parameters in "decreasing" order of their declaration.
+ TransformationReplaceParamsWithStruct transformation({71, 43, 18}, 72, 73,
+ {{}});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ transformation.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ }
+
+ after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %6 = OpTypeInt 32 1
+ %3 = OpTypeFunction %2
+ %8 = OpTypeFloat 32
+ %10 = OpTypeVector %8 2
+ %12 = OpTypeBool
+ %40 = OpTypePointer Function %12
+ %13 = OpTypeStruct %6 %8
+ %45 = OpTypeStruct %6 %8 %13
+ %47 = OpTypeStruct %45 %12 %10
+ %15 = OpTypeFunction %2 %40 %47
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %20 = OpFunction %2 None %15
+ %42 = OpFunctionParameter %40
+ %73 = OpFunctionParameter %47
+ %21 = OpLabel
+ %18 = OpCompositeExtract %10 %73 2
+ %43 = OpCompositeExtract %12 %73 1
+ %71 = OpCompositeExtract %45 %73 0
+ %19 = OpCompositeExtract %13 %71 2
+ %17 = OpCompositeExtract %8 %71 1
+ %16 = OpCompositeExtract %6 %71 0
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools