spirv-fuzz: Transformation to add wrappers for OpKill and similar (#3881)

Part of #3717.
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt
index 050bd96..5d4f428 100644
--- a/source/fuzz/CMakeLists.txt
+++ b/source/fuzz/CMakeLists.txt
@@ -145,6 +145,7 @@
         transformation_add_dead_block.h
         transformation_add_dead_break.h
         transformation_add_dead_continue.h
+        transformation_add_early_terminator_wrapper.h
         transformation_add_function.h
         transformation_add_global_undef.h
         transformation_add_global_variable.h
@@ -324,6 +325,7 @@
         transformation_add_dead_block.cpp
         transformation_add_dead_break.cpp
         transformation_add_dead_continue.cpp
+        transformation_add_early_terminator_wrapper.cpp
         transformation_add_function.cpp
         transformation_add_global_undef.cpp
         transformation_add_global_variable.cpp
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 99d64eb..95ead39 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1011,6 +1011,11 @@
   return ir_context->get_type_mgr()->GetId(&type);
 }
 
+uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
+  opt::analysis::Void type;
+  return ir_context->get_type_mgr()->GetId(&type);
+}
+
 uint32_t MaybeGetZeroConstant(
     opt::IRContext* ir_context,
     const TransformationContext& transformation_context,
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index b0141f8..6a6efb8 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -379,6 +379,10 @@
 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
                             const std::vector<uint32_t>& component_type_ids);
 
+// Returns a result id of an OpTypeVoid instruction if present. Returns 0
+// otherwise.
+uint32_t MaybeGetVoidType(opt::IRContext* ir_context);
+
 // Recursive definition is the following:
 // - if |scalar_or_composite_type_id| is a result id of a scalar type - returns
 //   a result id of the following constants (depending on the type): int -> 0,
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index bebe848..eb01b97 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -551,6 +551,7 @@
     TransformationAddBitInstructionSynonym add_bit_instruction_synonym = 77;
     TransformationAddLoopToCreateIntConstantSynonym add_loop_to_create_int_constant_synonym = 78;
     TransformationWrapRegionInSelection wrap_region_in_selection = 79;
+    TransformationAddEarlyTerminatorWrapper add_early_terminator_wrapper = 80;
     // Add additional option using the next available number.
   }
 }
@@ -764,6 +765,28 @@
 
 }
 
+message TransformationAddEarlyTerminatorWrapper {
+
+  // Adds a function to the module containing a single block with a single non-
+  // label instruction that is either OpKill, OpUnreachable, or
+  // OpTerminateInvocation.  The purpose of this is to allow such instructions
+  // to be subsequently replaced with wrapper functions, which can then enable
+  // transformations (such as inlining) that are hard in the direct presence
+  // of these instructions.
+
+  // Fresh id for the function.
+  uint32 function_fresh_id = 1;
+
+  // Fresh id for the single basic block in the function.
+  uint32 label_fresh_id = 2;
+
+  // One of OpKill, OpUnreachable, OpTerminateInvocation.  If additional early
+  // termination instructions are added to SPIR-V they should also be handled
+  // here.
+  uint32 opcode = 3;
+
+}
+
 message TransformationAddFunction {
 
   // Adds a SPIR-V function to the module.
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index ea7c97c..f03d6a9 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -27,6 +27,7 @@
 #include "source/fuzz/transformation_add_dead_block.h"
 #include "source/fuzz/transformation_add_dead_break.h"
 #include "source/fuzz/transformation_add_dead_continue.h"
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
 #include "source/fuzz/transformation_add_function.h"
 #include "source/fuzz/transformation_add_global_undef.h"
 #include "source/fuzz/transformation_add_global_variable.h"
@@ -133,6 +134,10 @@
     case protobufs::Transformation::TransformationCase::kAddDeadContinue:
       return MakeUnique<TransformationAddDeadContinue>(
           message.add_dead_continue());
+    case protobufs::Transformation::TransformationCase::
+        kAddEarlyTerminatorWrapper:
+      return MakeUnique<TransformationAddEarlyTerminatorWrapper>(
+          message.add_early_terminator_wrapper());
     case protobufs::Transformation::TransformationCase::kAddFunction:
       return MakeUnique<TransformationAddFunction>(message.add_function());
     case protobufs::Transformation::TransformationCase::kAddGlobalUndef:
diff --git a/source/fuzz/transformation_add_early_terminator_wrapper.cpp b/source/fuzz/transformation_add_early_terminator_wrapper.cpp
new file mode 100644
index 0000000..0aa1214
--- /dev/null
+++ b/source/fuzz/transformation_add_early_terminator_wrapper.cpp
@@ -0,0 +1,110 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
+
+#include "source/fuzz/fuzzer_util.h"
+#include "source/util/make_unique.h"
+
+namespace spvtools {
+namespace fuzz {
+
+TransformationAddEarlyTerminatorWrapper::
+    TransformationAddEarlyTerminatorWrapper(
+        const spvtools::fuzz::protobufs::
+            TransformationAddEarlyTerminatorWrapper& message)
+    : message_(message) {}
+
+TransformationAddEarlyTerminatorWrapper::
+    TransformationAddEarlyTerminatorWrapper(uint32_t function_fresh_id,
+                                            uint32_t label_fresh_id,
+                                            SpvOp opcode) {
+  message_.set_function_fresh_id(function_fresh_id);
+  message_.set_label_fresh_id(label_fresh_id);
+  message_.set_opcode(opcode);
+}
+
+bool TransformationAddEarlyTerminatorWrapper::IsApplicable(
+    opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
+  assert((message_.opcode() == SpvOpKill ||
+          message_.opcode() == SpvOpUnreachable ||
+          message_.opcode() == SpvOpTerminateInvocation) &&
+         "Invalid opcode.");
+
+  if (!fuzzerutil::IsFreshId(ir_context, message_.function_fresh_id())) {
+    return false;
+  }
+  if (!fuzzerutil::IsFreshId(ir_context, message_.label_fresh_id())) {
+    return false;
+  }
+  if (message_.function_fresh_id() == message_.label_fresh_id()) {
+    return false;
+  }
+  uint32_t void_type_id = fuzzerutil::MaybeGetVoidType(ir_context);
+  if (!void_type_id) {
+    return false;
+  }
+  return fuzzerutil::FindFunctionType(ir_context, {void_type_id});
+}
+
+void TransformationAddEarlyTerminatorWrapper::Apply(
+    opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
+  fuzzerutil::UpdateModuleIdBound(ir_context, message_.function_fresh_id());
+  fuzzerutil::UpdateModuleIdBound(ir_context, message_.label_fresh_id());
+
+  // Create a basic block of the form:
+  // %label_fresh_id = OpLabel
+  //                   OpKill|Unreachable|TerminateInvocation
+  auto basic_block = MakeUnique<opt::BasicBlock>(MakeUnique<opt::Instruction>(
+      ir_context, SpvOpLabel, 0, message_.label_fresh_id(),
+      opt::Instruction::OperandList()));
+  basic_block->AddInstruction(MakeUnique<opt::Instruction>(
+      ir_context, static_cast<SpvOp>(message_.opcode()), 0, 0,
+      opt::Instruction::OperandList()));
+
+  // Create a zero-argument void function.
+  auto void_type_id = fuzzerutil::MaybeGetVoidType(ir_context);
+  auto function = MakeUnique<opt::Function>(MakeUnique<opt::Instruction>(
+      ir_context, SpvOpFunction, void_type_id, message_.function_fresh_id(),
+      opt::Instruction::OperandList(
+          {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}},
+           {SPV_OPERAND_TYPE_TYPE_ID,
+            {fuzzerutil::FindFunctionType(ir_context, {void_type_id})}}})));
+
+  // Add the basic block to the function as the sole block, and add the function
+  // to the module.
+  basic_block->SetParent(function.get());
+  function->AddBasicBlock(std::move(basic_block));
+  function->SetFunctionEnd(MakeUnique<opt::Instruction>(
+      ir_context, SpvOpFunctionEnd, 0, 0, opt::Instruction::OperandList()));
+  ir_context->module()->AddFunction(std::move(function));
+
+  ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
+}
+
+std::unordered_set<uint32_t>
+TransformationAddEarlyTerminatorWrapper::GetFreshIds() const {
+  return std::unordered_set<uint32_t>(
+      {message_.function_fresh_id(), message_.label_fresh_id()});
+}
+
+protobufs::Transformation TransformationAddEarlyTerminatorWrapper::ToMessage()
+    const {
+  protobufs::Transformation result;
+  *result.mutable_add_early_terminator_wrapper() = message_;
+  return result;
+}
+
+}  // namespace fuzz
+}  // namespace spvtools
diff --git a/source/fuzz/transformation_add_early_terminator_wrapper.h b/source/fuzz/transformation_add_early_terminator_wrapper.h
new file mode 100644
index 0000000..273037e
--- /dev/null
+++ b/source/fuzz/transformation_add_early_terminator_wrapper.h
@@ -0,0 +1,63 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
+#define SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
+
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation.h"
+#include "source/fuzz/transformation_context.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace fuzz {
+
+class TransformationAddEarlyTerminatorWrapper : public Transformation {
+ public:
+  explicit TransformationAddEarlyTerminatorWrapper(
+      const protobufs::TransformationAddEarlyTerminatorWrapper& message);
+
+  TransformationAddEarlyTerminatorWrapper(uint32_t function_fresh_id,
+                                          uint32_t label_fresh_id,
+                                          SpvOp opcode);
+
+  // - |message_.function_fresh_id| and |message_.label_fresh_id| must be fresh
+  //   and distinct.
+  // - OpTypeVoid must be declared in the module.
+  // - The module must contain a type for a zero-argument void function.
+  bool IsApplicable(
+      opt::IRContext* ir_context,
+      const TransformationContext& transformation_context) const override;
+
+  // Adds a function to the module of the form:
+  //
+  // |message_.function_fresh_id| = OpFunction %void None %zero_args_return_void
+  //    |message_.label_fresh_id| = OpLabel
+  //                                |message_.opcode|
+  //                                OpFunctionEnd
+  void Apply(opt::IRContext* ir_context,
+             TransformationContext* transformation_context) const override;
+
+  std::unordered_set<uint32_t> GetFreshIds() const override;
+
+  protobufs::Transformation ToMessage() const override;
+
+ private:
+  protobufs::TransformationAddEarlyTerminatorWrapper message_;
+};
+
+}  // namespace fuzz
+}  // namespace spvtools
+
+#endif  // SOURCE_FUZZ_TRANSFORMATION_add_early_terminator_wrapper_H_
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt
index 7d08506..ecccaa3 100644
--- a/test/fuzz/CMakeLists.txt
+++ b/test/fuzz/CMakeLists.txt
@@ -44,6 +44,7 @@
           transformation_add_dead_block_test.cpp
           transformation_add_dead_break_test.cpp
           transformation_add_dead_continue_test.cpp
+          transformation_add_early_terminator_wrapper_test.cpp
           transformation_add_function_test.cpp
           transformation_add_global_undef_test.cpp
           transformation_add_global_variable_test.cpp
diff --git a/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp b/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp
new file mode 100644
index 0000000..8006770
--- /dev/null
+++ b/test/fuzz/transformation_add_early_terminator_wrapper_test.cpp
@@ -0,0 +1,160 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_early_terminator_wrapper.h"
+
+#include "test/fuzz/fuzz_test_util.h"
+
+namespace spvtools {
+namespace fuzz {
+namespace {
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, NoVoidType) {
+  std::string shader = R"(
+               OpCapability Shader
+               OpCapability Linkage
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpSource ESSL 320
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_4;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill)
+                   .IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, NoVoidFunctionType) {
+  std::string shader = R"(
+               OpCapability Shader
+               OpCapability Linkage
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_4;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill)
+                   .IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationAddEarlyTerminatorWrapperTest, BasicTest) {
+  std::string shader = R"(
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(2, 101, SpvOpKill)
+                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 4, SpvOpKill)
+                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(TransformationAddEarlyTerminatorWrapper(100, 100, SpvOpKill)
+                   .IsApplicable(context.get(), transformation_context));
+
+#ifndef NDEBUG
+  ASSERT_DEATH(TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpReturn)
+                   .IsApplicable(context.get(), transformation_context),
+               "Invalid opcode.");
+#endif
+
+  auto transformation1 =
+      TransformationAddEarlyTerminatorWrapper(100, 101, SpvOpKill);
+  auto transformation2 =
+      TransformationAddEarlyTerminatorWrapper(102, 103, SpvOpUnreachable);
+  auto transformation3 = TransformationAddEarlyTerminatorWrapper(
+      104, 105, SpvOpTerminateInvocation);
+
+  ASSERT_TRUE(
+      transformation1.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation1, context.get(),
+                        &transformation_context);
+  ASSERT_TRUE(
+      transformation2.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation2, context.get(),
+                        &transformation_context);
+  ASSERT_TRUE(
+      transformation3.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation3, context.get(),
+                        &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+        %100 = OpFunction %2 None %3
+        %101 = OpLabel
+               OpKill
+               OpFunctionEnd
+        %102 = OpFunction %2 None %3
+        %103 = OpLabel
+               OpUnreachable
+               OpFunctionEnd
+        %104 = OpFunction %2 None %3
+        %105 = OpLabel
+               OpTerminateInvocation
+               OpFunctionEnd
+  )";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+}  // namespace
+}  // namespace fuzz
+}  // namespace spvtools