Fold multiply and subtraction into FMA with negation (#4808)

This change adds a folding rule which transforms x * y - a and a - x * y
into FMA(x, y, -a) and FMA(-x, y, a), respectively.

While the SPIR-V instruction count remains the same, target instruction
sets typically feature FMA instruction variants that can negate an
operand. Also this transformation may unlock further optimizations which
eliminate the negation.

(Google bug: b/226145988)
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index ab7a20e..0d8f7c8 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1488,6 +1488,74 @@
   return false;
 }
 
+// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
+// negated if |negate_addition| is true, otherwise |x| gets negated.
+void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
+                             uint32_t a, bool negate_addition) {
+  uint32_t ext =
+      sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+  if (ext == 0) {
+    sub->context()->AddExtInstImport("GLSL.std.450");
+    ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+    assert(ext != 0 &&
+           "Could not add the GLSL.std.450 extended instruction set");
+  }
+
+  InstructionBuilder ir_builder(
+      sub->context(), sub,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+  Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
+                                           negate_addition ? a : x);
+  uint32_t neg_op = neg->result_id();  // -a : -x
+
+  std::vector<Operand> operands;
+  operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
+  operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
+
+  sub->SetOpcode(SpvOpExtInst);
+  sub->SetInOperands(std::move(operands));
+}
+
+// Folds a multiply and subtract into an Fma and negation.
+//
+// Cases:
+// (x * y) - a = Fma x y -a
+// a - (x * y) = Fma -x y a
+bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
+                           const std::vector<const analysis::Constant*>&) {
+  assert(sub->opcode() == SpvOpFSub);
+
+  if (!sub->IsFloatingPointFoldingAllowed()) {
+    return false;
+  }
+
+  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  for (int i = 0; i < 2; i++) {
+    uint32_t op_id = sub->GetSingleWordInOperand(i);
+    Instruction* mul = def_use_mgr->GetDef(op_id);
+
+    if (mul->opcode() != SpvOpFMul) {
+      continue;
+    }
+
+    if (!mul->IsFloatingPointFoldingAllowed()) {
+      continue;
+    }
+
+    uint32_t x = mul->GetSingleWordInOperand(0);
+    uint32_t y = mul->GetSingleWordInOperand(1);
+    uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
+    ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
+    return true;
+  }
+  return false;
+}
+
 FoldingRule IntMultipleBy1() {
   return [](IRContext*, Instruction* inst,
             const std::vector<const analysis::Constant*>& constants) {
@@ -2831,6 +2899,7 @@
   rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
   rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
   rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
+  rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
 
   rules_[SpvOpIAdd].push_back(RedundantIAdd());
   rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index e2240b8..8ee14a9 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -7359,7 +7359,7 @@
            "OpReturn\n" +
            "OpFunctionEnd",
        3, true),
-    // Test 5: that the OpExtInstImport instruction is generated if it is missing.
+    // Test 4: that the OpExtInstImport instruction is generated if it is missing.
    InstructionFoldingCase<bool>(
            std::string() +
            "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
@@ -7454,7 +7454,63 @@
                "OpStore %a %3\n" +
                "OpReturn\n" +
                "OpFunctionEnd",
-           3, false)
+           3, false),
+    // Test case 7: (x * y) - a = Fma(x, y, -a)
+    InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFSub %float %mul %la\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+   // Test case 8: a - (x * y) = Fma(-x, y, a)
+   InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFSub %float %la %mul\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true)
 ));
 
 using MatchingInstructionWithNoResultFoldingTest =