Add validation for Subgroup builtins (#2637)

Fixes #2611

* Validates builtins in the Vulkan environment:
  * NumSubgroups
  * SubgroupId
  * SubgroupEqMask
  * SubgroupGeMask
  * SubgroupGtMask
  * SubgroupLeMask
  * SubgroupLtMask
  * SubgroupLocalInvocationId
  * SubgroupSize
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 0656065..a9bf716 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -211,6 +211,17 @@
   spv_result_t ValidateSMBuiltinsAtDefinition(const Decoration& decoration,
                                               const Instruction& inst);
 
+  // Used for SubgroupEqMask, SubgroupGeMask, SubgroupGtMask, SubgroupLtMask,
+  // SubgroupLeMask.
+  spv_result_t ValidateI32Vec4InputAtDefinition(const Decoration& decoration,
+                                                const Instruction& inst);
+  // Used for SubgroupLocalInvocationId, SubgroupSize.
+  spv_result_t ValidateI32InputAtDefinition(const Decoration& decoration,
+                                            const Instruction& inst);
+  // Used for SubgroupId, NumSubgroups.
+  spv_result_t ValidateComputeI32InputAtDefinition(const Decoration& decoration,
+                                                   const Instruction& inst);
+
   // The following section contains functions which are called when id defined
   // by |referenced_inst| is
   // 1. referenced by |referenced_from_inst|
@@ -333,6 +344,11 @@
       const Decoration& decoration, const Instruction& built_in_inst,
       const Instruction& referenced_inst,
       const Instruction& referenced_from_inst);
+  // Used for SubgroupId and NumSubgroups.
+  spv_result_t ValidateComputeI32InputAtReference(
+      const Decoration& decoration, const Instruction& built_in_inst,
+      const Instruction& referenced_inst,
+      const Instruction& referenced_from_inst);
 
   spv_result_t ValidateSMBuiltinsAtReference(
       const Decoration& decoration, const Instruction& built_in_inst,
@@ -2603,6 +2619,171 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t BuiltInsValidator::ValidateComputeI32InputAtDefinition(
+    const Decoration& decoration, const Instruction& inst) {
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (decoration.struct_member_index() != Decoration::kInvalidMember) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << "BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " cannot be used as a member decoration ";
+    }
+    if (spv_result_t error = ValidateI32(
+            decoration, inst,
+            [this, &decoration,
+             &inst](const std::string& message) -> spv_result_t {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << "According to the "
+                     << spvLogStringForEnv(_.context()->target_env)
+                     << " spec BuiltIn "
+                     << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                      decoration.params()[0])
+                     << " variable needs to be a 32-bit int "
+                        "vector. "
+                     << message;
+            })) {
+      return error;
+    }
+  }
+
+  // Seed at reference checks with this built-in.
+  return ValidateComputeI32InputAtReference(decoration, inst, inst, inst);
+}
+
+spv_result_t BuiltInsValidator::ValidateComputeI32InputAtReference(
+    const Decoration& decoration, const Instruction& built_in_inst,
+    const Instruction& referenced_inst,
+    const Instruction& referenced_from_inst) {
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst);
+    if (storage_class != SpvStorageClassMax &&
+        storage_class != SpvStorageClassInput) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+             << spvLogStringForEnv(_.context()->target_env)
+             << " spec allows BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " to be only used for variables with Input storage class. "
+             << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+                                 referenced_from_inst)
+             << " " << GetStorageClassDesc(referenced_from_inst);
+    }
+
+    for (const SpvExecutionModel execution_model : execution_models_) {
+      bool has_vulkan_model = execution_model == SpvExecutionModelGLCompute ||
+                              execution_model == SpvExecutionModelTaskNV ||
+                              execution_model == SpvExecutionModelMeshNV;
+      if (spvIsVulkanEnv(_.context()->target_env) && !has_vulkan_model) {
+        return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+               << spvLogStringForEnv(_.context()->target_env)
+               << " spec allows BuiltIn "
+               << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                decoration.params()[0])
+               << " to be used only with GLCompute execution model. "
+               << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+                                   referenced_from_inst, execution_model);
+      }
+    }
+  }
+
+  if (function_id_ == 0) {
+    // Propagate this rule to all dependant ids in the global scope.
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+        std::bind(&BuiltInsValidator::ValidateComputeI32InputAtReference, this,
+                  decoration, built_in_inst, referenced_from_inst,
+                  std::placeholders::_1));
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t BuiltInsValidator::ValidateI32InputAtDefinition(
+    const Decoration& decoration, const Instruction& inst) {
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (decoration.struct_member_index() != Decoration::kInvalidMember) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << "BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " cannot be used as a member decoration ";
+    }
+    if (spv_result_t error = ValidateI32(
+            decoration, inst,
+            [this, &decoration,
+             &inst](const std::string& message) -> spv_result_t {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << "According to the "
+                     << spvLogStringForEnv(_.context()->target_env)
+                     << " spec BuiltIn "
+                     << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                      decoration.params()[0])
+                     << " variable needs to be a 32-bit int. " << message;
+            })) {
+      return error;
+    }
+
+    const SpvStorageClass storage_class = GetStorageClass(inst);
+    if (storage_class != SpvStorageClassMax &&
+        storage_class != SpvStorageClassInput) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << spvLogStringForEnv(_.context()->target_env)
+             << " spec allows BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " to be only used for variables with Input storage class. "
+             << GetReferenceDesc(decoration, inst, inst, inst) << " "
+             << GetStorageClassDesc(inst);
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t BuiltInsValidator::ValidateI32Vec4InputAtDefinition(
+    const Decoration& decoration, const Instruction& inst) {
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (decoration.struct_member_index() != Decoration::kInvalidMember) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << "BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " cannot be used as a member decoration ";
+    }
+    if (spv_result_t error = ValidateI32Vec(
+            decoration, inst, 4,
+            [this, &decoration,
+             &inst](const std::string& message) -> spv_result_t {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << "According to the "
+                     << spvLogStringForEnv(_.context()->target_env)
+                     << " spec BuiltIn "
+                     << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                      decoration.params()[0])
+                     << " variable needs to be a 4-component 32-bit int "
+                        "vector. "
+                     << message;
+            })) {
+      return error;
+    }
+
+    const SpvStorageClass storage_class = GetStorageClass(inst);
+    if (storage_class != SpvStorageClassMax &&
+        storage_class != SpvStorageClassInput) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << spvLogStringForEnv(_.context()->target_env)
+             << " spec allows BuiltIn "
+             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                              decoration.params()[0])
+             << " to be only used for variables with Input storage class. "
+             << GetReferenceDesc(decoration, inst, inst, inst) << " "
+             << GetStorageClassDesc(inst);
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
 spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition(
     const Decoration& decoration, const Instruction& inst) {
   if (spvIsVulkanOrWebGPUEnv(_.context()->target_env)) {
@@ -2788,6 +2969,21 @@
     case SpvBuiltInSamplePosition: {
       return ValidateSamplePositionAtDefinition(decoration, inst);
     }
+    case SpvBuiltInSubgroupId:
+    case SpvBuiltInNumSubgroups: {
+      return ValidateComputeI32InputAtDefinition(decoration, inst);
+    }
+    case SpvBuiltInSubgroupLocalInvocationId:
+    case SpvBuiltInSubgroupSize: {
+      return ValidateI32InputAtDefinition(decoration, inst);
+    }
+    case SpvBuiltInSubgroupEqMask:
+    case SpvBuiltInSubgroupGeMask:
+    case SpvBuiltInSubgroupGtMask:
+    case SpvBuiltInSubgroupLeMask:
+    case SpvBuiltInSubgroupLtMask: {
+      return ValidateI32Vec4InputAtDefinition(decoration, inst);
+    }
     case SpvBuiltInTessCoord: {
       return ValidateTessCoordAtDefinition(decoration, inst);
     }
@@ -2821,17 +3017,8 @@
     case SpvBuiltInEnqueuedWorkgroupSize:
     case SpvBuiltInGlobalOffset:
     case SpvBuiltInGlobalLinearId:
-    case SpvBuiltInSubgroupSize:
     case SpvBuiltInSubgroupMaxSize:
-    case SpvBuiltInNumSubgroups:
     case SpvBuiltInNumEnqueuedSubgroups:
-    case SpvBuiltInSubgroupId:
-    case SpvBuiltInSubgroupLocalInvocationId:
-    case SpvBuiltInSubgroupEqMaskKHR:
-    case SpvBuiltInSubgroupGeMaskKHR:
-    case SpvBuiltInSubgroupGtMaskKHR:
-    case SpvBuiltInSubgroupLeMaskKHR:
-    case SpvBuiltInSubgroupLtMaskKHR:
     case SpvBuiltInBaseVertex:
     case SpvBuiltInBaseInstance:
     case SpvBuiltInDrawIndex:
diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp
index df048ad..e3eca09 100644
--- a/test/val/val_builtins_test.cpp
+++ b/test/val/val_builtins_test.cpp
@@ -52,6 +52,8 @@
 using ::testing::ValuesIn;
 
 using ValidateBuiltIns = spvtest::ValidateBase<bool>;
+using ValidateVulkanSubgroupBuiltIns = spvtest::ValidateBase<
+    std::tuple<const char*, const char*, const char*, const char*, TestResult>>;
 using ValidateVulkanCombineBuiltInExecutionModelDataTypeResult =
     spvtest::ValidateBase<std::tuple<const char*, const char*, const char*,
                                      const char*, TestResult>>;
@@ -3080,6 +3082,226 @@
                         "type"));
 }
 
+TEST_P(ValidateVulkanSubgroupBuiltIns, InMain) {
+  const char* const built_in = std::get<0>(GetParam());
+  const char* const execution_model = std::get<1>(GetParam());
+  const char* const storage_class = std::get<2>(GetParam());
+  const char* const data_type = std::get<3>(GetParam());
+  const TestResult& test_result = std::get<4>(GetParam());
+
+  CodeGenerator generator = CodeGenerator::GetDefaultShaderCodeGenerator();
+  generator.capabilities_ += R"(
+OpCapability GroupNonUniformBallot
+)";
+
+  generator.before_types_ = "OpDecorate %built_in_var BuiltIn ";
+  generator.before_types_ += built_in;
+  generator.before_types_ += "\n";
+
+  std::ostringstream after_types;
+  after_types << "%built_in_ptr = OpTypePointer " << storage_class << " "
+              << data_type << "\n";
+  after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class;
+  after_types << "\n";
+  generator.after_types_ = after_types.str();
+
+  EntryPoint entry_point;
+  entry_point.name = "main";
+  entry_point.execution_model = execution_model;
+  if (strncmp(storage_class, "Input", 5) == 0 ||
+      strncmp(storage_class, "Output", 6) == 0) {
+    entry_point.interfaces = "%built_in_var";
+  }
+  entry_point.body =
+      std::string("%ld = OpLoad ") + data_type + " %built_in_var\n";
+
+  std::ostringstream execution_modes;
+  if (0 == std::strcmp(execution_model, "Fragment")) {
+    execution_modes << "OpExecutionMode %" << entry_point.name
+                    << " OriginUpperLeft\n";
+    if (0 == std::strcmp(built_in, "FragDepth")) {
+      execution_modes << "OpExecutionMode %" << entry_point.name
+                      << " DepthReplacing\n";
+    }
+  }
+  if (0 == std::strcmp(execution_model, "Geometry")) {
+    execution_modes << "OpExecutionMode %" << entry_point.name
+                    << " InputPoints\n";
+    execution_modes << "OpExecutionMode %" << entry_point.name
+                    << " OutputPoints\n";
+  }
+  if (0 == std::strcmp(execution_model, "GLCompute")) {
+    execution_modes << "OpExecutionMode %" << entry_point.name
+                    << " LocalSize 1 1 1\n";
+  }
+  entry_point.execution_modes = execution_modes.str();
+
+  generator.entry_points_.push_back(std::move(entry_point));
+
+  CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_1);
+  ASSERT_EQ(test_result.validation_result,
+            ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  if (test_result.error_str) {
+    EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str));
+  }
+  if (test_result.error_str2) {
+    EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2));
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupMaskNotVec4, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupEqMask", "SubgroupGeMask", "SubgroupGtMask",
+                   "SubgroupLeMask", "SubgroupLtMask"),
+            Values("GLCompute"), Values("Input"), Values("%u32vec3"),
+            Values(TestResult(SPV_ERROR_INVALID_DATA,
+                              "needs to be a 4-component 32-bit int vector"))));
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupMaskNotU32, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupEqMask", "SubgroupGeMask", "SubgroupGtMask",
+                   "SubgroupLeMask", "SubgroupLtMask"),
+            Values("GLCompute"), Values("Input"), Values("%f32vec4"),
+            Values(TestResult(SPV_ERROR_INVALID_DATA,
+                              "needs to be a 4-component 32-bit int vector"))));
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupMaskNotInput, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupEqMask", "SubgroupGeMask", "SubgroupGtMask",
+                   "SubgroupLeMask", "SubgroupLtMask"),
+            Values("GLCompute"), Values("Output", "Workgroup", "Private"),
+            Values("%u32vec4"),
+            Values(TestResult(
+                SPV_ERROR_INVALID_DATA,
+                "to be only used for variables with Input storage class"))));
+
+INSTANTIATE_TEST_SUITE_P(SubgroupMaskOk, ValidateVulkanSubgroupBuiltIns,
+                         Combine(Values("SubgroupEqMask", "SubgroupGeMask",
+                                        "SubgroupGtMask", "SubgroupLeMask",
+                                        "SubgroupLtMask"),
+                                 Values("GLCompute"), Values("Input"),
+                                 Values("%u32vec4"),
+                                 Values(TestResult(SPV_SUCCESS, ""))));
+
+TEST_F(ValidateBuiltIns, SubgroupMaskMemberDecorate) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability GroupNonUniformBallot
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %foo "foo"
+OpExecutionMode %foo LocalSize 1 1 1
+OpMemberDecorate %struct 0 BuiltIn SubgroupEqMask
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%struct = OpTypeStruct %int
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "BuiltIn SubgroupEqMask cannot be used as a member decoration"));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupInvocationIdAndSizeNotU32, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupLocalInvocationId", "SubgroupSize"),
+            Values("GLCompute"), Values("Input"), Values("%f32"),
+            Values(TestResult(SPV_ERROR_INVALID_DATA,
+                              "needs to be a 32-bit int"))));
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupInvocationIdAndSizeNotInput, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupLocalInvocationId", "SubgroupSize"),
+            Values("GLCompute"), Values("Output", "Workgroup", "Private"),
+            Values("%u32"),
+            Values(TestResult(
+                SPV_ERROR_INVALID_DATA,
+                "to be only used for variables with Input storage class"))));
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupInvocationIdAndSizeOk, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupLocalInvocationId", "SubgroupSize"),
+            Values("GLCompute"), Values("Input"), Values("%u32"),
+            Values(TestResult(SPV_SUCCESS, ""))));
+
+TEST_F(ValidateBuiltIns, SubgroupSizeMemberDecorate) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability GroupNonUniform
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %foo "foo"
+OpExecutionMode %foo LocalSize 1 1 1
+OpMemberDecorate %struct 0 BuiltIn SubgroupSize
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%struct = OpTypeStruct %int
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("BuiltIn SubgroupSize cannot be used as a member decoration"));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupNumAndIdNotU32, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupId", "NumSubgroups"), Values("GLCompute"),
+            Values("Input"), Values("%f32"),
+            Values(TestResult(SPV_ERROR_INVALID_DATA,
+                              "needs to be a 32-bit int"))));
+
+INSTANTIATE_TEST_SUITE_P(
+    SubgroupNumAndIdNotInput, ValidateVulkanSubgroupBuiltIns,
+    Combine(Values("SubgroupId", "NumSubgroups"), Values("GLCompute"),
+            Values("Output", "Workgroup", "Private"), Values("%u32"),
+            Values(TestResult(
+                SPV_ERROR_INVALID_DATA,
+                "to be only used for variables with Input storage class"))));
+
+INSTANTIATE_TEST_SUITE_P(SubgroupNumAndIdOk, ValidateVulkanSubgroupBuiltIns,
+                         Combine(Values("SubgroupId", "NumSubgroups"),
+                                 Values("GLCompute"), Values("Input"),
+                                 Values("%u32"),
+                                 Values(TestResult(SPV_SUCCESS, ""))));
+
+TEST_F(ValidateBuiltIns, SubgroupIdMemberDecorate) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability GroupNonUniform
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %foo "foo"
+OpExecutionMode %foo LocalSize 1 1 1
+OpMemberDecorate %struct 0 BuiltIn SubgroupId
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%struct = OpTypeStruct %int
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("BuiltIn SubgroupId cannot be used as a member decoration"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools