spirv-fuzz: Handle OpPhi during constant obfuscation (#3640)

Fixes #3639.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 15d1057..f9564bc 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1317,6 +1317,31 @@
   return result;
 }
 
+opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
+                                                 uint32_t block_id,
+                                                 SpvOp opcode) {
+  // CFG::block uses std::map::at which throws an exception when |block_id| is
+  // invalid. The error message is unhelpful, though. Thus, we test that
+  // |block_id| is valid here.
+  const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
+  (void)label_inst;  // Make compilers happy in release mode.
+  assert(label_inst && label_inst->opcode() == SpvOpLabel &&
+         "|block_id| is invalid");
+
+  auto* block = ir_context->cfg()->block(block_id);
+  auto it = block->rbegin();
+  assert(it != block->rend() && "Basic block can't be empty");
+
+  if (block->GetMergeInst()) {
+    ++it;
+    assert(it != block->rend() &&
+           "|block| must have at least two instructions:"
+           "terminator and a merge instruction");
+  }
+
+  return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
+}
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 6058022..cb01ca7 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -487,6 +487,12 @@
 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data);
 
+// Returns the last instruction in |block_id| before which an instruction with
+// opcode |opcode| can be inserted, or nullptr if there is no such instruction.
+opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
+                                                 uint32_t block_id,
+                                                 SpvOp opcode);
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.cpp b/source/fuzz/transformation_replace_constant_with_uniform.cpp
index a8f9495..8de7201 100644
--- a/source/fuzz/transformation_replace_constant_with_uniform.cpp
+++ b/source/fuzz/transformation_replace_constant_with_uniform.cpp
@@ -90,6 +90,40 @@
                                       operands_for_load);
 }
 
+opt::Instruction*
+TransformationReplaceConstantWithUniform::GetInsertBeforeInstruction(
+    opt::IRContext* ir_context) const {
+  auto* result =
+      FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
+  if (!result) {
+    return nullptr;
+  }
+
+  // The use might be in an OpPhi instruction.
+  if (result->opcode() == SpvOpPhi) {
+    // OpPhi instructions must be the first instructions in a block. Thus, we
+    // can't insert above the OpPhi instruction. Given the predecessor block
+    // that corresponds to the id use, get the last instruction in that block
+    // above which we can insert OpAccessChain and OpLoad.
+    return fuzzerutil::GetLastInsertBeforeInstruction(
+        ir_context,
+        result->GetSingleWordInOperand(
+            message_.id_use_descriptor().in_operand_index() + 1),
+        SpvOpLoad);
+  }
+
+  // The only operand that we could've replaced in the OpBranchConditional is
+  // the condition id. But that operand has a boolean type and uniform variables
+  // can't store booleans (see the spec on OpTypeBool). Thus, |result| can't be
+  // an OpBranchConditional.
+  assert(result->opcode() != SpvOpBranchConditional &&
+         "OpBranchConditional has no operands to replace");
+
+  assert(fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpLoad, result) &&
+         "We should be able to insert OpLoad and OpAccessChain at this point");
+  return result;
+}
+
 bool TransformationReplaceConstantWithUniform::IsApplicable(
     opt::IRContext* ir_context,
     const TransformationContext& transformation_context) const {
@@ -188,6 +222,12 @@
     }
   }
 
+  // Once all checks are completed, we should be able to safely insert
+  // OpAccessChain and OpLoad into the module.
+  assert(GetInsertBeforeInstruction(ir_context) &&
+         "There must exist an instruction that we can use to insert "
+         "OpAccessChain and OpLoad above");
+
   return true;
 }
 
@@ -195,7 +235,7 @@
     spvtools::opt::IRContext* ir_context,
     TransformationContext* /*unused*/) const {
   // Get the instruction that contains the id use we wish to replace.
-  auto instruction_containing_constant_use =
+  auto* instruction_containing_constant_use =
       FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
   assert(instruction_containing_constant_use &&
          "Precondition requires that the id use can be found.");
@@ -210,12 +250,17 @@
           ->GetDef(message_.id_use_descriptor().id_of_interest())
           ->type_id();
 
+  // Get an instruction that will be used to insert OpAccessChain and OpLoad.
+  auto* insert_before_inst = GetInsertBeforeInstruction(ir_context);
+  assert(insert_before_inst &&
+         "There must exist an insertion point for OpAccessChain and OpLoad");
+
   // Add an access chain instruction to target the uniform element.
-  instruction_containing_constant_use->InsertBefore(
+  insert_before_inst->InsertBefore(
       MakeAccessChainInstruction(ir_context, constant_type_id));
 
   // Add a load from this access chain.
-  instruction_containing_constant_use->InsertBefore(
+  insert_before_inst->InsertBefore(
       MakeLoadInstruction(ir_context, constant_type_id));
 
   // Adjust the instruction containing the usage of the constant so that this
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.h b/source/fuzz/transformation_replace_constant_with_uniform.h
index b72407c..c507c32 100644
--- a/source/fuzz/transformation_replace_constant_with_uniform.h
+++ b/source/fuzz/transformation_replace_constant_with_uniform.h
@@ -84,6 +84,11 @@
   std::unique_ptr<opt::Instruction> MakeLoadInstruction(
       spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const;
 
+  // OpAccessChain and OpLoad will be inserted above the instruction returned
+  // by this function. Returns nullptr if no such instruction is present.
+  opt::Instruction* GetInsertBeforeInstruction(
+      opt::IRContext* ir_context) const;
+
   protobufs::TransformationReplaceConstantWithUniform message_;
 };
 
diff --git a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
index 8cbba46..79757b3 100644
--- a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
+++ b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
+
 #include "source/fuzz/instruction_descriptor.h"
 #include "source/fuzz/uniform_buffer_element_descriptor.h"
 #include "test/fuzz/fuzz_test_util.h"
@@ -1548,6 +1549,111 @@
                    .IsApplicable(context.get(), transformation_context));
 }
 
+TEST(TransformationReplaceConstantWithUniformTest, ReplaceOpPhiOperand) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpDecorate %32 DescriptorSet 0
+               OpDecorate %32 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpConstant %6 2
+         %13 = OpConstant %6 4
+         %21 = OpConstant %6 1
+         %34 = OpConstant %6 0
+         %10 = OpTypeBool
+         %30 = OpTypeStruct %6
+         %31 = OpTypePointer Uniform %30
+         %32 = OpVariable %31 Uniform
+         %33 = OpTypePointer Uniform %6
+          %4 = OpFunction %2 None %3
+         %11 = OpLabel
+               OpBranch %5
+          %5 = OpLabel
+         %23 = OpPhi %6 %7 %11 %20 %15
+          %9 = OpSLessThan %10 %23 %13
+               OpLoopMerge %8 %15 None
+               OpBranchConditional %9 %15 %8
+         %15 = OpLabel
+         %20 = OpIAdd %6 %23 %21
+               OpBranch %5
+          %8 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  auto int_descriptor = MakeUniformBufferElementDescriptor(0, 0, {0});
+
+  ASSERT_TRUE(
+      AddFactHelper(&transformation_context, context.get(), 2, int_descriptor));
+
+  {
+    TransformationReplaceConstantWithUniform transformation(
+        MakeIdUseDescriptor(7, MakeInstructionDescriptor(23, SpvOpPhi, 0), 0),
+        int_descriptor, 50, 51);
+    ASSERT_TRUE(
+        transformation.IsApplicable(context.get(), transformation_context));
+    transformation.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+  }
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpDecorate %32 DescriptorSet 0
+               OpDecorate %32 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpConstant %6 2
+         %13 = OpConstant %6 4
+         %21 = OpConstant %6 1
+         %34 = OpConstant %6 0
+         %10 = OpTypeBool
+         %30 = OpTypeStruct %6
+         %31 = OpTypePointer Uniform %30
+         %32 = OpVariable %31 Uniform
+         %33 = OpTypePointer Uniform %6
+          %4 = OpFunction %2 None %3
+         %11 = OpLabel
+         %50 = OpAccessChain %33 %32 %34
+         %51 = OpLoad %6 %50
+               OpBranch %5
+          %5 = OpLabel
+         %23 = OpPhi %6 %51 %11 %20 %15
+          %9 = OpSLessThan %10 %23 %13
+               OpLoopMerge %8 %15 None
+               OpBranchConditional %9 %15 %8
+         %15 = OpLabel
+         %20 = OpIAdd %6 %23 %21
+               OpBranch %5
+          %8 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools