WebGPU: Array size at most max signed int + 1 (#3077)


This makes it easier to clamp indices for robust-buffer-access
behaviour.

See https://github.com/gpuweb/spirv-execution-env/issues/47
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 4d673b4..1f171cf 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -19,29 +19,40 @@
 #include "source/val/instruction.h"
 #include "source/val/validate.h"
 #include "source/val/validation_state.h"
+#include "spirv/unified1/spirv.h"
 
 namespace spvtools {
 namespace val {
 namespace {
 
-// True if the integer constant is > 0. |const_words| are words of the
-// constant-defining instruction (either OpConstant or
-// OpSpecConstant). typeWords are the words of the constant's-type-defining
-// OpTypeInt.
-bool AboveZero(const std::vector<uint32_t>& const_words,
-               const std::vector<uint32_t>& type_words) {
-  const uint32_t width = type_words[2];
-  const bool is_signed = type_words[3] > 0;
+// Returns, as an int64_t, the literal value from an OpConstant or the
+// default value of an OpSpecConstant, assuming it is an integral type.
+// For signed integers, relies the rule that literal value is sign extended
+// to fill out to word granularity.  Assumes that the constant value
+// has
+int64_t ConstantLiteralAsInt64(uint32_t width,
+                               const std::vector<uint32_t>& const_words) {
   const uint32_t lo_word = const_words[3];
-  if (width > 32) {
-    // The spec currently doesn't allow integers wider than 64 bits.
-    const uint32_t hi_word = const_words[4];  // Must exist, per spec.
-    if (is_signed && (hi_word >> 31)) return false;
-    return (lo_word | hi_word) > 0;
-  } else {
-    if (is_signed && (lo_word >> 31)) return false;
-    return lo_word > 0;
-  }
+  if (width <= 32) return int32_t(lo_word);
+  assert(width <= 64);
+  assert(const_words.size() > 4);
+  const uint32_t hi_word = const_words[4];  // Must exist, per spec.
+  return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
+}
+
+// Returns, as an uint64_t, the literal value from an OpConstant or the
+// default value of an OpSpecConstant, assuming it is an integral type.
+// For signed integers, relies the rule that literal value is sign extended
+// to fill out to word granularity.  Assumes that the constant value
+// has
+int64_t ConstantLiteralAsUint64(uint32_t width,
+                                const std::vector<uint32_t>& const_words) {
+  const uint32_t lo_word = const_words[3];
+  if (width <= 32) return lo_word;
+  assert(width <= 64);
+  assert(const_words.size() > 4);
+  const uint32_t hi_word = const_words[4];  // Must exist, per spec.
+  return (uint64_t(lo_word) | uint64_t(hi_word) << 32);
 }
 
 // Validates that type declarations are unique, unless multiple declarations
@@ -258,14 +269,33 @@
 
   switch (length->opcode()) {
     case SpvOpSpecConstant:
-    case SpvOpConstant:
-      if (AboveZero(length->words(), const_result_type->words())) break;
-    // Else fall through!
-    case SpvOpConstantNull: {
+    case SpvOpConstant: {
+      auto& type_words = const_result_type->words();
+      const bool is_signed = type_words[3] > 0;
+      const uint32_t width = type_words[2];
+      const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
+      if (ivalue == 0 || (ivalue < 0 && is_signed)) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "OpTypeArray Length <id> '" << _.getIdName(length_id)
+               << "' default value must be at least 1: found " << ivalue;
+      }
+      if (spvIsWebGPUEnv(_.context()->target_env)) {
+        // WebGPU has maximum integer width of 32 bits, and max array size
+        // is one more than the max signed integer representation.
+        const uint64_t max_permitted = (uint64_t(1) << 31);
+        const uint64_t uvalue = ConstantLiteralAsUint64(width, length->words());
+        if (uvalue > max_permitted) {
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
+                 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
+                 << "' size exceeds max value " << max_permitted
+                 << " permitted by WebGPU: got " << uvalue;
+        }
+      }
+    } break;
+    case SpvOpConstantNull:
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpTypeArray Length <id> '" << _.getIdName(length_id)
              << "' default value must be at least 1.";
-    }
     case SpvOpSpecConstantOp:
       // Assume it's OK, rather than try to evaluate the operation.
       break;
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 327aef1..019d91a 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -749,20 +749,40 @@
 // Signed or unsigned.
 enum Signed { kSigned, kUnsigned };
 
-// Creates an assembly snippet declaring OpTypeArray with the given length.
-std::string MakeArrayLength(const std::string& len, Signed isSigned,
-                            int width) {
+// Creates an assembly module declaring OpTypeArray with the given length.
+std::string MakeArrayLength(const std::string& len, Signed isSigned, int width,
+                            int max_int_width = 64,
+                            bool use_vulkan_memory_model = false) {
   std::ostringstream ss;
   ss << R"(
     OpCapability Shader
-    OpCapability Linkage
-    OpCapability Int16
-    OpCapability Int64
   )";
-  ss << "OpMemoryModel Logical GLSL450\n";
+  if (use_vulkan_memory_model) {
+    ss << " OpCapability VulkanMemoryModel\n";
+  }
+  if (width == 16) {
+    ss << " OpCapability Int16\n";
+  }
+  if (max_int_width > 32) {
+    ss << "\n  OpCapability Int64\n";
+  }
+  if (use_vulkan_memory_model) {
+    ss << " OpExtension \"SPV_KHR_vulkan_memory_model\"\n";
+    ss << "OpMemoryModel Logical Vulkan\n";
+  } else {
+    ss << "OpMemoryModel Logical GLSL450\n";
+  }
+  ss << "OpEntryPoint GLCompute %main \"main\"\n";
+  ss << "OpExecutionMode %main LocalSize 1 1 1\n";
   ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0");
   ss << " %l = OpConstant %t " << len;
   ss << " %a = OpTypeArray %t %l";
+  ss << " %void = OpTypeVoid \n"
+        " %voidfn = OpTypeFunction %void \n"
+        " %main = OpFunction %void None %voidfn \n"
+        " %entry = OpLabel\n"
+        " OpReturn\n"
+        " OpFunctionEnd\n";
   return ss.str();
 }
 
@@ -772,7 +792,8 @@
     : public spvtest::TextToBinaryTestBase<::testing::TestWithParam<int>> {
  protected:
   OpTypeArrayLengthTest()
-      : position_(spv_position_t{0, 0, 0}),
+      : env_(SPV_ENV_UNIVERSAL_1_0),
+        position_(spv_position_t{0, 0, 0}),
         diagnostic_(spvDiagnosticCreate(&position_, "")) {}
 
   ~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); }
@@ -783,7 +804,7 @@
     spvDiagnosticDestroy(diagnostic_);
     diagnostic_ = nullptr;
     const auto status =
-        spvValidate(ScopedContext().context, &cbinary, &diagnostic_);
+        spvValidate(ScopedContext(env_).context, &cbinary, &diagnostic_);
     if (status != SPV_SUCCESS) {
       spvDiagnosticPrint(diagnostic_);
       EXPECT_THAT(std::string(diagnostic_->error),
@@ -792,12 +813,15 @@
     return status;
   }
 
+ protected:
+  spv_target_env env_;
+
  private:
   spv_position_t position_;  // For creating diagnostic_.
   spv_diagnostic diagnostic_;
 };
 
-TEST_P(OpTypeArrayLengthTest, LengthPositive) {
+TEST_P(OpTypeArrayLengthTest, LengthPositiveSmall) {
   const int width = GetParam();
   EXPECT_EQ(SPV_SUCCESS,
             Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width))));
@@ -814,20 +838,19 @@
   const std::string fpad(width / 4 - 1, 'F');
   EXPECT_EQ(
       SPV_SUCCESS,
-      Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width))));
-  EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully(
-                             MakeArrayLength("0xF" + fpad, kUnsigned, width))));
+      Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width))))
+      << MakeArrayLength("0x7" + fpad, kSigned, width);
 }
 
 TEST_P(OpTypeArrayLengthTest, LengthZero) {
   const int width = GetParam();
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
 }
 
@@ -835,23 +858,88 @@
   const int width = GetParam();
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
   const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0');
   EXPECT_EQ(SPV_ERROR_INVALID_ID,
             Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)),
-                "OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
+                "OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
                 "least 1."));
 }
 
+// Returns the string form of an integer of the form 0x80....0 of the
+// given bit width.
+std::string big_num_ending_0(int bit_width) {
+  return "0x8" + std::string(bit_width / 4 - 1, '0');
+}
+
+// Returns the string form of an integer of the form 0x80..001 of the
+// given bit width.
+std::string big_num_ending_1(int bit_width) {
+  return "0x8" + std::string(bit_width / 4 - 2, '0') + "1";
+}
+
+TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InVulkan) {
+  env_ = SPV_ENV_VULKAN_1_0;
+  const int width = GetParam();
+  for (int max_int_width : {32, 64}) {
+    if (width > max_int_width) {
+      // Not valid to even make the OpConstant in this case.
+      continue;
+    }
+    const auto module = CompileSuccessfully(MakeArrayLength(
+        big_num_ending_0(width), kUnsigned, width, max_int_width));
+    EXPECT_EQ(SPV_SUCCESS, Val(module));
+  }
+}
+
+TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InVulkan) {
+  env_ = SPV_ENV_VULKAN_1_0;
+  const int width = GetParam();
+  for (int max_int_width : {32, 64}) {
+    if (width > max_int_width) {
+      // Not valid to even make the OpConstant in this case.
+      continue;
+    }
+    const auto module = CompileSuccessfully(MakeArrayLength(
+        big_num_ending_1(width), kUnsigned, width, max_int_width));
+    EXPECT_EQ(SPV_SUCCESS, Val(module));
+  }
+}
+
+TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InWebGPU) {
+  env_ = SPV_ENV_WEBGPU_0;
+  const int width = GetParam();
+  // WebGPU only has 32 bit integers.
+  if (width != 32) return;
+  const int max_int_width = 32;
+  const auto module = CompileSuccessfully(MakeArrayLength(
+      big_num_ending_0(width), kUnsigned, width, max_int_width, true));
+  EXPECT_EQ(SPV_SUCCESS, Val(module));
+}
+
+TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InWebGPU) {
+  env_ = SPV_ENV_WEBGPU_0;
+  const int width = GetParam();
+  // WebGPU only has 32 bit integers.
+  if (width != 32) return;
+  const int max_int_width = 32;
+  const auto module = CompileSuccessfully(MakeArrayLength(
+      big_num_ending_1(width), kUnsigned, width, max_int_width, true));
+  EXPECT_EQ(SPV_ERROR_INVALID_ID,
+            Val(module,
+                "OpTypeArray Length <id> '3\\[%.*\\]' size exceeds max value "
+                "2147483648 permitted by WebGPU: got 2147483649"));
+}
+
 // The only valid widths for integers are 8, 16, 32, and 64.
 // Since the Int8 capability requires the Kernel capability, and the Kernel
 // capability prohibits usage of signed integers, we can skip 8-bit integers