Fold Fmix should accept vector operands. (#2826)
Fixes #2819
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 06a1a81..e0a17e9 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -296,6 +296,51 @@
};
}
+// Returns the result of folding the constants in |constants| according the
+// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
+// per component.
+const analysis::Constant* FoldFPBinaryOp(
+ BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
+ const std::vector<const analysis::Constant*>& constants,
+ IRContext* context) {
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* result_type = type_mgr->GetType(result_type_id);
+ const analysis::Vector* vector_type = result_type->AsVector();
+
+ if (constants[0] == nullptr || constants[1] == nullptr) {
+ return nullptr;
+ }
+
+ if (vector_type != nullptr) {
+ std::vector<const analysis::Constant*> a_components;
+ std::vector<const analysis::Constant*> b_components;
+ std::vector<const analysis::Constant*> results_components;
+
+ a_components = constants[0]->GetVectorComponents(const_mgr);
+ b_components = constants[1]->GetVectorComponents(const_mgr);
+
+ // Fold each component of the vector.
+ for (uint32_t i = 0; i < a_components.size(); ++i) {
+ results_components.push_back(scalar_rule(vector_type->element_type(),
+ a_components[i], b_components[i],
+ const_mgr));
+ if (results_components[i] == nullptr) {
+ return nullptr;
+ }
+ }
+
+ // Build the constant object and return it.
+ std::vector<uint32_t> ids;
+ for (const analysis::Constant* member : results_components) {
+ ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
+ }
+ return const_mgr->GetConstant(vector_type, ids);
+ } else {
+ return scalar_rule(result_type, constants[0], constants[1], const_mgr);
+ }
+}
+
// Returns a |ConstantFoldingRule| that folds floating point scalars using
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
@@ -305,46 +350,10 @@
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- analysis::ConstantManager* const_mgr = context->get_constant_mgr();
- analysis::TypeManager* type_mgr = context->get_type_mgr();
- const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
- const analysis::Vector* vector_type = result_type->AsVector();
-
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
-
- if (constants[0] == nullptr || constants[1] == nullptr) {
- return nullptr;
- }
-
- if (vector_type != nullptr) {
- std::vector<const analysis::Constant*> a_components;
- std::vector<const analysis::Constant*> b_components;
- std::vector<const analysis::Constant*> results_components;
-
- a_components = constants[0]->GetVectorComponents(const_mgr);
- b_components = constants[1]->GetVectorComponents(const_mgr);
-
- // Fold each component of the vector.
- for (uint32_t i = 0; i < a_components.size(); ++i) {
- results_components.push_back(scalar_rule(vector_type->element_type(),
- a_components[i],
- b_components[i], const_mgr));
- if (results_components[i] == nullptr) {
- return nullptr;
- }
- }
-
- // Build the constant object and return it.
- std::vector<uint32_t> ids;
- for (const analysis::Constant* member : results_components) {
- ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
- }
- return const_mgr->GetConstant(vector_type, ids);
- } else {
- return scalar_rule(result_type, constants[0], constants[1], const_mgr);
- }
+ return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
};
}
@@ -435,29 +444,33 @@
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
-#define FOLD_FPARITH_OP(op) \
- [](const analysis::Type* result_type, const analysis::Constant* a, \
- const analysis::Constant* b, \
- analysis::ConstantManager* const_mgr_in_macro) \
- -> const analysis::Constant* { \
- assert(result_type != nullptr && a != nullptr && b != nullptr); \
- assert(result_type == a->type() && result_type == b->type()); \
- const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
- assert(float_type_in_macro != nullptr); \
- if (float_type_in_macro->width() == 32) { \
- float fa = a->GetFloat(); \
- float fb = b->GetFloat(); \
- utils::FloatProxy<float> result_in_macro(fa op fb); \
- std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
- return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
- } else if (float_type_in_macro->width() == 64) { \
- double fa = a->GetDouble(); \
- double fb = b->GetDouble(); \
- utils::FloatProxy<double> result_in_macro(fa op fb); \
- std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
- return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
- } \
- return nullptr; \
+#define FOLD_FPARITH_OP(op) \
+ [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
+ const analysis::Constant* b, \
+ analysis::ConstantManager* const_mgr_in_macro) \
+ -> const analysis::Constant* { \
+ assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
+ assert(result_type_in_macro == a->type() && \
+ result_type_in_macro == b->type()); \
+ const analysis::Float* float_type_in_macro = \
+ result_type_in_macro->AsFloat(); \
+ assert(float_type_in_macro != nullptr); \
+ if (float_type_in_macro->width() == 32) { \
+ float fa = a->GetFloat(); \
+ float fb = b->GetFloat(); \
+ utils::FloatProxy<float> result_in_macro(fa op fb); \
+ std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
+ return const_mgr_in_macro->GetConstant(result_type_in_macro, \
+ words_in_macro); \
+ } else if (float_type_in_macro->width() == 64) { \
+ double fa = a->GetDouble(); \
+ double fb = b->GetDouble(); \
+ utils::FloatProxy<double> result_in_macro(fa op fb); \
+ std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
+ return const_mgr_in_macro->GetConstant(result_type_in_macro, \
+ words_in_macro); \
+ } \
+ return nullptr; \
}
// Define the folding rule for conversion between floating point and integer
@@ -834,31 +847,49 @@
}
const analysis::Constant* one;
- if (constants[1]->type()->AsFloat()->width() == 32) {
- one = const_mgr->GetConstant(constants[1]->type(),
+ bool is_vector = false;
+ const analysis::Type* result_type = constants[1]->type();
+ const analysis::Type* base_type = result_type;
+ if (base_type->AsVector()) {
+ is_vector = true;
+ base_type = base_type->AsVector()->element_type();
+ }
+ assert(base_type->AsFloat() != nullptr &&
+ "FMix is suppose to act on floats or vectors of floats.");
+
+ if (base_type->AsFloat()->width() == 32) {
+ one = const_mgr->GetConstant(base_type,
utils::FloatProxy<float>(1.0f).GetWords());
} else {
- one = const_mgr->GetConstant(constants[1]->type(),
+ one = const_mgr->GetConstant(base_type,
utils::FloatProxy<double>(1.0).GetWords());
}
- const analysis::Constant* temp1 =
- FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
+ if (is_vector) {
+ uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
+ one =
+ const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
+ }
+
+ const analysis::Constant* temp1 = FoldFPBinaryOp(
+ FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
if (temp1 == nullptr) {
return nullptr;
}
- const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
- constants[1]->type(), constants[1], temp1, const_mgr);
+ const analysis::Constant* temp2 = FoldFPBinaryOp(
+ FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
if (temp2 == nullptr) {
return nullptr;
}
- const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
- constants[2]->type(), constants[2], constants[3], const_mgr);
+ const analysis::Constant* temp3 =
+ FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
+ {constants[2], constants[3]}, context);
if (temp3 == nullptr) {
return nullptr;
}
- return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
+ return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
+ context);
};
}
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index b5998c7..f24f08e 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -222,6 +222,7 @@
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
+%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5
%v2float_null = OpConstantNull %v2float
%double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
@@ -643,6 +644,58 @@
));
// clang-format on
+using FloatVectorInstructionFoldingTest =
+ ::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
+
+TEST_P(FloatVectorInstructionFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ SpvOp original_opcode = inst->opcode();
+ bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
+
+ // Make sure the instruction folded as expected.
+ EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
+ if (succeeded && inst != nullptr) {
+ EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
+ inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+ std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
+ EXPECT_THAT(opcodes, Contains(inst->opcode()));
+ analysis::ConstantManager* const_mrg = context->get_constant_mgr();
+ const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
+ EXPECT_NE(result, nullptr);
+ if (result != nullptr) {
+ const std::vector<const analysis::Constant*>& componenets =
+ result->AsVectorConstant()->GetComponents();
+ EXPECT_EQ(componenets.size(), tc.expected_result.size());
+ for (size_t i = 0; i < componenets.size(); i++) {
+ EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat());
+ }
+ }
+ }
+}
+
+// clang-format off
+INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5}
+ InstructionFoldingCase<std::vector<float>>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, {1.6f,1.5f})
+));
+// clang-format on
using BooleanInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>;