Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * [[ Add one-line brief description here ]]
11 : : *
12 : : * [[ Add lengthier description here ]]
13 : : * \todo document this file
14 : : */
15 : :
16 : : #include "theory/arith/arith_ite_utils.h"
17 : :
18 : : #include <ostream>
19 : :
20 : : #include "base/output.h"
21 : : #include "expr/skolem_manager.h"
22 : : #include "options/base_options.h"
23 : : #include "preprocessing/util/ite_utilities.h"
24 : : #include "smt/env.h"
25 : : #include "theory/arith/arith_utilities.h"
26 : : #include "theory/arith/linear/normal_form.h"
27 : : #include "theory/rewriter.h"
28 : : #include "theory/substitutions.h"
29 : : #include "theory/theory_model.h"
30 : :
31 : : using namespace std;
32 : :
33 : : namespace cvc5::internal {
34 : : namespace theory {
35 : : namespace arith {
36 : :
37 : 0 : Node ArithIteUtils::applyReduceVariablesInItes(Node n)
38 : : {
39 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
40 [ - - ]: 0 : if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
41 : : {
42 : 0 : nb << (n.getOperator());
43 : : }
44 [ - - ]: 0 : for (Node::iterator it = n.begin(), end = n.end(); it != end; ++it)
45 : : {
46 : 0 : nb << reduceVariablesInItes(*it);
47 : : }
48 : 0 : Node res = nb;
49 : 0 : return res;
50 : 0 : }
51 : :
52 : 0 : Node ArithIteUtils::reduceVariablesInItes(Node n)
53 : : {
54 : : using namespace cvc5::internal::kind;
55 [ - - ]: 0 : if (d_reduceVar.find(n) != d_reduceVar.end())
56 : : {
57 : 0 : Node res = d_reduceVar[n];
58 [ - - ]: 0 : return res.isNull() ? n : res;
59 : 0 : }
60 : :
61 [ - - ]: 0 : switch (n.getKind())
62 : : {
63 : 0 : case Kind::ITE:
64 : : {
65 : 0 : Node c = n[0], t = n[1], e = n[2];
66 : 0 : TypeNode tn = n.getType();
67 [ - - ]: 0 : if (tn.isRealOrInt())
68 : : {
69 : 0 : Node rc = reduceVariablesInItes(c);
70 : 0 : Node rt = reduceVariablesInItes(t);
71 : 0 : Node re = reduceVariablesInItes(e);
72 : :
73 : 0 : Node vt = d_varParts[t];
74 : 0 : Node ve = d_varParts[e];
75 [ - - ]: 0 : Node vpite = (vt == ve) ? vt : Node::null();
76 : :
77 : 0 : NodeManager* nm = nodeManager();
78 [ - - ]: 0 : if (vpite.isNull())
79 : : {
80 : 0 : Node rite = rc.iteNode(rt, re);
81 : : // do not apply
82 : 0 : d_reduceVar[n] = rite;
83 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
84 : 0 : d_varParts[n] = rite; // treat the ite as a variable
85 : 0 : return rite;
86 : 0 : }
87 : : else
88 : : {
89 : 0 : Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
90 : 0 : Node sum = nm->mkNode(Kind::ADD, vpite, constantite);
91 : 0 : d_reduceVar[n] = sum;
92 : 0 : d_constants[n] = constantite;
93 : 0 : d_varParts[n] = vpite;
94 : 0 : return sum;
95 : 0 : }
96 : 0 : }
97 : : else
98 : : { // non-arith ite
99 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
100 : : {
101 : : // don't bother adding to d_reduceVar
102 : 0 : return n;
103 : : }
104 : : else
105 : : {
106 : 0 : Node newIte = applyReduceVariablesInItes(n);
107 [ - - ]: 0 : d_reduceVar[n] = (n == newIte) ? Node::null() : newIte;
108 : 0 : return newIte;
109 : 0 : }
110 : : }
111 : 0 : }
112 : : break;
113 : 0 : default:
114 : : {
115 : 0 : TypeNode tn = n.getType();
116 : 0 : if (tn.isRealOrInt() && linear::Polynomial::isMember(n))
117 : : {
118 : 0 : Node newn = Node::null();
119 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
120 : : {
121 : 0 : newn = n;
122 : : }
123 [ - - ]: 0 : else if (n.getNumChildren() > 0)
124 : : {
125 : 0 : newn = applyReduceVariablesInItes(n);
126 : 0 : newn = rewrite(newn);
127 : 0 : Assert(linear::Polynomial::isMember(newn));
128 : : }
129 : : else
130 : : {
131 : 0 : newn = n;
132 : : }
133 : 0 : NodeManager* nm = nodeManager();
134 : 0 : linear::Polynomial p = linear::Polynomial::parsePolynomial(newn);
135 [ - - ]: 0 : if (p.isConstant())
136 : : {
137 : 0 : d_constants[n] = newn;
138 : 0 : d_varParts[n] = nm->mkConstRealOrInt(tn, Rational(0));
139 : : // don't bother adding to d_reduceVar
140 : 0 : return newn;
141 : : }
142 [ - - ]: 0 : else if (!p.containsConstant())
143 : : {
144 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
145 : 0 : d_varParts[n] = newn;
146 : 0 : d_reduceVar[n] = p.getNode();
147 : 0 : return p.getNode();
148 : : }
149 : : else
150 : : {
151 : 0 : linear::Monomial mc = p.getHead();
152 : 0 : d_constants[n] = mc.getConstant().getNode();
153 : 0 : d_varParts[n] = p.getTail().getNode();
154 : 0 : d_reduceVar[n] = newn;
155 : 0 : return newn;
156 : 0 : }
157 : 0 : }
158 : : else
159 : : {
160 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
161 : : {
162 : 0 : return n;
163 : : }
164 [ - - ]: 0 : if (n.getNumChildren() > 0)
165 : : {
166 : 0 : Node res = applyReduceVariablesInItes(n);
167 : 0 : d_reduceVar[n] = res;
168 : 0 : return res;
169 : 0 : }
170 : : else
171 : : {
172 : 0 : return n;
173 : : }
174 : : }
175 : 0 : }
176 : : break;
177 : : }
178 : : Unreachable();
179 : : }
180 : :
181 : 4 : ArithIteUtils::ArithIteUtils(
182 : : Env& env,
183 : : preprocessing::util::ContainsTermITEVisitor& contains,
184 : 4 : SubstitutionMap& subs)
185 : : : EnvObj(env),
186 : 4 : d_contains(contains),
187 : 4 : d_subs(subs),
188 : 4 : d_one(1),
189 : 4 : d_subcount(userContext(), 0),
190 : 4 : d_skolems(userContext()),
191 : 4 : d_implies(),
192 : 8 : d_orBinEqs()
193 : : {
194 : 4 : }
195 : :
196 : 4 : ArithIteUtils::~ArithIteUtils() {}
197 : :
198 : 0 : void ArithIteUtils::clear()
199 : : {
200 : 0 : d_reduceVar.clear();
201 : 0 : d_constants.clear();
202 : 0 : d_varParts.clear();
203 : 0 : }
204 : :
205 : 0 : const Integer& ArithIteUtils::gcdIte(Node n)
206 : : {
207 [ - - ]: 0 : if (d_gcds.find(n) != d_gcds.end())
208 : : {
209 : 0 : return d_gcds[n];
210 : : }
211 [ - - ]: 0 : if (n.isConst())
212 : : {
213 : 0 : const Rational& q = n.getConst<Rational>();
214 [ - - ]: 0 : if (q.isIntegral())
215 : : {
216 : 0 : d_gcds[n] = q.getNumerator();
217 : 0 : return d_gcds[n];
218 : : }
219 : : else
220 : : {
221 : 0 : return d_one;
222 : : }
223 : : }
224 : 0 : else if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
225 : : {
226 : 0 : const Integer& tgcd = gcdIte(n[1]);
227 [ - - ]: 0 : if (tgcd.isOne())
228 : : {
229 : 0 : d_gcds[n] = d_one;
230 : 0 : return d_one;
231 : : }
232 : : else
233 : : {
234 : 0 : const Integer& egcd = gcdIte(n[2]);
235 : 0 : Integer ite_gcd = tgcd.gcd(egcd);
236 : 0 : d_gcds[n] = ite_gcd;
237 : 0 : return d_gcds[n];
238 : 0 : }
239 : : }
240 : 0 : return d_one;
241 : : }
242 : :
243 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q)
244 : : {
245 [ - - ]: 0 : if (n.isConst())
246 : : {
247 : 0 : Assert(n.getType().isRealOrInt());
248 : 0 : return nodeManager()->mkConstRealOrInt(n.getType(),
249 : 0 : n.getConst<Rational>() * q);
250 : : }
251 : : else
252 : : {
253 : 0 : Assert(n.getKind() == Kind::ITE);
254 : 0 : Assert(n.getType().isInteger());
255 : 0 : Node rc = reduceConstantIteByGCD(n[0]);
256 : 0 : Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
257 : 0 : Node re = reduceIteConstantIteByGCD_rec(n[2], q);
258 : 0 : return rc.iteNode(rt, re);
259 : 0 : }
260 : : }
261 : :
262 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD(Node n)
263 : : {
264 : 0 : Assert(n.getKind() == Kind::ITE);
265 : 0 : Assert(n.getType().isRealOrInt());
266 : 0 : const Integer& gcd = gcdIte(n);
267 : 0 : NodeManager* nm = nodeManager();
268 [ - - ]: 0 : if (gcd.isOne())
269 : : {
270 : 0 : Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1], n[2]);
271 : 0 : d_reduceGcd[n] = newIte;
272 : 0 : return newIte;
273 : 0 : }
274 [ - - ]: 0 : else if (gcd.isZero())
275 : : {
276 : 0 : Node zeroNode = nm->mkConstRealOrInt(n.getType(), Rational(0));
277 : 0 : d_reduceGcd[n] = zeroNode;
278 : 0 : return zeroNode;
279 : 0 : }
280 : : else
281 : : {
282 : 0 : Rational divBy(Integer(1), gcd);
283 : 0 : Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
284 : 0 : Node gcdNode = nm->mkConstRealOrInt(n.getType(), Rational(gcd));
285 : 0 : Node multIte = nm->mkNode(Kind::MULT, gcdNode, redite);
286 : 0 : d_reduceGcd[n] = multIte;
287 : 0 : return multIte;
288 : 0 : }
289 : : }
290 : :
291 : 0 : Node ArithIteUtils::reduceConstantIteByGCD(Node n)
292 : : {
293 [ - - ]: 0 : if (d_reduceGcd.find(n) != d_reduceGcd.end())
294 : : {
295 : 0 : return d_reduceGcd[n];
296 : : }
297 : 0 : if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
298 : : {
299 : 0 : return reduceIteConstantIteByGCD(n);
300 : : }
301 : :
302 [ - - ]: 0 : if (n.getNumChildren() > 0)
303 : : {
304 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
305 [ - - ]: 0 : if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
306 : : {
307 : 0 : nb << (n.getOperator());
308 : : }
309 : 0 : bool anychange = false;
310 [ - - ]: 0 : for (Node::iterator it = n.begin(), end = n.end(); it != end; ++it)
311 : : {
312 : 0 : Node child = *it;
313 : 0 : Node redchild = reduceConstantIteByGCD(child);
314 : 0 : anychange = anychange || (child != redchild);
315 : 0 : nb << redchild;
316 : 0 : }
317 [ - - ]: 0 : if (anychange)
318 : : {
319 : 0 : Node res = nb;
320 : 0 : d_reduceGcd[n] = res;
321 : 0 : return res;
322 : 0 : }
323 : : else
324 : : {
325 : 0 : d_reduceGcd[n] = n;
326 : 0 : return n;
327 : : }
328 : 0 : }
329 : : else
330 : : {
331 : 0 : return n;
332 : : }
333 : : }
334 : :
335 : 8 : unsigned ArithIteUtils::getSubCount() const { return d_subcount; }
336 : :
337 : 0 : void ArithIteUtils::addSubstitution(TNode f, TNode t)
338 : : {
339 [ - - ]: 0 : Trace("arith::ite") << "adding " << f << " -> " << t << endl;
340 : 0 : d_subcount = d_subcount + 1;
341 : 0 : d_subs.addSubstitution(f, t);
342 : 0 : }
343 : :
344 : 0 : Node ArithIteUtils::applySubstitutions(TNode f)
345 : : {
346 : 0 : AlwaysAssert(!options().base.incrementalSolving);
347 : 0 : return d_subs.apply(f);
348 : : }
349 : :
350 : 0 : Node ArithIteUtils::selectForCmp(Node n) const
351 : : {
352 [ - - ]: 0 : if (n.getKind() == Kind::ITE)
353 : : {
354 [ - - ]: 0 : if (d_skolems.find(n[0]) != d_skolems.end())
355 : : {
356 : 0 : return selectForCmp(n[1]);
357 : : }
358 : : }
359 : 0 : return n;
360 : : }
361 : :
362 : 4 : void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions)
363 : : {
364 [ - + ][ - + ]: 4 : AlwaysAssert(!options().base.incrementalSolving);
[ - - ]
365 [ + + ]: 12 : for (size_t i = 0, N = assertions.size(); i < N; ++i)
366 : : {
367 : 8 : collectAssertions(assertions[i]);
368 : : }
369 : : bool solvedSomething;
370 [ - + ]: 4 : do
371 : : {
372 : 4 : solvedSomething = false;
373 : 4 : size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
374 [ - + ]: 4 : for (; readPos < N; readPos++)
375 : : {
376 : 0 : Node curr = d_orBinEqs[readPos];
377 : 0 : bool solved = solveBinOr(curr);
378 [ - - ]: 0 : if (solved)
379 : : {
380 : 0 : solvedSomething = true;
381 : : }
382 : : else
383 : : {
384 : : // didn't solve, push back
385 : 0 : d_orBinEqs[writePos] = curr;
386 : 0 : writePos++;
387 : : }
388 : 0 : }
389 [ - + ][ - + ]: 4 : Assert(writePos <= N);
[ - - ]
390 : 4 : d_orBinEqs.resize(writePos);
391 : : } while (solvedSomething);
392 : :
393 : 4 : d_implies.clear();
394 : 4 : d_orBinEqs.clear();
395 : 4 : }
396 : :
397 : 0 : void ArithIteUtils::addImplications(Node x, Node y)
398 : : {
399 : : // (or x y)
400 : : // (=> (not x) y)
401 : : // (=> (not y) x)
402 : :
403 : 0 : Node xneg = x.negate();
404 : 0 : Node yneg = y.negate();
405 : 0 : d_implies[xneg].insert(y);
406 : 0 : d_implies[yneg].insert(x);
407 : 0 : }
408 : :
409 : 8 : void ArithIteUtils::collectAssertions(TNode assertion)
410 : : {
411 [ - + ]: 8 : if (assertion.getKind() == Kind::OR)
412 : : {
413 [ - - ]: 0 : if (assertion.getNumChildren() == 2)
414 : : {
415 : 0 : TNode left = assertion[0], right = assertion[1];
416 : 0 : addImplications(left, right);
417 : 0 : if (left.getKind() == Kind::EQUAL && right.getKind() == Kind::EQUAL)
418 : : {
419 : 0 : if (left[0].getType().isInteger() && right[0].getType().isInteger())
420 : : {
421 : 0 : d_orBinEqs.push_back(assertion);
422 : : }
423 : : }
424 : 0 : }
425 : : }
426 [ - + ]: 8 : else if (assertion.getKind() == Kind::AND)
427 : : {
428 [ - - ]: 0 : for (unsigned i = 0, N = assertion.getNumChildren(); i < N; ++i)
429 : : {
430 : 0 : collectAssertions(assertion[i]);
431 : : }
432 : : }
433 : 8 : }
434 : :
435 : 0 : Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const
436 : : {
437 : 0 : Node negtb = tb.negate();
438 : 0 : Node negfb = fb.negate();
439 : 0 : ImpMap::const_iterator ti = d_implies.find(negtb);
440 : 0 : ImpMap::const_iterator fi = d_implies.find(negfb);
441 : :
442 [ - - ][ - - ]: 0 : if (ti != d_implies.end() && fi != d_implies.end())
[ - - ]
443 : : {
444 : 0 : const std::set<Node>& negtimp = ti->second;
445 : 0 : const std::set<Node>& negfimp = fi->second;
446 : :
447 : : // (or (not x) y)
448 : : // (or x z)
449 : : // (or y z)
450 : : // ---
451 : : // (ite x y z) return x
452 : : // ---
453 : : // (not y) => (not x)
454 : : // (not z) => x
455 : 0 : std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
456 [ - - ]: 0 : for (; ci != cend; ++ci)
457 : : {
458 : 0 : Node impliedByNotTB = *ci;
459 : 0 : Node impliedByNotTBNeg = impliedByNotTB.negate();
460 [ - - ]: 0 : if (negfimp.find(impliedByNotTBNeg) != negfimp.end())
461 : : {
462 : 0 : return impliedByNotTBNeg; // implies tb
463 : : }
464 [ - - ][ - - ]: 0 : }
465 : : }
466 : :
467 : 0 : return Node::null();
468 : 0 : }
469 : :
470 : 0 : bool ArithIteUtils::solveBinOr(TNode binor)
471 : : {
472 : 0 : Assert(binor.getKind() == Kind::OR);
473 : 0 : Assert(binor.getNumChildren() == 2);
474 : 0 : Assert(binor[0].getKind() == Kind::EQUAL);
475 : 0 : Assert(binor[1].getKind() == Kind::EQUAL);
476 : :
477 : : // Node n =
478 : 0 : Node n = applySubstitutions(binor);
479 [ - - ]: 0 : if (n != binor)
480 : : {
481 : 0 : n = rewrite(n);
482 : :
483 : 0 : if (!(n.getKind() == Kind::OR && n.getNumChildren() == 2
484 : 0 : && n[0].getKind() == Kind::EQUAL && n[1].getKind() == Kind::EQUAL))
485 : : {
486 : 0 : return false;
487 : : }
488 : : }
489 : :
490 : 0 : Assert(n.getKind() == Kind::OR);
491 : 0 : Assert(n.getNumChildren() == 2);
492 : 0 : TNode l = n[0];
493 : 0 : TNode r = n[1];
494 : :
495 : 0 : Assert(l.getKind() == Kind::EQUAL);
496 : 0 : Assert(r.getKind() == Kind::EQUAL);
497 : :
498 [ - - ]: 0 : Trace("arith::ite") << "bin or " << n << endl;
499 : :
500 : 0 : bool lArithEq = l.getKind() == Kind::EQUAL && l[0].getType().isInteger();
501 : 0 : bool rArithEq = r.getKind() == Kind::EQUAL && r[0].getType().isInteger();
502 : :
503 [ - - ][ - - ]: 0 : if (lArithEq && rArithEq)
504 : : {
505 : 0 : TNode sel = Node::null();
506 : 0 : TNode otherL = Node::null();
507 : 0 : TNode otherR = Node::null();
508 [ - - ]: 0 : if (l[0] == r[0])
509 : : {
510 : 0 : sel = l[0];
511 : 0 : otherL = l[1];
512 : 0 : otherR = r[1];
513 : : }
514 [ - - ]: 0 : else if (l[0] == r[1])
515 : : {
516 : 0 : sel = l[0];
517 : 0 : otherL = l[1];
518 : 0 : otherR = r[0];
519 : : }
520 [ - - ]: 0 : else if (l[1] == r[0])
521 : : {
522 : 0 : sel = l[1];
523 : 0 : otherL = l[0];
524 : 0 : otherR = r[1];
525 : : }
526 [ - - ]: 0 : else if (l[1] == r[1])
527 : : {
528 : 0 : sel = l[1];
529 : 0 : otherL = l[0];
530 : 0 : otherR = r[0];
531 : : }
532 [ - - ]: 0 : Trace("arith::ite") << "selected " << sel << endl;
533 : 0 : if (sel.isVar() && sel.getKind() != Kind::SKOLEM)
534 : : {
535 [ - - ]: 0 : Trace("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
536 : 0 : Node useForCmpL = selectForCmp(otherL);
537 : 0 : Node useForCmpR = selectForCmp(otherR);
538 : :
539 : 0 : Assert(linear::Polynomial::isMember(sel));
540 : 0 : Assert(linear::Polynomial::isMember(useForCmpL));
541 : 0 : Assert(linear::Polynomial::isMember(useForCmpR));
542 : : linear::Polynomial lside =
543 : 0 : linear::Polynomial::parsePolynomial(useForCmpL);
544 : : linear::Polynomial rside =
545 : 0 : linear::Polynomial::parsePolynomial(useForCmpR);
546 : 0 : linear::Polynomial diff = lside - rside;
547 : :
548 [ - - ]: 0 : Trace("arith::ite") << "diff: " << diff.getNode() << endl;
549 [ - - ]: 0 : if (diff.isConstant())
550 : : {
551 : : // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
552 : :
553 : 0 : NodeManager* nm = nodeManager();
554 : 0 : SkolemManager* sm = nm->getSkolemManager();
555 : :
556 : 0 : Node cnd = findIteCnd(binor[0], binor[1]);
557 : :
558 : 0 : Node eq = sel.eqNode(otherL);
559 : 0 : Node sk = sm->mkPurifySkolem(eq);
560 : 0 : Node ite = sk.iteNode(otherL, otherR);
561 : 0 : d_skolems.insert(sk, cnd);
562 : : // Given (or (= x c) (= x d)), we replace x by (ite @purifyX c d),
563 : : // where @purifyX is the purification skolem for (= x c), where c and
564 : : // d are known to be distinct.
565 : 0 : addSubstitution(sel, ite);
566 : 0 : return true;
567 : 0 : }
568 [ - - ][ - - ]: 0 : }
[ - - ][ - - ]
[ - - ]
569 [ - - ][ - - ]: 0 : }
[ - - ]
570 : 0 : return false;
571 : 0 : }
572 : :
573 : : } // namespace arith
574 : : } // namespace theory
575 : : } // namespace cvc5::internal
|