spirv-fuzz: Fix memory management in the fact manager (#3313)

Fixes a bug where, while recursively adding id equation facts, a
reference to a set of id equations could be used after it had been
freed (due to equivalence classes of equations being merged).
diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp
index 2cfe326..bc6f3b1 100644
--- a/source/fuzz/fact_manager.cpp
+++ b/source/fuzz/fact_manager.cpp
@@ -430,6 +430,9 @@
                              uint32_t maximum_equivalence_class_size);
 
  private:
+  using OperationSet =
+      std::unordered_set<Operation, OperationHash, OperationEquals>;
+
   // Adds the synonym |dd1| = |dd2| to the set of managed facts, and recurses
   // into sub-components of the data descriptors, if they are composites, to
   // record that their components are pairwise-synonymous.
@@ -448,6 +451,8 @@
       opt::IRContext* context, const protobufs::DataDescriptor& dd1,
       const protobufs::DataDescriptor& dd2) const;
 
+  OperationSet GetEquations(const protobufs::DataDescriptor* lhs) const;
+
   // Requires that |lhs_dd| and every element of |rhs_dds| is present in the
   // |synonymous_| equivalence relation, but is not necessarily its own
   // representative.  Records the fact that the equation
@@ -480,9 +485,7 @@
   // All data descriptors occurring in equations are required to be present in
   // the |synonymous_| equivalence relation, and to be their own representatives
   // in that relation.
-  std::unordered_map<
-      const protobufs::DataDescriptor*,
-      std::unordered_set<Operation, OperationHash, OperationEquals>>
+  std::unordered_map<const protobufs::DataDescriptor*, OperationSet>
       id_equations_;
 };
 
@@ -520,6 +523,16 @@
                            rhs_dd_ptrs, context);
 }
 
+FactManager::DataSynonymAndIdEquationFacts::OperationSet
+FactManager::DataSynonymAndIdEquationFacts::GetEquations(
+    const protobufs::DataDescriptor* lhs) const {
+  auto existing = id_equations_.find(lhs);
+  if (existing == id_equations_.end()) {
+    return OperationSet();
+  }
+  return existing->second;
+}
+
 void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
     const protobufs::DataDescriptor& lhs_dd, SpvOp opcode,
     const std::vector<const protobufs::DataDescriptor*>& rhs_dds,
@@ -538,9 +551,7 @@
   if (id_equations_.count(lhs_dd_representative) == 0) {
     // We have not seen an equation with this LHS before, so associate the LHS
     // with an initially empty set.
-    id_equations_.insert(
-        {lhs_dd_representative,
-         std::unordered_set<Operation, OperationHash, OperationEquals>()});
+    id_equations_.insert({lhs_dd_representative, OperationSet()});
   }
 
   {
@@ -562,44 +573,29 @@
   switch (opcode) {
     case SpvOpIAdd: {
       // Equation form: "a = b + c"
-      {
-        auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
-        if (existing_first_operand_equations != id_equations_.end()) {
-          for (auto equation : existing_first_operand_equations->second) {
-            if (equation.opcode == SpvOpISub) {
-              // Equation form: "a = (d - e) + c"
-              if (synonymous_.IsEquivalent(*equation.operands[1],
-                                           *rhs_dds[1])) {
-                // Equation form: "a = (d - c) + c"
-                // We can thus infer "a = d"
-                AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
-                                            context);
-              }
-              if (synonymous_.IsEquivalent(*equation.operands[0],
-                                           *rhs_dds[1])) {
-                // Equation form: "a = (c - e) + c"
-                // We can thus infer "a = -e"
-                AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
-                                         {equation.operands[1]}, context);
-              }
-            }
+      for (auto equation : GetEquations(rhs_dds[0])) {
+        if (equation.opcode == SpvOpISub) {
+          // Equation form: "a = (d - e) + c"
+          if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
+            // Equation form: "a = (d - c) + c"
+            // We can thus infer "a = d"
+            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
+          }
+          if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+            // Equation form: "a = (c - e) + c"
+            // We can thus infer "a = -e"
+            AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+                                     {equation.operands[1]}, context);
           }
         }
       }
-      {
-        auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
-        if (existing_second_operand_equations != id_equations_.end()) {
-          for (auto equation : existing_second_operand_equations->second) {
-            if (equation.opcode == SpvOpISub) {
-              // Equation form: "a = b + (d - e)"
-              if (synonymous_.IsEquivalent(*equation.operands[1],
-                                           *rhs_dds[0])) {
-                // Equation form: "a = b + (d - b)"
-                // We can thus infer "a = d"
-                AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
-                                            context);
-              }
-            }
+      for (auto equation : GetEquations(rhs_dds[1])) {
+        if (equation.opcode == SpvOpISub) {
+          // Equation form: "a = b + (d - e)"
+          if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
+            // Equation form: "a = b + (d - b)"
+            // We can thus infer "a = d"
+            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
           }
         }
       }
@@ -607,73 +603,54 @@
     }
     case SpvOpISub: {
       // Equation form: "a = b - c"
-      {
-        auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
-        if (existing_first_operand_equations != id_equations_.end()) {
-          for (auto equation : existing_first_operand_equations->second) {
-            if (equation.opcode == SpvOpIAdd) {
-              // Equation form: "a = (d + e) - c"
-              if (synonymous_.IsEquivalent(*equation.operands[0],
-                                           *rhs_dds[1])) {
-                // Equation form: "a = (c + e) - c"
-                // We can thus infer "a = e"
-                AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
-                                            context);
-              }
-              if (synonymous_.IsEquivalent(*equation.operands[1],
-                                           *rhs_dds[1])) {
-                // Equation form: "a = (d + c) - c"
-                // We can thus infer "a = d"
-                AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
-                                            context);
-              }
-            }
+      for (auto equation : GetEquations(rhs_dds[0])) {
+        if (equation.opcode == SpvOpIAdd) {
+          // Equation form: "a = (d + e) - c"
+          if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+            // Equation form: "a = (c + e) - c"
+            // We can thus infer "a = e"
+            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
+          }
+          if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
+            // Equation form: "a = (d + c) - c"
+            // We can thus infer "a = d"
+            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
+          }
+        }
 
-            if (equation.opcode == SpvOpISub) {
-              // Equation form: "a = (d - e) - c"
-              if (synonymous_.IsEquivalent(*equation.operands[0],
-                                           *rhs_dds[1])) {
-                // Equation form: "a = (c - e) - c"
-                // We can thus infer "a = -e"
-                AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
-                                         {equation.operands[1]}, context);
-              }
-            }
+        if (equation.opcode == SpvOpISub) {
+          // Equation form: "a = (d - e) - c"
+          if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+            // Equation form: "a = (c - e) - c"
+            // We can thus infer "a = -e"
+            AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+                                     {equation.operands[1]}, context);
           }
         }
       }
 
-      {
-        auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
-        if (existing_second_operand_equations != id_equations_.end()) {
-          for (auto equation : existing_second_operand_equations->second) {
-            if (equation.opcode == SpvOpIAdd) {
-              // Equation form: "a = b - (d + e)"
-              if (synonymous_.IsEquivalent(*equation.operands[0],
-                                           *rhs_dds[0])) {
-                // Equation form: "a = b - (b + e)"
-                // We can thus infer "a = -e"
-                AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
-                                         {equation.operands[1]}, context);
-              }
-              if (synonymous_.IsEquivalent(*equation.operands[1],
-                                           *rhs_dds[0])) {
-                // Equation form: "a = b - (d + b)"
-                // We can thus infer "a = -d"
-                AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
-                                         {equation.operands[0]}, context);
-              }
-            }
-            if (equation.opcode == SpvOpISub) {
-              // Equation form: "a = b - (d - e)"
-              if (synonymous_.IsEquivalent(*equation.operands[0],
-                                           *rhs_dds[0])) {
-                // Equation form: "a = b - (b - e)"
-                // We can thus infer "a = e"
-                AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
-                                            context);
-              }
-            }
+      for (auto equation : GetEquations(rhs_dds[1])) {
+        if (equation.opcode == SpvOpIAdd) {
+          // Equation form: "a = b - (d + e)"
+          if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
+            // Equation form: "a = b - (b + e)"
+            // We can thus infer "a = -e"
+            AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+                                     {equation.operands[1]}, context);
+          }
+          if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
+            // Equation form: "a = b - (d + b)"
+            // We can thus infer "a = -d"
+            AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+                                     {equation.operands[0]}, context);
+          }
+        }
+        if (equation.opcode == SpvOpISub) {
+          // Equation form: "a = b - (d - e)"
+          if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
+            // Equation form: "a = b - (b - e)"
+            // We can thus infer "a = e"
+            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
           }
         }
       }
@@ -682,14 +659,11 @@
     case SpvOpLogicalNot:
     case SpvOpSNegate: {
       // Equation form: "a = !b" or "a = -b"
-      auto existing_equations = id_equations_.find(rhs_dds[0]);
-      if (existing_equations != id_equations_.end()) {
-        for (auto equation : existing_equations->second) {
-          if (equation.opcode == opcode) {
-            // Equation form: "a = !!b" or "a = -(-b)"
-            // We can thus infer "a = b"
-            AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
-          }
+      for (auto equation : GetEquations(rhs_dds[0])) {
+        if (equation.opcode == opcode) {
+          // Equation form: "a = !!b" or "a = -(-b)"
+          // We can thus infer "a = b"
+          AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
         }
       }
       break;
@@ -1116,9 +1090,7 @@
     // equations about |still_representative|; create an empty set of equations
     // if this is the case.
     if (!id_equations_.count(still_representative)) {
-      id_equations_.insert(
-          {still_representative,
-           std::unordered_set<Operation, OperationHash, OperationEquals>()});
+      id_equations_.insert({still_representative, OperationSet()});
     }
     auto still_representative_id_equations =
         id_equations_.find(still_representative);