Fold Fmix should accept vector operands. (#2826)

Fixes #2819

diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 06a1a81..e0a17e9 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -296,6 +296,51 @@
   };
 }
 
+// Returns the result of folding the constants in |constants| according the
+// |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
+// per component.
+const analysis::Constant* FoldFPBinaryOp(
+    BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
+    const std::vector<const analysis::Constant*>& constants,
+    IRContext* context) {
+  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+  analysis::TypeManager* type_mgr = context->get_type_mgr();
+  const analysis::Type* result_type = type_mgr->GetType(result_type_id);
+  const analysis::Vector* vector_type = result_type->AsVector();
+
+  if (constants[0] == nullptr || constants[1] == nullptr) {
+    return nullptr;
+  }
+
+  if (vector_type != nullptr) {
+    std::vector<const analysis::Constant*> a_components;
+    std::vector<const analysis::Constant*> b_components;
+    std::vector<const analysis::Constant*> results_components;
+
+    a_components = constants[0]->GetVectorComponents(const_mgr);
+    b_components = constants[1]->GetVectorComponents(const_mgr);
+
+    // Fold each component of the vector.
+    for (uint32_t i = 0; i < a_components.size(); ++i) {
+      results_components.push_back(scalar_rule(vector_type->element_type(),
+                                               a_components[i], b_components[i],
+                                               const_mgr));
+      if (results_components[i] == nullptr) {
+        return nullptr;
+      }
+    }
+
+    // Build the constant object and return it.
+    std::vector<uint32_t> ids;
+    for (const analysis::Constant* member : results_components) {
+      ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
+    }
+    return const_mgr->GetConstant(vector_type, ids);
+  } else {
+    return scalar_rule(result_type, constants[0], constants[1], const_mgr);
+  }
+}
+
 // Returns a |ConstantFoldingRule| that folds floating point scalars using
 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
 // elements of the vector.  The |ConstantFoldingRule| that is returned assumes
@@ -305,46 +350,10 @@
   return [scalar_rule](IRContext* context, Instruction* inst,
                        const std::vector<const analysis::Constant*>& constants)
              -> const analysis::Constant* {
-    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
-    analysis::TypeManager* type_mgr = context->get_type_mgr();
-    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
-    const analysis::Vector* vector_type = result_type->AsVector();
-
     if (!inst->IsFloatingPointFoldingAllowed()) {
       return nullptr;
     }
-
-    if (constants[0] == nullptr || constants[1] == nullptr) {
-      return nullptr;
-    }
-
-    if (vector_type != nullptr) {
-      std::vector<const analysis::Constant*> a_components;
-      std::vector<const analysis::Constant*> b_components;
-      std::vector<const analysis::Constant*> results_components;
-
-      a_components = constants[0]->GetVectorComponents(const_mgr);
-      b_components = constants[1]->GetVectorComponents(const_mgr);
-
-      // Fold each component of the vector.
-      for (uint32_t i = 0; i < a_components.size(); ++i) {
-        results_components.push_back(scalar_rule(vector_type->element_type(),
-                                                 a_components[i],
-                                                 b_components[i], const_mgr));
-        if (results_components[i] == nullptr) {
-          return nullptr;
-        }
-      }
-
-      // Build the constant object and return it.
-      std::vector<uint32_t> ids;
-      for (const analysis::Constant* member : results_components) {
-        ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
-      }
-      return const_mgr->GetConstant(vector_type, ids);
-    } else {
-      return scalar_rule(result_type, constants[0], constants[1], const_mgr);
-    }
+    return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
   };
 }
 
@@ -435,29 +444,33 @@
 
 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
 // operator |op| must work for both float and double, and use syntax "f1 op f2".
-#define FOLD_FPARITH_OP(op)                                                \
-  [](const analysis::Type* result_type, const analysis::Constant* a,       \
-     const analysis::Constant* b,                                          \
-     analysis::ConstantManager* const_mgr_in_macro)                        \
-      -> const analysis::Constant* {                                       \
-    assert(result_type != nullptr && a != nullptr && b != nullptr);        \
-    assert(result_type == a->type() && result_type == b->type());          \
-    const analysis::Float* float_type_in_macro = result_type->AsFloat();   \
-    assert(float_type_in_macro != nullptr);                                \
-    if (float_type_in_macro->width() == 32) {                              \
-      float fa = a->GetFloat();                                            \
-      float fb = b->GetFloat();                                            \
-      utils::FloatProxy<float> result_in_macro(fa op fb);                  \
-      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
-      return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
-    } else if (float_type_in_macro->width() == 64) {                       \
-      double fa = a->GetDouble();                                          \
-      double fb = b->GetDouble();                                          \
-      utils::FloatProxy<double> result_in_macro(fa op fb);                 \
-      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
-      return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
-    }                                                                      \
-    return nullptr;                                                        \
+#define FOLD_FPARITH_OP(op)                                                   \
+  [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
+     const analysis::Constant* b,                                             \
+     analysis::ConstantManager* const_mgr_in_macro)                           \
+      -> const analysis::Constant* {                                          \
+    assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr);  \
+    assert(result_type_in_macro == a->type() &&                               \
+           result_type_in_macro == b->type());                                \
+    const analysis::Float* float_type_in_macro =                              \
+        result_type_in_macro->AsFloat();                                      \
+    assert(float_type_in_macro != nullptr);                                   \
+    if (float_type_in_macro->width() == 32) {                                 \
+      float fa = a->GetFloat();                                               \
+      float fb = b->GetFloat();                                               \
+      utils::FloatProxy<float> result_in_macro(fa op fb);                     \
+      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
+      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
+                                             words_in_macro);                 \
+    } else if (float_type_in_macro->width() == 64) {                          \
+      double fa = a->GetDouble();                                             \
+      double fb = b->GetDouble();                                             \
+      utils::FloatProxy<double> result_in_macro(fa op fb);                    \
+      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
+      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
+                                             words_in_macro);                 \
+    }                                                                         \
+    return nullptr;                                                           \
   }
 
 // Define the folding rule for conversion between floating point and integer
@@ -834,31 +847,49 @@
     }
 
     const analysis::Constant* one;
-    if (constants[1]->type()->AsFloat()->width() == 32) {
-      one = const_mgr->GetConstant(constants[1]->type(),
+    bool is_vector = false;
+    const analysis::Type* result_type = constants[1]->type();
+    const analysis::Type* base_type = result_type;
+    if (base_type->AsVector()) {
+      is_vector = true;
+      base_type = base_type->AsVector()->element_type();
+    }
+    assert(base_type->AsFloat() != nullptr &&
+           "FMix is suppose to act on floats or vectors of floats.");
+
+    if (base_type->AsFloat()->width() == 32) {
+      one = const_mgr->GetConstant(base_type,
                                    utils::FloatProxy<float>(1.0f).GetWords());
     } else {
-      one = const_mgr->GetConstant(constants[1]->type(),
+      one = const_mgr->GetConstant(base_type,
                                    utils::FloatProxy<double>(1.0).GetWords());
     }
 
-    const analysis::Constant* temp1 =
-        FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
+    if (is_vector) {
+      uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
+      one =
+          const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
+    }
+
+    const analysis::Constant* temp1 = FoldFPBinaryOp(
+        FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
     if (temp1 == nullptr) {
       return nullptr;
     }
 
-    const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
-        constants[1]->type(), constants[1], temp1, const_mgr);
+    const analysis::Constant* temp2 = FoldFPBinaryOp(
+        FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
     if (temp2 == nullptr) {
       return nullptr;
     }
-    const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
-        constants[2]->type(), constants[2], constants[3], const_mgr);
+    const analysis::Constant* temp3 =
+        FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
+                       {constants[2], constants[3]}, context);
     if (temp3 == nullptr) {
       return nullptr;
     }
-    return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
+    return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
+                          context);
   };
 }
 
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index b5998c7..f24f08e 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -222,6 +222,7 @@
 %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
+%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5
 %v2float_null = OpConstantNull %v2float
 %double_n1 = OpConstant %double -1
 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
@@ -643,6 +644,58 @@
 ));
 // clang-format on
 
+using FloatVectorInstructionFoldingTest =
+    ::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
+
+TEST_P(FloatVectorInstructionFoldingTest, Case) {
+  const auto& tc = GetParam();
+
+  // Build module.
+  std::unique_ptr<IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  // Fold the instruction to test.
+  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+  SpvOp original_opcode = inst->opcode();
+  bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
+
+  // Make sure the instruction folded as expected.
+  EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
+  if (succeeded && inst != nullptr) {
+    EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+    inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+    std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
+    EXPECT_THAT(opcodes, Contains(inst->opcode()));
+    analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+    const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
+    EXPECT_NE(result, nullptr);
+    if (result != nullptr) {
+      const std::vector<const analysis::Constant*>& componenets =
+          result->AsVectorConstant()->GetComponents();
+      EXPECT_EQ(componenets.size(), tc.expected_result.size());
+      for (size_t i = 0; i < componenets.size(); i++) {
+        EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat());
+      }
+    }
+  }
+}
+
+// clang-format off
+INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
+::testing::Values(
+   // Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5}
+   InstructionFoldingCase<std::vector<float>>(
+       Header() + "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       2, {1.6f,1.5f})
+));
+// clang-format on
 using BooleanInstructionFoldingTest =
     ::testing::TestWithParam<InstructionFoldingCase<bool>>;