Constant folding for OpVectorTimesScalar
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 4f2125a..361d715 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -100,7 +100,67 @@
analysis::TypeManager* type_mgr = context->get_type_mgr();
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
};
-} // namespace
+}
+
+ConstantFoldingRule FoldVectorTimesScalar() {
+ return [](ir::Instruction* inst,
+ const std::vector<const analysis::Constant*>& constants)
+ -> const analysis::Constant* {
+ assert(inst->opcode() == SpvOpVectorTimesScalar);
+ const analysis::Constant* c1 = constants[0];
+ const analysis::Constant* c2 = constants[1];
+ if (c1 == nullptr || c2 == nullptr) {
+ return nullptr;
+ }
+
+ ir::IRContext* context = inst->context();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+
+ // Check result type.
+ const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
+ const analysis::Vector* vector_type = result_type->AsVector();
+ assert(vector_type != nullptr);
+ const analysis::Type* element_type = vector_type->element_type();
+ assert(element_type != nullptr);
+ const analysis::Float* float_type = element_type->AsFloat();
+ assert(float_type != nullptr);
+
+ // Check types of c1 and c2.
+ assert(c1->type()->AsVector() == vector_type);
+ assert(c1->type()->AsVector()->element_type() == element_type &&
+ c2->type() == element_type);
+
+ // Get a float vector that is the result of vector-times-scalar.
+ std::vector<const analysis::Constant*> c1_components =
+ c1->GetVectorComponents(const_mgr);
+ std::vector<uint32_t> ids;
+ if (float_type->width() == 32) {
+ float scalar = c2->GetFloat();
+ for (uint32_t i = 0; i < c1_components.size(); ++i) {
+ spvutils::FloatProxy<float> result(c1_components[i]->GetFloat() *
+ scalar);
+ std::vector<uint32_t> words = result.GetWords();
+ const analysis::Constant* new_elem =
+ const_mgr->GetConstant(float_type, words);
+ ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
+ }
+ return const_mgr->GetConstant(vector_type, ids);
+ } else if (float_type->width() == 64) {
+ double scalar = c2->GetDouble();
+ for (uint32_t i = 0; i < c1_components.size(); ++i) {
+ spvutils::FloatProxy<double> result(c1_components[i]->GetDouble() *
+ scalar);
+ std::vector<uint32_t> words = result.GetWords();
+ const analysis::Constant* new_elem =
+ const_mgr->GetConstant(float_type, words);
+ ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
+ }
+ return const_mgr->GetConstant(vector_type, ids);
+ }
+ return nullptr;
+ };
+}
ConstantFoldingRule FoldCompositeWithConstants() {
// Folds an OpCompositeConstruct where all of the inputs are constants to a
@@ -560,6 +620,7 @@
rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
+ rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[SpvOpFNegate].push_back(FoldFNegate());
}
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index f05eadc..1874ded 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -3838,7 +3838,52 @@
"%4 = OpIMul %v2int %v2int_2_2 %3\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
- 4, true)
+ 4, true),
+ // Test case 18: Fold OpVectorTimesScalar
+ // {4,4} = OpVectorTimesScalar v2float {2,2} 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+ "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" +
+ "; CHECK: [[v2float_4_4:%\\w+]] = OpConstantComposite [[v2float]] [[float_4]] [[float_4]]\n" +
+ "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_4_4]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVectorTimesScalar %v2float %v2float_2_2 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 19: Fold OpVectorTimesScalar
+ // {-0,-0} = OpVectorTimesScalar v2float v2float_null -1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
+ "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" +
+ "; CHECK: [[float_n0:%\\w+]] = OpConstant [[float]] -0\n" +
+ "; CHECK: [[v2float_n0_n0:%\\w+]] = OpConstantComposite [[v2float]] [[float_n0]] [[float_n0]]\n" +
+ "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_n0_n0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVectorTimesScalar %v2float %v2float_null %float_n1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 20: Fold OpVectorTimesScalar
+ // {4,4} = OpVectorTimesScalar v2double {2,2} 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+ "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" +
+ "; CHECK: [[double_4:%\\w+]] = OpConstant [[double]] 4\n" +
+ "; CHECK: [[v2double_4_4:%\\w+]] = OpConstantComposite [[v2double]] [[double_4]] [[double_4]]\n" +
+ "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_4_4]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVectorTimesScalar %v2double %v2double_2_2 %double_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true)
));
INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,