Implement constant folding for many transcendentals (#3166)
* Implement constant folding for many transcendentals
This change adds support for folding of sin/cos/tan/asin/acos/atan,
exp/log/exp2/log2, sqrt, atan2 and pow.
The mechanism allows to use any C function to implement folding in the
future; for now I limited the actual additions to the most commonly used
intrinsics in the shaders.
Unary folder had to be tweaked to work with extended instructions - for
extended instructions, constants.size() == 2 and constants[0] ==
nullptr. This adjustment is similar to the one binary folder already
performs.
Fixes #1390.
* Fix Android build
On old versions of Android NDK, we don't get std::exp2/std::log2
because of partial C++11 support.
We do get ::exp2, but not ::log2 so we need to emulate that.
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 2a2493f..d262a7e 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -265,7 +265,10 @@
return nullptr;
}
- if (constants[0] == nullptr) {
+ const analysis::Constant* arg =
+ (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
+
+ if (arg == nullptr) {
return nullptr;
}
@@ -273,7 +276,7 @@
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> results_components;
- a_components = constants[0]->GetVectorComponents(const_mgr);
+ a_components = arg->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
@@ -291,7 +294,7 @@
}
return const_mgr->GetConstant(vector_type, ids);
} else {
- return scalar_rule(result_type, constants[0], const_mgr);
+ return scalar_rule(result_type, arg, const_mgr);
}
};
}
@@ -1070,6 +1073,60 @@
return nullptr;
}
+UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
+ return
+ [fp](const analysis::Type* result_type, const analysis::Constant* a,
+ analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
+ assert(result_type != nullptr && a != nullptr);
+ const analysis::Float* float_type = a->type()->AsFloat();
+ assert(float_type != nullptr);
+ assert(float_type == result_type->AsFloat());
+ if (float_type->width() == 32) {
+ float fa = a->GetFloat();
+ float res = static_cast<float>(fp(fa));
+ utils::FloatProxy<float> result(res);
+ std::vector<uint32_t> words = result.GetWords();
+ return const_mgr->GetConstant(result_type, words);
+ } else if (float_type->width() == 64) {
+ double fa = a->GetDouble();
+ double res = fp(fa);
+ utils::FloatProxy<double> result(res);
+ std::vector<uint32_t> words = result.GetWords();
+ return const_mgr->GetConstant(result_type, words);
+ }
+ return nullptr;
+ };
+}
+
+BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
+ double)) {
+ return
+ [fp](const analysis::Type* result_type, const analysis::Constant* a,
+ const analysis::Constant* b,
+ analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
+ assert(result_type != nullptr && a != nullptr);
+ const analysis::Float* float_type = a->type()->AsFloat();
+ assert(float_type != nullptr);
+ assert(float_type == result_type->AsFloat());
+ assert(float_type == b->type()->AsFloat());
+ if (float_type->width() == 32) {
+ float fa = a->GetFloat();
+ float fb = b->GetFloat();
+ float res = static_cast<float>(fp(fa, fb));
+ utils::FloatProxy<float> result(res);
+ std::vector<uint32_t> words = result.GetWords();
+ return const_mgr->GetConstant(result_type, words);
+ } else if (float_type->width() == 64) {
+ double fa = a->GetDouble();
+ double fb = b->GetDouble();
+ double res = fp(fa, fb);
+ utils::FloatProxy<double> result(res);
+ std::vector<uint32_t> words = result.GetWords();
+ return const_mgr->GetConstant(result_type, words);
+ }
+ return nullptr;
+ };
+}
} // namespace
void ConstantFoldingRules::AddFoldingRules() {
@@ -1175,6 +1232,45 @@
FoldClamp2);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
FoldClamp3);
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
+
+#ifdef __ANDROID__
+ // Android NDK r15c tageting ABI 15 doesn't have full support for C++11
+ // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
+ // available up until ABI 18 so we use a shim
+ auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
+#else
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
+#endif
+
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
+ FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
+ FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
+ ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
+ FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
}
}
} // namespace opt
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 26d1220..db01924 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -1693,7 +1693,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 0.2f),
- // Test case 21: FMax 1.0 4.0
+ // Test case 23: FMax 1.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1701,7 +1701,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
- // Test case 22: FMax 1.0 0.2
+ // Test case 24: FMax 1.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1709,7 +1709,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
- // Test case 23: FClamp 1.0 0.2 4.0
+ // Test case 25: FClamp 1.0 0.2 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1717,7 +1717,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
- // Test case 24: FClamp 0.2 2.0 4.0
+ // Test case 26: FClamp 0.2 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1725,7 +1725,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0f),
- // Test case 25: FClamp 2049.0 2.0 4.0
+ // Test case 27: FClamp 2049.0 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1733,7 +1733,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
- // Test case 26: FClamp 1.0 2.0 x
+ // Test case 28: FClamp 1.0 2.0 x
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1742,7 +1742,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
- // Test case 27: FClamp 1.0 x 0.5
+ // Test case 29: FClamp 1.0 x 0.5
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -1750,7 +1750,111 @@
"%2 = OpExtInst %float %1 FClamp %float_1 %undef %float_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 2, 0.5)
+ 2, 0.5),
+ // Test case 30: Sin 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Sin %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 31: Cos 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Cos %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 1.0),
+ // Test case 32: Tan 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Tan %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 33: Asin 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Asin %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 34: Acos 1.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Acos %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 35: Atan 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Atan %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 36: Exp 0.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Exp %float_0\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 1.0),
+ // Test case 37: Log 1.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Log %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 38: Exp2 2.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Exp2 %float_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 4.0),
+ // Test case 39: Log2 4.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Log2 %float_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 2.0),
+ // Test case 40: Sqrt 4.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Sqrt %float_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 2.0),
+ // Test case 41: Atan2 0.0 1.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Atan2 %float_0 %float_1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0.0),
+ // Test case 42: Pow 2.0 3.0
+ InstructionFoldingCase<float>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpExtInst %float %1 Pow %float_2 %float_3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 8.0)
));
// clang-format on
@@ -1967,7 +2071,25 @@
"%2 = OpExtInst %double %1 FClamp %double_1 %undef %double_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 2, 0.5)
+ 2, 0.5),
+ // Test case 21: Sqrt 4.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%undef = OpUndef %double\n" +
+ "%2 = OpExtInst %double %1 Sqrt %double_4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 2.0),
+ // Test case 22: Pow 2.0 3.0
+ InstructionFoldingCase<double>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%undef = OpUndef %double\n" +
+ "%2 = OpExtInst %double %1 Pow %double_2 %double_3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 8.0)
));
// clang-format on