LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/arith - arith_msum.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 143 157 91.1 %
Date: 2025-01-03 12:37:04 Functions: 9 10 90.0 %
Branches: 126 182 69.2 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Andrew Reynolds, Aina Niemetz, Andres Noetzli
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Arithmetic utilities regarding monomial sums.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/arith/arith_msum.h"
      17                 :            : 
      18                 :            : #include "theory/rewriter.h"
      19                 :            : #include "util/rational.h"
      20                 :            : 
      21                 :            : using namespace cvc5::internal::kind;
      22                 :            : 
      23                 :            : namespace cvc5::internal {
      24                 :            : namespace theory {
      25                 :            : 
      26                 :        978 : bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
      27                 :            : {
      28 [ +  - ][ +  - ]:        978 :   if (n.getKind() == Kind::MULT && n.getNumChildren() == 2 && n[0].isConst())
         [ +  - ][ +  - ]
         [ +  - ][ -  - ]
      29                 :            :   {
      30                 :        978 :     c = n[0];
      31                 :        978 :     v = n[1];
      32                 :        978 :     return true;
      33                 :            :   }
      34                 :          0 :   return false;
      35                 :            : }
      36                 :            : 
      37                 :    2001730 : bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
      38                 :            : {
      39         [ +  + ]:    2001730 :   if (n.isConst())
      40                 :            :   {
      41         [ +  - ]:     533831 :     if (msum.find(Node::null()) == msum.end())
      42                 :            :     {
      43                 :     533831 :       msum[Node::null()] = n;
      44                 :     533831 :       return true;
      45                 :            :     }
      46                 :            :   }
      47         [ +  - ]:     559003 :   else if (n.getKind() == Kind::MULT && n.getNumChildren() == 2
      48 [ +  + ][ +  - ]:    2026900 :            && n[0].isConst())
         [ +  + ][ +  + ]
                 [ -  - ]
      49                 :            :   {
      50         [ +  - ]:     559003 :     if (msum.find(n[1]) == msum.end())
      51                 :            :     {
      52                 :     559003 :       msum[n[1]] = n[0];
      53                 :     559003 :       return true;
      54                 :            :     }
      55                 :            :   }
      56                 :            :   else
      57                 :            :   {
      58         [ +  - ]:     908899 :     if (msum.find(n) == msum.end())
      59                 :            :     {
      60                 :     908899 :       msum[n] = Node::null();
      61                 :     908899 :       return true;
      62                 :            :     }
      63                 :            :   }
      64                 :          0 :   return false;
      65                 :            : }
      66                 :            : 
      67                 :    1219610 : bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
      68                 :            : {
      69         [ +  + ]:    1219610 :   if (n.getKind() == Kind::ADD)
      70                 :            :   {
      71         [ +  + ]:    1839630 :     for (Node nc : n)
      72                 :            :     {
      73         [ -  + ]:    1310880 :       if (!getMonomial(nc, msum))
      74                 :            :       {
      75                 :          0 :         return false;
      76                 :            :       }
      77                 :            :     }
      78                 :     528757 :     return true;
      79                 :            :   }
      80                 :     690856 :   return getMonomial(n, msum);
      81                 :            : }
      82                 :            : 
      83                 :     443029 : bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
      84                 :            : {
      85                 :    1329090 :   if (lit.getKind() == Kind::GEQ
      86                 :     443029 :       || (lit.getKind() == Kind::EQUAL && lit[0].getType().isRealOrInt()))
      87                 :            :   {
      88         [ +  - ]:     364755 :     if (getMonomialSum(lit[0], msum))
      89                 :            :     {
      90                 :     364755 :       if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
      91                 :            :       {
      92                 :      75844 :         return true;
      93                 :            :       }
      94                 :            :       else
      95                 :            :       {
      96                 :            :         // subtract the other side
      97                 :     288911 :         std::map<Node, Node> msum2;
      98                 :     288911 :         NodeManager* nm = NodeManager::currentNM();
      99         [ +  - ]:     288911 :         if (getMonomialSum(lit[1], msum2))
     100                 :            :         {
     101                 :     608881 :           for (std::map<Node, Node>::iterator it = msum2.begin();
     102         [ +  + ]:     608881 :                it != msum2.end();
     103                 :     319970 :                ++it)
     104                 :            :           {
     105                 :     319970 :             std::map<Node, Node>::iterator it2 = msum.find(it->first);
     106         [ +  + ]:     319970 :             if (it2 != msum.end())
     107                 :            :             {
     108                 :        130 :               Rational r1 = it2->second.isNull()
     109                 :            :                                 ? Rational(1)
     110         [ +  + ]:        260 :                                 : it2->second.getConst<Rational>();
     111                 :        130 :               Rational r2 = it->second.isNull()
     112                 :            :                                 ? Rational(1)
     113         [ +  + ]:        130 :                                 : it->second.getConst<Rational>();
     114                 :        130 :               msum[it->first] = nm->mkConstRealOrInt(r1 - r2);
     115                 :            :             }
     116                 :            :             else
     117                 :            :             {
     118                 :     639680 :               msum[it->first] = it->second.isNull()
     119 [ +  + ][ -  - ]:    1113710 :                                     ? nm->mkConstInt(Rational(-1))
     120                 :            :                                     : nm->mkConstRealOrInt(
     121 [ +  + ][ +  + ]:     793868 :                                           -it->second.getConst<Rational>());
                 [ -  - ]
     122                 :            :             }
     123                 :            :           }
     124                 :     288911 :           return true;
     125                 :            :         }
     126                 :            :       }
     127                 :            :     }
     128                 :            :   }
     129                 :      78274 :   return false;
     130                 :            : }
     131                 :            : 
     132                 :       1779 : Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
     133                 :            : {
     134                 :       1779 :   NodeManager* nm = NodeManager::currentNM();
     135                 :       1779 :   std::vector<Node> children;
     136         [ +  + ]:       3558 :   for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
     137                 :       1779 :        ++it)
     138                 :            :   {
     139                 :       3558 :     Node m;
     140         [ +  + ]:       1779 :     if (!it->first.isNull())
     141                 :            :     {
     142                 :        921 :       m = mkCoeffTerm(it->second, it->first);
     143                 :            :     }
     144                 :            :     else
     145                 :            :     {
     146 [ -  + ][ -  + ]:        858 :       Assert(!it->second.isNull());
                 [ -  - ]
     147                 :        858 :       m = it->second;
     148                 :            :     }
     149                 :       1779 :     children.push_back(m);
     150                 :            :   }
     151                 :       1779 :   return children.size() > 1
     152                 :            :              ? nm->mkNode(Kind::ADD, children)
     153                 :       3558 :              : (children.size() == 1 ? children[0]
     154 [ -  + ][ +  - ]:       8895 :                                      : nm->mkConstInt(Rational(0)));
         [ -  + ][ -  - ]
     155                 :            : }
     156                 :            : 
     157                 :     245998 : int ArithMSum::isolate(
     158                 :            :     Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
     159                 :            : {
     160 [ -  + ][ -  + ]:     245998 :   Assert(veq_c.isNull());
                 [ -  - ]
     161                 :     245998 :   std::map<Node, Node>::const_iterator itv = msum.find(v);
     162         [ +  + ]:     245998 :   if (itv != msum.end())
     163                 :            :   {
     164                 :     239495 :     NodeManager* nm = NodeManager::currentNM();
     165                 :     239495 :     std::vector<Node> children;
     166                 :            :     Rational r =
     167         [ +  + ]:     239495 :         itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
     168         [ +  + ]:     239495 :     if (r.sgn() != 0)
     169                 :            :     {
     170                 :     239493 :       TypeNode vtn = v.getType();
     171                 :     814878 :       for (std::map<Node, Node>::const_iterator it = msum.begin();
     172         [ +  + ]:     814878 :            it != msum.end();
     173                 :     575385 :            ++it)
     174                 :            :       {
     175         [ +  + ]:     575385 :         if (it->first != v)
     176                 :            :         {
     177                 :     671784 :           Node m;
     178         [ +  + ]:     335892 :           if (!it->first.isNull())
     179                 :            :           {
     180                 :     243265 :             m = mkCoeffTerm(it->second, it->first);
     181                 :            :           }
     182                 :            :           else
     183                 :            :           {
     184                 :      92627 :             m = it->second;
     185                 :            :           }
     186                 :     335892 :           children.push_back(m);
     187                 :            :         }
     188                 :            :       }
     189                 :     239493 :       val = children.size() > 1
     190 [ +  + ][ +  + ]:     822756 :                 ? nm->mkNode(Kind::ADD, children)
     191                 :     318663 :                 : (children.size() == 1 ? children[0]
     192 [ +  + ][ -  - ]:     504093 :                                         : nm->mkConstInt(Rational(0)));
     193 [ +  + ][ +  + ]:     239493 :       if (!r.isOne() && !r.isNegativeOne())
                 [ +  + ]
     194                 :            :       {
     195         [ +  + ]:      12677 :         if (vtn.isInteger())
     196                 :            :         {
     197                 :       6460 :           veq_c = nm->mkConstRealOrInt(r.abs());
     198                 :            :         }
     199                 :            :         else
     200                 :            :         {
     201                 :      12434 :           val = nm->mkNode(
     202                 :      18651 :               Kind::MULT, val, nm->mkConstReal(Rational(1) / r.abs()));
     203                 :            :         }
     204                 :            :       }
     205                 :     810448 :       val = r.sgn() == 1 ? nm->mkNode(
     206                 :     405224 :                 Kind::MULT, nm->mkConstRealOrInt(Rational(-1)), val)
     207                 :     239493 :                          : val;
     208 [ +  + ][ +  + ]:     239493 :       return (r.sgn() == 1 || k == Kind::EQUAL) ? 1 : -1;
     209                 :            :     }
     210                 :            :   }
     211                 :       6505 :   return 0;
     212                 :            : }
     213                 :            : 
     214                 :      14011 : int ArithMSum::isolate(
     215                 :            :     Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
     216                 :            : {
     217                 :      28022 :   Node veq_c;
     218                 :      28022 :   Node val;
     219                 :            :   // isolate v in the (in)equality
     220                 :      14011 :   int ires = isolate(v, msum, veq_c, val, k);
     221         [ +  + ]:      14011 :   if (ires != 0)
     222                 :            :   {
     223                 :      13972 :     NodeManager* nm = NodeManager::currentNM();
     224                 :      13972 :     Node vc = v;
     225         [ +  + ]:      13972 :     if (!veq_c.isNull())
     226                 :            :     {
     227         [ +  + ]:        147 :       if (doCoeff)
     228                 :            :       {
     229                 :        119 :         vc = nm->mkNode(Kind::MULT, veq_c, vc);
     230                 :            :       }
     231                 :            :       else
     232                 :            :       {
     233                 :         28 :         return 0;
     234                 :            :       }
     235                 :            :     }
     236                 :      13944 :     bool inOrder = ires == 1;
     237                 :            :     // ensure type is correct for equality
     238         [ +  + ]:      13944 :     if (k == Kind::EQUAL)
     239                 :            :     {
     240                 :       5827 :       bool vci = vc.getType().isInteger();
     241                 :       5827 :       bool vi = val.getType().isInteger();
     242 [ +  + ][ +  + ]:       5827 :       if (!vci && vi)
     243                 :            :       {
     244                 :         12 :         val = nm->mkNode(Kind::TO_REAL, val);
     245                 :            :       }
     246 [ +  + ][ +  + ]:       5815 :       else if (vci && !vi)
     247                 :            :       {
     248                 :          4 :         val = nm->mkNode(Kind::TO_INTEGER, val);
     249                 :            :       }
     250 [ -  + ][ -  - ]:      11654 :       Assert(val.getType() == vc.getType())
     251                 :       5827 :           << val << " " << vc << " " << val.getType() << " " << vc.getType();
     252                 :            :     }
     253 [ +  + ][ +  + ]:      13944 :     veq = nm->mkNode(k, inOrder ? vc : val, inOrder ? val : vc);
     254                 :            :   }
     255                 :      13983 :   return ires;
     256                 :            : }
     257                 :            : 
     258                 :        140 : Node ArithMSum::solveEqualityFor(Node lit, Node v)
     259                 :            : {
     260 [ -  + ][ -  + ]:        140 :   Assert(lit.getKind() == Kind::EQUAL);
                 [ -  - ]
     261                 :            :   // first look directly at sides
     262                 :        280 :   TypeNode tn = lit[0].getType();
     263         [ +  + ]:        256 :   for (unsigned r = 0; r < 2; r++)
     264                 :            :   {
     265         [ +  + ]:        198 :     if (lit[r] == v)
     266                 :            :     {
     267                 :         82 :       return lit[1 - r];
     268                 :            :     }
     269                 :            :   }
     270         [ +  - ]:         58 :   if (tn.isRealOrInt())
     271                 :            :   {
     272                 :         58 :     std::map<Node, Node> msum;
     273         [ +  - ]:         58 :     if (ArithMSum::getMonomialSumLit(lit, msum))
     274                 :            :     {
     275                 :         58 :       Node val, veqc;
     276         [ +  + ]:         58 :       if (ArithMSum::isolate(v, msum, veqc, val, Kind::EQUAL) != 0)
     277                 :            :       {
     278         [ +  - ]:         50 :         if (veqc.isNull())
     279                 :            :         {
     280                 :            :           // in this case, we have an integer equality with a coefficient
     281                 :            :           // on the variable we solved for that could not be eliminated,
     282                 :            :           // hence we fail.
     283                 :         50 :           return val;
     284                 :            :         }
     285                 :            :       }
     286                 :            :     }
     287                 :            :   }
     288                 :          8 :   return Node::null();
     289                 :            : }
     290                 :            : 
     291                 :          0 : bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
     292                 :            : {
     293                 :          0 :   std::map<Node, Node> msum;
     294         [ -  - ]:          0 :   if (getMonomialSum(n, msum))
     295                 :            :   {
     296                 :          0 :     std::map<Node, Node>::iterator it = msum.find(v);
     297         [ -  - ]:          0 :     if (it == msum.end())
     298                 :            :     {
     299                 :          0 :       return false;
     300                 :            :     }
     301                 :            :     else
     302                 :            :     {
     303                 :          0 :       coeff = it->second;
     304                 :          0 :       msum.erase(v);
     305                 :          0 :       rem = mkNode(msum);
     306                 :          0 :       return true;
     307                 :            :     }
     308                 :            :   }
     309                 :          0 :   return false;
     310                 :            : }
     311                 :            : 
     312                 :       5420 : void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
     313                 :            : {
     314         [ +  + ]:      16474 :   for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
     315                 :            :   {
     316         [ +  - ]:      11054 :     Trace(c) << "  ";
     317         [ +  + ]:      11054 :     if (!it->second.isNull())
     318                 :            :     {
     319         [ +  - ]:       4977 :       Trace(c) << it->second;
     320         [ +  + ]:       4977 :       if (!it->first.isNull())
     321                 :            :       {
     322         [ +  - ]:       2778 :         Trace(c) << " * ";
     323                 :            :       }
     324                 :            :     }
     325         [ +  + ]:      11054 :     if (!it->first.isNull())
     326                 :            :     {
     327         [ +  - ]:       8855 :       Trace(c) << it->first;
     328                 :            :     }
     329         [ +  - ]:      11054 :     Trace(c) << std::endl;
     330                 :            :   }
     331         [ +  - ]:       5420 :   Trace(c) << std::endl;
     332                 :       5420 : }
     333                 :            : 
     334                 :            : }  // namespace theory
     335                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14