Add folding rule to generate Fma instructions (#4783)

Adding Fma instruction can speed up the code.  This was requested by
swiftshader, so they do not have to do this analysis themselves.  It can
also help reduce the code size, and the work the ICD compilers have to
do.
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index c879a0c..d15ad04 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1430,6 +1430,64 @@
   };
 }
 
+// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
+void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
+  uint32_t ext =
+      inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+  if (ext == 0) {
+    inst->context()->AddExtInstImport("GLSL.std.450");
+    ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+    assert(ext != 0 &&
+           "Could not add the GLSL.std.450 extended instruction set");
+  }
+
+  std::vector<Operand> operands;
+  operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
+  operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
+  operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
+
+  inst->SetOpcode(SpvOpExtInst);
+  inst->SetInOperands(std::move(operands));
+}
+
+// Folds a multiple and add into an Fma.
+//
+// Cases:
+// (x * y) + a = Fma x y a
+// a + (x * y) = Fma x y a
+bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
+                           const std::vector<const analysis::Constant*>&) {
+  assert(inst->opcode() == SpvOpFAdd);
+
+  if (!inst->IsFloatingPointFoldingAllowed()) {
+    return false;
+  }
+
+  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+  for (int i = 0; i < 2; i++) {
+    uint32_t op_id = inst->GetSingleWordInOperand(i);
+    Instruction* op_inst = def_use_mgr->GetDef(op_id);
+
+    if (op_inst->opcode() != SpvOpFMul) {
+      continue;
+    }
+
+    if (!op_inst->IsFloatingPointFoldingAllowed()) {
+      continue;
+    }
+
+    uint32_t x = op_inst->GetSingleWordInOperand(0);
+    uint32_t y = op_inst->GetSingleWordInOperand(1);
+    uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
+    ReplaceWithFma(inst, x, y, a);
+    return true;
+  }
+  return false;
+}
+
 FoldingRule IntMultipleBy1() {
   return [](IRContext*, Instruction* inst,
             const std::vector<const analysis::Constant*>& constants) {
@@ -2543,6 +2601,7 @@
   rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
   rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
   rules_[SpvOpFAdd].push_back(FactorAddMuls());
+  rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
 
   rules_[SpvOpFDiv].push_back(RedundantFDiv());
   rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 7565ca7..2ca3256 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -7108,6 +7108,214 @@
         3, true)
  ));
 
+INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
+::testing::Values(
+   // Test case 0: (x * y) + a = Fma(x, y, a)
+   InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFAdd %float %mul %la\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+    // Test case 1:  a + (x * y) = Fma(x, y, a)
+   InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFAdd %float %la %mul\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+   // Test case 2: (x * y) + a = Fma(x, y, a) with vectors
+   InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_v4float Function\n" +
+           "%y = OpVariable %_ptr_v4float Function\n" +
+           "%a = OpVariable %_ptr_v4float Function\n" +
+           "%lx = OpLoad %v4float %x\n" +
+           "%ly = OpLoad %v4float %y\n" +
+           "%mul = OpFMul %v4float %lx %ly\n" +
+           "%la = OpLoad %v4float %a\n" +
+           "%3 = OpFAdd %v4float %mul %la\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+    // Test case 3:  a + (x * y) = Fma(x, y, a) with vectors
+   InstructionFoldingCase<bool>(
+       Header() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFAdd %float %la %mul\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+    // Test 5: that the OpExtInstImport instruction is generated if it is missing.
+   InstructionFoldingCase<bool>(
+           std::string() +
+           "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
+           "; CHECK: OpFunction\n" +
+           "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
+           "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
+           "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
+           "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
+           "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
+           "; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
+           "OpCapability Shader\n" +
+           "OpMemoryModel Logical GLSL450\n" +
+           "OpEntryPoint Fragment %main \"main\"\n" +
+           "OpExecutionMode %main OriginUpperLeft\n" +
+           "OpSource GLSL 140\n" +
+           "OpName %main \"main\"\n" +
+           "%void = OpTypeVoid\n" +
+           "%void_func = OpTypeFunction %void\n" +
+           "%bool = OpTypeBool\n" +
+           "%float = OpTypeFloat 32\n" +
+           "%_ptr_float = OpTypePointer Function %float\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFAdd %float %mul %la\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, true),
+   // Test 5: Don't fold if the multiple is marked no contract.
+   InstructionFoldingCase<bool>(
+       std::string() +
+           "OpCapability Shader\n" +
+           "OpMemoryModel Logical GLSL450\n" +
+           "OpEntryPoint Fragment %main \"main\"\n" +
+           "OpExecutionMode %main OriginUpperLeft\n" +
+           "OpSource GLSL 140\n" +
+           "OpName %main \"main\"\n" +
+           "OpDecorate %mul NoContraction\n" +
+           "%void = OpTypeVoid\n" +
+           "%void_func = OpTypeFunction %void\n" +
+           "%bool = OpTypeBool\n" +
+           "%float = OpTypeFloat 32\n" +
+           "%_ptr_float = OpTypePointer Function %float\n" +
+           "%main = OpFunction %void None %void_func\n" +
+           "%main_lab = OpLabel\n" +
+           "%x = OpVariable %_ptr_float Function\n" +
+           "%y = OpVariable %_ptr_float Function\n" +
+           "%a = OpVariable %_ptr_float Function\n" +
+           "%lx = OpLoad %float %x\n" +
+           "%ly = OpLoad %float %y\n" +
+           "%mul = OpFMul %float %lx %ly\n" +
+           "%la = OpLoad %float %a\n" +
+           "%3 = OpFAdd %float %mul %la\n" +
+           "OpStore %a %3\n" +
+           "OpReturn\n" +
+           "OpFunctionEnd",
+       3, false),
+       // Test 6: Don't fold if the add is marked no contract.
+       InstructionFoldingCase<bool>(
+           std::string() +
+               "OpCapability Shader\n" +
+               "OpMemoryModel Logical GLSL450\n" +
+               "OpEntryPoint Fragment %main \"main\"\n" +
+               "OpExecutionMode %main OriginUpperLeft\n" +
+               "OpSource GLSL 140\n" +
+               "OpName %main \"main\"\n" +
+               "OpDecorate %3 NoContraction\n" +
+               "%void = OpTypeVoid\n" +
+               "%void_func = OpTypeFunction %void\n" +
+               "%bool = OpTypeBool\n" +
+               "%float = OpTypeFloat 32\n" +
+               "%_ptr_float = OpTypePointer Function %float\n" +
+               "%main = OpFunction %void None %void_func\n" +
+               "%main_lab = OpLabel\n" +
+               "%x = OpVariable %_ptr_float Function\n" +
+               "%y = OpVariable %_ptr_float Function\n" +
+               "%a = OpVariable %_ptr_float Function\n" +
+               "%lx = OpLoad %float %x\n" +
+               "%ly = OpLoad %float %y\n" +
+               "%mul = OpFMul %float %lx %ly\n" +
+               "%la = OpLoad %float %a\n" +
+               "%3 = OpFAdd %float %mul %la\n" +
+               "OpStore %a %3\n" +
+               "OpReturn\n" +
+               "OpFunctionEnd",
+           3, false)
+));
+
 using MatchingInstructionWithNoResultFoldingTest =
 ::testing::TestWithParam<InstructionFoldingCase<bool>>;