// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/fuzz/transformation_add_function.h"

#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/instruction_message.h"

namespace spvtools {
namespace fuzz {

TransformationAddFunction::TransformationAddFunction(
    const spvtools::fuzz::protobufs::TransformationAddFunction& message)
    : message_(message) {}

TransformationAddFunction::TransformationAddFunction(
    const std::vector<protobufs::Instruction>& instructions) {
  for (auto& instruction : instructions) {
    *message_.add_instruction() = instruction;
  }
  message_.set_is_livesafe(false);
}

TransformationAddFunction::TransformationAddFunction(
    const std::vector<protobufs::Instruction>& instructions,
    uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
    const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
    uint32_t kill_unreachable_return_value_id,
    const std::vector<protobufs::AccessChainClampingInfo>&
        access_chain_clampers) {
  for (auto& instruction : instructions) {
    *message_.add_instruction() = instruction;
  }
  message_.set_is_livesafe(true);
  message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  message_.set_loop_limit_constant_id(loop_limit_constant_id);
  for (auto& loop_limiter : loop_limiters) {
    *message_.add_loop_limiter_info() = loop_limiter;
  }
  message_.set_kill_unreachable_return_value_id(
      kill_unreachable_return_value_id);
  for (auto& access_clamper : access_chain_clampers) {
    *message_.add_access_chain_clamping_info() = access_clamper;
  }
}

bool TransformationAddFunction::IsApplicable(
    opt::IRContext* ir_context,
    const TransformationContext& transformation_context) const {
  // This transformation may use a lot of ids, all of which need to be fresh
  // and distinct.  This set tracks them.
  std::set<uint32_t> ids_used_by_this_transformation;

  // Ensure that all result ids in the new function are fresh and distinct.
  for (auto& instruction : message_.instruction()) {
    if (instruction.result_id()) {
      if (!CheckIdIsFreshAndNotUsedByThisTransformation(
              instruction.result_id(), ir_context,
              &ids_used_by_this_transformation)) {
        return false;
      }
    }
  }

  if (message_.is_livesafe()) {
    // Ensure that all ids provided for making the function livesafe are fresh
    // and distinct.
    if (!CheckIdIsFreshAndNotUsedByThisTransformation(
            message_.loop_limiter_variable_id(), ir_context,
            &ids_used_by_this_transformation)) {
      return false;
    }
    for (auto& loop_limiter_info : message_.loop_limiter_info()) {
      if (!CheckIdIsFreshAndNotUsedByThisTransformation(
              loop_limiter_info.load_id(), ir_context,
              &ids_used_by_this_transformation)) {
        return false;
      }
      if (!CheckIdIsFreshAndNotUsedByThisTransformation(
              loop_limiter_info.increment_id(), ir_context,
              &ids_used_by_this_transformation)) {
        return false;
      }
      if (!CheckIdIsFreshAndNotUsedByThisTransformation(
              loop_limiter_info.compare_id(), ir_context,
              &ids_used_by_this_transformation)) {
        return false;
      }
      if (!CheckIdIsFreshAndNotUsedByThisTransformation(
              loop_limiter_info.logical_op_id(), ir_context,
              &ids_used_by_this_transformation)) {
        return false;
      }
    }
    for (auto& access_chain_clamping_info :
         message_.access_chain_clamping_info()) {
      for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
        if (!CheckIdIsFreshAndNotUsedByThisTransformation(
                pair.first(), ir_context, &ids_used_by_this_transformation)) {
          return false;
        }
        if (!CheckIdIsFreshAndNotUsedByThisTransformation(
                pair.second(), ir_context, &ids_used_by_this_transformation)) {
          return false;
        }
      }
    }
  }

  // Because checking all the conditions for a function to be valid is a big
  // job that the SPIR-V validator can already do, a "try it and see" approach
  // is taken here.

  // We first clone the current module, so that we can try adding the new
  // function without risking wrecking |ir_context|.
  auto cloned_module = fuzzerutil::CloneIRContext(ir_context);

  // We try to add a function to the cloned module, which may fail if
  // |message_.instruction| is not sufficiently well-formed.
  if (!TryToAddFunction(cloned_module.get())) {
    return false;
  }

  // Check whether the cloned module is still valid after adding the function.
  // If it is not, the transformation is not applicable.
  if (!fuzzerutil::IsValid(cloned_module.get(),
                           transformation_context.GetValidatorOptions(),
                           fuzzerutil::kSilentMessageConsumer)) {
    return false;
  }

  if (message_.is_livesafe()) {
    if (!TryToMakeFunctionLivesafe(cloned_module.get(),
                                   transformation_context)) {
      return false;
    }
    // After making the function livesafe, we check validity of the module
    // again.  This is because the turning of OpKill, OpUnreachable and OpReturn
    // instructions into branches changes control flow graph reachability, which
    // has the potential to make the module invalid when it was otherwise valid.
    // It is simpler to rely on the validator to guard against this than to
    // consider all scenarios when making a function livesafe.
    if (!fuzzerutil::IsValid(cloned_module.get(),
                             transformation_context.GetValidatorOptions(),
                             fuzzerutil::kSilentMessageConsumer)) {
      return false;
    }
  }
  return true;
}

void TransformationAddFunction::Apply(
    opt::IRContext* ir_context,
    TransformationContext* transformation_context) const {
  // Add the function to the module.  As the transformation is applicable, this
  // should succeed.
  bool success = TryToAddFunction(ir_context);
  assert(success && "The function should be successfully added.");
  (void)(success);  // Keep release builds happy (otherwise they may complain
                    // that |success| is not used).

  if (message_.is_livesafe()) {
    // Make the function livesafe, which also should succeed.
    success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
    assert(success && "It should be possible to make the function livesafe.");
    (void)(success);  // Keep release builds happy.
  }
  ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);

  assert(message_.instruction(0).opcode() == SpvOpFunction &&
         "The first instruction of an 'add function' transformation must be "
         "OpFunction.");

  if (message_.is_livesafe()) {
    // Inform the fact manager that the function is livesafe.
    transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
        message_.instruction(0).result_id());
  } else {
    // Inform the fact manager that all blocks in the function are dead.
    for (auto& inst : message_.instruction()) {
      if (inst.opcode() == SpvOpLabel) {
        transformation_context->GetFactManager()->AddFactBlockIsDead(
            inst.result_id());
      }
    }
  }

  // Record the fact that all pointer parameters and variables declared in the
  // function should be regarded as having irrelevant values.  This allows other
  // passes to store arbitrarily to such variables, and to pass them freely as
  // parameters to other functions knowing that it is OK if they get
  // over-written.
  for (auto& instruction : message_.instruction()) {
    switch (instruction.opcode()) {
      case SpvOpFunctionParameter:
        if (ir_context->get_def_use_mgr()
                ->GetDef(instruction.result_type_id())
                ->opcode() == SpvOpTypePointer) {
          transformation_context->GetFactManager()
              ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
        }
        break;
      case SpvOpVariable:
        transformation_context->GetFactManager()
            ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
        break;
      default:
        break;
    }
  }
}

protobufs::Transformation TransformationAddFunction::ToMessage() const {
  protobufs::Transformation result;
  *result.mutable_add_function() = message_;
  return result;
}

bool TransformationAddFunction::TryToAddFunction(
    opt::IRContext* ir_context) const {
  // This function returns false if |message_.instruction| was not well-formed
  // enough to actually create a function and add it to |ir_context|.

  // A function must have at least some instructions.
  if (message_.instruction().empty()) {
    return false;
  }

  // A function must start with OpFunction.
  auto function_begin = message_.instruction(0);
  if (function_begin.opcode() != SpvOpFunction) {
    return false;
  }

  // Make a function, headed by the OpFunction instruction.
  std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
      InstructionFromMessage(ir_context, function_begin));

  // Keeps track of which instruction protobuf message we are currently
  // considering.
  uint32_t instruction_index = 1;
  const auto num_instructions =
      static_cast<uint32_t>(message_.instruction().size());

  // Iterate through all function parameter instructions, adding parameters to
  // the new function.
  while (instruction_index < num_instructions &&
         message_.instruction(instruction_index).opcode() ==
             SpvOpFunctionParameter) {
    new_function->AddParameter(InstructionFromMessage(
        ir_context, message_.instruction(instruction_index)));
    instruction_index++;
  }

  // After the parameters, there needs to be a label.
  if (instruction_index == num_instructions ||
      message_.instruction(instruction_index).opcode() != SpvOpLabel) {
    return false;
  }

  // Iterate through the instructions block by block until the end of the
  // function is reached.
  while (instruction_index < num_instructions &&
         message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
    // Invariant: we should always be at a label instruction at this point.
    assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);

    // Make a basic block using the label instruction, with the new function
    // as its parent.
    std::unique_ptr<opt::BasicBlock> block =
        MakeUnique<opt::BasicBlock>(InstructionFromMessage(
            ir_context, message_.instruction(instruction_index)));
    block->SetParent(new_function.get());

    // Consider successive instructions until we hit another label or the end
    // of the function, adding each such instruction to the block.
    instruction_index++;
    while (instruction_index < num_instructions &&
           message_.instruction(instruction_index).opcode() !=
               SpvOpFunctionEnd &&
           message_.instruction(instruction_index).opcode() != SpvOpLabel) {
      block->AddInstruction(InstructionFromMessage(
          ir_context, message_.instruction(instruction_index)));
      instruction_index++;
    }
    // Add the block to the new function.
    new_function->AddBasicBlock(std::move(block));
  }
  // Having considered all the blocks, we should be at the last instruction and
  // it needs to be OpFunctionEnd.
  if (instruction_index != num_instructions - 1 ||
      message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
    return false;
  }
  // Set the function's final instruction, add the function to the module and
  // report success.
  new_function->SetFunctionEnd(InstructionFromMessage(
      ir_context, message_.instruction(instruction_index)));
  ir_context->AddFunction(std::move(new_function));

  ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);

  return true;
}

bool TransformationAddFunction::TryToMakeFunctionLivesafe(
    opt::IRContext* ir_context,
    const TransformationContext& transformation_context) const {
  assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");

  // Get a pointer to the added function.
  opt::Function* added_function = nullptr;
  for (auto& function : *ir_context->module()) {
    if (function.result_id() == message_.instruction(0).result_id()) {
      added_function = &function;
      break;
    }
  }
  assert(added_function && "The added function should have been found.");

  if (!TryToAddLoopLimiters(ir_context, added_function)) {
    // Adding loop limiters did not work; bail out.
    return false;
  }

  // Consider all the instructions in the function, and:
  // - attempt to replace OpKill and OpUnreachable with return instructions
  // - attempt to clamp access chains to be within bounds
  // - check that OpFunctionCall instructions are only to livesafe functions
  for (auto& block : *added_function) {
    for (auto& inst : block) {
      switch (inst.opcode()) {
        case SpvOpKill:
        case SpvOpUnreachable:
          if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
                                                    &inst)) {
            return false;
          }
          break;
        case SpvOpAccessChain:
        case SpvOpInBoundsAccessChain:
          if (!TryToClampAccessChainIndices(ir_context, &inst)) {
            return false;
          }
          break;
        case SpvOpFunctionCall:
          // A livesafe function my only call other livesafe functions.
          if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
                  inst.GetSingleWordInOperand(0))) {
            return false;
          }
        default:
          break;
      }
    }
  }
  return true;
}

uint32_t TransformationAddFunction::GetBackEdgeBlockId(
    opt::IRContext* ir_context, uint32_t loop_header_block_id) {
  const auto* loop_header_block =
      ir_context->cfg()->block(loop_header_block_id);
  assert(loop_header_block && "|loop_header_block_id| is invalid");

  for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
    if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
            ->Dominates(loop_header_block_id, pred)) {
      return pred;
    }
  }

  return 0;
}

bool TransformationAddFunction::TryToAddLoopLimiters(
    opt::IRContext* ir_context, opt::Function* added_function) const {
  // Collect up all the loop headers so that we can subsequently add loop
  // limiting logic.
  std::vector<opt::BasicBlock*> loop_headers;
  for (auto& block : *added_function) {
    if (block.IsLoopHeader()) {
      loop_headers.push_back(&block);
    }
  }

  if (loop_headers.empty()) {
    // There are no loops, so no need to add any loop limiters.
    return true;
  }

  // Check that the module contains appropriate ingredients for declaring and
  // manipulating a loop limiter.

  auto loop_limit_constant_id_instr =
      ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  if (!loop_limit_constant_id_instr ||
      loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
    // The loop limit constant id instruction must exist and have an
    // appropriate opcode.
    return false;
  }

  auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
      loop_limit_constant_id_instr->type_id());
  if (loop_limit_type->opcode() != SpvOpTypeInt ||
      loop_limit_type->GetSingleWordInOperand(0) != 32) {
    // The type of the loop limit constant must be 32-bit integer.  It
    // doesn't actually matter whether the integer is signed or not.
    return false;
  }

  // Find the id of the "unsigned int" type.
  opt::analysis::Integer unsigned_int_type(32, false);
  uint32_t unsigned_int_type_id =
      ir_context->get_type_mgr()->GetId(&unsigned_int_type);
  if (!unsigned_int_type_id) {
    // Unsigned int is not available; we need this type in order to add loop
    // limiters.
    return false;
  }
  auto registered_unsigned_int_type =
      ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);

  // Look for 0 of type unsigned int.
  opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
                                  {0});
  auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
  if (!registered_zero) {
    // We need 0 in order to be able to initialize loop limiters.
    return false;
  }
  uint32_t zero_id = ir_context->get_constant_mgr()
                         ->GetDefiningInstruction(registered_zero)
                         ->result_id();

  // Look for 1 of type unsigned int.
  opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
                                 {1});
  auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
  if (!registered_one) {
    // We need 1 in order to be able to increment loop limiters.
    return false;
  }
  uint32_t one_id = ir_context->get_constant_mgr()
                        ->GetDefiningInstruction(registered_one)
                        ->result_id();

  // Look for pointer-to-unsigned int type.
  opt::analysis::Pointer pointer_to_unsigned_int_type(
      registered_unsigned_int_type, SpvStorageClassFunction);
  uint32_t pointer_to_unsigned_int_type_id =
      ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  if (!pointer_to_unsigned_int_type_id) {
    // We need pointer-to-unsigned int in order to declare the loop limiter
    // variable.
    return false;
  }

  // Look for bool type.
  opt::analysis::Bool bool_type;
  uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  if (!bool_type_id) {
    // We need bool in order to compare the loop limiter's value with the loop
    // limit constant.
    return false;
  }

  // Declare the loop limiter variable at the start of the function's entry
  // block, via an instruction of the form:
  //   %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
  added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
      ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id,
      message_.loop_limiter_variable_id(),
      opt::Instruction::OperandList(
          {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
           {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  // Update the module's id bound since we have added the loop limiter
  // variable id.
  fuzzerutil::UpdateModuleIdBound(ir_context,
                                  message_.loop_limiter_variable_id());

  // Consider each loop in turn.
  for (auto loop_header : loop_headers) {
    // Look for the loop's back-edge block.  This is a predecessor of the loop
    // header that is dominated by the loop header.
    const auto back_edge_block_id =
        GetBackEdgeBlockId(ir_context, loop_header->id());
    if (!back_edge_block_id) {
      // The loop's back-edge block must be unreachable.  This means that the
      // loop cannot iterate, so there is no need to make it lifesafe; we can
      // move on from this loop.
      continue;
    }

    // If the loop's merge block is unreachable, then there are no constraints
    // on where the merge block appears in relation to the blocks of the loop.
    // This means we need to be careful when adding a branch from the back-edge
    // block to the merge block: the branch might make the loop merge reachable,
    // and it might then be dominated by the loop header and possibly by other
    // blocks in the loop. Since a block needs to appear before those blocks it
    // strictly dominates, this could make the module invalid. To avoid this
    // problem we bail out in the case where the loop header does not dominate
    // the loop merge.
    if (!ir_context->GetDominatorAnalysis(added_function)
             ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
      return false;
    }

    // Go through the sequence of loop limiter infos and find the one
    // corresponding to this loop.
    bool found = false;
    protobufs::LoopLimiterInfo loop_limiter_info;
    for (auto& info : message_.loop_limiter_info()) {
      if (info.loop_header_id() == loop_header->id()) {
        loop_limiter_info = info;
        found = true;
        break;
      }
    }
    if (!found) {
      // We don't have loop limiter info for this loop header.
      return false;
    }

    // The back-edge block either has the form:
    //
    // (1)
    //
    // %l = OpLabel
    //      ... instructions ...
    //      OpBranch %loop_header
    //
    // (2)
    //
    // %l = OpLabel
    //      ... instructions ...
    //      OpBranchConditional %c %loop_header %loop_merge
    //
    // (3)
    //
    // %l = OpLabel
    //      ... instructions ...
    //      OpBranchConditional %c %loop_merge %loop_header
    //
    // We turn these into the following:
    //
    // (1)
    //
    //  %l = OpLabel
    //       ... instructions ...
    // %t1 = OpLoad %uint32 %loop_limiter
    // %t2 = OpIAdd %uint32 %t1 %one
    //       OpStore %loop_limiter %t2
    // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
    //       OpBranchConditional %t3 %loop_merge %loop_header
    //
    // (2)
    //
    //  %l = OpLabel
    //       ... instructions ...
    // %t1 = OpLoad %uint32 %loop_limiter
    // %t2 = OpIAdd %uint32 %t1 %one
    //       OpStore %loop_limiter %t2
    // %t3 = OpULessThan %bool %t1 %loop_limit
    // %t4 = OpLogicalAnd %bool %c %t3
    //       OpBranchConditional %t4 %loop_header %loop_merge
    //
    // (3)
    //
    //  %l = OpLabel
    //       ... instructions ...
    // %t1 = OpLoad %uint32 %loop_limiter
    // %t2 = OpIAdd %uint32 %t1 %one
    //       OpStore %loop_limiter %t2
    // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
    // %t4 = OpLogicalOr %bool %c %t3
    //       OpBranchConditional %t4 %loop_merge %loop_header

    auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
    auto back_edge_block_terminator = back_edge_block->terminator();
    bool compare_using_greater_than_equal;
    if (back_edge_block_terminator->opcode() == SpvOpBranch) {
      compare_using_greater_than_equal = true;
    } else {
      assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
      assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
                   loop_header->id() &&
               back_edge_block_terminator->GetSingleWordInOperand(2) ==
                   loop_header->MergeBlockId()) ||
              (back_edge_block_terminator->GetSingleWordInOperand(2) ==
                   loop_header->id() &&
               back_edge_block_terminator->GetSingleWordInOperand(1) ==
                   loop_header->MergeBlockId())) &&
             "A back edge edge block must branch to"
             " either the loop header or merge");
      compare_using_greater_than_equal =
          back_edge_block_terminator->GetSingleWordInOperand(1) ==
          loop_header->MergeBlockId();
    }

    std::vector<std::unique_ptr<opt::Instruction>> new_instructions;

    // Add a load from the loop limiter variable, of the form:
    //   %t1 = OpLoad %uint32 %loop_limiter
    new_instructions.push_back(MakeUnique<opt::Instruction>(
        ir_context, SpvOpLoad, unsigned_int_type_id,
        loop_limiter_info.load_id(),
        opt::Instruction::OperandList(
            {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));

    // Increment the loaded value:
    //   %t2 = OpIAdd %uint32 %t1 %one
    new_instructions.push_back(MakeUnique<opt::Instruction>(
        ir_context, SpvOpIAdd, unsigned_int_type_id,
        loop_limiter_info.increment_id(),
        opt::Instruction::OperandList(
            {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
             {SPV_OPERAND_TYPE_ID, {one_id}}})));

    // Store the incremented value back to the loop limiter variable:
    //   OpStore %loop_limiter %t2
    new_instructions.push_back(MakeUnique<opt::Instruction>(
        ir_context, SpvOpStore, 0, 0,
        opt::Instruction::OperandList(
            {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
             {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));

    // Compare the loaded value with the loop limit; either:
    //   %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
    // or
    //   %t3 = OpULessThan %bool %t1 %loop_limit
    new_instructions.push_back(MakeUnique<opt::Instruction>(
        ir_context,
        compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
                                         : SpvOpULessThan,
        bool_type_id, loop_limiter_info.compare_id(),
        opt::Instruction::OperandList(
            {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
             {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));

    if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
      new_instructions.push_back(MakeUnique<opt::Instruction>(
          ir_context,
          compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
          bool_type_id, loop_limiter_info.logical_op_id(),
          opt::Instruction::OperandList(
              {{SPV_OPERAND_TYPE_ID,
                {back_edge_block_terminator->GetSingleWordInOperand(0)}},
               {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
    }

    // Add the new instructions at the end of the back edge block, before the
    // terminator and any loop merge instruction (as the back edge block can
    // be the loop header).
    if (back_edge_block->GetLoopMergeInst()) {
      back_edge_block->GetLoopMergeInst()->InsertBefore(
          std::move(new_instructions));
    } else {
      back_edge_block_terminator->InsertBefore(std::move(new_instructions));
    }

    if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
      back_edge_block_terminator->SetInOperand(
          0, {loop_limiter_info.logical_op_id()});
    } else {
      assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
             "Back-edge terminator must be OpBranch or OpBranchConditional");

      // Check that, if the merge block starts with OpPhi instructions, suitable
      // ids have been provided to give these instructions a value corresponding
      // to the new incoming edge from the back edge block.
      auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
      if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
                                          merge_block,
                                          loop_limiter_info.phi_id())) {
        return false;
      }

      // Augment OpPhi instructions at the loop merge with the given ids.
      uint32_t phi_index = 0;
      for (auto& inst : *merge_block) {
        if (inst.opcode() != SpvOpPhi) {
          break;
        }
        assert(phi_index <
                   static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
               "There should be at least one phi id per OpPhi instruction.");
        inst.AddOperand(
            {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
        inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
        phi_index++;
      }

      // Add the new edge, by changing OpBranch to OpBranchConditional.
      back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
      back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
          {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
           {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
           {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
    }

    // Update the module's id bound with respect to the various ids that
    // have been used for loop limiter manipulation.
    fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
    fuzzerutil::UpdateModuleIdBound(ir_context,
                                    loop_limiter_info.increment_id());
    fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
    fuzzerutil::UpdateModuleIdBound(ir_context,
                                    loop_limiter_info.logical_op_id());
  }
  return true;
}

bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
    opt::IRContext* ir_context, opt::Function* added_function,
    opt::Instruction* kill_or_unreachable_inst) const {
  assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
          kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
         "Precondition: instruction must be OpKill or OpUnreachable.");

  // Get the function's return type.
  auto function_return_type_inst =
      ir_context->get_def_use_mgr()->GetDef(added_function->type_id());

  if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
    // The function has void return type, so change this instruction to
    // OpReturn.
    kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
  } else {
    // The function has non-void return type, so change this instruction
    // to OpReturnValue, using the value id provided with the
    // transformation.

    // We first check that the id, %id, provided with the transformation
    // specifically to turn OpKill and OpUnreachable instructions into
    // OpReturnValue %id has the same type as the function's return type.
    if (ir_context->get_def_use_mgr()
            ->GetDef(message_.kill_unreachable_return_value_id())
            ->type_id() != function_return_type_inst->result_id()) {
      return false;
    }
    kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
    kill_or_unreachable_inst->SetInOperands(
        {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  }
  return true;
}

bool TransformationAddFunction::TryToClampAccessChainIndices(
    opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
  assert((access_chain_inst->opcode() == SpvOpAccessChain ||
          access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
         "Precondition: instruction must be OpAccessChain or "
         "OpInBoundsAccessChain.");

  // Find the AccessChainClampingInfo associated with this access chain.
  const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
      nullptr;
  for (auto& clamping_info : message_.access_chain_clamping_info()) {
    if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
      access_chain_clamping_info = &clamping_info;
      break;
    }
  }
  if (!access_chain_clamping_info) {
    // No access chain clamping information was found; the function cannot be
    // made livesafe.
    return false;
  }

  // Check that there is a (compare_id, select_id) pair for every
  // index associated with the instruction.
  if (static_cast<uint32_t>(
          access_chain_clamping_info->compare_and_select_ids().size()) !=
      access_chain_inst->NumInOperands() - 1) {
    return false;
  }

  // Walk the access chain, clamping each index to be within bounds if it is
  // not a constant.
  auto base_object = ir_context->get_def_use_mgr()->GetDef(
      access_chain_inst->GetSingleWordInOperand(0));
  assert(base_object && "The base object must exist.");
  auto pointer_type =
      ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
  assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
         "The base object must have pointer type.");
  auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
      pointer_type->GetSingleWordInOperand(1));

  // Consider each index input operand in turn (operand 0 is the base object).
  for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
       index++) {
    // We are going to turn:
    //
    // %result = OpAccessChain %type %object ... %index ...
    //
    // into:
    //
    // %t1 = OpULessThanEqual %bool %index %bound_minus_one
    // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
    // %result = OpAccessChain %type %object ... %t2 ...
    //
    // ... unless %index is already a constant.

    // Get the bound for the composite being indexed into; e.g. the number of
    // columns of matrix or the size of an array.
    uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
        *should_be_composite_type, ir_context);

    // Get the instruction associated with the index and figure out its integer
    // type.
    const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
    auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
    auto index_type_inst =
        ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
    assert(index_type_inst->opcode() == SpvOpTypeInt);
    assert(index_type_inst->GetSingleWordInOperand(0) == 32);
    opt::analysis::Integer* index_int_type =
        ir_context->get_type_mgr()
            ->GetType(index_type_inst->result_id())
            ->AsInteger();

    if (index_inst->opcode() != SpvOpConstant ||
        index_inst->GetSingleWordInOperand(0) >= bound) {
      // The index is either non-constant or an out-of-bounds constant, so we
      // need to clamp it.
      assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
             "Access chain indices into structures are required to be "
             "constants.");
      opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
      if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
        // We do not have an integer constant whose value is |bound| -1.
        return false;
      }

      opt::analysis::Bool bool_type;
      uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
      if (!bool_type_id) {
        // Bool type is not declared; we cannot do a comparison.
        return false;
      }

      uint32_t bound_minus_one_id =
          ir_context->get_constant_mgr()
              ->GetDefiningInstruction(&bound_minus_one)
              ->result_id();

      uint32_t compare_id =
          access_chain_clamping_info->compare_and_select_ids(index - 1).first();
      uint32_t select_id =
          access_chain_clamping_info->compare_and_select_ids(index - 1)
              .second();
      std::vector<std::unique_ptr<opt::Instruction>> new_instructions;

      // Compare the index with the bound via an instruction of the form:
      //   %t1 = OpULessThanEqual %bool %index %bound_minus_one
      new_instructions.push_back(MakeUnique<opt::Instruction>(
          ir_context, SpvOpULessThanEqual, bool_type_id, compare_id,
          opt::Instruction::OperandList(
              {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
               {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));

      // Select the index if in-bounds, otherwise one less than the bound:
      //   %t2 = OpSelect %int_type %t1 %index %bound_minus_one
      new_instructions.push_back(MakeUnique<opt::Instruction>(
          ir_context, SpvOpSelect, index_type_inst->result_id(), select_id,
          opt::Instruction::OperandList(
              {{SPV_OPERAND_TYPE_ID, {compare_id}},
               {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
               {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));

      // Add the new instructions before the access chain
      access_chain_inst->InsertBefore(std::move(new_instructions));

      // Replace %index with %t2.
      access_chain_inst->SetInOperand(index, {select_id});
      fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
      fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
    }
    should_be_composite_type =
        FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
  }
  return true;
}

opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
    opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
    uint32_t index_id) {
  uint32_t sub_object_type_id;
  switch (composite_type_inst.opcode()) {
    case SpvOpTypeArray:
    case SpvOpTypeRuntimeArray:
      sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
      break;
    case SpvOpTypeMatrix:
    case SpvOpTypeVector:
      sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
      break;
    case SpvOpTypeStruct: {
      auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
      assert(index_inst->opcode() == SpvOpConstant);
      assert(ir_context->get_def_use_mgr()
                 ->GetDef(index_inst->type_id())
                 ->opcode() == SpvOpTypeInt);
      assert(ir_context->get_def_use_mgr()
                 ->GetDef(index_inst->type_id())
                 ->GetSingleWordInOperand(0) == 32);
      uint32_t index_value = index_inst->GetSingleWordInOperand(0);
      sub_object_type_id =
          composite_type_inst.GetSingleWordInOperand(index_value);
      break;
    }
    default:
      assert(false && "Unknown composite type.");
      sub_object_type_id = 0;
      break;
  }
  assert(sub_object_type_id && "No sub-object found.");
  return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
}

std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
  std::unordered_set<uint32_t> result;
  for (auto& instruction : message_.instruction()) {
    result.insert(instruction.result_id());
  }
  if (message_.is_livesafe()) {
    result.insert(message_.loop_limiter_variable_id());
    for (auto& loop_limiter_info : message_.loop_limiter_info()) {
      result.insert(loop_limiter_info.load_id());
      result.insert(loop_limiter_info.increment_id());
      result.insert(loop_limiter_info.compare_id());
      result.insert(loop_limiter_info.logical_op_id());
    }
    for (auto& access_chain_clamping_info :
         message_.access_chain_clamping_info()) {
      for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
        result.insert(pair.first());
        result.insert(pair.second());
      }
    }
  }
  return result;
}

}  // namespace fuzz
}  // namespace spvtools
