spirv-fuzz: Use overflow ids when inlining functions (#3880)
Fixes #3751.
diff --git a/source/fuzz/transformation_inline_function.cpp b/source/fuzz/transformation_inline_function.cpp
index 31f6fb3..cdcee1b 100644
--- a/source/fuzz/transformation_inline_function.cpp
+++ b/source/fuzz/transformation_inline_function.cpp
@@ -33,7 +33,8 @@
}
bool TransformationInlineFunction::IsApplicable(
- opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
+ opt::IRContext* ir_context,
+ const TransformationContext& transformation_context) const {
// The values in the |message_.result_id_map| must be all fresh and all
// distinct.
const auto result_id_map =
@@ -71,7 +72,8 @@
// Since the entry block label will not be inlined, only the remaining
// labels must have a corresponding value in the map.
if (&block != &*called_function->entry() &&
- !result_id_map.count(block.GetLabel()->result_id())) {
+ !result_id_map.count(block.id()) &&
+ !transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
return false;
}
@@ -81,7 +83,8 @@
// If |instruction| has result id, then it must have a mapped id in
// |result_id_map|.
if (instruction.HasResultId() &&
- !result_id_map.count(instruction.result_id())) {
+ !result_id_map.count(instruction.result_id()) &&
+ !transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
return false;
}
}
@@ -100,15 +103,37 @@
}
void TransformationInlineFunction::Apply(
- opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
+ opt::IRContext* ir_context,
+ TransformationContext* transformation_context) const {
auto* function_call_instruction =
ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
auto* caller_function =
ir_context->get_instr_block(function_call_instruction)->GetParent();
auto* called_function = fuzzerutil::FindFunction(
ir_context, function_call_instruction->GetSingleWordInOperand(0));
- const auto result_id_map =
+ std::map<uint32_t, uint32_t> result_id_map =
fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
+
+ // If there are gaps in the result id map, fill them using overflow ids.
+ for (auto& block : *called_function) {
+ if (&block != &*called_function->entry() &&
+ !result_id_map.count(block.id())) {
+ result_id_map.insert(
+ {block.id(),
+ transformation_context->GetOverflowIdSource()->GetNextOverflowId()});
+ }
+ for (auto& instruction : block) {
+ // If |instruction| has result id, then it must have a mapped id in
+ // |result_id_map|.
+ if (instruction.HasResultId() &&
+ !result_id_map.count(instruction.result_id())) {
+ result_id_map.insert({instruction.result_id(),
+ transformation_context->GetOverflowIdSource()
+ ->GetNextOverflowId()});
+ }
+ }
+ }
+
auto* successor_block = ir_context->cfg()->block(
ir_context->get_instr_block(function_call_instruction)
->terminator()
@@ -128,7 +153,7 @@
MakeUnique<opt::Instruction>(entry_block_instruction));
}
- AdaptInlinedInstruction(ir_context, inlined_instruction);
+ AdaptInlinedInstruction(result_id_map, ir_context, inlined_instruction);
}
// Inline the |called_function| non-entry blocks.
@@ -141,13 +166,11 @@
cloned_block = caller_function->InsertBasicBlockBefore(
std::unique_ptr<opt::BasicBlock>(cloned_block), successor_block);
cloned_block->SetParent(caller_function);
- cloned_block->GetLabel()->SetResultId(
- result_id_map.at(cloned_block->GetLabel()->result_id()));
- fuzzerutil::UpdateModuleIdBound(ir_context,
- cloned_block->GetLabel()->result_id());
+ cloned_block->GetLabel()->SetResultId(result_id_map.at(cloned_block->id()));
+ fuzzerutil::UpdateModuleIdBound(ir_context, cloned_block->id());
for (auto& inlined_instruction : *cloned_block) {
- AdaptInlinedInstruction(ir_context, &inlined_instruction);
+ AdaptInlinedInstruction(result_id_map, ir_context, &inlined_instruction);
}
}
@@ -202,14 +225,13 @@
}
void TransformationInlineFunction::AdaptInlinedInstruction(
+ const std::map<uint32_t, uint32_t>& result_id_map,
opt::IRContext* ir_context,
opt::Instruction* instruction_to_be_inlined) const {
auto* function_call_instruction =
ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
auto* called_function = fuzzerutil::FindFunction(
ir_context, function_call_instruction->GetSingleWordInOperand(0));
- const auto result_id_map =
- fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
const auto* function_call_block =
ir_context->get_instr_block(function_call_instruction);
diff --git a/source/fuzz/transformation_inline_function.h b/source/fuzz/transformation_inline_function.h
index 272024a..8105d92 100644
--- a/source/fuzz/transformation_inline_function.h
+++ b/source/fuzz/transformation_inline_function.h
@@ -33,7 +33,7 @@
const std::map<uint32_t, uint32_t>& result_id_map);
// - |message_.result_id_map| must map the instructions of the called function
- // to fresh ids.
+ // to fresh ids, unless overflow ids are available.
// - |message_.function_call_id| must be an OpFunctionCall instruction.
// It must not have an early return and must not use OpUnreachable or
// OpKill. This is to guard against making the module invalid when the
@@ -67,8 +67,9 @@
// Inline |instruction_to_be_inlined| by setting its ids to the corresponding
// ids in |result_id_map|.
- void AdaptInlinedInstruction(opt::IRContext* ir_context,
- opt::Instruction* instruction) const;
+ void AdaptInlinedInstruction(
+ const std::map<uint32_t, uint32_t>& result_id_map,
+ opt::IRContext* ir_context, opt::Instruction* instruction) const;
};
} // namespace fuzz
diff --git a/test/fuzz/transformation_inline_function_test.cpp b/test/fuzz/transformation_inline_function_test.cpp
index 89887be..09cf936 100644
--- a/test/fuzz/transformation_inline_function_test.cpp
+++ b/test/fuzz/transformation_inline_function_test.cpp
@@ -14,6 +14,7 @@
#include "source/fuzz/transformation_inline_function.h"
+#include "source/fuzz/counter_overflow_id_source.h"
#include "source/fuzz/instruction_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
@@ -533,6 +534,7 @@
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
+#ifndef NDEBUG
// Tests the id of the returned value not included in the id map.
transformation = TransformationInlineFunction(25, {{56, 69},
{57, 70},
@@ -544,8 +546,10 @@
{64, 76},
{65, 77},
{66, 78}});
- ASSERT_FALSE(
- transformation.IsApplicable(context.get(), transformation_context));
+ ASSERT_DEATH(
+ transformation.IsApplicable(context.get(), transformation_context),
+ "Bad attempt to query whether overflow ids are available.");
+#endif
transformation = TransformationInlineFunction(25, {{57, 69},
{58, 70},
@@ -819,6 +823,198 @@
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationInlineFunctionTest, OverflowIds) {
+ std::string reference_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %39 "main"
+
+; Types
+ %2 = OpTypeFloat 32
+ %3 = OpTypeVector %2 4
+ %4 = OpTypePointer Function %3
+ %5 = OpTypeVoid
+ %6 = OpTypeFunction %5
+ %7 = OpTypeFunction %2 %4 %4
+
+; Constant scalars
+ %8 = OpConstant %2 1
+ %9 = OpConstant %2 2
+ %10 = OpConstant %2 3
+ %11 = OpConstant %2 4
+ %12 = OpConstant %2 5
+ %13 = OpConstant %2 6
+ %14 = OpConstant %2 7
+ %15 = OpConstant %2 8
+
+; Constant vectors
+ %16 = OpConstantComposite %3 %8 %9 %10 %11
+ %17 = OpConstantComposite %3 %12 %13 %14 %15
+
+; dot product function
+ %18 = OpFunction %2 None %7
+ %19 = OpFunctionParameter %4
+ %20 = OpFunctionParameter %4
+ %21 = OpLabel
+ %22 = OpLoad %3 %19
+ %23 = OpLoad %3 %20
+ %24 = OpCompositeExtract %2 %22 0
+ %25 = OpCompositeExtract %2 %23 0
+ %26 = OpFMul %2 %24 %25
+ %27 = OpCompositeExtract %2 %22 1
+ %28 = OpCompositeExtract %2 %23 1
+ %29 = OpFMul %2 %27 %28
+ OpBranch %100
+ %100 = OpLabel
+ %30 = OpCompositeExtract %2 %22 2
+ %31 = OpCompositeExtract %2 %23 2
+ %32 = OpFMul %2 %30 %31
+ %33 = OpCompositeExtract %2 %22 3
+ %34 = OpCompositeExtract %2 %23 3
+ %35 = OpFMul %2 %33 %34
+ %36 = OpFAdd %2 %26 %29
+ %37 = OpFAdd %2 %32 %36
+ %38 = OpFAdd %2 %35 %37
+ OpReturnValue %38
+ OpFunctionEnd
+
+; main function
+ %39 = OpFunction %5 None %6
+ %40 = OpLabel
+ %41 = OpVariable %4 Function
+ %42 = OpVariable %4 Function
+ OpStore %41 %16
+ OpStore %42 %17
+ %43 = OpFunctionCall %2 %18 %41 %42 ; dot product function call
+ OpBranch %44
+ %44 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context =
+ BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ auto overflow_ids_unique_ptr = MakeUnique<CounterOverflowIdSource>(1000);
+ auto overflow_ids_ptr = overflow_ids_unique_ptr.get();
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options,
+ std::move(overflow_ids_unique_ptr));
+ auto transformation = TransformationInlineFunction(43, {{22, 45},
+ {23, 46},
+ {24, 47},
+ {25, 48},
+ {26, 49},
+ {27, 50},
+ {28, 51},
+ {29, 52}});
+
+ // The following ids are left un-mapped; overflow ids will be required for
+ // them: 30, 31, 32, 33, 34, 35, 36, 37, 38, 100
+
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context,
+ overflow_ids_ptr->GetIssuedOverflowIds());
+
+ std::string variant_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %39 "main"
+
+; Types
+ %2 = OpTypeFloat 32
+ %3 = OpTypeVector %2 4
+ %4 = OpTypePointer Function %3
+ %5 = OpTypeVoid
+ %6 = OpTypeFunction %5
+ %7 = OpTypeFunction %2 %4 %4
+
+; Constant scalars
+ %8 = OpConstant %2 1
+ %9 = OpConstant %2 2
+ %10 = OpConstant %2 3
+ %11 = OpConstant %2 4
+ %12 = OpConstant %2 5
+ %13 = OpConstant %2 6
+ %14 = OpConstant %2 7
+ %15 = OpConstant %2 8
+
+; Constant vectors
+ %16 = OpConstantComposite %3 %8 %9 %10 %11
+ %17 = OpConstantComposite %3 %12 %13 %14 %15
+
+; dot product function
+ %18 = OpFunction %2 None %7
+ %19 = OpFunctionParameter %4
+ %20 = OpFunctionParameter %4
+ %21 = OpLabel
+ %22 = OpLoad %3 %19
+ %23 = OpLoad %3 %20
+ %24 = OpCompositeExtract %2 %22 0
+ %25 = OpCompositeExtract %2 %23 0
+ %26 = OpFMul %2 %24 %25
+ %27 = OpCompositeExtract %2 %22 1
+ %28 = OpCompositeExtract %2 %23 1
+ %29 = OpFMul %2 %27 %28
+ OpBranch %100
+ %100 = OpLabel
+ %30 = OpCompositeExtract %2 %22 2
+ %31 = OpCompositeExtract %2 %23 2
+ %32 = OpFMul %2 %30 %31
+ %33 = OpCompositeExtract %2 %22 3
+ %34 = OpCompositeExtract %2 %23 3
+ %35 = OpFMul %2 %33 %34
+ %36 = OpFAdd %2 %26 %29
+ %37 = OpFAdd %2 %32 %36
+ %38 = OpFAdd %2 %35 %37
+ OpReturnValue %38
+ OpFunctionEnd
+
+; main function
+ %39 = OpFunction %5 None %6
+ %40 = OpLabel
+ %41 = OpVariable %4 Function
+ %42 = OpVariable %4 Function
+ OpStore %41 %16
+ OpStore %42 %17
+ %45 = OpLoad %3 %41
+ %46 = OpLoad %3 %42
+ %47 = OpCompositeExtract %2 %45 0
+ %48 = OpCompositeExtract %2 %46 0
+ %49 = OpFMul %2 %47 %48
+ %50 = OpCompositeExtract %2 %45 1
+ %51 = OpCompositeExtract %2 %46 1
+ %52 = OpFMul %2 %50 %51
+ OpBranch %1000
+ %1000 = OpLabel
+ %1001 = OpCompositeExtract %2 %45 2
+ %1002 = OpCompositeExtract %2 %46 2
+ %1003 = OpFMul %2 %1001 %1002
+ %1004 = OpCompositeExtract %2 %45 3
+ %1005 = OpCompositeExtract %2 %46 3
+ %1006 = OpFMul %2 %1004 %1005
+ %1007 = OpFAdd %2 %49 %52
+ %1008 = OpFAdd %2 %1003 %1007
+ %1009 = OpFAdd %2 %1006 %1008
+ %43 = OpCopyObject %2 %1009
+ OpBranch %44
+ %44 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools