Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Addition utilities for the arithmetic rewriter.
11 : : */
12 : :
13 : : #include "theory/arith/rewriter/addition.h"
14 : :
15 : : #include <iostream>
16 : :
17 : : #include "base/check.h"
18 : : #include "expr/node.h"
19 : : #include "theory/arith/rewriter/node_utils.h"
20 : : #include "theory/arith/rewriter/ordering.h"
21 : : #include "util/real_algebraic_number.h"
22 : :
23 : : namespace cvc5::internal {
24 : : namespace theory {
25 : : namespace arith {
26 : : namespace rewriter {
27 : :
28 : 0 : std::ostream& operator<<(std::ostream& os, const Sum& sum)
29 : : {
30 [ - - ]: 0 : for (auto it = sum.begin(); it != sum.end(); ++it)
31 : : {
32 [ - - ]: 0 : if (it != sum.begin()) os << " + ";
33 [ - - ]: 0 : if (it->first.isConst())
34 : : {
35 : 0 : Assert(it->first.getConst<Rational>().isOne());
36 : 0 : os << it->second;
37 : 0 : continue;
38 : : }
39 : 0 : os << it->second << "*" << it->first;
40 : : }
41 : 0 : return os;
42 : : }
43 : :
44 : : namespace {
45 : :
46 : : /**
47 : : * Adds a factor n to a product, consisting of the numerical multiplicity and
48 : : * the remaining (non-numerical) factors. If n is a product itself, its children
49 : : * are merged into the product. If n is a constant or a real algebraic number,
50 : : * it is multiplied to the multiplicity. Otherwise, n is added to product.
51 : : *
52 : : * Invariant:
53 : : * multiplicity' * multiply(product') = n * multiplicity * multiply(product)
54 : : */
55 : 33353901 : void addToProduct(std::vector<Node>& product,
56 : : RealAlgebraicNumber& multiplicity,
57 : : TNode n)
58 : : {
59 [ + + ][ + ]: 33353901 : switch (n.getKind())
60 : : {
61 : 5925364 : case Kind::MULT:
62 : : case Kind::NONLINEAR_MULT:
63 [ + + ]: 18663863 : for (const auto& child : n)
64 : : {
65 : : // make sure constants are properly extracted.
66 : : // recursion is safe, as mult is already flattened
67 : 12738499 : addToProduct(product, multiplicity, child);
68 : 12738499 : }
69 : 5925364 : break;
70 : 116 : case Kind::REAL_ALGEBRAIC_NUMBER: multiplicity *= getRAN(n); break;
71 : 27428421 : default:
72 [ + + ]: 27428421 : if (n.isConst())
73 : : {
74 : 11175991 : multiplicity *= n.getConst<Rational>();
75 : : }
76 : : else
77 : : {
78 : 16252430 : product.emplace_back(n);
79 : : }
80 : : }
81 : 33353901 : }
82 : :
83 : : /**
84 : : * Add a new summand, consisting of the product and the multiplicity, to a sum.
85 : : * Either adds the summand as a new entry to the sum, or adds the multiplicity
86 : : * to an already existing summand. Removes the entry, if the multiplicity is
87 : : * zero afterwards.
88 : : *
89 : : * Invariant:
90 : : * add(s.n * s.ran for s in sum')
91 : : * = add(s.n * s.ran for s in sum) + multiplicity * product
92 : : */
93 : 20295204 : void addToSum(Sum& sum, TNode product, const RealAlgebraicNumber& multiplicity)
94 : : {
95 [ + + ]: 20295204 : if (multiplicity.isZero()) return;
96 : 18162171 : auto it = sum.find(product);
97 [ + + ]: 18162171 : if (it == sum.end())
98 : : {
99 : 17335042 : sum.emplace(product, multiplicity);
100 : : }
101 : : else
102 : : {
103 : 827129 : it->second += multiplicity;
104 [ + + ]: 827129 : if (it->second.isZero())
105 : : {
106 : 335544 : sum.erase(it);
107 : : }
108 : : }
109 : : }
110 : :
111 : : /**
112 : : * Evaluates `basemultiplicity * baseproduct * sum` into a single node (of kind
113 : : * `ADD`, unless the sum has less than two summands).
114 : : */
115 : 71598 : Node collectSumWithBase(NodeManager* nm,
116 : : const Sum& sum,
117 : : const RealAlgebraicNumber& basemultiplicity,
118 : : const std::vector<Node>& baseproduct)
119 : : {
120 [ - + ]: 71598 : if (sum.empty()) return mkConst(nm, Rational(0));
121 : : // construct the sum as nodes.
122 : 71598 : NodeBuilder nb(nm, Kind::ADD);
123 [ + + ]: 233736 : for (const auto& summand : sum)
124 : : {
125 [ - + ][ - + ]: 162138 : Assert(!summand.second.isZero());
[ - - ]
126 : 162138 : RealAlgebraicNumber mult = summand.second * basemultiplicity;
127 : 162138 : std::vector<Node> product = baseproduct;
128 : 162138 : rewriter::addToProduct(product, mult, summand.first);
129 : 162138 : nb << mkMultTerm(nm, mult, std::move(product));
130 : 162138 : }
131 [ - + ]: 71598 : if (nb.getNumChildren() == 1)
132 : : {
133 : 0 : return nb[0];
134 : : }
135 : 71598 : return nb.constructNode();
136 : 71598 : }
137 : : } // namespace
138 : :
139 : 5435506 : bool isIntegral(const Sum& sum)
140 : : {
141 : 5435506 : std::vector<TNode> queue;
142 [ + + ]: 17108530 : for (const auto& s : sum)
143 : : {
144 : 11673029 : queue.emplace_back(s.first);
145 [ + + ]: 11673029 : if (!s.second.isRational()) return false;
146 : : }
147 [ + + ]: 16751249 : while (!queue.empty())
148 : : {
149 : 12636083 : TNode cur = queue.back();
150 : 12636083 : queue.pop_back();
151 : :
152 [ + + ]: 12636083 : if (cur.isConst()) continue;
153 [ + + ]: 10503217 : switch (cur.getKind())
154 : : {
155 : 1303087 : case Kind::ADD:
156 : : case Kind::NEG:
157 : : case Kind::SUB:
158 : : case Kind::MULT:
159 : : case Kind::NONLINEAR_MULT:
160 : 1303087 : queue.insert(queue.end(), cur.begin(), cur.end());
161 : 1303087 : break;
162 : 9200130 : default:
163 [ + + ]: 9200130 : if (!cur.getType().isInteger()) return false;
164 : : }
165 [ + + ][ + ]: 12636083 : }
166 : 4115166 : return true;
167 : 5435506 : }
168 : :
169 : 24212113 : void addToSum(Sum& sum, TNode n, bool negate)
170 : : {
171 [ + + ]: 24212113 : if (n.getKind() == Kind::ADD)
172 : : {
173 [ + + ]: 14665341 : for (const auto& child : n)
174 : : {
175 : 10349933 : addToSum(sum, child, negate);
176 : 10349933 : }
177 : 4315408 : return;
178 : : }
179 : 19896705 : std::vector<Node> monomial;
180 : 19896705 : RealAlgebraicNumber multiplicity(Integer(1));
181 [ + + ]: 19896705 : if (negate)
182 : : {
183 : 6385673 : multiplicity = Integer(-1);
184 : : }
185 : 19896705 : addToProduct(monomial, multiplicity, n);
186 : 19896705 : addToSum(sum, mkNonlinearMult(n.getNodeManager(), monomial), multiplicity);
187 : 19896705 : }
188 : :
189 : 642264 : void addToSumNoMixed(Sum& sum, TNode n, bool negate)
190 : : {
191 : 642264 : Kind k = n.getKind();
192 [ + + ]: 642264 : if (k == Kind::ADD)
193 : : {
194 [ + + ]: 317748 : for (const auto& child : n)
195 : : {
196 : 215897 : addToSum(
197 [ + + ]: 431794 : sum, child.getKind() == Kind::TO_REAL ? child[0] : child, negate);
198 : 215897 : }
199 : 101851 : return;
200 : : }
201 [ + + ]: 540413 : else if (k == Kind::TO_REAL)
202 : : {
203 : 3 : addToSum(sum, n[0], negate);
204 : 3 : return;
205 : : }
206 : 540410 : addToSum(sum, n, negate);
207 : : }
208 : :
209 : 231006 : void addMonomialToSum(Sum& sum,
210 : : TNode product,
211 : : RealAlgebraicNumber& multiplicity)
212 : : {
213 [ - + ][ - + ]: 231006 : Assert(product.getKind() != Kind::ADD);
[ - - ]
214 : 231006 : std::vector<Node> monomial;
215 : 231006 : addToProduct(monomial, multiplicity, product);
216 : 231006 : addToSum(
217 : 462012 : sum, mkNonlinearMult(product.getNodeManager(), monomial), multiplicity);
218 : 231006 : }
219 : :
220 : 7699578 : Node collectSum(NodeManager* nm, const Sum& sum)
221 : : {
222 [ + + ]: 7699578 : if (sum.empty()) return mkConst(nm, Rational(0));
223 [ + - ]: 7261350 : Trace("arith-rewriter") << "Collecting sum " << sum << std::endl;
224 : : // construct the sum as nodes.
225 : 7261350 : NodeBuilder nb(nm, Kind::ADD);
226 [ + + ]: 19565400 : for (const auto& s : sum)
227 : : {
228 : 12304050 : nb << mkMultTerm(s.second, s.first);
229 : : }
230 [ + + ]: 7261350 : if (nb.getNumChildren() == 1)
231 : : {
232 : 3302054 : return nb[0];
233 : : }
234 : 3959296 : return nb.constructNode();
235 : 7261350 : }
236 : :
237 : 71598 : Node distributeMultiplication(NodeManager* nm,
238 : : const std::vector<TNode>& factors)
239 : : {
240 [ - + ]: 71598 : if (TraceIsOn("arith-rewriter-distribute"))
241 : : {
242 [ - - ]: 0 : Trace("arith-rewriter-distribute") << "Distributing" << std::endl;
243 [ - - ]: 0 : for (const auto& f : factors)
244 : : {
245 [ - - ]: 0 : Trace("arith-rewriter-distribute") << "\t" << f << std::endl;
246 : : }
247 : : }
248 : : // factors that are not sums, separated into numerical and non-numerical
249 : 71598 : RealAlgebraicNumber basemultiplicity(Integer(1));
250 : 71598 : std::vector<Node> base;
251 : : // maps products to their (possibly real algebraic) multiplicities.
252 : : // The current (intermediate) value is the sum of these (multiplied by the
253 : : // base factors).
254 : 71598 : Sum sum;
255 : : // Add a base summand
256 : 71598 : sum.emplace(mkConst(nm, Rational(1)), RealAlgebraicNumber(Integer(1)));
257 : :
258 : : // multiply factors one by one to basmultiplicity * base * sum
259 [ + + ]: 216826 : for (const auto& factor : factors)
260 : : {
261 : : // Subtractions are rewritten already, we only need to care about additions
262 [ - + ][ - + ]: 145228 : Assert(factor.getKind() != Kind::SUB);
[ - - ]
263 : 145228 : Assert(factor.getKind() != Kind::NEG
264 : : || (factor[0].isConst() || isRAN(factor[0])));
265 [ + + ]: 145228 : if (factor.getKind() != Kind::ADD)
266 : : {
267 [ + + ][ + - ]: 73067 : Assert(!(factor.isConst() && factor.getConst<Rational>().isZero()));
[ - + ][ - + ]
[ - - ]
268 : 73067 : addToProduct(base, basemultiplicity, factor);
269 : 73067 : continue;
270 : : }
271 : : // temporary to store factor * sum, will be moved to sum at the end
272 : 72161 : Sum newsum;
273 : :
274 [ + + ]: 146219 : for (const auto& summand : sum)
275 : : {
276 [ + + ]: 241551 : for (const auto& child : factor)
277 : : {
278 : : // add summand * child to newsum
279 : 167493 : RealAlgebraicNumber multiplicity = summand.second;
280 [ + + ]: 167493 : if (child.isConst())
281 : : {
282 : 41250 : multiplicity *= child.getConst<Rational>();
283 : 41250 : addToSum(newsum, summand.first, multiplicity);
284 : 41250 : continue;
285 : : }
286 [ - + ]: 126243 : if (isRAN(child))
287 : : {
288 : 0 : multiplicity *= getRAN(child);
289 : 0 : addToSum(newsum, summand.first, multiplicity);
290 : 0 : continue;
291 : : }
292 : :
293 : : // construct the new product
294 : 126243 : std::vector<Node> newProduct;
295 : 126243 : addToProduct(newProduct, multiplicity, summand.first);
296 : 126243 : addToProduct(newProduct, multiplicity, child);
297 : 126243 : std::sort(newProduct.begin(), newProduct.end(), LeafNodeComparator());
298 : 126243 : addToSum(newsum, mkNonlinearMult(nm, newProduct), multiplicity);
299 [ + + ][ + + ]: 208743 : }
300 : : }
301 [ - + ]: 72161 : if (TraceIsOn("arith-rewriter-distribute"))
302 : : {
303 [ - - ]: 0 : Trace("arith-rewriter-distribute")
304 : 0 : << "multiplied with " << factor << std::endl;
305 [ - - ]: 0 : Trace("arith-rewriter-distribute")
306 : 0 : << "base: " << basemultiplicity << " * " << base << std::endl;
307 [ - - ]: 0 : Trace("arith-rewriter-distribute") << "sum:" << std::endl;
308 [ - - ]: 0 : for (const auto& summand : newsum)
309 : : {
310 [ - - ]: 0 : Trace("arith-rewriter-distribute")
311 : 0 : << "\t" << summand.second << " * " << summand.first << std::endl;
312 : : }
313 : : }
314 : :
315 : 72161 : sum = std::move(newsum);
316 : 72161 : }
317 : : // now mult(factors) == base * add(sum)
318 : :
319 : 143196 : return collectSumWithBase(nm, sum, basemultiplicity, base);
320 : 71598 : }
321 : :
322 : : } // namespace rewriter
323 : : } // namespace arith
324 : : } // namespace theory
325 : : } // namespace cvc5::internal
|