Basic validation for Component decorations (#2679)


* Add basic validation for Component decoration
* Add validator tests for Component decoration

diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp
index a6a7ce6..5174600 100644
--- a/source/val/validate_decorations.cpp
+++ b/source/val/validate_decorations.cpp
@@ -1393,6 +1393,81 @@
          << spvOpcodeString(inst.opcode());
 }
 
+// Returns SPV_SUCCESS if validation rules are satisfied for the Component
+// decoration.  Otherwise emits a diagnostic and returns something other than
+// SPV_SUCCESS.
+spv_result_t CheckComponentDecoration(ValidationState_t& vstate,
+                                      const Instruction& inst,
+                                      const Decoration& decoration) {
+  assert(inst.id() && "Parser ensures the target of the decoration has an ID");
+
+  uint32_t type_id;
+  if (decoration.struct_member_index() == Decoration::kInvalidMember) {
+    // The target must be a memory object declaration.
+    const auto opcode = inst.opcode();
+    if (opcode != SpvOpVariable && opcode != SpvOpFunctionParameter) {
+      return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+             << "Target of Component decoration must be a memory object "
+                "declaration (a variable or a function parameter)";
+    }
+
+    // Only valid for the Input and Output Storage Classes.
+    const auto storage_class = opcode == SpvOpVariable
+                                   ? inst.GetOperandAs<SpvStorageClass>(2)
+                                   : SpvStorageClassMax;
+    if (storage_class != SpvStorageClassInput &&
+        storage_class != SpvStorageClassOutput &&
+        storage_class != SpvStorageClassMax) {
+      return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+             << "Target of Component decoration is invalid: must point to a "
+                "Storage Class of Input(1) or Output(3). Found Storage "
+                "Class "
+             << storage_class;
+    }
+
+    type_id = inst.type_id();
+    if (vstate.IsPointerType(type_id)) {
+      const auto pointer = vstate.FindDef(type_id);
+      type_id = pointer->GetOperandAs<uint32_t>(2);
+    }
+  } else {
+    if (inst.opcode() != SpvOpTypeStruct) {
+      return vstate.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << "Attempted to get underlying data type via member index for "
+                "non-struct type.";
+    }
+    type_id = inst.word(decoration.struct_member_index() + 2);
+  }
+
+  if (spvIsVulkanEnv(vstate.context()->target_env)) {
+    if (!vstate.IsIntScalarOrVectorType(type_id) &&
+        !vstate.IsFloatScalarOrVectorType(type_id)) {
+      return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+             << "Component decoration specified for type "
+             << vstate.getIdName(type_id) << " that is not a scalar or vector";
+    }
+
+    // For 16-, and 32-bit types, it is invalid if this sequence of components
+    // gets larger than 3.
+    const auto bit_width = vstate.GetBitWidth(type_id);
+    if (bit_width == 16 || bit_width == 32) {
+      assert(decoration.params().size() == 1 &&
+             "Grammar ensures Component has one parameter");
+
+      const auto component = decoration.params()[0];
+      const auto last_component = component + vstate.GetDimension(type_id) - 1;
+      if (last_component > 3) {
+        return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+               << "Sequence of components starting with " << component
+               << " and ending with " << last_component
+               << " gets larger than 3";
+      }
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
 #define PASS_OR_BAIL_AT_LINE(X, LINE)           \
   {                                             \
     spv_result_t e##LINE = (X);                 \
@@ -1421,6 +1496,9 @@
 
     for (const auto& decoration : decorations) {
       switch (decoration.dec_type()) {
+        case SpvDecorationComponent:
+          PASS_OR_BAIL(CheckComponentDecoration(vstate, *inst, decoration));
+          break;
         case SpvDecorationFPRoundingMode:
           if (is_shader)
             PASS_OR_BAIL(CheckFPRoundingModeForShaders(vstate, *inst));
diff --git a/test/val/val_capability_test.cpp b/test/val/val_capability_test.cpp
index 4fb2a7c..5a6e751 100644
--- a/test/val/val_capability_test.cpp
+++ b/test/val/val_capability_test.cpp
@@ -1168,8 +1168,10 @@
           ShaderDependencies()),
 std::make_pair(std::string(kOpenCLMemoryModel) +
           "OpEntryPoint Kernel %func \"compute\" \n"
-          "OpDecorate %intt Component 0\n"
-          "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid),
+          "OpDecorate %var Component 0\n"
+          "%intt = OpTypeInt 32 0\n"
+          "%ptr = OpTypePointer Input %intt\n"
+          "%var = OpVariable %ptr Input\n" + std::string(kVoidFVoid),
           ShaderDependencies()),
 std::make_pair(std::string(kOpenCLMemoryModel) +
           "OpEntryPoint Kernel %func \"compute\" \n"
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp
index c454bed..5ea1d87 100644
--- a/test/val/val_decoration_test.cpp
+++ b/test/val/val_decoration_test.cpp
@@ -6400,6 +6400,314 @@
                         "requires SPIR-V version 1.3 or earlier"));
 }
 
+// Component
+
+TEST_F(ValidateDecorations, ComponentDecorationBadTarget) {
+  std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %main "main"
+OpDecorate %t Component 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%t = OpTypeVector %float 2
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Target of Component decoration must be "
+                        "a memory object declaration"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationBadStorageClass) {
+  std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %main "main"
+OpDecorate %v Component 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%t = OpTypeVector %float 2
+%ptr_private = OpTypePointer Private %t
+%v = OpVariable %ptr_private Private
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Target of Component decoration is invalid: must "
+                        "point to a Storage Class of Input(1) or Output(3)"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationBadTypeVulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv = R"(
+OpCapability Shader
+OpCapability Matrix
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %main "main"
+OpDecorate %v Component 0
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%vtype = OpTypeVector %float 4
+%t = OpTypeMatrix %vtype 4
+%ptr_input = OpTypePointer Input %t
+%v = OpVariable %ptr_input Input
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Component decoration specified for type"));
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a scalar or vector"));
+}
+
+std::string ShaderWithComponentDecoration(const std::string& type,
+                                          const std::string& decoration) {
+  return R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %entryPointOutput
+OpExecutionMode %main OriginUpperLeft
+OpDecorate %entryPointOutput Location 0
+OpDecorate %entryPointOutput )" +
+         decoration + R"(
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%vtype = )" + type + R"(
+%float_0 = OpConstant %float 0
+%_ptr_Output_vtype = OpTypePointer Output %vtype
+%entryPointOutput = OpVariable %_ptr_Output_vtype Output
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationIntGood0Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeInt 32 0", "Component 0");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationIntGood1Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeInt 32 0", "Component 1");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationIntGood2Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeInt 32 0", "Component 2");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationIntGood3Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeInt 32 0", "Component 3");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationIntBad4Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeInt 32 0", "Component 4");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Sequence of components starting with 4 "
+                        "and ending with 4 gets larger than 3"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationVector3GoodVulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeVector %float 3", "Component 1");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationVector4GoodVulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeVector %float 4", "Component 0");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationVector4Bad1Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeVector %float 4", "Component 1");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Sequence of components starting with 1 "
+                        "and ending with 4 gets larger than 3"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationVector4Bad3Vulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv =
+      ShaderWithComponentDecoration("OpTypeVector %float 4", "Component 3");
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Sequence of components starting with 3 "
+                        "and ending with 6 gets larger than 3"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationBlockGood) {
+  std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main" %9 %12
+OpExecutionMode %4 OriginUpperLeft
+OpDecorate %9 Location 0
+OpMemberDecorate %block 0 Location 2
+OpMemberDecorate %block 0 Component 1
+OpDecorate %block Block
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%float = OpTypeFloat 32
+%vec3 = OpTypeVector %float 3
+%8 = OpTypePointer Output %vec3
+%9 = OpVariable %8 Output
+%block = OpTypeStruct %vec3
+%11 = OpTypePointer Input %block
+%12 = OpVariable %11 Input
+%int = OpTypeInt 32 1
+%14 = OpConstant %int 0
+%15 = OpTypePointer Input %vec3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%16 = OpAccessChain %15 %12 %14
+%17 = OpLoad %vec3 %16
+OpStore %9 %17
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationBlockBadVulkan) {
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main" %9 %12
+OpExecutionMode %4 OriginUpperLeft
+OpDecorate %9 Location 0
+OpMemberDecorate %block 0 Location 2
+OpMemberDecorate %block 0 Component 2
+OpDecorate %block Block
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%float = OpTypeFloat 32
+%vec3 = OpTypeVector %float 3
+%8 = OpTypePointer Output %vec3
+%9 = OpVariable %8 Output
+%block = OpTypeStruct %vec3
+%11 = OpTypePointer Input %block
+%12 = OpVariable %11 Input
+%int = OpTypeInt 32 1
+%14 = OpConstant %int 0
+%15 = OpTypePointer Input %vec3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%16 = OpAccessChain %15 %12 %14
+%17 = OpLoad %vec3 %16
+OpStore %9 %17
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, env);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Sequence of components starting with 2 "
+                        "and ending with 4 gets larger than 3"));
+}
+
+TEST_F(ValidateDecorations, ComponentDecorationFunctionParameter) {
+  std::string spirv = R"(
+              OpCapability Shader
+              OpMemoryModel Logical GLSL450
+              OpEntryPoint Vertex %main "main"
+
+              OpDecorate %param_f Component 0
+
+      %void = OpTypeVoid
+   %void_fn = OpTypeFunction %void
+     %float = OpTypeFloat 32
+   %float_0 = OpConstant %float 0
+   %int     = OpTypeInt 32 0
+   %int_2   = OpConstant %int 2
+  %struct_b = OpTypeStruct %float
+
+%extra_fn = OpTypeFunction %void %float
+
+  %helper = OpFunction %void None %extra_fn
+ %param_f = OpFunctionParameter %float
+%helper_label = OpLabel
+            OpReturn
+            OpFunctionEnd
+
+    %main = OpFunction %void None %void_fn
+   %label = OpLabel
+            OpReturn
+            OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools