spirv-val: Make Constant evaluation consistent (#5587)

Bring 64-bit evaluation in line with 32-bit evaluation.
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 42fbc52..a7e9942 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -1120,7 +1120,7 @@
 
   if (num_components != 0) {
     uint64_t actual_num_components = 0;
-    if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) {
+    if (!_.EvalConstantValUint64(type_inst->word(3), &actual_num_components)) {
       assert(0 && "Array type definition is corrupt");
     }
     if (actual_num_components != num_components) {
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index ed043b6..26486da 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -94,7 +94,7 @@
           break;
         }
 
-        if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+        if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
           assert(0 && "Array type definition is corrupt");
         }
         if (component_index >= array_size) {
@@ -289,7 +289,7 @@
       }
 
       uint64_t array_size = 0;
-      if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
+      if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
         assert(0 && "Array type definition is corrupt");
       }
 
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index 0334b60..7b73c9c 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -3100,7 +3100,7 @@
 
           uint32_t vector_count = inst->word(6);
           uint64_t const_val;
-          if (!_.GetConstantValUint64(vector_count, &const_val)) {
+          if (!_.EvalConstantValUint64(vector_count, &const_val)) {
             return _.diag(SPV_ERROR_INVALID_DATA, inst)
                    << ext_inst_name()
                    << ": Vector Count must be 32-bit integer OpConstant";
@@ -3191,7 +3191,7 @@
           uint32_t component_count = inst->word(6);
           if (vulkanDebugInfo) {
             uint64_t const_val;
-            if (!_.GetConstantValUint64(component_count, &const_val)) {
+            if (!_.EvalConstantValUint64(component_count, &const_val)) {
               return _.diag(SPV_ERROR_INVALID_DATA, inst)
                      << ext_inst_name()
                      << ": Component Count must be 32-bit integer OpConstant";
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp
index 46a32f2..543f345 100644
--- a/source/val/validate_image.cpp
+++ b/source/val/validate_image.cpp
@@ -495,7 +495,7 @@
     }
 
     uint64_t array_size = 0;
-    if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+    if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
       assert(0 && "Array type definition is corrupt");
     }
 
@@ -1210,7 +1210,7 @@
 
   if (info.multisampled == 0) {
     uint64_t ms = 0;
-    if (!_.GetConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
+    if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
         ms != 0) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Sample for Image with MS 0 to be a valid <id> for "
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 5b25eeb..41dd71e 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1374,22 +1374,18 @@
       case spv::Op::OpTypeStruct: {
         // In case of structures, there is an additional constraint on the
         // index: the index must be an OpConstant.
-        if (spv::Op::OpConstant != cur_word_instr->opcode()) {
+        int64_t cur_index;
+        if (!_.EvalConstantValInt64(cur_word, &cur_index)) {
           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
                  << "The <id> passed to " << instr_name
                  << " to index into a "
                     "structure must be an OpConstant.";
         }
-        // Get the index value from the OpConstant (word 3 of OpConstant).
-        // OpConstant could be a signed integer. But it's okay to treat it as
-        // unsigned because a negative constant int would never be seen as
-        // correct as a struct offset, since structs can't have more than 2
-        // billion members.
-        const uint32_t cur_index = cur_word_instr->word(3);
+
         // The index points to the struct member we want, therefore, the index
         // should be less than the number of struct members.
-        const uint32_t num_struct_members =
-            static_cast<uint32_t>(type_pointee->words().size() - 2);
+        const int64_t num_struct_members =
+            static_cast<int64_t>(type_pointee->words().size() - 2);
         if (cur_index >= num_struct_members) {
           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
                  << "Index is out of bounds: " << instr_name
@@ -1400,7 +1396,8 @@
                  << num_struct_members - 1 << ".";
         }
         // Struct members IDs start at word 2 of OpTypeStruct.
-        auto structMemberId = type_pointee->word(cur_index + 2);
+        const size_t word_index = static_cast<size_t>(cur_index) + 2;
+        auto structMemberId = type_pointee->word(word_index);
         type_pointee = _.FindDef(structMemberId);
         break;
       }
diff --git a/source/val/validate_non_uniform.cpp b/source/val/validate_non_uniform.cpp
index 74449e9..75967d2 100644
--- a/source/val/validate_non_uniform.cpp
+++ b/source/val/validate_non_uniform.cpp
@@ -389,20 +389,25 @@
 
   if (inst->words().size() > 6) {
     const uint32_t cluster_size_op_id = inst->GetOperandAs<uint32_t>(5);
-    const uint32_t cluster_size_type = _.GetTypeId(cluster_size_op_id);
+    const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id);
+    const uint32_t cluster_size_type =
+        cluster_size_inst ? cluster_size_inst->type_id() : 0;
     if (!_.IsUnsignedIntScalarType(cluster_size_type)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "ClusterSize must be a scalar of integer type, whose "
                 "Signedness operand is 0.";
     }
 
-    uint64_t cluster_size;
-    if (!_.GetConstantValUint64(cluster_size_op_id, &cluster_size)) {
+    if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "ClusterSize must come from a constant instruction.";
     }
 
-    if ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0)) {
+    uint64_t cluster_size;
+    const bool valid_const =
+        _.EvalConstantValUint64(cluster_size_op_id, &cluster_size);
+    if (valid_const &&
+        ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) {
       return _.diag(SPV_WARNING, inst)
              << "Behavior is undefined unless ClusterSize is at least 1 and a "
                 "power of 2.";
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 7edd12f..cb26a52 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -24,21 +24,6 @@
 namespace val {
 namespace {
 
-// Returns, as an int64_t, the literal value from an OpConstant or the
-// default value of an OpSpecConstant, assuming it is an integral type.
-// For signed integers, relies the rule that literal value is sign extended
-// to fill out to word granularity.  Assumes that the constant value
-// has
-int64_t ConstantLiteralAsInt64(uint32_t width,
-                               const std::vector<uint32_t>& const_words) {
-  const uint32_t lo_word = const_words[3];
-  if (width <= 32) return int32_t(lo_word);
-  assert(width <= 64);
-  assert(const_words.size() > 4);
-  const uint32_t hi_word = const_words[4];  // Must exist, per spec.
-  return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
-}
-
 // Validates that type declarations are unique, unless multiple declarations
 // of the same data type are allowed by the specification.
 // (see section 2.8 Types and Variables)
@@ -252,29 +237,17 @@
            << " is not a constant integer type.";
   }
 
-  switch (length->opcode()) {
-    case spv::Op::OpSpecConstant:
-    case spv::Op::OpConstant: {
-      auto& type_words = const_result_type->words();
-      const bool is_signed = type_words[3] > 0;
-      const uint32_t width = type_words[2];
-      const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
-      if (ivalue == 0 || (ivalue < 0 && is_signed)) {
-        return _.diag(SPV_ERROR_INVALID_ID, inst)
-               << "OpTypeArray Length <id> " << _.getIdName(length_id)
-               << " default value must be at least 1: found " << ivalue;
-      }
-    } break;
-    case spv::Op::OpConstantNull:
+  int64_t length_value;
+  if (_.EvalConstantValInt64(length_id, &length_value)) {
+    auto& type_words = const_result_type->words();
+    const bool is_signed = type_words[3] > 0;
+    if (length_value == 0 || (length_value < 0 && is_signed)) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpTypeArray Length <id> " << _.getIdName(length_id)
-             << " default value must be at least 1.";
-    case spv::Op::OpSpecConstantOp:
-      // Assume it's OK, rather than try to evaluate the operation.
-      break;
-    default:
-      assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
+             << " default value must be at least 1: found " << length_value;
+    }
   }
+
   return SPV_SUCCESS;
 }
 
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 25b374d..fa5ae3e 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1209,7 +1209,7 @@
   if (!IsCooperativeMatrixKHRType(id)) return false;
   const Instruction* inst = FindDef(id);
   uint64_t matrixUse = 0;
-  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+  if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
     return matrixUse ==
            static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
   }
@@ -1220,7 +1220,7 @@
   if (!IsCooperativeMatrixKHRType(id)) return false;
   const Instruction* inst = FindDef(id);
   uint64_t matrixUse = 0;
-  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+  if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
     return matrixUse ==
            static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
   }
@@ -1230,7 +1230,7 @@
   if (!IsCooperativeMatrixKHRType(id)) return false;
   const Instruction* inst = FindDef(id);
   uint64_t matrixUse = 0;
-  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+  if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
     return matrixUse == static_cast<uint64_t>(
                             spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
   }
@@ -1340,20 +1340,23 @@
   return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
 }
 
-bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
+bool ValidationState_t::EvalConstantValUint64(uint32_t id,
+                                              uint64_t* val) const {
   const Instruction* inst = FindDef(id);
   if (!inst) {
     assert(0 && "Instruction not found");
     return false;
   }
 
-  if (inst->opcode() != spv::Op::OpConstant &&
-      inst->opcode() != spv::Op::OpSpecConstant)
-    return false;
-
   if (!IsIntScalarType(inst->type_id())) return false;
 
-  if (inst->words().size() == 4) {
+  if (inst->opcode() == spv::Op::OpConstantNull) {
+    *val = 0;
+  } else if (inst->opcode() != spv::Op::OpConstant) {
+    // Spec constant values cannot be evaluated so don't consider constant for
+    // static validation
+    return false;
+  } else if (inst->words().size() == 4) {
     *val = inst->word(3);
   } else {
     assert(inst->words().size() == 5);
@@ -1363,6 +1366,32 @@
   return true;
 }
 
+bool ValidationState_t::EvalConstantValInt64(uint32_t id, int64_t* val) const {
+  const Instruction* inst = FindDef(id);
+  if (!inst) {
+    assert(0 && "Instruction not found");
+    return false;
+  }
+
+  if (!IsIntScalarType(inst->type_id())) return false;
+
+  if (inst->opcode() == spv::Op::OpConstantNull) {
+    *val = 0;
+  } else if (inst->opcode() != spv::Op::OpConstant) {
+    // Spec constant values cannot be evaluated so don't consider constant for
+    // static validation
+    return false;
+  } else if (inst->words().size() == 4) {
+    *val = int32_t(inst->word(3));
+  } else {
+    assert(inst->words().size() == 5);
+    const uint32_t lo_word = inst->word(3);
+    const uint32_t hi_word = inst->word(4);
+    *val = static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
+  }
+  return true;
+}
+
 std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
     uint32_t id) const {
   const Instruction* const inst = FindDef(id);
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 46a8cbf..27acdcc 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -648,10 +648,6 @@
                     const std::function<bool(const Instruction*)>& f,
                     bool traverse_all_types = true) const;
 
-  // Gets value from OpConstant and OpSpecConstant as uint64.
-  // Returns false on failure (no instruction, wrong instruction, not int).
-  bool GetConstantValUint64(uint32_t id, uint64_t* val) const;
-
   // Returns type_id if id has type or zero otherwise.
   uint32_t GetTypeId(uint32_t id) const;
 
@@ -726,6 +722,14 @@
     pointer_to_storage_image_.insert(type_id);
   }
 
+  // Tries to evaluate a any scalar integer OpConstant as uint64.
+  // OpConstantNull is defined as zero for scalar int (will return true)
+  // OpSpecConstant* return false since their values cannot be relied upon
+  // during validation.
+  bool EvalConstantValUint64(uint32_t id, uint64_t* val) const;
+  // Same as EvalConstantValUint64 but returns a signed int
+  bool EvalConstantValInt64(uint32_t id, int64_t* val) const;
+
   // Tries to evaluate a 32-bit signed or unsigned scalar integer constant.
   // Returns tuple <is_int32, is_const_int32, value>.
   // OpSpecConstant* return |is_const_int32| as false since their values cannot
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 7acac56..e236134 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1056,7 +1056,7 @@
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
               HasSubstr(make_message("OpTypeArray Length <id> '2[%2]' default "
-                                     "value must be at least 1.")));
+                                     "value must be at least 1: found 0")));
 }
 
 TEST_P(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) {