// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {
namespace {

// Returns true if |a| and |b| are instructions defining pointers that point to
// types logically match and the decorations that apply to |b| are a subset
// of the decorations that apply to |a|.
bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
                              ValidationState_t& _) {
  if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
    return false;
  }

  const auto& dec_a = _.id_decorations(a->id());
  const auto& dec_b = _.id_decorations(b->id());
  for (const auto& dec : dec_b) {
    if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
      return false;
    }
  }

  uint32_t a_type = a->GetOperandAs<uint32_t>(2);
  uint32_t b_type = b->GetOperandAs<uint32_t>(2);

  if (a_type == b_type) {
    return true;
  }

  Instruction* a_type_inst = _.FindDef(a_type);
  Instruction* b_type_inst = _.FindDef(b_type);

  return _.LogicallyMatch(a_type_inst, b_type_inst, true);
}

spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  const auto function_type = _.FindDef(function_type_id);
  if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
           << "' is not a function type.";
  }

  const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  if (return_id != inst->type_id()) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
           << "' does not match the Function Type's return type <id> '"
           << _.getIdName(return_id) << "'.";
  }

  const std::vector<SpvOp> acceptable = {
      SpvOpDecorate,
      SpvOpEnqueueKernel,
      SpvOpEntryPoint,
      SpvOpExecutionMode,
      SpvOpExecutionModeId,
      SpvOpFunctionCall,
      SpvOpGetKernelNDrangeSubGroupCount,
      SpvOpGetKernelNDrangeMaxSubGroupSize,
      SpvOpGetKernelWorkGroupSize,
      SpvOpGetKernelPreferredWorkGroupSizeMultiple,
      SpvOpGetKernelLocalSizeForSubgroupCount,
      SpvOpGetKernelMaxNumSubgroups,
      SpvOpName};
  for (auto& pair : inst->uses()) {
    const auto* use = pair.first;
    if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
        acceptable.end()) {
      return _.diag(SPV_ERROR_INVALID_ID, use)
             << "Invalid use of function result id " << _.getIdName(inst->id())
             << ".";
    }
  }

  return SPV_SUCCESS;
}

spv_result_t ValidateFunctionParameter(ValidationState_t& _,
                                       const Instruction* inst) {
  // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  size_t param_index = 0;
  size_t inst_num = inst->LineNum() - 1;
  if (inst_num == 0) {
    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
           << "Function parameter cannot be the first instruction.";
  }

  auto func_inst = &_.ordered_instructions()[inst_num];
  while (--inst_num) {
    func_inst = &_.ordered_instructions()[inst_num];
    if (func_inst->opcode() == SpvOpFunction) {
      break;
    } else if (func_inst->opcode() == SpvOpFunctionParameter) {
      ++param_index;
    }
  }

  if (func_inst->opcode() != SpvOpFunction) {
    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
           << "Function parameter must be preceded by a function.";
  }

  const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  const auto function_type = _.FindDef(function_type_id);
  if (!function_type) {
    return _.diag(SPV_ERROR_INVALID_ID, func_inst)
           << "Missing function type definition.";
  }
  if (param_index >= function_type->words().size() - 3) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "Too many OpFunctionParameters for " << func_inst->id()
           << ": expected " << function_type->words().size() - 3
           << " based on the function's type";
  }

  const auto param_type =
      _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  if (!param_type || inst->type_id() != param_type->id()) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunctionParameter Result Type <id> '"
           << _.getIdName(inst->type_id())
           << "' does not match the OpTypeFunction parameter "
              "type of the same index.";
  }

  // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
  // RestrictPointerEXT, or AliasedPointerEXT.
  auto param_nonarray_type_id = param_type->id();
  while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
    param_nonarray_type_id =
        _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
  }
  if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
    auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
    if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
        SpvStorageClassPhysicalStorageBufferEXT) {
      // check for Aliased or Restrict
      const auto& decorations = _.id_decorations(inst->id());

      bool foundAliased = std::any_of(
          decorations.begin(), decorations.end(), [](const Decoration& d) {
            return SpvDecorationAliased == d.dec_type();
          });

      bool foundRestrict = std::any_of(
          decorations.begin(), decorations.end(), [](const Decoration& d) {
            return SpvDecorationRestrict == d.dec_type();
          });

      if (!foundAliased && !foundRestrict) {
        return _.diag(SPV_ERROR_INVALID_ID, inst)
               << "OpFunctionParameter " << inst->id()
               << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
                  "pointer.";
      }
      if (foundAliased && foundRestrict) {
        return _.diag(SPV_ERROR_INVALID_ID, inst)
               << "OpFunctionParameter " << inst->id()
               << ": can't specify both Aliased and Restrict for "
                  "PhysicalStorageBufferEXT pointer.";
      }
    } else {
      const auto pointee_type_id =
          param_nonarray_type->GetOperandAs<uint32_t>(2);
      const auto pointee_type = _.FindDef(pointee_type_id);
      if (SpvOpTypePointer == pointee_type->opcode() &&
          pointee_type->GetOperandAs<uint32_t>(1u) ==
              SpvStorageClassPhysicalStorageBufferEXT) {
        // check for AliasedPointerEXT/RestrictPointerEXT
        const auto& decorations = _.id_decorations(inst->id());

        bool foundAliased = std::any_of(
            decorations.begin(), decorations.end(), [](const Decoration& d) {
              return SpvDecorationAliasedPointerEXT == d.dec_type();
            });

        bool foundRestrict = std::any_of(
            decorations.begin(), decorations.end(), [](const Decoration& d) {
              return SpvDecorationRestrictPointerEXT == d.dec_type();
            });

        if (!foundAliased && !foundRestrict) {
          return _.diag(SPV_ERROR_INVALID_ID, inst)
                 << "OpFunctionParameter " << inst->id()
                 << ": expected AliasedPointerEXT or RestrictPointerEXT for "
                    "PhysicalStorageBufferEXT pointer.";
        }
        if (foundAliased && foundRestrict) {
          return _.diag(SPV_ERROR_INVALID_ID, inst)
                 << "OpFunctionParameter " << inst->id()
                 << ": can't specify both AliasedPointerEXT and "
                    "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
        }
      }
    }
  }

  return SPV_SUCCESS;
}

spv_result_t ValidateFunctionCall(ValidationState_t& _,
                                  const Instruction* inst) {
  const auto function_id = inst->GetOperandAs<uint32_t>(2);
  const auto function = _.FindDef(function_id);
  if (!function || SpvOpFunction != function->opcode()) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
           << "' is not a function.";
  }

  auto return_type = _.FindDef(function->type_id());
  if (!return_type || return_type->id() != inst->type_id()) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunctionCall Result Type <id> '"
           << _.getIdName(inst->type_id())
           << "'s type does not match Function <id> '"
           << _.getIdName(return_type->id()) << "'s return type.";
  }

  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  const auto function_type = _.FindDef(function_type_id);
  if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "Missing function type definition.";
  }

  const auto function_call_arg_count = inst->words().size() - 4;
  const auto function_param_count = function_type->words().size() - 3;
  if (function_param_count != function_call_arg_count) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "OpFunctionCall Function <id>'s parameter count does not match "
              "the argument count.";
  }

  for (size_t argument_index = 3, param_index = 2;
       argument_index < inst->operands().size();
       argument_index++, param_index++) {
    const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
    const auto argument = _.FindDef(argument_id);
    if (!argument) {
      return _.diag(SPV_ERROR_INVALID_ID, inst)
             << "Missing argument " << argument_index - 3 << " definition.";
    }

    const auto argument_type = _.FindDef(argument->type_id());
    if (!argument_type) {
      return _.diag(SPV_ERROR_INVALID_ID, inst)
             << "Missing argument " << argument_index - 3
             << " type definition.";
    }

    const auto parameter_type_id =
        function_type->GetOperandAs<uint32_t>(param_index);
    const auto parameter_type = _.FindDef(parameter_type_id);
    if (!parameter_type || argument_type->id() != parameter_type->id()) {
      if (!_.options()->before_hlsl_legalization ||
          !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
        return _.diag(SPV_ERROR_INVALID_ID, inst)
               << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
               << "'s type does not match Function <id> '"
               << _.getIdName(parameter_type_id) << "'s parameter type.";
      }
    }

    if (_.addressing_model() == SpvAddressingModelLogical) {
      if (parameter_type->opcode() == SpvOpTypePointer &&
          !_.options()->relax_logical_pointer) {
        SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
        // Validate which storage classes can be pointer operands.
        switch (sc) {
          case SpvStorageClassUniformConstant:
          case SpvStorageClassFunction:
          case SpvStorageClassPrivate:
          case SpvStorageClassWorkgroup:
          case SpvStorageClassAtomicCounter:
            // These are always allowed.
            break;
          case SpvStorageClassStorageBuffer:
            if (!_.features().variable_pointers_storage_buffer) {
              return _.diag(SPV_ERROR_INVALID_ID, inst)
                     << "StorageBuffer pointer operand "
                     << _.getIdName(argument_id)
                     << " requires a variable pointers capability";
            }
            break;
          default:
            return _.diag(SPV_ERROR_INVALID_ID, inst)
                   << "Invalid storage class for pointer operand "
                   << _.getIdName(argument_id);
        }

        // Validate memory object declaration requirements.
        if (argument->opcode() != SpvOpVariable &&
            argument->opcode() != SpvOpFunctionParameter) {
          const bool ssbo_vptr =
              _.features().variable_pointers_storage_buffer &&
              sc == SpvStorageClassStorageBuffer;
          const bool wg_vptr =
              _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
          const bool uc_ptr = sc == SpvStorageClassUniformConstant;
          if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
            return _.diag(SPV_ERROR_INVALID_ID, inst)
                   << "Pointer operand " << _.getIdName(argument_id)
                   << " must be a memory object declaration";
          }
        }
      }
    }
  }
  return SPV_SUCCESS;
}

}  // namespace

spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  switch (inst->opcode()) {
    case SpvOpFunction:
      if (auto error = ValidateFunction(_, inst)) return error;
      break;
    case SpvOpFunctionParameter:
      if (auto error = ValidateFunctionParameter(_, inst)) return error;
      break;
    case SpvOpFunctionCall:
      if (auto error = ValidateFunctionCall(_, inst)) return error;
      break;
    default:
      break;
  }

  return SPV_SUCCESS;
}

}  // namespace val
}  // namespace spvtools
