Implement the OpMatrixTimesVector linear algebra case (#3500)

This PR implements the OpMatrixTimesVector case for the
replace linear algebra instruction transformation.
diff --git a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
index 7116002..43fba52 100644
--- a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
+++ b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
@@ -44,6 +44,7 @@
     if (instruction->opcode() != SpvOpVectorTimesScalar &&
         instruction->opcode() != SpvOpMatrixTimesScalar &&
         instruction->opcode() != SpvOpVectorTimesMatrix &&
+        instruction->opcode() != SpvOpMatrixTimesVector &&
         instruction->opcode() != SpvOpDot) {
       return;
     }
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index ce95f97..6cd964a 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -1190,13 +1190,13 @@
   //   OpVectorTimesScalar
   //   OpMatrixTimesScalar
   //   OpVectorTimesMatrix
+  //   OpMatrixTimesVector
   //   OpDot
   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
   // Right now we only support certain operations. When this issue is addressed
   // the supporting comments can be removed.
   // To be supported in the future:
   //   OpTranspose
-  //   OpMatrixTimesVector
   //   OpMatrixTimesMatrix
   //   OpOuterProduct
   InstructionDescriptor instruction_descriptor = 2;
diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
index 73977ad..cebb6ef 100644
--- a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
+++ b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
@@ -48,6 +48,7 @@
   if (instruction->opcode() != SpvOpVectorTimesScalar &&
       instruction->opcode() != SpvOpMatrixTimesScalar &&
       instruction->opcode() != SpvOpVectorTimesMatrix &&
+      instruction->opcode() != SpvOpMatrixTimesVector &&
       instruction->opcode() != SpvOpDot) {
     return false;
   }
@@ -84,6 +85,9 @@
     case SpvOpVectorTimesMatrix:
       ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
       break;
+    case SpvOpMatrixTimesVector:
+      ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
+      break;
     case SpvOpDot:
       ReplaceOpDot(ir_context, linear_algebra_instruction);
       break;
@@ -152,6 +156,28 @@
               ->element_count();
       return vector_component_count * (3 * matrix_column_count + 1);
     }
+    case SpvOpMatrixTimesVector: {
+      // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
+      // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
+      // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
+      // each vector component, 1 OpCompositeExtract instruction will be
+      // inserted.
+      auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
+          instruction->GetSingleWordInOperand(0));
+      uint32_t matrix_column_count =
+          ir_context->get_type_mgr()
+              ->GetType(matrix_instruction->type_id())
+              ->AsMatrix()
+              ->element_count();
+      uint32_t matrix_row_count = ir_context->get_type_mgr()
+                                      ->GetType(matrix_instruction->type_id())
+                                      ->AsMatrix()
+                                      ->element_type()
+                                      ->AsVector()
+                                      ->element_count();
+      return 3 * matrix_column_count * matrix_row_count +
+             2 * matrix_column_count - matrix_row_count;
+    }
     case SpvOpDot:
       // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
       // will be inserted. The first two OpFMul instructions will result the
@@ -419,6 +445,121 @@
       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
 }
 
+void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
+    opt::IRContext* ir_context,
+    opt::Instruction* linear_algebra_instruction) const {
+  // Gets matrix information.
+  auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
+      linear_algebra_instruction->GetSingleWordInOperand(0));
+  uint32_t matrix_column_count = ir_context->get_type_mgr()
+                                     ->GetType(matrix_instruction->type_id())
+                                     ->AsMatrix()
+                                     ->element_count();
+  auto matrix_column_type = ir_context->get_type_mgr()
+                                ->GetType(matrix_instruction->type_id())
+                                ->AsMatrix()
+                                ->element_type();
+  uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
+
+  // Extracts matrix columns.
+  uint32_t fresh_id_index = 0;
+  std::vector<uint32_t> matrix_column_ids(matrix_column_count);
+  for (uint32_t i = 0; i < matrix_column_count; i++) {
+    matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
+    linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+        ir_context, SpvOpCompositeExtract,
+        ir_context->get_type_mgr()->GetId(matrix_column_type),
+        matrix_column_ids[i],
+        opt::Instruction::OperandList(
+            {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
+             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
+  }
+
+  // Gets vector information.
+  auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
+      linear_algebra_instruction->GetSingleWordInOperand(1));
+  auto vector_component_type = ir_context->get_type_mgr()
+                                   ->GetType(vector_instruction->type_id())
+                                   ->AsVector()
+                                   ->element_type();
+
+  // Extracts vector components.
+  std::vector<uint32_t> vector_component_ids(matrix_column_count);
+  for (uint32_t i = 0; i < matrix_column_count; i++) {
+    vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
+    linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+        ir_context, SpvOpCompositeExtract,
+        ir_context->get_type_mgr()->GetId(vector_component_type),
+        vector_component_ids[i],
+        opt::Instruction::OperandList(
+            {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
+             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
+  }
+
+  std::vector<uint32_t> result_component_ids(matrix_row_count);
+  for (uint32_t i = 0; i < matrix_row_count; i++) {
+    std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
+    for (uint32_t j = 0; j < matrix_column_count; j++) {
+      // Extracts column component.
+      uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
+      linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+          ir_context, SpvOpCompositeExtract,
+          ir_context->get_type_mgr()->GetId(vector_component_type),
+          column_extract_id,
+          opt::Instruction::OperandList(
+              {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
+               {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
+
+      // Multiplies corresponding vector and column components.
+      float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
+      linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+          ir_context, SpvOpFMul,
+          ir_context->get_type_mgr()->GetId(vector_component_type),
+          float_multiplication_ids[j],
+          opt::Instruction::OperandList(
+              {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
+               {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
+    }
+
+    // Adds the multiplication results.
+    std::vector<uint32_t> float_add_ids;
+    uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
+    float_add_ids.push_back(float_add_id);
+    linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+        ir_context, SpvOpFAdd,
+        ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
+        opt::Instruction::OperandList(
+            {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
+             {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
+    for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
+      float_add_id = message_.fresh_ids(fresh_id_index++);
+      float_add_ids.push_back(float_add_id);
+      linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+          ir_context, SpvOpFAdd,
+          ir_context->get_type_mgr()->GetId(vector_component_type),
+          float_add_id,
+          opt::Instruction::OperandList(
+              {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
+               {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
+    }
+
+    result_component_ids[i] = float_add_ids.back();
+  }
+
+  // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
+  // instruction.
+  linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
+  linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
+  linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
+  for (uint32_t i = 2; i < result_component_ids.size(); i++) {
+    linear_algebra_instruction->AddOperand(
+        {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
+  }
+
+  fuzzerutil::UpdateModuleIdBound(
+      ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
+}
+
 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
     opt::IRContext* ir_context,
     opt::Instruction* linear_algebra_instruction) const {
diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.h b/source/fuzz/transformation_replace_linear_algebra_instruction.h
index 352463c..39dc589 100644
--- a/source/fuzz/transformation_replace_linear_algebra_instruction.h
+++ b/source/fuzz/transformation_replace_linear_algebra_instruction.h
@@ -64,6 +64,10 @@
   void ReplaceOpVectorTimesMatrix(opt::IRContext* ir_context,
                                   opt::Instruction* instruction) const;
 
+  // Replaces an OpMatrixTimesVector instruction.
+  void ReplaceOpMatrixTimesVector(opt::IRContext* ir_context,
+                                  opt::Instruction* instruction) const;
+
   // Replaces an OpDot instruction.
   void ReplaceOpDot(opt::IRContext* ir_context,
                     opt::Instruction* instruction) const;
diff --git a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
index 7562262..148b1a9 100644
--- a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
+++ b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
@@ -859,6 +859,339 @@
   ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
 }
 
+TEST(TransformationReplaceLinearAlgebraInstructionTest,
+     ReplaceOpMatrixTimesVector) {
+  std::string reference_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %54 "main"
+               OpExecutionMode %54 OriginUpperLeft
+               OpSource ESSL 310
+               OpName %54 "main"
+
+; Types
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpTypeFloat 32
+          %5 = OpTypeVector %4 2
+          %6 = OpTypeVector %4 3
+          %7 = OpTypeVector %4 4
+          %8 = OpTypeMatrix %5 2
+          %9 = OpTypeMatrix %5 3
+         %10 = OpTypeMatrix %5 4
+         %11 = OpTypeMatrix %6 2
+         %12 = OpTypeMatrix %6 3
+         %13 = OpTypeMatrix %6 4
+         %14 = OpTypeMatrix %7 2
+         %15 = OpTypeMatrix %7 3
+         %16 = OpTypeMatrix %7 4
+
+; Constant scalars
+         %17 = OpConstant %4 1
+         %18 = OpConstant %4 2
+         %19 = OpConstant %4 3
+         %20 = OpConstant %4 4
+         %21 = OpConstant %4 5
+         %22 = OpConstant %4 6
+         %23 = OpConstant %4 7
+         %24 = OpConstant %4 8
+         %25 = OpConstant %4 9
+         %26 = OpConstant %4 10
+         %27 = OpConstant %4 11
+         %28 = OpConstant %4 12
+         %29 = OpConstant %4 13
+         %30 = OpConstant %4 14
+         %31 = OpConstant %4 15
+         %32 = OpConstant %4 16
+
+; Constant vectors
+         %33 = OpConstantComposite %5 %17 %18
+         %34 = OpConstantComposite %5 %19 %20
+         %35 = OpConstantComposite %5 %21 %22
+         %36 = OpConstantComposite %5 %23 %24
+         %37 = OpConstantComposite %6 %17 %18 %19
+         %38 = OpConstantComposite %6 %20 %21 %22
+         %39 = OpConstantComposite %6 %23 %24 %25
+         %40 = OpConstantComposite %6 %26 %27 %28
+         %41 = OpConstantComposite %7 %17 %18 %19 %20
+         %42 = OpConstantComposite %7 %21 %22 %23 %24
+         %43 = OpConstantComposite %7 %25 %26 %27 %28
+         %44 = OpConstantComposite %7 %29 %30 %31 %32
+
+; Constant matrices
+         %45 = OpConstantComposite %8 %33 %34
+         %46 = OpConstantComposite %9 %33 %34 %35
+         %47 = OpConstantComposite %10 %33 %34 %35 %36
+         %48 = OpConstantComposite %11 %37 %38
+         %49 = OpConstantComposite %12 %37 %38 %39
+         %50 = OpConstantComposite %13 %37 %38 %39 %40
+         %51 = OpConstantComposite %14 %41 %42
+         %52 = OpConstantComposite %15 %41 %42 %43
+         %53 = OpConstantComposite %16 %41 %42 %43 %44
+
+; main function
+         %54 = OpFunction %2 None %3
+         %55 = OpLabel
+
+; Multiplying 2x2 matrix by 2-dimensional vector
+         %56 = OpMatrixTimesVector %5 %45 %33
+
+; Multiplying 3x2 matrix by 2-dimensional vector
+         %57 = OpMatrixTimesVector %6 %48 %34
+
+; Multiplying 4x2 matrix by 2-dimensional vector
+         %58 = OpMatrixTimesVector %7 %51 %35
+
+; Multiplying 2x3 matrix by 3-dimensional vector
+         %59 = OpMatrixTimesVector %5 %46 %37
+
+; Multiplying 3x3 matrix by 3-dimensional vector
+         %60 = OpMatrixTimesVector %6 %49 %38
+
+; Multiplying 4x3 matrix by 3-dimensional vector
+         %61 = OpMatrixTimesVector %7 %52 %39
+
+; Multiplying 2x4 matrix by 4-dimensional vector
+         %62 = OpMatrixTimesVector %5 %47 %41
+
+; Multiplying 3x4 matrix by 4-dimensional vector
+         %63 = OpMatrixTimesVector %6 %50 %42
+
+; Multiplying 4x4 matrix by 4-dimensional vector
+         %64 = OpMatrixTimesVector %7 %53 %43
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context =
+      BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  auto instruction_descriptor =
+      MakeInstructionDescriptor(56, SpvOpMatrixTimesVector, 0);
+  auto transformation = TransformationReplaceLinearAlgebraInstruction(
+      {65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78},
+      instruction_descriptor);
+  transformation.Apply(context.get(), &transformation_context);
+
+  instruction_descriptor =
+      MakeInstructionDescriptor(57, SpvOpMatrixTimesVector, 0);
+  transformation = TransformationReplaceLinearAlgebraInstruction(
+      {79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
+       97},
+      instruction_descriptor);
+  transformation.Apply(context.get(), &transformation_context);
+
+  instruction_descriptor =
+      MakeInstructionDescriptor(58, SpvOpMatrixTimesVector, 0);
+  transformation = TransformationReplaceLinearAlgebraInstruction(
+      {98,  99,  100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
+       110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121},
+      instruction_descriptor);
+  transformation.Apply(context.get(), &transformation_context);
+
+  instruction_descriptor =
+      MakeInstructionDescriptor(59, SpvOpMatrixTimesVector, 0);
+  transformation = TransformationReplaceLinearAlgebraInstruction(
+      {122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132,
+       133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143},
+      instruction_descriptor);
+  transformation.Apply(context.get(), &transformation_context);
+
+  std::string variant_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %54 "main"
+               OpExecutionMode %54 OriginUpperLeft
+               OpSource ESSL 310
+               OpName %54 "main"
+
+; Types
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpTypeFloat 32
+          %5 = OpTypeVector %4 2
+          %6 = OpTypeVector %4 3
+          %7 = OpTypeVector %4 4
+          %8 = OpTypeMatrix %5 2
+          %9 = OpTypeMatrix %5 3
+         %10 = OpTypeMatrix %5 4
+         %11 = OpTypeMatrix %6 2
+         %12 = OpTypeMatrix %6 3
+         %13 = OpTypeMatrix %6 4
+         %14 = OpTypeMatrix %7 2
+         %15 = OpTypeMatrix %7 3
+         %16 = OpTypeMatrix %7 4
+
+; Constant scalars
+         %17 = OpConstant %4 1
+         %18 = OpConstant %4 2
+         %19 = OpConstant %4 3
+         %20 = OpConstant %4 4
+         %21 = OpConstant %4 5
+         %22 = OpConstant %4 6
+         %23 = OpConstant %4 7
+         %24 = OpConstant %4 8
+         %25 = OpConstant %4 9
+         %26 = OpConstant %4 10
+         %27 = OpConstant %4 11
+         %28 = OpConstant %4 12
+         %29 = OpConstant %4 13
+         %30 = OpConstant %4 14
+         %31 = OpConstant %4 15
+         %32 = OpConstant %4 16
+
+; Constant vectors
+         %33 = OpConstantComposite %5 %17 %18
+         %34 = OpConstantComposite %5 %19 %20
+         %35 = OpConstantComposite %5 %21 %22
+         %36 = OpConstantComposite %5 %23 %24
+         %37 = OpConstantComposite %6 %17 %18 %19
+         %38 = OpConstantComposite %6 %20 %21 %22
+         %39 = OpConstantComposite %6 %23 %24 %25
+         %40 = OpConstantComposite %6 %26 %27 %28
+         %41 = OpConstantComposite %7 %17 %18 %19 %20
+         %42 = OpConstantComposite %7 %21 %22 %23 %24
+         %43 = OpConstantComposite %7 %25 %26 %27 %28
+         %44 = OpConstantComposite %7 %29 %30 %31 %32
+
+; Constant matrices
+         %45 = OpConstantComposite %8 %33 %34
+         %46 = OpConstantComposite %9 %33 %34 %35
+         %47 = OpConstantComposite %10 %33 %34 %35 %36
+         %48 = OpConstantComposite %11 %37 %38
+         %49 = OpConstantComposite %12 %37 %38 %39
+         %50 = OpConstantComposite %13 %37 %38 %39 %40
+         %51 = OpConstantComposite %14 %41 %42
+         %52 = OpConstantComposite %15 %41 %42 %43
+         %53 = OpConstantComposite %16 %41 %42 %43 %44
+
+; main function
+         %54 = OpFunction %2 None %3
+         %55 = OpLabel
+
+; Multiplying 2x2 matrix by 2-dimensional vector
+         %65 = OpCompositeExtract %5 %45 0
+         %66 = OpCompositeExtract %5 %45 1
+         %67 = OpCompositeExtract %4 %33 0
+         %68 = OpCompositeExtract %4 %33 1
+         %69 = OpCompositeExtract %4 %65 0
+         %70 = OpFMul %4 %69 %67
+         %71 = OpCompositeExtract %4 %66 0
+         %72 = OpFMul %4 %71 %68
+         %73 = OpFAdd %4 %70 %72
+         %74 = OpCompositeExtract %4 %65 1
+         %75 = OpFMul %4 %74 %67
+         %76 = OpCompositeExtract %4 %66 1
+         %77 = OpFMul %4 %76 %68
+         %78 = OpFAdd %4 %75 %77
+         %56 = OpCompositeConstruct %5 %73 %78
+
+; Multiplying 3x2 matrix by 2-dimensional vector
+         %79 = OpCompositeExtract %6 %48 0
+         %80 = OpCompositeExtract %6 %48 1
+         %81 = OpCompositeExtract %4 %34 0
+         %82 = OpCompositeExtract %4 %34 1
+         %83 = OpCompositeExtract %4 %79 0
+         %84 = OpFMul %4 %83 %81
+         %85 = OpCompositeExtract %4 %80 0
+         %86 = OpFMul %4 %85 %82
+         %87 = OpFAdd %4 %84 %86
+         %88 = OpCompositeExtract %4 %79 1
+         %89 = OpFMul %4 %88 %81
+         %90 = OpCompositeExtract %4 %80 1
+         %91 = OpFMul %4 %90 %82
+         %92 = OpFAdd %4 %89 %91
+         %93 = OpCompositeExtract %4 %79 2
+         %94 = OpFMul %4 %93 %81
+         %95 = OpCompositeExtract %4 %80 2
+         %96 = OpFMul %4 %95 %82
+         %97 = OpFAdd %4 %94 %96
+         %57 = OpCompositeConstruct %6 %87 %92 %97
+
+; Multiplying 4x2 matrix by 2-dimensional vector
+         %98 = OpCompositeExtract %7 %51 0
+         %99 = OpCompositeExtract %7 %51 1
+        %100 = OpCompositeExtract %4 %35 0
+        %101 = OpCompositeExtract %4 %35 1
+        %102 = OpCompositeExtract %4 %98 0
+        %103 = OpFMul %4 %102 %100
+        %104 = OpCompositeExtract %4 %99 0
+        %105 = OpFMul %4 %104 %101
+        %106 = OpFAdd %4 %103 %105
+        %107 = OpCompositeExtract %4 %98 1
+        %108 = OpFMul %4 %107 %100
+        %109 = OpCompositeExtract %4 %99 1
+        %110 = OpFMul %4 %109 %101
+        %111 = OpFAdd %4 %108 %110
+        %112 = OpCompositeExtract %4 %98 2
+        %113 = OpFMul %4 %112 %100
+        %114 = OpCompositeExtract %4 %99 2
+        %115 = OpFMul %4 %114 %101
+        %116 = OpFAdd %4 %113 %115
+        %117 = OpCompositeExtract %4 %98 3
+        %118 = OpFMul %4 %117 %100
+        %119 = OpCompositeExtract %4 %99 3
+        %120 = OpFMul %4 %119 %101
+        %121 = OpFAdd %4 %118 %120
+         %58 = OpCompositeConstruct %7 %106 %111 %116 %121
+
+; Multiplying 2x3 matrix by 3-dimensional vector
+        %122 = OpCompositeExtract %5 %46 0
+        %123 = OpCompositeExtract %5 %46 1
+        %124 = OpCompositeExtract %5 %46 2
+        %125 = OpCompositeExtract %4 %37 0
+        %126 = OpCompositeExtract %4 %37 1
+        %127 = OpCompositeExtract %4 %37 2
+        %128 = OpCompositeExtract %4 %122 0
+        %129 = OpFMul %4 %128 %125
+        %130 = OpCompositeExtract %4 %123 0
+        %131 = OpFMul %4 %130 %126
+        %132 = OpCompositeExtract %4 %124 0
+        %133 = OpFMul %4 %132 %127
+        %134 = OpFAdd %4 %129 %131
+        %135 = OpFAdd %4 %133 %134
+        %136 = OpCompositeExtract %4 %122 1
+        %137 = OpFMul %4 %136 %125
+        %138 = OpCompositeExtract %4 %123 1
+        %139 = OpFMul %4 %138 %126
+        %140 = OpCompositeExtract %4 %124 1
+        %141 = OpFMul %4 %140 %127
+        %142 = OpFAdd %4 %137 %139
+        %143 = OpFAdd %4 %141 %142
+         %59 = OpCompositeConstruct %5 %135 %143
+
+; Multiplying 3x3 matrix by 3-dimensional vector
+         %60 = OpMatrixTimesVector %6 %49 %38
+
+; Multiplying 4x3 matrix by 3-dimensional vector
+         %61 = OpMatrixTimesVector %7 %52 %39
+
+; Multiplying 2x4 matrix by 4-dimensional vector
+         %62 = OpMatrixTimesVector %5 %47 %41
+
+; Multiplying 3x4 matrix by 4-dimensional vector
+         %63 = OpMatrixTimesVector %6 %50 %42
+
+; Multiplying 4x4 matrix by 4-dimensional vector
+         %64 = OpMatrixTimesVector %7 %53 %43
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(IsValid(env, context.get()));
+  ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
 TEST(TransformationReplaceLinearAlgebraInstructionTest, ReplaceOpDot) {
   std::string reference_shader = R"(
                OpCapability Shader