spirv-fuzz: Fix memory management in the fact manager (#3313)
Fixes a bug where, while recursively adding id equation facts, a
reference to a set of id equations could be used after it had been
freed (due to equivalence classes of equations being merged).
diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp
index 2cfe326..bc6f3b1 100644
--- a/source/fuzz/fact_manager.cpp
+++ b/source/fuzz/fact_manager.cpp
@@ -430,6 +430,9 @@
uint32_t maximum_equivalence_class_size);
private:
+ using OperationSet =
+ std::unordered_set<Operation, OperationHash, OperationEquals>;
+
// Adds the synonym |dd1| = |dd2| to the set of managed facts, and recurses
// into sub-components of the data descriptors, if they are composites, to
// record that their components are pairwise-synonymous.
@@ -448,6 +451,8 @@
opt::IRContext* context, const protobufs::DataDescriptor& dd1,
const protobufs::DataDescriptor& dd2) const;
+ OperationSet GetEquations(const protobufs::DataDescriptor* lhs) const;
+
// Requires that |lhs_dd| and every element of |rhs_dds| is present in the
// |synonymous_| equivalence relation, but is not necessarily its own
// representative. Records the fact that the equation
@@ -480,9 +485,7 @@
// All data descriptors occurring in equations are required to be present in
// the |synonymous_| equivalence relation, and to be their own representatives
// in that relation.
- std::unordered_map<
- const protobufs::DataDescriptor*,
- std::unordered_set<Operation, OperationHash, OperationEquals>>
+ std::unordered_map<const protobufs::DataDescriptor*, OperationSet>
id_equations_;
};
@@ -520,6 +523,16 @@
rhs_dd_ptrs, context);
}
+FactManager::DataSynonymAndIdEquationFacts::OperationSet
+FactManager::DataSynonymAndIdEquationFacts::GetEquations(
+ const protobufs::DataDescriptor* lhs) const {
+ auto existing = id_equations_.find(lhs);
+ if (existing == id_equations_.end()) {
+ return OperationSet();
+ }
+ return existing->second;
+}
+
void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
const protobufs::DataDescriptor& lhs_dd, SpvOp opcode,
const std::vector<const protobufs::DataDescriptor*>& rhs_dds,
@@ -538,9 +551,7 @@
if (id_equations_.count(lhs_dd_representative) == 0) {
// We have not seen an equation with this LHS before, so associate the LHS
// with an initially empty set.
- id_equations_.insert(
- {lhs_dd_representative,
- std::unordered_set<Operation, OperationHash, OperationEquals>()});
+ id_equations_.insert({lhs_dd_representative, OperationSet()});
}
{
@@ -562,44 +573,29 @@
switch (opcode) {
case SpvOpIAdd: {
// Equation form: "a = b + c"
- {
- auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
- if (existing_first_operand_equations != id_equations_.end()) {
- for (auto equation : existing_first_operand_equations->second) {
- if (equation.opcode == SpvOpISub) {
- // Equation form: "a = (d - e) + c"
- if (synonymous_.IsEquivalent(*equation.operands[1],
- *rhs_dds[1])) {
- // Equation form: "a = (d - c) + c"
- // We can thus infer "a = d"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
- context);
- }
- if (synonymous_.IsEquivalent(*equation.operands[0],
- *rhs_dds[1])) {
- // Equation form: "a = (c - e) + c"
- // We can thus infer "a = -e"
- AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
- {equation.operands[1]}, context);
- }
- }
+ for (auto equation : GetEquations(rhs_dds[0])) {
+ if (equation.opcode == SpvOpISub) {
+ // Equation form: "a = (d - e) + c"
+ if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
+ // Equation form: "a = (d - c) + c"
+ // We can thus infer "a = d"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
+ }
+ if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+ // Equation form: "a = (c - e) + c"
+ // We can thus infer "a = -e"
+ AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+ {equation.operands[1]}, context);
}
}
}
- {
- auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
- if (existing_second_operand_equations != id_equations_.end()) {
- for (auto equation : existing_second_operand_equations->second) {
- if (equation.opcode == SpvOpISub) {
- // Equation form: "a = b + (d - e)"
- if (synonymous_.IsEquivalent(*equation.operands[1],
- *rhs_dds[0])) {
- // Equation form: "a = b + (d - b)"
- // We can thus infer "a = d"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
- context);
- }
- }
+ for (auto equation : GetEquations(rhs_dds[1])) {
+ if (equation.opcode == SpvOpISub) {
+ // Equation form: "a = b + (d - e)"
+ if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
+ // Equation form: "a = b + (d - b)"
+ // We can thus infer "a = d"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
}
}
@@ -607,73 +603,54 @@
}
case SpvOpISub: {
// Equation form: "a = b - c"
- {
- auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
- if (existing_first_operand_equations != id_equations_.end()) {
- for (auto equation : existing_first_operand_equations->second) {
- if (equation.opcode == SpvOpIAdd) {
- // Equation form: "a = (d + e) - c"
- if (synonymous_.IsEquivalent(*equation.operands[0],
- *rhs_dds[1])) {
- // Equation form: "a = (c + e) - c"
- // We can thus infer "a = e"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
- context);
- }
- if (synonymous_.IsEquivalent(*equation.operands[1],
- *rhs_dds[1])) {
- // Equation form: "a = (d + c) - c"
- // We can thus infer "a = d"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
- context);
- }
- }
+ for (auto equation : GetEquations(rhs_dds[0])) {
+ if (equation.opcode == SpvOpIAdd) {
+ // Equation form: "a = (d + e) - c"
+ if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+ // Equation form: "a = (c + e) - c"
+ // We can thus infer "a = e"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
+ }
+ if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
+ // Equation form: "a = (d + c) - c"
+ // We can thus infer "a = d"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
+ }
+ }
- if (equation.opcode == SpvOpISub) {
- // Equation form: "a = (d - e) - c"
- if (synonymous_.IsEquivalent(*equation.operands[0],
- *rhs_dds[1])) {
- // Equation form: "a = (c - e) - c"
- // We can thus infer "a = -e"
- AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
- {equation.operands[1]}, context);
- }
- }
+ if (equation.opcode == SpvOpISub) {
+ // Equation form: "a = (d - e) - c"
+ if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
+ // Equation form: "a = (c - e) - c"
+ // We can thus infer "a = -e"
+ AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+ {equation.operands[1]}, context);
}
}
}
- {
- auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
- if (existing_second_operand_equations != id_equations_.end()) {
- for (auto equation : existing_second_operand_equations->second) {
- if (equation.opcode == SpvOpIAdd) {
- // Equation form: "a = b - (d + e)"
- if (synonymous_.IsEquivalent(*equation.operands[0],
- *rhs_dds[0])) {
- // Equation form: "a = b - (b + e)"
- // We can thus infer "a = -e"
- AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
- {equation.operands[1]}, context);
- }
- if (synonymous_.IsEquivalent(*equation.operands[1],
- *rhs_dds[0])) {
- // Equation form: "a = b - (d + b)"
- // We can thus infer "a = -d"
- AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
- {equation.operands[0]}, context);
- }
- }
- if (equation.opcode == SpvOpISub) {
- // Equation form: "a = b - (d - e)"
- if (synonymous_.IsEquivalent(*equation.operands[0],
- *rhs_dds[0])) {
- // Equation form: "a = b - (b - e)"
- // We can thus infer "a = e"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
- context);
- }
- }
+ for (auto equation : GetEquations(rhs_dds[1])) {
+ if (equation.opcode == SpvOpIAdd) {
+ // Equation form: "a = b - (d + e)"
+ if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
+ // Equation form: "a = b - (b + e)"
+ // We can thus infer "a = -e"
+ AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+ {equation.operands[1]}, context);
+ }
+ if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
+ // Equation form: "a = b - (d + b)"
+ // We can thus infer "a = -d"
+ AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
+ {equation.operands[0]}, context);
+ }
+ }
+ if (equation.opcode == SpvOpISub) {
+ // Equation form: "a = b - (d - e)"
+ if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
+ // Equation form: "a = b - (b - e)"
+ // We can thus infer "a = e"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
}
}
}
@@ -682,14 +659,11 @@
case SpvOpLogicalNot:
case SpvOpSNegate: {
// Equation form: "a = !b" or "a = -b"
- auto existing_equations = id_equations_.find(rhs_dds[0]);
- if (existing_equations != id_equations_.end()) {
- for (auto equation : existing_equations->second) {
- if (equation.opcode == opcode) {
- // Equation form: "a = !!b" or "a = -(-b)"
- // We can thus infer "a = b"
- AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
- }
+ for (auto equation : GetEquations(rhs_dds[0])) {
+ if (equation.opcode == opcode) {
+ // Equation form: "a = !!b" or "a = -(-b)"
+ // We can thus infer "a = b"
+ AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
}
break;
@@ -1116,9 +1090,7 @@
// equations about |still_representative|; create an empty set of equations
// if this is the case.
if (!id_equations_.count(still_representative)) {
- id_equations_.insert(
- {still_representative,
- std::unordered_set<Operation, OperationHash, OperationEquals>()});
+ id_equations_.insert({still_representative, OperationSet()});
}
auto still_representative_id_equations =
id_equations_.find(still_representative);