spirv-fuzz: Implement the OpVectorTimesMatrix linear algebra case (#3489)
This PR implements the OpVectorTimesMatrix case for the
replace linear algebra instruction transformation.
diff --git a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
index 7f9b848..7116002 100644
--- a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
+++ b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp
@@ -43,6 +43,7 @@
// |spvOpcodeIsLinearAlgebra|.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpMatrixTimesScalar &&
+ instruction->opcode() != SpvOpVectorTimesMatrix &&
instruction->opcode() != SpvOpDot) {
return;
}
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index a7cc7c9..b913d02 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -1176,13 +1176,13 @@
// Supported:
// OpVectorTimesScalar
// OpMatrixTimesScalar
+ // OpVectorTimesMatrix
// OpDot
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations. When this issue is addressed
// the supporting comments can be removed.
// To be supported in the future:
// OpTranspose
- // OpVectorTimesMatrix
// OpMatrixTimesVector
// OpMatrixTimesMatrix
// OpOuterProduct
diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
index 1c7d0c9..20135ab 100644
--- a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
+++ b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp
@@ -47,6 +47,7 @@
// It must be a supported linear algebra instruction.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpMatrixTimesScalar &&
+ instruction->opcode() != SpvOpVectorTimesMatrix &&
instruction->opcode() != SpvOpDot) {
return false;
}
@@ -81,6 +82,9 @@
case SpvOpMatrixTimesScalar:
ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
break;
+ case SpvOpVectorTimesMatrix:
+ ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
+ break;
case SpvOpDot:
ReplaceOpDot(ir_context, linear_algebra_instruction);
break;
@@ -128,6 +132,27 @@
->AsVector()
->element_count());
}
+ case SpvOpVectorTimesMatrix: {
+ // For each vector component, 1 OpCompositeExtract instruction will be
+ // inserted. For each matrix column, |1 + vector_component_count|
+ // OpCompositeExtract, |vector_component_count| OpFMul and
+ // |vector_component_count - 1| OpFAdd instructions will be inserted.
+ auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
+ instruction->GetSingleWordInOperand(0));
+ auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
+ instruction->GetSingleWordInOperand(1));
+ uint32_t vector_component_count =
+ ir_context->get_type_mgr()
+ ->GetType(vector_instruction->type_id())
+ ->AsVector()
+ ->element_count();
+ uint32_t matrix_column_count =
+ ir_context->get_type_mgr()
+ ->GetType(matrix_instruction->type_id())
+ ->AsMatrix()
+ ->element_count();
+ return vector_component_count * (3 * matrix_column_count + 1);
+ }
case SpvOpDot:
// For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
// will be inserted. The first two OpFMul instructions will result the
@@ -280,6 +305,121 @@
}
}
+void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
+ opt::IRContext* ir_context,
+ opt::Instruction* linear_algebra_instruction) const {
+ // Gets vector information.
+ auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
+ linear_algebra_instruction->GetSingleWordInOperand(0));
+ uint32_t vector_component_count = ir_context->get_type_mgr()
+ ->GetType(vector_instruction->type_id())
+ ->AsVector()
+ ->element_count();
+ auto vector_component_type = ir_context->get_type_mgr()
+ ->GetType(vector_instruction->type_id())
+ ->AsVector()
+ ->element_type();
+
+ // Extracts vector components.
+ uint32_t fresh_id_index = 0;
+ std::vector<uint32_t> vector_component_ids(vector_component_count);
+ for (uint32_t i = 0; i < vector_component_count; i++) {
+ vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpCompositeExtract,
+ ir_context->get_type_mgr()->GetId(vector_component_type),
+ vector_component_ids[i],
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
+ }
+
+ // Gets matrix information.
+ auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
+ linear_algebra_instruction->GetSingleWordInOperand(1));
+ uint32_t matrix_column_count = ir_context->get_type_mgr()
+ ->GetType(matrix_instruction->type_id())
+ ->AsMatrix()
+ ->element_count();
+ auto matrix_column_type = ir_context->get_type_mgr()
+ ->GetType(matrix_instruction->type_id())
+ ->AsMatrix()
+ ->element_type();
+
+ std::vector<uint32_t> result_component_ids(matrix_column_count);
+ for (uint32_t i = 0; i < matrix_column_count; i++) {
+ // Extracts matrix column.
+ uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpCompositeExtract,
+ ir_context->get_type_mgr()->GetId(matrix_column_type),
+ matrix_extract_id,
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
+
+ std::vector<uint32_t> float_multiplication_ids(vector_component_count);
+ for (uint32_t j = 0; j < vector_component_count; j++) {
+ // Extracts column component.
+ uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpCompositeExtract,
+ ir_context->get_type_mgr()->GetId(vector_component_type),
+ column_extract_id,
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
+
+ // Multiplies corresponding vector and column components.
+ float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpFMul,
+ ir_context->get_type_mgr()->GetId(vector_component_type),
+ float_multiplication_ids[j],
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
+ {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
+ }
+
+ // Adds the multiplication results.
+ std::vector<uint32_t> float_add_ids;
+ uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
+ float_add_ids.push_back(float_add_id);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpFAdd,
+ ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
+ {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
+ for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
+ float_add_id = message_.fresh_ids(fresh_id_index++);
+ float_add_ids.push_back(float_add_id);
+ linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
+ ir_context, SpvOpFAdd,
+ ir_context->get_type_mgr()->GetId(vector_component_type),
+ float_add_id,
+ opt::Instruction::OperandList(
+ {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
+ {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
+ }
+
+ result_component_ids[i] = float_add_ids.back();
+ }
+
+ // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
+ // instruction.
+ linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
+ linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
+ linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
+ for (uint32_t i = 2; i < result_component_ids.size(); i++) {
+ linear_algebra_instruction->AddOperand(
+ {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
+ }
+
+ fuzzerutil::UpdateModuleIdBound(
+ ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
+}
+
void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
opt::IRContext* ir_context,
opt::Instruction* linear_algebra_instruction) const {
diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.h b/source/fuzz/transformation_replace_linear_algebra_instruction.h
index 45b1262..352463c 100644
--- a/source/fuzz/transformation_replace_linear_algebra_instruction.h
+++ b/source/fuzz/transformation_replace_linear_algebra_instruction.h
@@ -60,6 +60,10 @@
void ReplaceOpMatrixTimesScalar(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
+ // Replaces an OpVectorTimesMatrix instruction.
+ void ReplaceOpVectorTimesMatrix(opt::IRContext* ir_context,
+ opt::Instruction* instruction) const;
+
// Replaces an OpDot instruction.
void ReplaceOpDot(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
diff --git a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
index c9a1aee..7562262 100644
--- a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
+++ b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp
@@ -524,6 +524,341 @@
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
+TEST(TransformationReplaceLinearAlgebraInstructionTest,
+ ReplaceOpVectorTimesMatrix) {
+ std::string reference_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %54 "main"
+ OpExecutionMode %54 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %54 "main"
+
+; Types
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %4 = OpTypeFloat 32
+ %5 = OpTypeVector %4 2
+ %6 = OpTypeVector %4 3
+ %7 = OpTypeVector %4 4
+ %8 = OpTypeMatrix %5 2
+ %9 = OpTypeMatrix %5 3
+ %10 = OpTypeMatrix %5 4
+ %11 = OpTypeMatrix %6 2
+ %12 = OpTypeMatrix %6 3
+ %13 = OpTypeMatrix %6 4
+ %14 = OpTypeMatrix %7 2
+ %15 = OpTypeMatrix %7 3
+ %16 = OpTypeMatrix %7 4
+
+; Constant scalars
+ %17 = OpConstant %4 1
+ %18 = OpConstant %4 2
+ %19 = OpConstant %4 3
+ %20 = OpConstant %4 4
+ %21 = OpConstant %4 5
+ %22 = OpConstant %4 6
+ %23 = OpConstant %4 7
+ %24 = OpConstant %4 8
+ %25 = OpConstant %4 9
+ %26 = OpConstant %4 10
+ %27 = OpConstant %4 11
+ %28 = OpConstant %4 12
+ %29 = OpConstant %4 13
+ %30 = OpConstant %4 14
+ %31 = OpConstant %4 15
+ %32 = OpConstant %4 16
+
+; Constant vectors
+ %33 = OpConstantComposite %5 %17 %18
+ %34 = OpConstantComposite %5 %19 %20
+ %35 = OpConstantComposite %5 %21 %22
+ %36 = OpConstantComposite %5 %23 %24
+ %37 = OpConstantComposite %6 %17 %18 %19
+ %38 = OpConstantComposite %6 %20 %21 %22
+ %39 = OpConstantComposite %6 %23 %24 %25
+ %40 = OpConstantComposite %6 %26 %27 %28
+ %41 = OpConstantComposite %7 %17 %18 %19 %20
+ %42 = OpConstantComposite %7 %21 %22 %23 %24
+ %43 = OpConstantComposite %7 %25 %26 %27 %28
+ %44 = OpConstantComposite %7 %29 %30 %31 %32
+
+; Constant matrices
+ %45 = OpConstantComposite %8 %33 %34
+ %46 = OpConstantComposite %9 %33 %34 %35
+ %47 = OpConstantComposite %10 %33 %34 %35 %36
+ %48 = OpConstantComposite %11 %37 %38
+ %49 = OpConstantComposite %12 %37 %38 %39
+ %50 = OpConstantComposite %13 %37 %38 %39 %40
+ %51 = OpConstantComposite %14 %41 %42
+ %52 = OpConstantComposite %15 %41 %42 %43
+ %53 = OpConstantComposite %16 %41 %42 %43 %44
+
+; main function
+ %54 = OpFunction %2 None %3
+ %55 = OpLabel
+
+; Multiplying 2-dimensional vector by 2x2 matrix
+ %56 = OpVectorTimesMatrix %5 %33 %45
+
+; Multiplying 2-dimensional vector by 2x3 matrix
+ %57 = OpVectorTimesMatrix %6 %34 %46
+
+; Multiplying 2-dimensional vector by 2x4 matrix
+ %58 = OpVectorTimesMatrix %7 %35 %47
+
+; Multiplying 3-dimensional vector by 3x2 matrix
+ %59 = OpVectorTimesMatrix %5 %37 %48
+
+; Multiplying 3-dimensional vector by 3x3 matrix
+ %60 = OpVectorTimesMatrix %6 %38 %49
+
+; Multiplying 3-dimensional vector by 3x4 matrix
+ %61 = OpVectorTimesMatrix %7 %39 %50
+
+; Multiplying 4-dimensional vector by 4x2 matrix
+ %62 = OpVectorTimesMatrix %5 %41 %51
+
+; Multiplying 4-dimensional vector by 4x3 matrix
+ %63 = OpVectorTimesMatrix %6 %42 %52
+
+; Multiplying 4-dimensional vector by 4x4 matrix
+ %64 = OpVectorTimesMatrix %7 %43 %53
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context =
+ BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+
+ auto instruction_descriptor =
+ MakeInstructionDescriptor(56, SpvOpVectorTimesMatrix, 0);
+ auto transformation = TransformationReplaceLinearAlgebraInstruction(
+ {65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78},
+ instruction_descriptor);
+ transformation.Apply(context.get(), &transformation_context);
+
+ instruction_descriptor =
+ MakeInstructionDescriptor(57, SpvOpVectorTimesMatrix, 0);
+ transformation = TransformationReplaceLinearAlgebraInstruction(
+ {79, 80, 81, 82, 83, 84, 85, 86, 87, 88,
+ 89, 90, 91, 92, 93, 94, 95, 96, 97, 98},
+ instruction_descriptor);
+ transformation.Apply(context.get(), &transformation_context);
+
+ instruction_descriptor =
+ MakeInstructionDescriptor(58, SpvOpVectorTimesMatrix, 0);
+ transformation = TransformationReplaceLinearAlgebraInstruction(
+ {99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
+ 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124},
+ instruction_descriptor);
+ transformation.Apply(context.get(), &transformation_context);
+
+ instruction_descriptor =
+ MakeInstructionDescriptor(59, SpvOpVectorTimesMatrix, 0);
+ transformation = TransformationReplaceLinearAlgebraInstruction(
+ {125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135,
+ 136, 137, 138, 139, 140, 141, 142, 143, 144, 145},
+ instruction_descriptor);
+ transformation.Apply(context.get(), &transformation_context);
+
+ std::string variant_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %54 "main"
+ OpExecutionMode %54 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %54 "main"
+
+; Types
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %4 = OpTypeFloat 32
+ %5 = OpTypeVector %4 2
+ %6 = OpTypeVector %4 3
+ %7 = OpTypeVector %4 4
+ %8 = OpTypeMatrix %5 2
+ %9 = OpTypeMatrix %5 3
+ %10 = OpTypeMatrix %5 4
+ %11 = OpTypeMatrix %6 2
+ %12 = OpTypeMatrix %6 3
+ %13 = OpTypeMatrix %6 4
+ %14 = OpTypeMatrix %7 2
+ %15 = OpTypeMatrix %7 3
+ %16 = OpTypeMatrix %7 4
+
+; Constant scalars
+ %17 = OpConstant %4 1
+ %18 = OpConstant %4 2
+ %19 = OpConstant %4 3
+ %20 = OpConstant %4 4
+ %21 = OpConstant %4 5
+ %22 = OpConstant %4 6
+ %23 = OpConstant %4 7
+ %24 = OpConstant %4 8
+ %25 = OpConstant %4 9
+ %26 = OpConstant %4 10
+ %27 = OpConstant %4 11
+ %28 = OpConstant %4 12
+ %29 = OpConstant %4 13
+ %30 = OpConstant %4 14
+ %31 = OpConstant %4 15
+ %32 = OpConstant %4 16
+
+; Constant vectors
+ %33 = OpConstantComposite %5 %17 %18
+ %34 = OpConstantComposite %5 %19 %20
+ %35 = OpConstantComposite %5 %21 %22
+ %36 = OpConstantComposite %5 %23 %24
+ %37 = OpConstantComposite %6 %17 %18 %19
+ %38 = OpConstantComposite %6 %20 %21 %22
+ %39 = OpConstantComposite %6 %23 %24 %25
+ %40 = OpConstantComposite %6 %26 %27 %28
+ %41 = OpConstantComposite %7 %17 %18 %19 %20
+ %42 = OpConstantComposite %7 %21 %22 %23 %24
+ %43 = OpConstantComposite %7 %25 %26 %27 %28
+ %44 = OpConstantComposite %7 %29 %30 %31 %32
+
+; Constant matrices
+ %45 = OpConstantComposite %8 %33 %34
+ %46 = OpConstantComposite %9 %33 %34 %35
+ %47 = OpConstantComposite %10 %33 %34 %35 %36
+ %48 = OpConstantComposite %11 %37 %38
+ %49 = OpConstantComposite %12 %37 %38 %39
+ %50 = OpConstantComposite %13 %37 %38 %39 %40
+ %51 = OpConstantComposite %14 %41 %42
+ %52 = OpConstantComposite %15 %41 %42 %43
+ %53 = OpConstantComposite %16 %41 %42 %43 %44
+
+; main function
+ %54 = OpFunction %2 None %3
+ %55 = OpLabel
+
+; Multiplying 2-dimensional vector by 2x2 matrix
+ %65 = OpCompositeExtract %4 %33 0
+ %66 = OpCompositeExtract %4 %33 1
+ %67 = OpCompositeExtract %5 %45 0
+ %68 = OpCompositeExtract %4 %67 0
+ %69 = OpFMul %4 %65 %68
+ %70 = OpCompositeExtract %4 %67 1
+ %71 = OpFMul %4 %66 %70
+ %72 = OpFAdd %4 %69 %71
+ %73 = OpCompositeExtract %5 %45 1
+ %74 = OpCompositeExtract %4 %73 0
+ %75 = OpFMul %4 %65 %74
+ %76 = OpCompositeExtract %4 %73 1
+ %77 = OpFMul %4 %66 %76
+ %78 = OpFAdd %4 %75 %77
+ %56 = OpCompositeConstruct %5 %72 %78
+
+; Multiplying 2-dimensional vector by 2x3 matrix
+ %79 = OpCompositeExtract %4 %34 0
+ %80 = OpCompositeExtract %4 %34 1
+ %81 = OpCompositeExtract %5 %46 0
+ %82 = OpCompositeExtract %4 %81 0
+ %83 = OpFMul %4 %79 %82
+ %84 = OpCompositeExtract %4 %81 1
+ %85 = OpFMul %4 %80 %84
+ %86 = OpFAdd %4 %83 %85
+ %87 = OpCompositeExtract %5 %46 1
+ %88 = OpCompositeExtract %4 %87 0
+ %89 = OpFMul %4 %79 %88
+ %90 = OpCompositeExtract %4 %87 1
+ %91 = OpFMul %4 %80 %90
+ %92 = OpFAdd %4 %89 %91
+ %93 = OpCompositeExtract %5 %46 2
+ %94 = OpCompositeExtract %4 %93 0
+ %95 = OpFMul %4 %79 %94
+ %96 = OpCompositeExtract %4 %93 1
+ %97 = OpFMul %4 %80 %96
+ %98 = OpFAdd %4 %95 %97
+ %57 = OpCompositeConstruct %6 %86 %92 %98
+
+; Multiplying 2-dimensional vector by 2x4 matrix
+ %99 = OpCompositeExtract %4 %35 0
+ %100 = OpCompositeExtract %4 %35 1
+ %101 = OpCompositeExtract %5 %47 0
+ %102 = OpCompositeExtract %4 %101 0
+ %103 = OpFMul %4 %99 %102
+ %104 = OpCompositeExtract %4 %101 1
+ %105 = OpFMul %4 %100 %104
+ %106 = OpFAdd %4 %103 %105
+ %107 = OpCompositeExtract %5 %47 1
+ %108 = OpCompositeExtract %4 %107 0
+ %109 = OpFMul %4 %99 %108
+ %110 = OpCompositeExtract %4 %107 1
+ %111 = OpFMul %4 %100 %110
+ %112 = OpFAdd %4 %109 %111
+ %113 = OpCompositeExtract %5 %47 2
+ %114 = OpCompositeExtract %4 %113 0
+ %115 = OpFMul %4 %99 %114
+ %116 = OpCompositeExtract %4 %113 1
+ %117 = OpFMul %4 %100 %116
+ %118 = OpFAdd %4 %115 %117
+ %119 = OpCompositeExtract %5 %47 3
+ %120 = OpCompositeExtract %4 %119 0
+ %121 = OpFMul %4 %99 %120
+ %122 = OpCompositeExtract %4 %119 1
+ %123 = OpFMul %4 %100 %122
+ %124 = OpFAdd %4 %121 %123
+ %58 = OpCompositeConstruct %7 %106 %112 %118 %124
+
+; Multiplying 3-dimensional vector by 3x2 matrix
+ %125 = OpCompositeExtract %4 %37 0
+ %126 = OpCompositeExtract %4 %37 1
+ %127 = OpCompositeExtract %4 %37 2
+ %128 = OpCompositeExtract %6 %48 0
+ %129 = OpCompositeExtract %4 %128 0
+ %130 = OpFMul %4 %125 %129
+ %131 = OpCompositeExtract %4 %128 1
+ %132 = OpFMul %4 %126 %131
+ %133 = OpCompositeExtract %4 %128 2
+ %134 = OpFMul %4 %127 %133
+ %135 = OpFAdd %4 %130 %132
+ %136 = OpFAdd %4 %134 %135
+ %137 = OpCompositeExtract %6 %48 1
+ %138 = OpCompositeExtract %4 %137 0
+ %139 = OpFMul %4 %125 %138
+ %140 = OpCompositeExtract %4 %137 1
+ %141 = OpFMul %4 %126 %140
+ %142 = OpCompositeExtract %4 %137 2
+ %143 = OpFMul %4 %127 %142
+ %144 = OpFAdd %4 %139 %141
+ %145 = OpFAdd %4 %143 %144
+ %59 = OpCompositeConstruct %5 %136 %145
+
+; Multiplying 3-dimensional vector by 3x3 matrix
+ %60 = OpVectorTimesMatrix %6 %38 %49
+
+; Multiplying 3-dimensional vector by 3x4 matrix
+ %61 = OpVectorTimesMatrix %7 %39 %50
+
+; Multiplying 4-dimensional vector by 4x2 matrix
+ %62 = OpVectorTimesMatrix %5 %41 %51
+
+; Multiplying 4-dimensional vector by 4x3 matrix
+ %63 = OpVectorTimesMatrix %6 %42 %52
+
+; Multiplying 4-dimensional vector by 4x4 matrix
+ %64 = OpVectorTimesMatrix %7 %43 %53
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsValid(env, context.get()));
+ ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
TEST(TransformationReplaceLinearAlgebraInstructionTest, ReplaceOpDot) {
std::string reference_shader = R"(
OpCapability Shader