Implement constant folding for many transcendentals (#3166)

* Implement constant folding for many transcendentals

This change adds support for folding of sin/cos/tan/asin/acos/atan,
exp/log/exp2/log2, sqrt, atan2 and pow.

The mechanism allows to use any C function to implement folding in the
future; for now I limited the actual additions to the most commonly used
intrinsics in the shaders.

Unary folder had to be tweaked to work with extended instructions - for
extended instructions, constants.size() == 2 and constants[0] ==
nullptr. This adjustment is similar to the one binary folder already
performs.

Fixes #1390.

* Fix Android build

On old versions of Android NDK, we don't get std::exp2/std::log2
because of partial C++11 support.

We do get ::exp2, but not ::log2 so we need to emulate that.
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 2a2493f..d262a7e 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -265,7 +265,10 @@
       return nullptr;
     }
 
-    if (constants[0] == nullptr) {
+    const analysis::Constant* arg =
+        (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
+
+    if (arg == nullptr) {
       return nullptr;
     }
 
@@ -273,7 +276,7 @@
       std::vector<const analysis::Constant*> a_components;
       std::vector<const analysis::Constant*> results_components;
 
-      a_components = constants[0]->GetVectorComponents(const_mgr);
+      a_components = arg->GetVectorComponents(const_mgr);
 
       // Fold each component of the vector.
       for (uint32_t i = 0; i < a_components.size(); ++i) {
@@ -291,7 +294,7 @@
       }
       return const_mgr->GetConstant(vector_type, ids);
     } else {
-      return scalar_rule(result_type, constants[0], const_mgr);
+      return scalar_rule(result_type, arg, const_mgr);
     }
   };
 }
@@ -1070,6 +1073,60 @@
   return nullptr;
 }
 
+UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
+  return
+      [fp](const analysis::Type* result_type, const analysis::Constant* a,
+           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
+        assert(result_type != nullptr && a != nullptr);
+        const analysis::Float* float_type = a->type()->AsFloat();
+        assert(float_type != nullptr);
+        assert(float_type == result_type->AsFloat());
+        if (float_type->width() == 32) {
+          float fa = a->GetFloat();
+          float res = static_cast<float>(fp(fa));
+          utils::FloatProxy<float> result(res);
+          std::vector<uint32_t> words = result.GetWords();
+          return const_mgr->GetConstant(result_type, words);
+        } else if (float_type->width() == 64) {
+          double fa = a->GetDouble();
+          double res = fp(fa);
+          utils::FloatProxy<double> result(res);
+          std::vector<uint32_t> words = result.GetWords();
+          return const_mgr->GetConstant(result_type, words);
+        }
+        return nullptr;
+      };
+}
+
+BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
+                                                               double)) {
+  return
+      [fp](const analysis::Type* result_type, const analysis::Constant* a,
+           const analysis::Constant* b,
+           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
+        assert(result_type != nullptr && a != nullptr);
+        const analysis::Float* float_type = a->type()->AsFloat();
+        assert(float_type != nullptr);
+        assert(float_type == result_type->AsFloat());
+        assert(float_type == b->type()->AsFloat());
+        if (float_type->width() == 32) {
+          float fa = a->GetFloat();
+          float fb = b->GetFloat();
+          float res = static_cast<float>(fp(fa, fb));
+          utils::FloatProxy<float> result(res);
+          std::vector<uint32_t> words = result.GetWords();
+          return const_mgr->GetConstant(result_type, words);
+        } else if (float_type->width() == 64) {
+          double fa = a->GetDouble();
+          double fb = b->GetDouble();
+          double res = fp(fa, fb);
+          utils::FloatProxy<double> result(res);
+          std::vector<uint32_t> words = result.GetWords();
+          return const_mgr->GetConstant(result_type, words);
+        }
+        return nullptr;
+      };
+}
 }  // namespace
 
 void ConstantFoldingRules::AddFoldingRules() {
@@ -1175,6 +1232,45 @@
         FoldClamp2);
     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
         FoldClamp3);
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
+
+#ifdef __ANDROID__
+    // Android NDK r15c tageting ABI 15 doesn't have full support for C++11
+    // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
+    // available up until ABI 18 so we use a shim
+    auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
+#else
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
+#endif
+
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
+        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
+        FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
+    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
+        FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
   }
 }
 }  // namespace opt
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 26d1220..db01924 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -1693,7 +1693,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 0.2f),
-    // Test case 21: FMax 1.0 4.0
+    // Test case 23: FMax 1.0 4.0
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1701,7 +1701,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 4.0f),
-    // Test case 22: FMax 1.0 0.2
+    // Test case 24: FMax 1.0 0.2
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1709,7 +1709,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 1.0f),
-    // Test case 23: FClamp 1.0 0.2 4.0
+    // Test case 25: FClamp 1.0 0.2 4.0
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1717,7 +1717,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 1.0f),
-    // Test case 24: FClamp 0.2 2.0 4.0
+    // Test case 26: FClamp 0.2 2.0 4.0
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1725,7 +1725,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 2.0f),
-    // Test case 25: FClamp 2049.0 2.0 4.0
+    // Test case 27: FClamp 2049.0 2.0 4.0
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1733,7 +1733,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 4.0f),
-    // Test case 26: FClamp 1.0 2.0 x
+    // Test case 28: FClamp 1.0 2.0 x
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1742,7 +1742,7 @@
             "OpReturn\n" +
             "OpFunctionEnd",
         2, 2.0),
-    // Test case 27: FClamp 1.0 x 0.5
+    // Test case 29: FClamp 1.0 x 0.5
     InstructionFoldingCase<float>(
         Header() + "%main = OpFunction %void None %void_func\n" +
             "%main_lab = OpLabel\n" +
@@ -1750,7 +1750,111 @@
             "%2 = OpExtInst %float %1 FClamp %float_1 %undef %float_0p5\n" +
             "OpReturn\n" +
             "OpFunctionEnd",
-        2, 0.5)
+        2, 0.5),
+    // Test case 30: Sin 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Sin %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 31: Cos 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Cos %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 1.0),
+    // Test case 32: Tan 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Tan %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 33: Asin 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Asin %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 34: Acos 1.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Acos %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 35: Atan 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Atan %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 36: Exp 0.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Exp %float_0\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 1.0),
+    // Test case 37: Log 1.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Log %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 38: Exp2 2.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Exp2 %float_2\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 4.0),
+    // Test case 39: Log2 4.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Log2 %float_4\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 2.0),
+    // Test case 40: Sqrt 4.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Sqrt %float_4\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 2.0),
+    // Test case 41: Atan2 0.0 1.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Atan2 %float_0 %float_1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 0.0),
+    // Test case 42: Pow 2.0 3.0
+    InstructionFoldingCase<float>(
+        Header() + "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpExtInst %float %1 Pow %float_2 %float_3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        2, 8.0)
 ));
 // clang-format on
 
@@ -1967,7 +2071,25 @@
                 "%2 = OpExtInst %double %1 FClamp %double_1 %undef %double_0p5\n" +
                 "OpReturn\n" +
                 "OpFunctionEnd",
-            2, 0.5)
+            2, 0.5),
+        // Test case 21: Sqrt 4.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%undef = OpUndef %double\n" +
+                "%2 = OpExtInst %double %1 Sqrt %double_4\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, 2.0),
+        // Test case 22: Pow 2.0 3.0
+        InstructionFoldingCase<double>(
+            Header() + "%main = OpFunction %void None %void_func\n" +
+                "%main_lab = OpLabel\n" +
+                "%undef = OpUndef %double\n" +
+                "%2 = OpExtInst %double %1 Pow %double_2 %double_3\n" +
+                "OpReturn\n" +
+                "OpFunctionEnd",
+            2, 8.0)
 ));
 // clang-format on