LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/arith/nl - equality_substitution.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 105 119 88.2 %
Date: 2024-11-18 12:41:18 Functions: 7 7 100.0 %
Branches: 68 115 59.1 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Gereon Kremer
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Implementation of new non-linear solver.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/arith/nl/equality_substitution.h"
      17                 :            : 
      18                 :            : #include "smt/env.h"
      19                 :            : #include "theory/arith/arith_utilities.h"
      20                 :            : 
      21                 :            : namespace cvc5::internal {
      22                 :            : namespace theory {
      23                 :            : namespace arith {
      24                 :            : namespace nl {
      25                 :            : 
      26                 :            : namespace {
      27                 :            : struct ShouldTraverse : public SubstitutionMap::ShouldTraverseCallback
      28                 :            : {
      29                 :      21920 :   bool operator()(TNode n) const override
      30                 :            :   {
      31    [ +  + ][ + ]:      21920 :     switch (theory::kindToTheoryId(n.getKind()))
      32                 :            :     {
      33                 :       9339 :       case TheoryId::THEORY_BOOL:
      34                 :       9339 :       case TheoryId::THEORY_BUILTIN: return true;
      35                 :      11662 :       case TheoryId::THEORY_ARITH: return !isTranscendentalKind(n.getKind());
      36                 :        919 :       default: return false;
      37                 :            :     }
      38                 :            :   }
      39                 :            : };
      40                 :            : }  // namespace
      41                 :            : 
      42                 :      32273 : EqualitySubstitution::EqualitySubstitution(Env& env)
      43                 :      32273 :     : EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>())
      44                 :            : {
      45                 :      32273 : }
      46                 :        110 : void EqualitySubstitution::reset()
      47                 :            : {
      48                 :        110 :   d_substitutions = std::make_unique<SubstitutionMap>();
      49                 :        110 :   d_conflict.clear();
      50                 :        110 :   d_conflictMap.clear();
      51                 :        110 :   d_trackOrigin.clear();
      52                 :        110 : }
      53                 :            : 
      54                 :        110 : std::vector<Node> EqualitySubstitution::eliminateEqualities(
      55                 :            :     const std::vector<Node>& assertions)
      56                 :            : {
      57         [ -  + ]:        110 :   if (TraceIsOn("nl-eqs"))
      58                 :            :   {
      59         [ -  - ]:          0 :     Trace("nl-eqs") << "Input:" << std::endl;
      60         [ -  - ]:          0 :     for (const auto& a : assertions)
      61                 :            :     {
      62         [ -  - ]:          0 :       Trace("nl-eqs") << "\t" << a << std::endl;
      63                 :            :     }
      64                 :            :   }
      65                 :        220 :   std::set<TNode> tracker;
      66                 :        220 :   std::vector<Node> asserts = assertions;
      67                 :        220 :   std::vector<Node> next;
      68                 :        110 :   const ShouldTraverse stc;
      69                 :            : 
      70                 :        110 :   size_t last_size = 0;
      71         [ +  + ]:        272 :   while (asserts.size() != last_size)
      72                 :            :   {
      73                 :        166 :     last_size = asserts.size();
      74                 :            :     // collect all eliminations from original into d_substitutions
      75         [ +  + ]:       4023 :     for (const auto& orig : asserts)
      76                 :            :     {
      77         [ +  + ]:       3912 :       if (orig.getKind() != Kind::EQUAL) continue;
      78                 :        533 :       tracker.clear();
      79                 :        533 :       d_substitutions->invalidateCache();
      80                 :            :       Node o =
      81                 :        533 :           d_substitutions->apply(orig, d_env.getRewriter(), &tracker, &stc);
      82         [ +  + ]:        533 :       if (o.getKind() != Kind::EQUAL) continue;
      83 [ -  + ][ -  + ]:        478 :       Assert(o.getNumChildren() == 2);
                 [ -  - ]
      84         [ +  + ]:        988 :       for (size_t i = 0; i < 2; ++i)
      85                 :            :       {
      86 [ +  + ][ +  + ]:        733 :         const auto& l = (o[i].getKind() == Kind::TO_REAL ? o[i][0] : o[i]);
                 [ -  - ]
      87 [ +  + ][ +  + ]:        733 :         const auto& r = (o[1-i].getKind() == Kind::TO_REAL ? o[1-i][0] : o[1-i]);
                 [ -  - ]
      88                 :            :         // lhs can't be constant
      89         [ +  + ]:        733 :         if (l.isConst()) continue;
      90                 :            :         // types must match (otherwise we might have int/real issues)
      91         [ +  + ]:        626 :         if (r.getType() != l.getType()) continue;
      92                 :            :         // can't substitute stuff from other theories
      93         [ +  + ]:        622 :         if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue;
      94                 :            :         // can't substitute the same thing twice
      95         [ -  + ]:        305 :         if (d_substitutions->hasSubstitution(l)) continue;
      96                 :            :         // lhs can't be a subexpression of rhs, would leaf to recursion
      97         [ +  + ]:        305 :         if (expr::hasSubterm(r, l)) continue;
      98                 :            :         // the same, but after substitution
      99                 :        223 :         d_substitutions->invalidateCache();
     100         [ -  + ]:        223 :         if (expr::hasSubterm(d_substitutions->apply(r, nullptr, nullptr, &stc), l)) continue;
     101         [ +  - ]:        446 :         Trace("nl-eqs") << "Found substitution " << l << " -> " << r
     102                 :          0 :                         << std::endl
     103                 :        223 :                         << " from " << o << " / " << orig << std::endl;
     104                 :        223 :         d_substitutions->addSubstitution(l, r);
     105                 :        223 :         d_trackOrigin.emplace(l, o);
     106         [ +  + ]:        223 :         if (o != orig)
     107                 :            :         {
     108                 :         38 :           addToConflictMap(o, orig, tracker);
     109                 :            :         }
     110                 :        223 :         break;
     111                 :            :       }
     112                 :            :     }
     113                 :            : 
     114                 :            :     // simplify with subs from original into next
     115                 :        166 :     next.clear();
     116         [ +  + ]:       3987 :     for (const auto& a : asserts)
     117                 :            :     {
     118                 :       3825 :       tracker.clear();
     119                 :       3825 :       d_substitutions->invalidateCache();
     120                 :            :       Node simp =
     121                 :       3825 :           d_substitutions->apply(a, d_env.getRewriter(), &tracker, &stc);
     122         [ +  + ]:       3825 :       if (simp.isConst())
     123                 :            :       {
     124         [ +  + ]:        350 :         if (simp.getConst<bool>())
     125                 :            :         {
     126                 :        346 :           continue;
     127                 :            :         }
     128         [ +  - ]:          4 :         Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl;
     129         [ +  + ]:         12 :         for (TNode t : tracker)
     130                 :            :         {
     131         [ +  - ]:          8 :           Trace("nl-eqs") << "Tracker has " << t << std::endl;
     132                 :          8 :           auto toit = d_trackOrigin.find(t);
     133 [ -  + ][ -  + ]:          8 :           Assert(toit != d_trackOrigin.end());
                 [ -  - ]
     134                 :          8 :           d_conflict.emplace_back(toit->second);
     135                 :            :         }
     136                 :          4 :         d_conflict.emplace_back(a);
     137                 :          4 :         postprocessConflict(d_conflict);
     138         [ +  - ]:          4 :         Trace("nl-eqs") << "Direct conflict: " << d_conflict << std::endl;
     139         [ +  - ]:          8 :         Trace("nl-eqs") << std::endl
     140                 :          0 :                         << d_conflict.size() << " vs "
     141                 :          0 :                         << std::distance(d_substitutions->begin(),
     142                 :          4 :                                          d_substitutions->end())
     143                 :          0 :                         << std::endl
     144                 :          4 :                         << std::endl;
     145                 :          4 :         return {};
     146                 :            :       }
     147         [ +  + ]:       3475 :       if (simp != a)
     148                 :            :       {
     149         [ +  - ]:        787 :         Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl;
     150                 :        787 :         addToConflictMap(simp, a, tracker);
     151                 :            :       }
     152                 :       3475 :       next.emplace_back(simp);
     153                 :            :     }
     154                 :        162 :     asserts = std::move(next);
     155                 :            :   }
     156                 :        106 :   d_conflict.clear();
     157         [ -  + ]:        106 :   if (TraceIsOn("nl-eqs"))
     158                 :            :   {
     159         [ -  - ]:          0 :     Trace("nl-eqs") << "Output:" << std::endl;
     160         [ -  - ]:          0 :     for (const auto& a : asserts)
     161                 :            :     {
     162         [ -  - ]:          0 :       Trace("nl-eqs") << "\t" << a << std::endl;
     163                 :            :     }
     164         [ -  - ]:          0 :     Trace("nl-eqs") << "Substitutions:" << std::endl;
     165         [ -  - ]:          0 :     for (const auto& subs : d_substitutions->getSubstitutions())
     166                 :            :     {
     167         [ -  - ]:          0 :       Trace("nl-eqs") << "\t" << subs.first << " -> " << subs.second
     168                 :          0 :                       << std::endl;
     169                 :            :     }
     170                 :            :   }
     171                 :        106 :   return asserts;
     172                 :            : }
     173                 :         98 : void EqualitySubstitution::postprocessConflict(
     174                 :            :     std::vector<Node>& conflict) const
     175                 :            : {
     176         [ +  - ]:         98 :   Trace("nl-eqs") << "Postprocessing " << conflict << std::endl;
     177                 :         98 :   std::set<Node> result;
     178         [ +  + ]:        424 :   for (const auto& c : conflict)
     179                 :            :   {
     180                 :        326 :     auto it = d_conflictMap.find(c);
     181         [ +  + ]:        326 :     if (it == d_conflictMap.end())
     182                 :            :     {
     183                 :        280 :       result.insert(c);
     184                 :            :     }
     185                 :            :     else
     186                 :            :     {
     187         [ +  - ]:         46 :       Trace("nl-eqs") << "Origin of " << c << ": " << it->second << std::endl;
     188                 :         46 :       result.insert(it->second.begin(), it->second.end());
     189                 :            :     }
     190                 :            :   }
     191                 :         98 :   conflict.clear();
     192                 :         98 :   conflict.insert(conflict.end(), result.begin(), result.end());
     193         [ +  - ]:         98 :   Trace("nl-eqs") << "-> " << conflict << std::endl;
     194                 :         98 : }
     195                 :       2246 : void EqualitySubstitution::insertOrigins(std::set<Node>& dest,
     196                 :            :                                          const Node& n) const
     197                 :            : {
     198                 :       2246 :   auto it = d_conflictMap.find(n);
     199         [ +  + ]:       2246 :   if (it == d_conflictMap.end())
     200                 :            :   {
     201                 :       1926 :     dest.insert(n);
     202                 :            :   }
     203                 :            :   else
     204                 :            :   {
     205                 :        320 :     dest.insert(it->second.begin(), it->second.end());
     206                 :            :   }
     207                 :       2246 : }
     208                 :        825 : void EqualitySubstitution::addToConflictMap(const Node& n,
     209                 :            :                                             const Node& orig,
     210                 :            :                                             const std::set<TNode>& tracker)
     211                 :            : {
     212                 :        825 :   std::set<Node> origins;
     213                 :        825 :   insertOrigins(origins, orig);
     214         [ +  + ]:       2246 :   for (const auto& t : tracker)
     215                 :            :   {
     216                 :       1421 :     auto tit = d_trackOrigin.find(t);
     217 [ -  + ][ -  + ]:       1421 :     Assert(tit != d_trackOrigin.end());
                 [ -  - ]
     218                 :       1421 :     insertOrigins(origins, tit->second);
     219                 :            :   }
     220                 :        825 :   d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end()));
     221                 :        825 : }
     222                 :            : 
     223                 :            : }  // namespace nl
     224                 :            : }  // namespace arith
     225                 :            : }  // namespace theory
     226                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14