spirv-fuzz: Return IR and transformation context after replay (#3846)

Before this change, the replayer would return a SPIR-V binary. This
did not allow further transforming the resulting module: it would need
to be re-parsed, and the transformation context arising from the
replayed transformations was not available. This change makes it so
that after replay an IR context and transformation context are
returned instead; the IR context can subsequently be turned into a
binary if desired.

This change paves the way for an upcoming PR to integrate spirv-reduce
with the spirv-fuzz shrinker.
diff --git a/source/fuzz/replayer.cpp b/source/fuzz/replayer.cpp
index 380e58b..a5f311f 100644
--- a/source/fuzz/replayer.cpp
+++ b/source/fuzz/replayer.cpp
@@ -59,7 +59,7 @@
               "The number of transformations to be replayed must not "
               "exceed the size of the transformation sequence.");
     return {Replayer::ReplayerResultStatus::kTooManyTransformationsRequested,
-            std::vector<uint32_t>(), protobufs::TransformationSequence()};
+            nullptr, nullptr, protobufs::TransformationSequence()};
   }
 
   spvtools::SpirvTools tools(target_env_);
@@ -67,15 +67,15 @@
     consumer_(SPV_MSG_ERROR, nullptr, {},
               "Failed to create SPIRV-Tools interface; stopping.");
     return {Replayer::ReplayerResultStatus::kFailedToCreateSpirvToolsInterface,
-            std::vector<uint32_t>(), protobufs::TransformationSequence()};
+            nullptr, nullptr, protobufs::TransformationSequence()};
   }
 
   // Initial binary should be valid.
   if (!tools.Validate(&binary_in_[0], binary_in_.size(), validator_options_)) {
     consumer_(SPV_MSG_INFO, nullptr, {},
               "Initial binary is invalid; stopping.");
-    return {Replayer::ReplayerResultStatus::kInitialBinaryInvalid,
-            std::vector<uint32_t>(), protobufs::TransformationSequence()};
+    return {Replayer::ReplayerResultStatus::kInitialBinaryInvalid, nullptr,
+            nullptr, protobufs::TransformationSequence()};
   }
 
   // Build the module from the input binary.
@@ -140,7 +140,7 @@
                     "Binary became invalid during replay (set a "
                     "breakpoint to inspect); stopping.");
           return {Replayer::ReplayerResultStatus::kReplayValidationFailure,
-                  std::vector<uint32_t>(), protobufs::TransformationSequence()};
+                  nullptr, nullptr, protobufs::TransformationSequence()};
         }
 
         // The binary was valid, so it becomes the latest valid binary.
@@ -149,10 +149,8 @@
     }
   }
 
-  // Write out the module as a binary.
-  std::vector<uint32_t> binary_out;
-  ir_context->module()->ToBinary(&binary_out, false);
-  return {Replayer::ReplayerResultStatus::kComplete, std::move(binary_out),
+  return {Replayer::ReplayerResultStatus::kComplete, std::move(ir_context),
+          std::move(transformation_context),
           std::move(transformation_sequence_out)};
 }
 
diff --git a/source/fuzz/replayer.h b/source/fuzz/replayer.h
index 5bc62d9..8e0303e 100644
--- a/source/fuzz/replayer.h
+++ b/source/fuzz/replayer.h
@@ -19,6 +19,8 @@
 #include <vector>
 
 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation_context.h"
+#include "source/opt/ir_context.h"
 #include "spirv-tools/libspirv.hpp"
 
 namespace spvtools {
@@ -39,7 +41,8 @@
 
   struct ReplayerResult {
     ReplayerResultStatus status;
-    std::vector<uint32_t> transformed_binary;
+    std::unique_ptr<opt::IRContext> transformed_module;
+    std::unique_ptr<TransformationContext> transformation_context;
     protobufs::TransformationSequence applied_transformations;
   };
 
@@ -70,9 +73,10 @@
   // ids will be available during replay starting from this value.
   //
   // On success, returns a successful result status together with the
-  // transformations that were successfully applied and the binary resulting
-  // from applying them.  Otherwise, returns an appropriate result status
-  // together with an empty binary and empty transformation sequence.
+  // transformations that were applied, the IR for the transformed module, and
+  // the transformation context that arises from applying these transformations.
+  // Otherwise, returns an appropriate result status, an empty transformation
+  // sequence, and null pointers for the IR context and transformation context.
   ReplayerResult Run();
 
  private:
diff --git a/source/fuzz/shrinker.cpp b/source/fuzz/shrinker.cpp
index a88a1ea..ef6a990 100644
--- a/source/fuzz/shrinker.cpp
+++ b/source/fuzz/shrinker.cpp
@@ -70,7 +70,7 @@
     uint32_t step_limit, bool validate_during_replay,
     spv_validator_options validator_options)
     : target_env_(target_env),
-      consumer_(consumer),
+      consumer_(std::move(consumer)),
       binary_in_(binary_in),
       initial_facts_(initial_facts),
       transformation_sequence_in_(transformation_sequence_in),
@@ -120,8 +120,9 @@
   // Get the binary that results from running these transformations, and the
   // subsequence of the initial transformations that actually apply (in
   // principle this could be a strict subsequence).
-  std::vector<uint32_t> current_best_binary =
-      std::move(initial_replay_result.transformed_binary);
+  std::vector<uint32_t> current_best_binary;
+  initial_replay_result.transformed_module->module()->ToBinary(
+      &current_best_binary, false);
   protobufs::TransformationSequence current_best_transformations =
       std::move(initial_replay_result.applied_transformations);
 
@@ -215,12 +216,14 @@
           "Removing this chunk of transformations should not have an effect "
           "on earlier chunks.");
 
-      if (interestingness_function_(replay_result.transformed_binary,
-                                    attempt)) {
+      std::vector<uint32_t> transformed_binary;
+      replay_result.transformed_module->module()->ToBinary(&transformed_binary,
+                                                           false);
+      if (interestingness_function_(transformed_binary, attempt)) {
         // If the binary arising from the smaller transformation sequence is
         // interesting, this becomes our current best binary and transformation
         // sequence.
-        current_best_binary = std::move(replay_result.transformed_binary);
+        current_best_binary = std::move(transformed_binary);
         current_best_transformations =
             std::move(replay_result.applied_transformations);
         progress_this_round = true;
diff --git a/test/fuzz/fuzz_test_util.cpp b/test/fuzz/fuzz_test_util.cpp
index c717961..c874f7a 100644
--- a/test/fuzz/fuzz_test_util.cpp
+++ b/test/fuzz/fuzz_test_util.cpp
@@ -72,6 +72,13 @@
   return IsEqual(env, binary_1, binary_2);
 }
 
+bool IsEqual(const spv_target_env env, const std::vector<uint32_t>& binary_1,
+             const opt::IRContext* ir_2) {
+  std::vector<uint32_t> binary_2;
+  ir_2->module()->ToBinary(&binary_2, false);
+  return IsEqual(env, binary_1, binary_2);
+}
+
 bool IsValid(spv_target_env env, const opt::IRContext* ir) {
   std::vector<uint32_t> binary;
   ir->module()->ToBinary(&binary, false);
diff --git a/test/fuzz/fuzz_test_util.h b/test/fuzz/fuzz_test_util.h
index 19d1918..1126de1 100644
--- a/test/fuzz/fuzz_test_util.h
+++ b/test/fuzz/fuzz_test_util.h
@@ -45,6 +45,11 @@
 bool IsEqual(spv_target_env env, const opt::IRContext* ir_1,
              const opt::IRContext* ir_2);
 
+// Turns |ir_2| into a binary, then returns true if and only if the resulting
+// binary is bit-wise equal to |binary_1|.
+bool IsEqual(spv_target_env env, const std::vector<uint32_t>& binary_1,
+             const opt::IRContext* ir_2);
+
 // Assembles the given IR context and returns true if and only if
 // the resulting binary is valid.
 bool IsValid(spv_target_env env, const opt::IRContext* ir);
diff --git a/test/fuzz/fuzzer_replayer_test.cpp b/test/fuzz/fuzzer_replayer_test.cpp
index bfcf4ea..fc81713 100644
--- a/test/fuzz/fuzzer_replayer_test.cpp
+++ b/test/fuzz/fuzzer_replayer_test.cpp
@@ -1684,8 +1684,8 @@
     replayer_result.applied_transformations.SerializeToString(
         &replayer_transformations_string);
     ASSERT_EQ(fuzzer_transformations_string, replayer_transformations_string);
-    ASSERT_EQ(fuzzer_result.transformed_binary,
-              replayer_result.transformed_binary);
+    ASSERT_TRUE(IsEqual(env, fuzzer_result.transformed_binary,
+                        replayer_result.transformed_module.get()));
   }
 }
 
diff --git a/test/fuzz/replayer_test.cpp b/test/fuzz/replayer_test.cpp
index 2444e9f..6a956b0 100644
--- a/test/fuzz/replayer_test.cpp
+++ b/test/fuzz/replayer_test.cpp
@@ -14,7 +14,12 @@
 
 #include "source/fuzz/replayer.h"
 
+#include "source/fuzz/data_descriptor.h"
 #include "source/fuzz/instruction_descriptor.h"
+#include "source/fuzz/transformation_add_constant_scalar.h"
+#include "source/fuzz/transformation_add_global_variable.h"
+#include "source/fuzz/transformation_add_parameter.h"
+#include "source/fuzz/transformation_add_synonym.h"
 #include "source/fuzz/transformation_split_block.h"
 #include "test/fuzz/fuzz_test_util.h"
 
@@ -169,8 +174,8 @@
                OpReturn
                OpFunctionEnd
     )";
-    ASSERT_TRUE(
-        IsEqual(env, kFullySplitShader, replayer_result.transformed_binary));
+    ASSERT_TRUE(IsEqual(env, kFullySplitShader,
+                        replayer_result.transformed_module.get()));
   }
 
   {
@@ -247,8 +252,8 @@
                OpReturn
                OpFunctionEnd
     )";
-    ASSERT_TRUE(
-        IsEqual(env, kHalfSplitShader, replayer_result.transformed_binary));
+    ASSERT_TRUE(IsEqual(env, kHalfSplitShader,
+                        replayer_result.transformed_module.get()));
   }
 
   {
@@ -263,7 +268,8 @@
               replayer_result.status);
     // No transformations should be applied
     ASSERT_EQ(0, replayer_result.applied_transformations.transformation_size());
-    ASSERT_TRUE(IsEqual(env, kTestShader, replayer_result.transformed_binary));
+    ASSERT_TRUE(
+        IsEqual(env, kTestShader, replayer_result.transformed_module.get()));
   }
 
   {
@@ -282,10 +288,122 @@
     // No transformations should be applied
     ASSERT_EQ(0, replayer_result.applied_transformations.transformation_size());
     // The output binary should be empty
-    ASSERT_TRUE(replayer_result.transformed_binary.empty());
+    ASSERT_EQ(nullptr, replayer_result.transformed_module);
   }
 }
 
+TEST(ReplayerTest, CheckFactsAfterReplay) {
+  const std::string kTestShader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %8 = OpTypeInt 32 1
+          %9 = OpTypePointer Function %8
+         %50 = OpTypePointer Private %8
+         %11 = OpConstant %8 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %10 = OpVariable %9 Function
+               OpStore %10 %11
+         %12 = OpFunctionCall %2 %6
+               OpReturn
+               OpFunctionEnd
+          %6 = OpFunction %2 None %3
+          %7 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  spvtools::ValidatorOptions validator_options;
+
+  std::vector<uint32_t> binary_in;
+  SpirvTools t(env);
+  t.SetMessageConsumer(kSilentConsumer);
+  ASSERT_TRUE(t.Assemble(kTestShader, &binary_in, kFuzzAssembleOption));
+  ASSERT_TRUE(t.Validate(binary_in));
+
+  protobufs::TransformationSequence transformations;
+  *transformations.add_transformation() =
+      TransformationAddConstantScalar(100, 8, {42}, true).ToMessage();
+  *transformations.add_transformation() =
+      TransformationAddGlobalVariable(101, 50, SpvStorageClassPrivate, 100,
+                                      true)
+          .ToMessage();
+  *transformations.add_transformation() =
+      TransformationAddParameter(6, 102, 8, {{12, 100}}, 103).ToMessage();
+  *transformations.add_transformation() =
+      TransformationAddSynonym(
+          11,
+          protobufs::TransformationAddSynonym::SynonymType::
+              TransformationAddSynonym_SynonymType_COPY_OBJECT,
+          104, MakeInstructionDescriptor(12, SpvOpFunctionCall, 0))
+          .ToMessage();
+
+  // Full replay
+  protobufs::FactSequence empty_facts;
+  auto replayer_result =
+      Replayer(env, kSilentConsumer, binary_in, empty_facts, transformations,
+               transformations.transformation_size(), 0, true,
+               validator_options)
+          .Run();
+  // Replay should succeed.
+  ASSERT_EQ(Replayer::ReplayerResultStatus::kComplete, replayer_result.status);
+  // All transformations should be applied.
+  ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals(
+      transformations, replayer_result.applied_transformations));
+
+  const std::string kExpected = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %8 = OpTypeInt 32 1
+          %9 = OpTypePointer Function %8
+         %50 = OpTypePointer Private %8
+         %11 = OpConstant %8 1
+        %100 = OpConstant %8 42
+        %101 = OpVariable %50 Private %100
+        %103 = OpTypeFunction %2 %8
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %10 = OpVariable %9 Function
+               OpStore %10 %11
+        %104 = OpCopyObject %8 %11
+         %12 = OpFunctionCall %2 %6 %100
+               OpReturn
+               OpFunctionEnd
+          %6 = OpFunction %2 None %103
+        %102 = OpFunctionParameter %8
+          %7 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  ASSERT_TRUE(
+      IsEqual(env, kExpected, replayer_result.transformed_module.get()));
+
+  ASSERT_TRUE(
+      replayer_result.transformation_context->GetFactManager()->IdIsIrrelevant(
+          100));
+  ASSERT_TRUE(replayer_result.transformation_context->GetFactManager()
+                  ->PointeeValueIsIrrelevant(101));
+  ASSERT_TRUE(
+      replayer_result.transformation_context->GetFactManager()->IdIsIrrelevant(
+          102));
+  ASSERT_TRUE(
+      replayer_result.transformation_context->GetFactManager()->IsSynonymous(
+          MakeDataDescriptor(11, {}), MakeDataDescriptor(104, {})));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools
diff --git a/tools/fuzz/fuzz.cpp b/tools/fuzz/fuzz.cpp
index 80ac9f5..d19904b 100644
--- a/tools/fuzz/fuzz.cpp
+++ b/tools/fuzz/fuzz.cpp
@@ -480,8 +480,7 @@
           initial_facts, transformation_sequence, num_transformations_to_apply,
           0, fuzzer_options->replay_validation_enabled, validator_options)
           .Run();
-
-  *binary_out = std::move(replay_result.transformed_binary);
+  replay_result.transformed_module->module()->ToBinary(binary_out, false);
   *transformations_applied = std::move(replay_result.applied_transformations);
   return replay_result.status ==
          spvtools::fuzz::Replayer::ReplayerResultStatus::kComplete;