Require ColMajor or RowMajor for matrices (#4878)

Fixes #4875

* Require that matrices in laid out structs have RowMajor or ColMajor
  set as per SPIR-V section 2.16.2 (shader validation)
diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp
index a66c6b0..4e4f108 100644
--- a/source/val/validate_decorations.cpp
+++ b/source/val/validate_decorations.cpp
@@ -660,7 +660,8 @@
 }
 
 // Returns true if all ids of given type have a specified decoration.
-bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration,
+bool checkForRequiredDecoration(uint32_t struct_id,
+                                std::function<bool(SpvDecoration)> checker,
                                 SpvOp type, ValidationState_t& vstate) {
   const auto& members = getStructMembers(struct_id, vstate);
   for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) {
@@ -668,10 +669,10 @@
     if (type != vstate.FindDef(id)->opcode()) continue;
     bool found = false;
     for (auto& dec : vstate.id_decorations(id)) {
-      if (decoration == dec.dec_type()) found = true;
+      if (checker(dec.dec_type())) found = true;
     }
     for (auto& dec : vstate.id_decorations(struct_id)) {
-      if (decoration == dec.dec_type() &&
+      if (checker(dec.dec_type()) &&
           (int)memberIdx == dec.struct_member_index()) {
         found = true;
       }
@@ -681,7 +682,7 @@
     }
   }
   for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) {
-    if (!checkForRequiredDecoration(id, decoration, type, vstate)) {
+    if (!checkForRequiredDecoration(id, checker, type, vstate)) {
       return false;
     }
   }
@@ -1223,30 +1224,48 @@
               return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
                      << "Structure id " << id << " decorated as " << deco_str
                      << " must not use GLSLPacked decoration.";
-            } else if (!checkForRequiredDecoration(id, SpvDecorationArrayStride,
-                                                   SpvOpTypeArray, vstate)) {
+            } else if (!checkForRequiredDecoration(
+                           id,
+                           [](SpvDecoration d) {
+                             return d == SpvDecorationArrayStride;
+                           },
+                           SpvOpTypeArray, vstate)) {
               return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
                      << "Structure id " << id << " decorated as " << deco_str
                      << " must be explicitly laid out with ArrayStride "
                         "decorations.";
-            } else if (!checkForRequiredDecoration(id,
-                                                   SpvDecorationMatrixStride,
-                                                   SpvOpTypeMatrix, vstate)) {
+            } else if (!checkForRequiredDecoration(
+                           id,
+                           [](SpvDecoration d) {
+                             return d == SpvDecorationMatrixStride;
+                           },
+                           SpvOpTypeMatrix, vstate)) {
               return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
                      << "Structure id " << id << " decorated as " << deco_str
                      << " must be explicitly laid out with MatrixStride "
                         "decorations.";
+            } else if (!checkForRequiredDecoration(
+                           id,
+                           [](SpvDecoration d) {
+                             return d == SpvDecorationRowMajor ||
+                                    d == SpvDecorationColMajor;
+                           },
+                           SpvOpTypeMatrix, vstate)) {
+              return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
+                     << "Structure id " << id << " decorated as " << deco_str
+                     << " must be explicitly laid out with RowMajor or "
+                        "ColMajor decorations.";
             } else if (blockRules &&
-                       (SPV_SUCCESS != (recursive_status = checkLayout(
-                                            id, sc_str, deco_str, true,
-                                            scalar_block_layout, 0,
-                                            constraints, vstate)))) {
+                       (SPV_SUCCESS !=
+                        (recursive_status = checkLayout(
+                             id, sc_str, deco_str, true, scalar_block_layout, 0,
+                             constraints, vstate)))) {
               return recursive_status;
             } else if (bufferRules &&
-                       (SPV_SUCCESS != (recursive_status = checkLayout(
-                                            id, sc_str, deco_str, false,
-                                            scalar_block_layout, 0,
-                                            constraints, vstate)))) {
+                       (SPV_SUCCESS !=
+                        (recursive_status = checkLayout(
+                             id, sc_str, deco_str, false, scalar_block_layout,
+                             0, constraints, vstate)))) {
               return recursive_status;
             }
           }
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp
index 7bed986..77526bf 100644
--- a/test/val/val_decoration_test.cpp
+++ b/test/val/val_decoration_test.cpp
@@ -8023,6 +8023,7 @@
 OpDecorate %block Block
 OpMemberDecorate %block 0 Offset 0
 OpMemberDecorate %block 0 MatrixStride 3
+OpMemberDecorate %block 0 ColMajor
 OpDecorate %var DescriptorSet 0
 OpDecorate %var Binding 0
 %void = OpTypeVoid
@@ -8059,6 +8060,7 @@
 OpDecorate %block Block
 OpMemberDecorate %block 0 Offset 0
 OpMemberDecorate %block 0 MatrixStride 3
+OpMemberDecorate %block 0 ColMajor
 OpDecorate %var DescriptorSet 0
 OpDecorate %var Binding 0
 %void = OpTypeVoid
@@ -8094,6 +8096,7 @@
 OpDecorate %block Block
 OpMemberDecorate %block 0 Offset 0
 OpMemberDecorate %block 0 MatrixStride 3
+OpMemberDecorate %block 0 ColMajor
 OpDecorate %var DescriptorSet 0
 OpDecorate %var Binding 0
 %void = OpTypeVoid
@@ -8130,6 +8133,7 @@
 OpDecorate %block Block
 OpMemberDecorate %block 0 Offset 0
 OpMemberDecorate %block 0 MatrixStride 3
+OpMemberDecorate %block 0 RowMajor
 OpDecorate %var DescriptorSet 0
 OpDecorate %var Binding 0
 %void = OpTypeVoid
@@ -8871,6 +8875,135 @@
                         "alignment to 16"));
 }
 
+TEST_F(ValidateDecorations, MatrixMissingMajornessUniform) {
+  const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float2 = OpTypeVector %float 2
+%matrix = OpTypeMatrix %float2 2
+%block = OpTypeStruct %matrix
+%ptr_block = OpTypePointer Uniform %block
+%var = OpVariable %ptr_block Uniform
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "must be explicitly laid out with RowMajor or ColMajor decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixMissingMajornessStorageBuffer) {
+  const std::string spirv = R"(
+OpCapability Shader
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float2 = OpTypeVector %float 2
+%matrix = OpTypeMatrix %float2 2
+%block = OpTypeStruct %matrix
+%ptr_block = OpTypePointer StorageBuffer %block
+%var = OpVariable %ptr_block StorageBuffer
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "must be explicitly laid out with RowMajor or ColMajor decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixMissingMajornessPushConstant) {
+  const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float2 = OpTypeVector %float 2
+%matrix = OpTypeMatrix %float2 2
+%block = OpTypeStruct %matrix
+%ptr_block = OpTypePointer PushConstant %block
+%var = OpVariable %ptr_block PushConstant
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "must be explicitly laid out with RowMajor or ColMajor decorations"));
+}
+
+TEST_F(ValidateDecorations, StructWithRowAndColMajor) {
+  const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+OpMemberDecorate %block 0 ColMajor
+OpMemberDecorate %block 1 Offset 32
+OpMemberDecorate %block 1 MatrixStride 16
+OpMemberDecorate %block 1 RowMajor
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float2 = OpTypeVector %float 2
+%matrix = OpTypeMatrix %float2 2
+%block = OpTypeStruct %matrix %matrix
+%ptr_block = OpTypePointer PushConstant %block
+%var = OpVariable %ptr_block PushConstant
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools