spirv-fuzz: Transformation to add wrappers for OpKill and similar (#3881)
Part of #3717.
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt
index 050bd96..5d4f428 100644
--- a/source/fuzz/CMakeLists.txt
+++ b/source/fuzz/CMakeLists.txt
@@ -145,6 +145,7 @@
transformation_add_dead_block.h
transformation_add_dead_break.h
transformation_add_dead_continue.h
+ transformation_add_early_terminator_wrapper.h
transformation_add_function.h
transformation_add_global_undef.h
transformation_add_global_variable.h
@@ -324,6 +325,7 @@
transformation_add_dead_block.cpp
transformation_add_dead_break.cpp
transformation_add_dead_continue.cpp
+ transformation_add_early_terminator_wrapper.cpp
transformation_add_function.cpp
transformation_add_global_undef.cpp
transformation_add_global_variable.cpp
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 99d64eb..95ead39 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1011,6 +1011,11 @@
return ir_context->get_type_mgr()->GetId(&type);
}
+uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
+ opt::analysis::Void type;
+ return ir_context->get_type_mgr()->GetId(&type);
+}
+
uint32_t MaybeGetZeroConstant(
opt::IRContext* ir_context,
const TransformationContext& transformation_context,
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index b0141f8..6a6efb8 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -379,6 +379,10 @@
uint32_t MaybeGetStructType(opt::IRContext* ir_context,
const std::vector<uint32_t>& component_type_ids);
+// Returns a result id of an OpTypeVoid instruction if present. Returns 0
+// otherwise.
+uint32_t MaybeGetVoidType(opt::IRContext* ir_context);
+
// Recursive definition is the following:
// - if |scalar_or_composite_type_id| is a result id of a scalar type - returns
// a result id of the following constants (depending on the type): int -> 0,
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index bebe848..eb01b97 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -551,6 +551,7 @@
TransformationAddBitInstructionSynonym add_bit_instruction_synonym = 77;
TransformationAddLoopToCreateIntConstantSynonym add_loop_to_create_int_constant_synonym = 78;
TransformationWrapRegionInSelection wrap_region_in_selection = 79;
+ TransformationAddEarlyTerminatorWrapper add_early_terminator_wrapper = 80;
// Add additional option using the next available number.
}
}
@@ -764,6 +765,28 @@
}
+message TransformationAddEarlyTerminatorWrapper {
+
+ // Adds a function to the module containing a single block with a single non-
+ // label instruction that is either OpKill, OpUnreachable, or
+ // OpTerminateInvocation. The purpose of this is to allow such instructions
+ // to be subsequently replaced with wrapper functions, which can then enable
+ // transformations (such as inlining) that are hard in the direct presence
+ // of these instructions.
+
+ // Fresh id for the function.
+ uint32 function_fresh_id = 1;
+
+ // Fresh id for the single basic block in the function.
+ uint32 label_fresh_id = 2;
+
+ // One of OpKill, OpUnreachable, OpTerminateInvocation. If additional early
+ // termination instructions are added to SPIR-V they should also be handled
+ // here.
+ uint32 opcode = 3;
+
+}
+
message TransformationAddFunction {
// Adds a SPIR-V function to the module.
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index ea7c97c..f03d6a9 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -27,6 +27,7 @@
#include "source/fuzz/transformation_add_dead_block.h"
#include "source/fuzz/transformation_add_dead_break.h"
#include "source/fuzz/transformation_add_dead_continue.h"
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
#include "source/fuzz/transformation_add_function.h"
#include "source/fuzz/transformation_add_global_undef.h"
#include "source/fuzz/transformation_add_global_variable.h"
@@ -133,6 +134,10 @@
case protobufs::Transformation::TransformationCase::kAddDeadContinue:
return MakeUnique<TransformationAddDeadContinue>(
message.add_dead_continue());
+ case protobufs::Transformation::TransformationCase::
+ kAddEarlyTerminatorWrapper:
+ return MakeUnique<TransformationAddEarlyTerminatorWrapper>(
+ message.add_early_terminator_wrapper());
case protobufs::Transformation::TransformationCase::kAddFunction:
return MakeUnique<TransformationAddFunction>(message.add_function());
case protobufs::Transformation::TransformationCase::kAddGlobalUndef:
diff --git a/source/fuzz/transformation_add_early_terminator_wrapper.cpp b/source/fuzz/transformation_add_early_terminator_wrapper.cpp
new file mode 100644
index 0000000..0aa1214
--- /dev/null
+++ b/source/fuzz/transformation_add_early_terminator_wrapper.cpp
@@ -0,0 +1,110 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
+
+#include "source/fuzz/fuzzer_util.h"
+#include "source/util/make_unique.h"
+
+namespace spvtools {
+namespace fuzz {
+
+TransformationAddEarlyTerminatorWrapper::
+ TransformationAddEarlyTerminatorWrapper(
+ const spvtools::fuzz::protobufs::
+ TransformationAddEarlyTerminatorWrapper& message)
+ : message_(message) {}
+
+TransformationAddEarlyTerminatorWrapper::
+ TransformationAddEarlyTerminatorWrapper(uint32_t function_fresh_id,
+ uint32_t label_fresh_id,
+ SpvOp opcode) {
+ message_.set_function_fresh_id(function_fresh_id);
+ message_.set_label_fresh_id(label_fresh_id);
+ message_.set_opcode(opcode);
+}
+
+bool TransformationAddEarlyTerminatorWrapper::IsApplicable(
+ opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
+ assert((message_.opcode() == SpvOpKill ||
+ message_.opcode() == SpvOpUnreachable ||
+ message_.opcode() == SpvOpTerminateInvocation) &&
+ "Invalid opcode.");
+
+ if (!fuzzerutil::IsFreshId(ir_context, message_.function_fresh_id())) {
+ return false;
+ }
+ if (!fuzzerutil::IsFreshId(ir_context, message_.label_fresh_id())) {
+ return false;
+ }
+ if (message_.function_fresh_id() == message_.label_fresh_id()) {
+ return false;
+ }
+ uint32_t void_type_id = fuzzerutil::MaybeGetVoidType(ir_context);
+ if (!void_type_id) {
+ return false;
+ }
+ return fuzzerutil::FindFunctionType(ir_context, {void_type_id});
+}
+
+void TransformationAddEarlyTerminatorWrapper::Apply(
+ opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
+ fuzzerutil::UpdateModuleIdBound(ir_context, message_.function_fresh_id());
+ fuzzerutil::UpdateModuleIdBound(ir_context, message_.label_fresh_id());
+
+ // Create a basic block of the form:
+ // %label_fresh_id = OpLabel
+ // OpKill|Unreachable|TerminateInvocation
+ auto basic_block = MakeUnique<opt::BasicBlock>(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpLabel, 0, message_.label_fresh_id(),
+ opt::Instruction::OperandList()));
+ basic_block->AddInstruction(MakeUnique<opt::Instruction>(
+ ir_context, static_cast<SpvOp>(message_.opcode()), 0, 0,
+ opt::Instruction::OperandList()));
+
+ // Create a zero-argument void function.
+ auto void_type_id = fuzzerutil::MaybeGetVoidType(ir_context);
+ auto function = MakeUnique<opt::Function>(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpFunction, void_type_id, message_.function_fresh_id(),
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}},
+ {SPV_OPERAND_TYPE_TYPE_ID,
+ {fuzzerutil::FindFunctionType(ir_context, {void_type_id})}}})));
+
+ // Add the basic block to the function as the sole block, and add the function
+ // to the module.
+ basic_block->SetParent(function.get());
+ function->AddBasicBlock(std::move(basic_block));
+ function->SetFunctionEnd(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpFunctionEnd, 0, 0, opt::Instruction::OperandList()));
+ ir_context->module()->AddFunction(std::move(function));
+
+ ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
+}
+
+std::unordered_set<uint32_t>
+TransformationAddEarlyTerminatorWrapper::GetFreshIds() const {
+ return std::unordered_set<uint32_t>(
+ {message_.function_fresh_id(), message_.label_fresh_id()});
+}
+
+protobufs::Transformation TransformationAddEarlyTerminatorWrapper::ToMessage()
+ const {
+ protobufs::Transformation result;
+ *result.mutable_add_early_terminator_wrapper() = message_;
+ return result;
+}
+
+} // namespace fuzz
+} // namespace spvtools
diff --git a/source/fuzz/transformation_add_early_terminator_wrapper.h b/source/fuzz/transformation_add_early_terminator_wrapper.h
new file mode 100644
index 0000000..273037e
--- /dev/null
+++ b/source/fuzz/transformation_add_early_terminator_wrapper.h
@@ -0,0 +1,63 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
+#define SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
+
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation.h"
+#include "source/fuzz/transformation_context.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace fuzz {
+
+class TransformationAddEarlyTerminatorWrapper : public Transformation {
+ public:
+ explicit TransformationAddEarlyTerminatorWrapper(
+ const protobufs::TransformationAddEarlyTerminatorWrapper& message);
+
+ TransformationAddEarlyTerminatorWrapper(uint32_t function_fresh_id,
+ uint32_t label_fresh_id,
+ SpvOp opcode);
+
+ // - |message_.function_fresh_id| and |message_.label_fresh_id| must be fresh
+ // and distinct.
+ // - OpTypeVoid must be declared in the module.
+ // - The module must contain a type for a zero-argument void function.
+ bool IsApplicable(
+ opt::IRContext* ir_context,
+ const TransformationContext& transformation_context) const override;
+
+ // Adds a function to the module of the form:
+ //
+ // |message_.function_fresh_id| = OpFunction %void None %zero_args_return_void
+ // |message_.label_fresh_id| = OpLabel
+ // |message_.opcode|
+ // OpFunctionEnd
+ void Apply(opt::IRContext* ir_context,
+ TransformationContext* transformation_context) const override;
+
+ std::unordered_set<uint32_t> GetFreshIds() const override;
+
+ protobufs::Transformation ToMessage() const override;
+
+ private:
+ protobufs::TransformationAddEarlyTerminatorWrapper message_;
+};
+
+} // namespace fuzz
+} // namespace spvtools
+
+#endif // SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt
index 7d08506..ecccaa3 100644
--- a/test/fuzz/CMakeLists.txt
+++ b/test/fuzz/CMakeLists.txt
@@ -44,6 +44,7 @@
transformation_add_dead_block_test.cpp
transformation_add_dead_break_test.cpp
transformation_add_dead_continue_test.cpp
+ transformation_add_early_terminator_wrapper_test.cpp
transformation_add_function_test.cpp
transformation_add_global_undef_test.cpp
transformation_add_global_variable_test.cpp
diff --git a/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp b/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp
new file mode 100644
index 0000000..8006770
--- /dev/null
+++ b/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp
@@ -0,0 +1,160 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
+
+#include "test/fuzz/fuzz_test_util.h"
+
+namespace spvtools {
+namespace fuzz {
+namespace {
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, NoVoidType) {
+ std::string shader = R"(
+ OpCapability Shader
+ OpCapability Linkage
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpSource ESSL 320
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_4;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill)
+ .IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, NoVoidFunctionType) {
+ std::string shader = R"(
+ OpCapability Shader
+ OpCapability Linkage
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_4;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill)
+ .IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, BasicTest) {
+ std::string shader = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(2, 101, SpvOpKill)
+ .IsApplicable(context.get(), transformation_context));
+ ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 4, SpvOpKill)
+ .IsApplicable(context.get(), transformation_context));
+ ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 100, SpvOpKill)
+ .IsApplicable(context.get(), transformation_context));
+
+#ifndef NDEBUG
+ ASSERT_DEATH(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpReturn)
+ .IsApplicable(context.get(), transformation_context),
+ "Invalid opcode.");
+#endif
+
+ auto transformation1 =
+ TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill);
+ auto transformation2 =
+ TransformationAddEarlyTerminatorWrapper(102, 103, SpvOpUnreachable);
+ auto transformation3 = TransformationAddEarlyTerminatorWrapper(
+ 104, 105, SpvOpTerminateInvocation);
+
+ ASSERT_TRUE(
+ transformation1.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation1, context.get(),
+ &transformation_context);
+ ASSERT_TRUE(
+ transformation2.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation2, context.get(),
+ &transformation_context);
+ ASSERT_TRUE(
+ transformation3.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation3, context.get(),
+ &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %100 = OpFunction %2 None %3
+ %101 = OpLabel
+ OpKill
+ OpFunctionEnd
+ %102 = OpFunction %2 None %3
+ %103 = OpLabel
+ OpUnreachable
+ OpFunctionEnd
+ %104 = OpFunction %2 None %3
+ %105 = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+ )";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+} // namespace
+} // namespace fuzz
+} // namespace spvtools