Branch data Line data Source code
1 : : /****************************************************************************** 2 : : * Top contributors (to current version): 3 : : * Gereon Kremer 4 : : * 5 : : * This file is part of the cvc5 project. 6 : : * 7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS 8 : : * in the top-level source directory and their institutional affiliations. 9 : : * All rights reserved. See the file COPYING in the top-level source 10 : : * directory for licensing information. 11 : : * **************************************************************************** 12 : : * 13 : : * Implementation of new non-linear solver. 14 : : */ 15 : : 16 : : #include "theory/arith/nl/equality_substitution.h" 17 : : 18 : : #include "smt/env.h" 19 : : #include "theory/arith/arith_utilities.h" 20 : : 21 : : namespace cvc5::internal { 22 : : namespace theory { 23 : : namespace arith { 24 : : namespace nl { 25 : : 26 : : namespace { 27 : : struct ShouldTraverse : public SubstitutionMap::ShouldTraverseCallback 28 : : { 29 : 21920 : bool operator()(TNode n) const override 30 : : { 31 [ + + ][ + ]: 21920 : switch (theory::kindToTheoryId(n.getKind())) 32 : : { 33 : 9339 : case TheoryId::THEORY_BOOL: 34 : 9339 : case TheoryId::THEORY_BUILTIN: return true; 35 : 11662 : case TheoryId::THEORY_ARITH: return !isTranscendentalKind(n.getKind()); 36 : 919 : default: return false; 37 : : } 38 : : } 39 : : }; 40 : : } // namespace 41 : : 42 : 32273 : EqualitySubstitution::EqualitySubstitution(Env& env) 43 : 32273 : : EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>()) 44 : : { 45 : 32273 : } 46 : 110 : void EqualitySubstitution::reset() 47 : : { 48 : 110 : d_substitutions = std::make_unique<SubstitutionMap>(); 49 : 110 : d_conflict.clear(); 50 : 110 : d_conflictMap.clear(); 51 : 110 : d_trackOrigin.clear(); 52 : 110 : } 53 : : 54 : 110 : std::vector<Node> EqualitySubstitution::eliminateEqualities( 55 : : const std::vector<Node>& assertions) 56 : : { 57 [ - + ]: 110 : if (TraceIsOn("nl-eqs")) 58 : : { 59 [ - - ]: 0 : Trace("nl-eqs") << "Input:" << std::endl; 60 [ - - ]: 0 : for (const auto& a : assertions) 61 : : { 62 [ - - ]: 0 : Trace("nl-eqs") << "\t" << a << std::endl; 63 : : } 64 : : } 65 : 220 : std::set<TNode> tracker; 66 : 220 : std::vector<Node> asserts = assertions; 67 : 220 : std::vector<Node> next; 68 : 110 : const ShouldTraverse stc; 69 : : 70 : 110 : size_t last_size = 0; 71 [ + + ]: 272 : while (asserts.size() != last_size) 72 : : { 73 : 166 : last_size = asserts.size(); 74 : : // collect all eliminations from original into d_substitutions 75 [ + + ]: 4023 : for (const auto& orig : asserts) 76 : : { 77 [ + + ]: 3912 : if (orig.getKind() != Kind::EQUAL) continue; 78 : 533 : tracker.clear(); 79 : 533 : d_substitutions->invalidateCache(); 80 : : Node o = 81 : 533 : d_substitutions->apply(orig, d_env.getRewriter(), &tracker, &stc); 82 [ + + ]: 533 : if (o.getKind() != Kind::EQUAL) continue; 83 [ - + ][ - + ]: 478 : Assert(o.getNumChildren() == 2); [ - - ] 84 [ + + ]: 988 : for (size_t i = 0; i < 2; ++i) 85 : : { 86 [ + + ][ + + ]: 733 : const auto& l = (o[i].getKind() == Kind::TO_REAL ? o[i][0] : o[i]); [ - - ] 87 [ + + ][ + + ]: 733 : const auto& r = (o[1-i].getKind() == Kind::TO_REAL ? o[1-i][0] : o[1-i]); [ - - ] 88 : : // lhs can't be constant 89 [ + + ]: 733 : if (l.isConst()) continue; 90 : : // types must match (otherwise we might have int/real issues) 91 [ + + ]: 626 : if (r.getType() != l.getType()) continue; 92 : : // can't substitute stuff from other theories 93 [ + + ]: 622 : if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue; 94 : : // can't substitute the same thing twice 95 [ - + ]: 305 : if (d_substitutions->hasSubstitution(l)) continue; 96 : : // lhs can't be a subexpression of rhs, would leaf to recursion 97 [ + + ]: 305 : if (expr::hasSubterm(r, l)) continue; 98 : : // the same, but after substitution 99 : 223 : d_substitutions->invalidateCache(); 100 [ - + ]: 223 : if (expr::hasSubterm(d_substitutions->apply(r, nullptr, nullptr, &stc), l)) continue; 101 [ + - ]: 446 : Trace("nl-eqs") << "Found substitution " << l << " -> " << r 102 : 0 : << std::endl 103 : 223 : << " from " << o << " / " << orig << std::endl; 104 : 223 : d_substitutions->addSubstitution(l, r); 105 : 223 : d_trackOrigin.emplace(l, o); 106 [ + + ]: 223 : if (o != orig) 107 : : { 108 : 38 : addToConflictMap(o, orig, tracker); 109 : : } 110 : 223 : break; 111 : : } 112 : : } 113 : : 114 : : // simplify with subs from original into next 115 : 166 : next.clear(); 116 [ + + ]: 3987 : for (const auto& a : asserts) 117 : : { 118 : 3825 : tracker.clear(); 119 : 3825 : d_substitutions->invalidateCache(); 120 : : Node simp = 121 : 3825 : d_substitutions->apply(a, d_env.getRewriter(), &tracker, &stc); 122 [ + + ]: 3825 : if (simp.isConst()) 123 : : { 124 [ + + ]: 350 : if (simp.getConst<bool>()) 125 : : { 126 : 346 : continue; 127 : : } 128 [ + - ]: 4 : Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; 129 [ + + ]: 12 : for (TNode t : tracker) 130 : : { 131 [ + - ]: 8 : Trace("nl-eqs") << "Tracker has " << t << std::endl; 132 : 8 : auto toit = d_trackOrigin.find(t); 133 [ - + ][ - + ]: 8 : Assert(toit != d_trackOrigin.end()); [ - - ] 134 : 8 : d_conflict.emplace_back(toit->second); 135 : : } 136 : 4 : d_conflict.emplace_back(a); 137 : 4 : postprocessConflict(d_conflict); 138 [ + - ]: 4 : Trace("nl-eqs") << "Direct conflict: " << d_conflict << std::endl; 139 [ + - ]: 8 : Trace("nl-eqs") << std::endl 140 : 0 : << d_conflict.size() << " vs " 141 : 0 : << std::distance(d_substitutions->begin(), 142 : 4 : d_substitutions->end()) 143 : 0 : << std::endl 144 : 4 : << std::endl; 145 : 4 : return {}; 146 : : } 147 [ + + ]: 3475 : if (simp != a) 148 : : { 149 [ + - ]: 787 : Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; 150 : 787 : addToConflictMap(simp, a, tracker); 151 : : } 152 : 3475 : next.emplace_back(simp); 153 : : } 154 : 162 : asserts = std::move(next); 155 : : } 156 : 106 : d_conflict.clear(); 157 [ - + ]: 106 : if (TraceIsOn("nl-eqs")) 158 : : { 159 [ - - ]: 0 : Trace("nl-eqs") << "Output:" << std::endl; 160 [ - - ]: 0 : for (const auto& a : asserts) 161 : : { 162 [ - - ]: 0 : Trace("nl-eqs") << "\t" << a << std::endl; 163 : : } 164 [ - - ]: 0 : Trace("nl-eqs") << "Substitutions:" << std::endl; 165 [ - - ]: 0 : for (const auto& subs : d_substitutions->getSubstitutions()) 166 : : { 167 [ - - ]: 0 : Trace("nl-eqs") << "\t" << subs.first << " -> " << subs.second 168 : 0 : << std::endl; 169 : : } 170 : : } 171 : 106 : return asserts; 172 : : } 173 : 98 : void EqualitySubstitution::postprocessConflict( 174 : : std::vector<Node>& conflict) const 175 : : { 176 [ + - ]: 98 : Trace("nl-eqs") << "Postprocessing " << conflict << std::endl; 177 : 98 : std::set<Node> result; 178 [ + + ]: 424 : for (const auto& c : conflict) 179 : : { 180 : 326 : auto it = d_conflictMap.find(c); 181 [ + + ]: 326 : if (it == d_conflictMap.end()) 182 : : { 183 : 280 : result.insert(c); 184 : : } 185 : : else 186 : : { 187 [ + - ]: 46 : Trace("nl-eqs") << "Origin of " << c << ": " << it->second << std::endl; 188 : 46 : result.insert(it->second.begin(), it->second.end()); 189 : : } 190 : : } 191 : 98 : conflict.clear(); 192 : 98 : conflict.insert(conflict.end(), result.begin(), result.end()); 193 [ + - ]: 98 : Trace("nl-eqs") << "-> " << conflict << std::endl; 194 : 98 : } 195 : 2246 : void EqualitySubstitution::insertOrigins(std::set<Node>& dest, 196 : : const Node& n) const 197 : : { 198 : 2246 : auto it = d_conflictMap.find(n); 199 [ + + ]: 2246 : if (it == d_conflictMap.end()) 200 : : { 201 : 1926 : dest.insert(n); 202 : : } 203 : : else 204 : : { 205 : 320 : dest.insert(it->second.begin(), it->second.end()); 206 : : } 207 : 2246 : } 208 : 825 : void EqualitySubstitution::addToConflictMap(const Node& n, 209 : : const Node& orig, 210 : : const std::set<TNode>& tracker) 211 : : { 212 : 825 : std::set<Node> origins; 213 : 825 : insertOrigins(origins, orig); 214 [ + + ]: 2246 : for (const auto& t : tracker) 215 : : { 216 : 1421 : auto tit = d_trackOrigin.find(t); 217 [ - + ][ - + ]: 1421 : Assert(tit != d_trackOrigin.end()); [ - - ] 218 : 1421 : insertOrigins(origins, tit->second); 219 : : } 220 : 825 : d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end())); 221 : 825 : } 222 : : 223 : : } // namespace nl 224 : : } // namespace arith 225 : : } // namespace theory 226 : : } // namespace cvc5::internal