LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/arith/rewriter - addition.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 127 151 84.1 %
Date: 2026-06-10 10:33:01 Functions: 9 10 90.0 %
Branches: 81 126 64.3 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * This file is part of the cvc5 project.
       3                 :            :  *
       4                 :            :  * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
       5                 :            :  * in the top-level source directory and their institutional affiliations.
       6                 :            :  * All rights reserved.  See the file COPYING in the top-level source
       7                 :            :  * directory for licensing information.
       8                 :            :  * ****************************************************************************
       9                 :            :  *
      10                 :            :  * Addition utilities for the arithmetic rewriter.
      11                 :            :  */
      12                 :            : 
      13                 :            : #include "theory/arith/rewriter/addition.h"
      14                 :            : 
      15                 :            : #include <iostream>
      16                 :            : 
      17                 :            : #include "base/check.h"
      18                 :            : #include "expr/node.h"
      19                 :            : #include "theory/arith/rewriter/node_utils.h"
      20                 :            : #include "theory/arith/rewriter/ordering.h"
      21                 :            : #include "util/real_algebraic_number.h"
      22                 :            : 
      23                 :            : namespace cvc5::internal {
      24                 :            : namespace theory {
      25                 :            : namespace arith {
      26                 :            : namespace rewriter {
      27                 :            : 
      28                 :          0 : std::ostream& operator<<(std::ostream& os, const Sum& sum)
      29                 :            : {
      30         [ -  - ]:          0 :   for (auto it = sum.begin(); it != sum.end(); ++it)
      31                 :            :   {
      32         [ -  - ]:          0 :     if (it != sum.begin()) os << " + ";
      33         [ -  - ]:          0 :     if (it->first.isConst())
      34                 :            :     {
      35                 :          0 :       Assert(it->first.getConst<Rational>().isOne());
      36                 :          0 :       os << it->second;
      37                 :          0 :       continue;
      38                 :            :     }
      39                 :          0 :     os << it->second << "*" << it->first;
      40                 :            :   }
      41                 :          0 :   return os;
      42                 :            : }
      43                 :            : 
      44                 :            : namespace {
      45                 :            : 
      46                 :            : /**
      47                 :            :  * Adds a factor n to a product, consisting of the numerical multiplicity and
      48                 :            :  * the remaining (non-numerical) factors. If n is a product itself, its children
      49                 :            :  * are merged into the product. If n is a constant or a real algebraic number,
      50                 :            :  * it is multiplied to the multiplicity. Otherwise, n is added to product.
      51                 :            :  *
      52                 :            :  * Invariant:
      53                 :            :  *   multiplicity' * multiply(product') = n * multiplicity * multiply(product)
      54                 :            :  */
      55                 :   33353901 : void addToProduct(std::vector<Node>& product,
      56                 :            :                   RealAlgebraicNumber& multiplicity,
      57                 :            :                   TNode n)
      58                 :            : {
      59    [ +  + ][ + ]:   33353901 :   switch (n.getKind())
      60                 :            :   {
      61                 :    5925364 :     case Kind::MULT:
      62                 :            :     case Kind::NONLINEAR_MULT:
      63         [ +  + ]:   18663863 :       for (const auto& child : n)
      64                 :            :       {
      65                 :            :         // make sure constants are properly extracted.
      66                 :            :         // recursion is safe, as mult is already flattened
      67                 :   12738499 :         addToProduct(product, multiplicity, child);
      68                 :   12738499 :       }
      69                 :    5925364 :       break;
      70                 :        116 :     case Kind::REAL_ALGEBRAIC_NUMBER: multiplicity *= getRAN(n); break;
      71                 :   27428421 :     default:
      72         [ +  + ]:   27428421 :       if (n.isConst())
      73                 :            :       {
      74                 :   11175991 :         multiplicity *= n.getConst<Rational>();
      75                 :            :       }
      76                 :            :       else
      77                 :            :       {
      78                 :   16252430 :         product.emplace_back(n);
      79                 :            :       }
      80                 :            :   }
      81                 :   33353901 : }
      82                 :            : 
      83                 :            : /**
      84                 :            :  * Add a new summand, consisting of the product and the multiplicity, to a sum.
      85                 :            :  * Either adds the summand as a new entry to the sum, or adds the multiplicity
      86                 :            :  * to an already existing summand. Removes the entry, if the multiplicity is
      87                 :            :  * zero afterwards.
      88                 :            :  *
      89                 :            :  * Invariant:
      90                 :            :  *   add(s.n * s.ran for s in sum')
      91                 :            :  *   = add(s.n * s.ran for s in sum) + multiplicity * product
      92                 :            :  */
      93                 :   20295204 : void addToSum(Sum& sum, TNode product, const RealAlgebraicNumber& multiplicity)
      94                 :            : {
      95         [ +  + ]:   20295204 :   if (multiplicity.isZero()) return;
      96                 :   18162171 :   auto it = sum.find(product);
      97         [ +  + ]:   18162171 :   if (it == sum.end())
      98                 :            :   {
      99                 :   17335042 :     sum.emplace(product, multiplicity);
     100                 :            :   }
     101                 :            :   else
     102                 :            :   {
     103                 :     827129 :     it->second += multiplicity;
     104         [ +  + ]:     827129 :     if (it->second.isZero())
     105                 :            :     {
     106                 :     335544 :       sum.erase(it);
     107                 :            :     }
     108                 :            :   }
     109                 :            : }
     110                 :            : 
     111                 :            : /**
     112                 :            :  * Evaluates `basemultiplicity * baseproduct * sum` into a single node (of kind
     113                 :            :  * `ADD`, unless the sum has less than two summands).
     114                 :            :  */
     115                 :      71598 : Node collectSumWithBase(NodeManager* nm,
     116                 :            :                         const Sum& sum,
     117                 :            :                         const RealAlgebraicNumber& basemultiplicity,
     118                 :            :                         const std::vector<Node>& baseproduct)
     119                 :            : {
     120         [ -  + ]:      71598 :   if (sum.empty()) return mkConst(nm, Rational(0));
     121                 :            :   // construct the sum as nodes.
     122                 :      71598 :   NodeBuilder nb(nm, Kind::ADD);
     123         [ +  + ]:     233736 :   for (const auto& summand : sum)
     124                 :            :   {
     125 [ -  + ][ -  + ]:     162138 :     Assert(!summand.second.isZero());
                 [ -  - ]
     126                 :     162138 :     RealAlgebraicNumber mult = summand.second * basemultiplicity;
     127                 :     162138 :     std::vector<Node> product = baseproduct;
     128                 :     162138 :     rewriter::addToProduct(product, mult, summand.first);
     129                 :     162138 :     nb << mkMultTerm(nm, mult, std::move(product));
     130                 :     162138 :   }
     131         [ -  + ]:      71598 :   if (nb.getNumChildren() == 1)
     132                 :            :   {
     133                 :          0 :     return nb[0];
     134                 :            :   }
     135                 :      71598 :   return nb.constructNode();
     136                 :      71598 : }
     137                 :            : }  // namespace
     138                 :            : 
     139                 :    5435506 : bool isIntegral(const Sum& sum)
     140                 :            : {
     141                 :    5435506 :   std::vector<TNode> queue;
     142         [ +  + ]:   17108530 :   for (const auto& s : sum)
     143                 :            :   {
     144                 :   11673029 :     queue.emplace_back(s.first);
     145         [ +  + ]:   11673029 :     if (!s.second.isRational()) return false;
     146                 :            :   }
     147         [ +  + ]:   16751249 :   while (!queue.empty())
     148                 :            :   {
     149                 :   12636083 :     TNode cur = queue.back();
     150                 :   12636083 :     queue.pop_back();
     151                 :            : 
     152         [ +  + ]:   12636083 :     if (cur.isConst()) continue;
     153         [ +  + ]:   10503217 :     switch (cur.getKind())
     154                 :            :     {
     155                 :    1303087 :       case Kind::ADD:
     156                 :            :       case Kind::NEG:
     157                 :            :       case Kind::SUB:
     158                 :            :       case Kind::MULT:
     159                 :            :       case Kind::NONLINEAR_MULT:
     160                 :    1303087 :         queue.insert(queue.end(), cur.begin(), cur.end());
     161                 :    1303087 :         break;
     162                 :    9200130 :       default:
     163         [ +  + ]:    9200130 :         if (!cur.getType().isInteger()) return false;
     164                 :            :     }
     165    [ +  + ][ + ]:   12636083 :   }
     166                 :    4115166 :   return true;
     167                 :    5435506 : }
     168                 :            : 
     169                 :   24212113 : void addToSum(Sum& sum, TNode n, bool negate)
     170                 :            : {
     171         [ +  + ]:   24212113 :   if (n.getKind() == Kind::ADD)
     172                 :            :   {
     173         [ +  + ]:   14665341 :     for (const auto& child : n)
     174                 :            :     {
     175                 :   10349933 :       addToSum(sum, child, negate);
     176                 :   10349933 :     }
     177                 :    4315408 :     return;
     178                 :            :   }
     179                 :   19896705 :   std::vector<Node> monomial;
     180                 :   19896705 :   RealAlgebraicNumber multiplicity(Integer(1));
     181         [ +  + ]:   19896705 :   if (negate)
     182                 :            :   {
     183                 :    6385673 :     multiplicity = Integer(-1);
     184                 :            :   }
     185                 :   19896705 :   addToProduct(monomial, multiplicity, n);
     186                 :   19896705 :   addToSum(sum, mkNonlinearMult(n.getNodeManager(), monomial), multiplicity);
     187                 :   19896705 : }
     188                 :            : 
     189                 :     642264 : void addToSumNoMixed(Sum& sum, TNode n, bool negate)
     190                 :            : {
     191                 :     642264 :   Kind k = n.getKind();
     192         [ +  + ]:     642264 :   if (k == Kind::ADD)
     193                 :            :   {
     194         [ +  + ]:     317748 :     for (const auto& child : n)
     195                 :            :     {
     196                 :     215897 :       addToSum(
     197         [ +  + ]:     431794 :           sum, child.getKind() == Kind::TO_REAL ? child[0] : child, negate);
     198                 :     215897 :     }
     199                 :     101851 :     return;
     200                 :            :   }
     201         [ +  + ]:     540413 :   else if (k == Kind::TO_REAL)
     202                 :            :   {
     203                 :          3 :     addToSum(sum, n[0], negate);
     204                 :          3 :     return;
     205                 :            :   }
     206                 :     540410 :   addToSum(sum, n, negate);
     207                 :            : }
     208                 :            : 
     209                 :     231006 : void addMonomialToSum(Sum& sum,
     210                 :            :                       TNode product,
     211                 :            :                       RealAlgebraicNumber& multiplicity)
     212                 :            : {
     213 [ -  + ][ -  + ]:     231006 :   Assert(product.getKind() != Kind::ADD);
                 [ -  - ]
     214                 :     231006 :   std::vector<Node> monomial;
     215                 :     231006 :   addToProduct(monomial, multiplicity, product);
     216                 :     231006 :   addToSum(
     217                 :     462012 :       sum, mkNonlinearMult(product.getNodeManager(), monomial), multiplicity);
     218                 :     231006 : }
     219                 :            : 
     220                 :    7699578 : Node collectSum(NodeManager* nm, const Sum& sum)
     221                 :            : {
     222         [ +  + ]:    7699578 :   if (sum.empty()) return mkConst(nm, Rational(0));
     223         [ +  - ]:    7261350 :   Trace("arith-rewriter") << "Collecting sum " << sum << std::endl;
     224                 :            :   // construct the sum as nodes.
     225                 :    7261350 :   NodeBuilder nb(nm, Kind::ADD);
     226         [ +  + ]:   19565400 :   for (const auto& s : sum)
     227                 :            :   {
     228                 :   12304050 :     nb << mkMultTerm(s.second, s.first);
     229                 :            :   }
     230         [ +  + ]:    7261350 :   if (nb.getNumChildren() == 1)
     231                 :            :   {
     232                 :    3302054 :     return nb[0];
     233                 :            :   }
     234                 :    3959296 :   return nb.constructNode();
     235                 :    7261350 : }
     236                 :            : 
     237                 :      71598 : Node distributeMultiplication(NodeManager* nm,
     238                 :            :                               const std::vector<TNode>& factors)
     239                 :            : {
     240         [ -  + ]:      71598 :   if (TraceIsOn("arith-rewriter-distribute"))
     241                 :            :   {
     242         [ -  - ]:          0 :     Trace("arith-rewriter-distribute") << "Distributing" << std::endl;
     243         [ -  - ]:          0 :     for (const auto& f : factors)
     244                 :            :     {
     245         [ -  - ]:          0 :       Trace("arith-rewriter-distribute") << "\t" << f << std::endl;
     246                 :            :     }
     247                 :            :   }
     248                 :            :   // factors that are not sums, separated into numerical and non-numerical
     249                 :      71598 :   RealAlgebraicNumber basemultiplicity(Integer(1));
     250                 :      71598 :   std::vector<Node> base;
     251                 :            :   // maps products to their (possibly real algebraic) multiplicities.
     252                 :            :   // The current (intermediate) value is the sum of these (multiplied by the
     253                 :            :   // base factors).
     254                 :      71598 :   Sum sum;
     255                 :            :   // Add a base summand
     256                 :      71598 :   sum.emplace(mkConst(nm, Rational(1)), RealAlgebraicNumber(Integer(1)));
     257                 :            : 
     258                 :            :   // multiply factors one by one to basmultiplicity * base * sum
     259         [ +  + ]:     216826 :   for (const auto& factor : factors)
     260                 :            :   {
     261                 :            :     // Subtractions are rewritten already, we only need to care about additions
     262 [ -  + ][ -  + ]:     145228 :     Assert(factor.getKind() != Kind::SUB);
                 [ -  - ]
     263                 :     145228 :     Assert(factor.getKind() != Kind::NEG
     264                 :            :            || (factor[0].isConst() || isRAN(factor[0])));
     265         [ +  + ]:     145228 :     if (factor.getKind() != Kind::ADD)
     266                 :            :     {
     267 [ +  + ][ +  - ]:      73067 :       Assert(!(factor.isConst() && factor.getConst<Rational>().isZero()));
         [ -  + ][ -  + ]
                 [ -  - ]
     268                 :      73067 :       addToProduct(base, basemultiplicity, factor);
     269                 :      73067 :       continue;
     270                 :            :     }
     271                 :            :     // temporary to store factor * sum, will be moved to sum at the end
     272                 :      72161 :     Sum newsum;
     273                 :            : 
     274         [ +  + ]:     146219 :     for (const auto& summand : sum)
     275                 :            :     {
     276         [ +  + ]:     241551 :       for (const auto& child : factor)
     277                 :            :       {
     278                 :            :         // add summand * child to newsum
     279                 :     167493 :         RealAlgebraicNumber multiplicity = summand.second;
     280         [ +  + ]:     167493 :         if (child.isConst())
     281                 :            :         {
     282                 :      41250 :           multiplicity *= child.getConst<Rational>();
     283                 :      41250 :           addToSum(newsum, summand.first, multiplicity);
     284                 :      41250 :           continue;
     285                 :            :         }
     286         [ -  + ]:     126243 :         if (isRAN(child))
     287                 :            :         {
     288                 :          0 :           multiplicity *= getRAN(child);
     289                 :          0 :           addToSum(newsum, summand.first, multiplicity);
     290                 :          0 :           continue;
     291                 :            :         }
     292                 :            : 
     293                 :            :         // construct the new product
     294                 :     126243 :         std::vector<Node> newProduct;
     295                 :     126243 :         addToProduct(newProduct, multiplicity, summand.first);
     296                 :     126243 :         addToProduct(newProduct, multiplicity, child);
     297                 :     126243 :         std::sort(newProduct.begin(), newProduct.end(), LeafNodeComparator());
     298                 :     126243 :         addToSum(newsum, mkNonlinearMult(nm, newProduct), multiplicity);
     299 [ +  + ][ +  + ]:     208743 :       }
     300                 :            :     }
     301         [ -  + ]:      72161 :     if (TraceIsOn("arith-rewriter-distribute"))
     302                 :            :     {
     303         [ -  - ]:          0 :       Trace("arith-rewriter-distribute")
     304                 :          0 :           << "multiplied with " << factor << std::endl;
     305         [ -  - ]:          0 :       Trace("arith-rewriter-distribute")
     306                 :          0 :           << "base: " << basemultiplicity << " * " << base << std::endl;
     307         [ -  - ]:          0 :       Trace("arith-rewriter-distribute") << "sum:" << std::endl;
     308         [ -  - ]:          0 :       for (const auto& summand : newsum)
     309                 :            :       {
     310         [ -  - ]:          0 :         Trace("arith-rewriter-distribute")
     311                 :          0 :             << "\t" << summand.second << " * " << summand.first << std::endl;
     312                 :            :       }
     313                 :            :     }
     314                 :            : 
     315                 :      72161 :     sum = std::move(newsum);
     316                 :      72161 :   }
     317                 :            :   // now mult(factors) == base * add(sum)
     318                 :            : 
     319                 :     143196 :   return collectSumWithBase(nm, sum, basemultiplicity, base);
     320                 :      71598 : }
     321                 :            : 
     322                 :            : }  // namespace rewriter
     323                 :            : }  // namespace arith
     324                 :            : }  // namespace theory
     325                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14