spirv-fuzz: Transformation to add OpConstantNull (#3273)

Adds a transformation for adding OpConstantNull to a module, for
appropriate data types.
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt
index cd53aea..0f0581f 100644
--- a/source/fuzz/CMakeLists.txt
+++ b/source/fuzz/CMakeLists.txt
@@ -79,6 +79,7 @@
         transformation_access_chain.h
         transformation_add_constant_boolean.h
         transformation_add_constant_composite.h
+        transformation_add_constant_null.h
         transformation_add_constant_scalar.h
         transformation_add_dead_block.h
         transformation_add_dead_break.h
@@ -171,6 +172,7 @@
         transformation_access_chain.cpp
         transformation_add_constant_boolean.cpp
         transformation_add_constant_composite.cpp
+        transformation_add_constant_null.cpp
         transformation_add_constant_scalar.cpp
         transformation_add_dead_block.cpp
         transformation_add_dead_break.cpp
diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp
index 63ce7a6..4ba5305 100644
--- a/source/fuzz/fuzzer_pass_donate_modules.cpp
+++ b/source/fuzz/fuzzer_pass_donate_modules.cpp
@@ -22,6 +22,7 @@
 #include "source/fuzz/instruction_message.h"
 #include "source/fuzz/transformation_add_constant_boolean.h"
 #include "source/fuzz/transformation_add_constant_composite.h"
+#include "source/fuzz/transformation_add_constant_null.h"
 #include "source/fuzz/transformation_add_constant_scalar.h"
 #include "source/fuzz/transformation_add_function.h"
 #include "source/fuzz/transformation_add_global_undef.h"
@@ -394,6 +395,20 @@
             original_id_to_donated_id->at(type_or_value.type_id()),
             constituent_ids));
       } break;
+      case SpvOpConstantNull: {
+        if (!original_id_to_donated_id->count(type_or_value.type_id())) {
+          // We did not donate the type associated with this null constant, so
+          // we cannot donate the null constant.
+          continue;
+        }
+
+        // It is fine to have multiple OpConstantNull instructions of the same
+        // type, so we just add this to the recipient module.
+        new_result_id = GetFuzzerContext()->GetFreshId();
+        ApplyTransformation(TransformationAddConstantNull(
+            new_result_id,
+            original_id_to_donated_id->at(type_or_value.type_id())));
+      } break;
       case SpvOpVariable: {
         // This is a global variable that could have one of various storage
         // classes.  However, we change all global variable pointer storage
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 90cf9fe..4d85984 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -537,6 +537,13 @@
   return 0;
 }
 
+bool IsNullConstantSupported(const opt::analysis::Type& type) {
+  return type.AsBool() || type.AsInteger() || type.AsFloat() ||
+         type.AsMatrix() || type.AsVector() || type.AsArray() ||
+         type.AsStruct() || type.AsPointer() || type.AsEvent() ||
+         type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue();
+}
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 08edfc5..886029a 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -210,6 +210,10 @@
 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
                              SpvStorageClass storage_class);
 
+// Returns true if and only if |type| is one of the types for which it is legal
+// to have an OpConstantNull value.
+bool IsNullConstantSupported(const opt::analysis::Type& type);
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index b816e3b..5dc70c3 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -372,6 +372,7 @@
     TransformationSwapCommutableOperands swap_commutable_operands = 41;
     TransformationPermuteFunctionParameters permute_function_parameters = 42;
     TransformationToggleAccessChainInstruction toggle_access_chain_instruction = 43;
+    TransformationAddConstantNull add_constant_null = 44;
     // Add additional option using the next available number.
   }
 }
@@ -422,6 +423,18 @@
 
 }
 
+message TransformationAddConstantNull {
+
+  // Adds a null constant.
+
+  // Id for the constant
+  uint32 fresh_id = 1;
+
+  // Type of the constant
+  uint32 type_id = 2;
+
+}
+
 message TransformationAddConstantScalar {
 
   // Adds a constant of the given scalar type.
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index 6f008fc..40d2010 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -20,6 +20,7 @@
 #include "source/fuzz/transformation_access_chain.h"
 #include "source/fuzz/transformation_add_constant_boolean.h"
 #include "source/fuzz/transformation_add_constant_composite.h"
+#include "source/fuzz/transformation_add_constant_null.h"
 #include "source/fuzz/transformation_add_constant_scalar.h"
 #include "source/fuzz/transformation_add_dead_block.h"
 #include "source/fuzz/transformation_add_dead_break.h"
@@ -78,6 +79,9 @@
     case protobufs::Transformation::TransformationCase::kAddConstantComposite:
       return MakeUnique<TransformationAddConstantComposite>(
           message.add_constant_composite());
+    case protobufs::Transformation::TransformationCase::kAddConstantNull:
+      return MakeUnique<TransformationAddConstantNull>(
+          message.add_constant_null());
     case protobufs::Transformation::TransformationCase::kAddConstantScalar:
       return MakeUnique<TransformationAddConstantScalar>(
           message.add_constant_scalar());
diff --git a/source/fuzz/transformation_add_constant_null.cpp b/source/fuzz/transformation_add_constant_null.cpp
new file mode 100644
index 0000000..dedbc21
--- /dev/null
+++ b/source/fuzz/transformation_add_constant_null.cpp
@@ -0,0 +1,66 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_constant_null.h"
+
+#include "source/fuzz/fuzzer_util.h"
+
+namespace spvtools {
+namespace fuzz {
+
+TransformationAddConstantNull::TransformationAddConstantNull(
+    const spvtools::fuzz::protobufs::TransformationAddConstantNull& message)
+    : message_(message) {}
+
+TransformationAddConstantNull::TransformationAddConstantNull(uint32_t fresh_id,
+                                                             uint32_t type_id) {
+  message_.set_fresh_id(fresh_id);
+  message_.set_type_id(type_id);
+}
+
+bool TransformationAddConstantNull::IsApplicable(
+    opt::IRContext* context, const TransformationContext& /*unused*/) const {
+  // A fresh id is required.
+  if (!fuzzerutil::IsFreshId(context, message_.fresh_id())) {
+    return false;
+  }
+  auto type = context->get_type_mgr()->GetType(message_.type_id());
+  // The type must exist.
+  if (!type) {
+    return false;
+  }
+  // The type must be one of the types for which null constants are allowed,
+  // according to the SPIR-V spec.
+  return fuzzerutil::IsNullConstantSupported(*type);
+}
+
+void TransformationAddConstantNull::Apply(
+    opt::IRContext* context, TransformationContext* /*unused*/) const {
+  context->module()->AddGlobalValue(MakeUnique<opt::Instruction>(
+      context, SpvOpConstantNull, message_.type_id(), message_.fresh_id(),
+      opt::Instruction::OperandList()));
+  fuzzerutil::UpdateModuleIdBound(context, message_.fresh_id());
+  // We have added an instruction to the module, so need to be careful about the
+  // validity of existing analyses.
+  context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone);
+}
+
+protobufs::Transformation TransformationAddConstantNull::ToMessage() const {
+  protobufs::Transformation result;
+  *result.mutable_add_constant_null() = message_;
+  return result;
+}
+
+}  // namespace fuzz
+}  // namespace spvtools
diff --git a/source/fuzz/transformation_add_constant_null.h b/source/fuzz/transformation_add_constant_null.h
new file mode 100644
index 0000000..590fc0d
--- /dev/null
+++ b/source/fuzz/transformation_add_constant_null.h
@@ -0,0 +1,54 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_
+#define SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_
+
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation.h"
+#include "source/fuzz/transformation_context.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace fuzz {
+
+class TransformationAddConstantNull : public Transformation {
+ public:
+  explicit TransformationAddConstantNull(
+      const protobufs::TransformationAddConstantNull& message);
+
+  TransformationAddConstantNull(uint32_t fresh_id, uint32_t type_id);
+
+  // - |message_.fresh_id| must be fresh
+  // - |message_.type_id| must be the id of a type for which it is acceptable
+  //   to create a null constant
+  bool IsApplicable(
+      opt::IRContext* context,
+      const TransformationContext& transformation_context) const override;
+
+  // Adds an OpConstantNull instruction to the module, with |message_.type_id|
+  // as its type.  The instruction has result id |message_.fresh_id|.
+  void Apply(opt::IRContext* context,
+             TransformationContext* transformation_context) const override;
+
+  protobufs::Transformation ToMessage() const override;
+
+ private:
+  protobufs::TransformationAddConstantNull message_;
+};
+
+}  // namespace fuzz
+}  // namespace spvtools
+
+#endif  // SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_NULL_H_
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt
index a8b09b9..679a61b 100644
--- a/test/fuzz/CMakeLists.txt
+++ b/test/fuzz/CMakeLists.txt
@@ -28,6 +28,7 @@
           transformation_access_chain_test.cpp
           transformation_add_constant_boolean_test.cpp
           transformation_add_constant_composite_test.cpp
+          transformation_add_constant_null_test.cpp
           transformation_add_constant_scalar_test.cpp
           transformation_add_dead_block_test.cpp
           transformation_add_dead_break_test.cpp
diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
index 549dd13..40d7d24 100644
--- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp
+++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
@@ -508,6 +508,72 @@
   ASSERT_TRUE(IsValid(env, recipient_context.get()));
 }
 
+TEST(FuzzerPassDonateModulesTest, DonateOpConstantNull) {
+  std::string recipient_shader = R"(
+               OpCapability Shader
+               OpCapability ImageQuery
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpSourceExtension "GL_EXT_samplerless_texture_functions"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::string donor_shader = R"(
+               OpCapability Shader
+               OpCapability ImageQuery
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeFloat 32
+          %7 = OpTypePointer Private %6
+          %8 = OpConstantNull %7
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto recipient_context =
+      BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context =
+      BuildModule(env, consumer, donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
+                                      &transformation_context, &fuzzer_context,
+                                      &transformation_sequence, {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get(), false);
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
 TEST(FuzzerPassDonateModulesTest, Miscellaneous1) {
   std::string recipient_shader = R"(
                OpCapability Shader
diff --git a/test/fuzz/transformation_add_constant_null_test.cpp b/test/fuzz/transformation_add_constant_null_test.cpp
new file mode 100644
index 0000000..0bfee34
--- /dev/null
+++ b/test/fuzz/transformation_add_constant_null_test.cpp
@@ -0,0 +1,140 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_add_constant_null.h"
+#include "test/fuzz/fuzz_test_util.h"
+
+namespace spvtools {
+namespace fuzz {
+namespace {
+
+TEST(TransformationAddConstantNullTest, BasicTest) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeFloat 32
+          %7 = OpTypeInt 32 1
+          %8 = OpTypeVector %6 2
+          %9 = OpTypeVector %6 3
+         %10 = OpTypeVector %6 4
+         %11 = OpTypeVector %7 2
+         %20 = OpTypeSampler
+         %21 = OpTypeImage %6 2D 0 0 0 0 Rgba32f
+         %22 = OpTypeSampledImage %21
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_4;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  // Id already in use
+  ASSERT_FALSE(TransformationAddConstantNull(4, 11).IsApplicable(
+      context.get(), transformation_context));
+  // %1 is not a type
+  ASSERT_FALSE(TransformationAddConstantNull(100, 1).IsApplicable(
+      context.get(), transformation_context));
+
+  // %3 is a function type
+  ASSERT_FALSE(TransformationAddConstantNull(100, 3).IsApplicable(
+      context.get(), transformation_context));
+
+  // %20 is a sampler type
+  ASSERT_FALSE(TransformationAddConstantNull(100, 20).IsApplicable(
+      context.get(), transformation_context));
+
+  // %21 is an image type
+  ASSERT_FALSE(TransformationAddConstantNull(100, 21).IsApplicable(
+      context.get(), transformation_context));
+
+  // %22 is a sampled image type
+  ASSERT_FALSE(TransformationAddConstantNull(100, 22).IsApplicable(
+      context.get(), transformation_context));
+
+  TransformationAddConstantNull transformations[] = {
+      // %100 = OpConstantNull %6
+      TransformationAddConstantNull(100, 6),
+
+      // %101 = OpConstantNull %7
+      TransformationAddConstantNull(101, 7),
+
+      // %102 = OpConstantNull %8
+      TransformationAddConstantNull(102, 8),
+
+      // %103 = OpConstantNull %9
+      TransformationAddConstantNull(103, 9),
+
+      // %104 = OpConstantNull %10
+      TransformationAddConstantNull(104, 10),
+
+      // %105 = OpConstantNull %11
+      TransformationAddConstantNull(105, 11)};
+
+  for (auto& transformation : transformations) {
+    ASSERT_TRUE(
+        transformation.IsApplicable(context.get(), transformation_context));
+    transformation.Apply(context.get(), &transformation_context);
+  }
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeFloat 32
+          %7 = OpTypeInt 32 1
+          %8 = OpTypeVector %6 2
+          %9 = OpTypeVector %6 3
+         %10 = OpTypeVector %6 4
+         %11 = OpTypeVector %7 2
+         %20 = OpTypeSampler
+         %21 = OpTypeImage %6 2D 0 0 0 0 Rgba32f
+         %22 = OpTypeSampledImage %21
+        %100 = OpConstantNull %6
+        %101 = OpConstantNull %7
+        %102 = OpConstantNull %8
+        %103 = OpConstantNull %9
+        %104 = OpConstantNull %10
+        %105 = OpConstantNull %11
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+}  // namespace
+}  // namespace fuzz
+}  // namespace spvtools