Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * [[ Add one-line brief description here ]]
11 : : *
12 : : * [[ Add lengthier description here ]]
13 : : * \todo document this file
14 : : */
15 : :
16 : : #include "theory/arith/arith_ite_utils.h"
17 : :
18 : : #include <ostream>
19 : :
20 : : #include "base/output.h"
21 : : #include "expr/skolem_manager.h"
22 : : #include "options/base_options.h"
23 : : #include "preprocessing/util/ite_utilities.h"
24 : : #include "smt/env.h"
25 : : #include "theory/arith/arith_utilities.h"
26 : : #include "theory/arith/linear/normal_form.h"
27 : : #include "theory/rewriter.h"
28 : : #include "theory/substitutions.h"
29 : : #include "theory/theory_model.h"
30 : :
31 : : using namespace std;
32 : :
33 : : namespace cvc5::internal {
34 : : namespace theory {
35 : : namespace arith {
36 : :
37 : 0 : Node ArithIteUtils::applyReduceVariablesInItes(Node n){
38 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
39 [ - - ]: 0 : if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
40 : 0 : nb << (n.getOperator());
41 : : }
42 [ - - ]: 0 : for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
43 : 0 : nb << reduceVariablesInItes(*it);
44 : : }
45 : 0 : Node res = nb;
46 : 0 : return res;
47 : 0 : }
48 : :
49 : 0 : Node ArithIteUtils::reduceVariablesInItes(Node n){
50 : : using namespace cvc5::internal::kind;
51 [ - - ]: 0 : if(d_reduceVar.find(n) != d_reduceVar.end()){
52 : 0 : Node res = d_reduceVar[n];
53 [ - - ]: 0 : return res.isNull() ? n : res;
54 : 0 : }
55 : :
56 [ - - ]: 0 : switch(n.getKind()){
57 : 0 : case Kind::ITE:
58 : : {
59 : 0 : Node c = n[0], t = n[1], e = n[2];
60 : 0 : TypeNode tn = n.getType();
61 [ - - ]: 0 : if (tn.isRealOrInt())
62 : : {
63 : 0 : Node rc = reduceVariablesInItes(c);
64 : 0 : Node rt = reduceVariablesInItes(t);
65 : 0 : Node re = reduceVariablesInItes(e);
66 : :
67 : 0 : Node vt = d_varParts[t];
68 : 0 : Node ve = d_varParts[e];
69 [ - - ]: 0 : Node vpite = (vt == ve) ? vt : Node::null();
70 : :
71 : 0 : NodeManager* nm = nodeManager();
72 [ - - ]: 0 : if (vpite.isNull())
73 : : {
74 : 0 : Node rite = rc.iteNode(rt, re);
75 : : // do not apply
76 : 0 : d_reduceVar[n] = rite;
77 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
78 : 0 : d_varParts[n] = rite; // treat the ite as a variable
79 : 0 : return rite;
80 : 0 : }
81 : : else
82 : : {
83 : 0 : Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
84 : 0 : Node sum = nm->mkNode(Kind::ADD, vpite, constantite);
85 : 0 : d_reduceVar[n] = sum;
86 : 0 : d_constants[n] = constantite;
87 : 0 : d_varParts[n] = vpite;
88 : 0 : return sum;
89 : 0 : }
90 : 0 : }
91 : : else
92 : : { // non-arith ite
93 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
94 : : {
95 : : // don't bother adding to d_reduceVar
96 : 0 : return n;
97 : : }
98 : : else
99 : : {
100 : 0 : Node newIte = applyReduceVariablesInItes(n);
101 [ - - ]: 0 : d_reduceVar[n] = (n == newIte) ? Node::null() : newIte;
102 : 0 : return newIte;
103 : 0 : }
104 : : }
105 : 0 : }
106 : : break;
107 : 0 : default:
108 : : {
109 : 0 : TypeNode tn = n.getType();
110 : 0 : if (tn.isRealOrInt() && linear::Polynomial::isMember(n))
111 : : {
112 : 0 : Node newn = Node::null();
113 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
114 : : {
115 : 0 : newn = n;
116 : : }
117 [ - - ]: 0 : else if (n.getNumChildren() > 0)
118 : : {
119 : 0 : newn = applyReduceVariablesInItes(n);
120 : 0 : newn = rewrite(newn);
121 : 0 : Assert(linear::Polynomial::isMember(newn));
122 : : }
123 : : else
124 : : {
125 : 0 : newn = n;
126 : : }
127 : 0 : NodeManager* nm = nodeManager();
128 : 0 : linear::Polynomial p = linear::Polynomial::parsePolynomial(newn);
129 [ - - ]: 0 : if (p.isConstant())
130 : : {
131 : 0 : d_constants[n] = newn;
132 : 0 : d_varParts[n] = nm->mkConstRealOrInt(tn, Rational(0));
133 : : // don't bother adding to d_reduceVar
134 : 0 : return newn;
135 : : }
136 [ - - ]: 0 : else if (!p.containsConstant())
137 : : {
138 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
139 : 0 : d_varParts[n] = newn;
140 : 0 : d_reduceVar[n] = p.getNode();
141 : 0 : return p.getNode();
142 : : }
143 : : else
144 : : {
145 : 0 : linear::Monomial mc = p.getHead();
146 : 0 : d_constants[n] = mc.getConstant().getNode();
147 : 0 : d_varParts[n] = p.getTail().getNode();
148 : 0 : d_reduceVar[n] = newn;
149 : 0 : return newn;
150 : 0 : }
151 : 0 : }
152 : : else
153 : : {
154 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
155 : : {
156 : 0 : return n;
157 : : }
158 [ - - ]: 0 : if (n.getNumChildren() > 0)
159 : : {
160 : 0 : Node res = applyReduceVariablesInItes(n);
161 : 0 : d_reduceVar[n] = res;
162 : 0 : return res;
163 : 0 : }
164 : : else
165 : : {
166 : 0 : return n;
167 : : }
168 : : }
169 : 0 : }
170 : : break;
171 : : }
172 : : Unreachable();
173 : : }
174 : :
175 : 4 : ArithIteUtils::ArithIteUtils(
176 : : Env& env,
177 : : preprocessing::util::ContainsTermITEVisitor& contains,
178 : 4 : SubstitutionMap& subs)
179 : : : EnvObj(env),
180 : 4 : d_contains(contains),
181 : 4 : d_subs(subs),
182 : 4 : d_one(1),
183 : 4 : d_subcount(userContext(), 0),
184 : 4 : d_skolems(userContext()),
185 : 4 : d_implies(),
186 : 8 : d_orBinEqs()
187 : : {
188 : 4 : }
189 : :
190 : 4 : ArithIteUtils::~ArithIteUtils(){
191 : 4 : }
192 : :
193 : 0 : void ArithIteUtils::clear(){
194 : 0 : d_reduceVar.clear();
195 : 0 : d_constants.clear();
196 : 0 : d_varParts.clear();
197 : 0 : }
198 : :
199 : 0 : const Integer& ArithIteUtils::gcdIte(Node n){
200 [ - - ]: 0 : if(d_gcds.find(n) != d_gcds.end()){
201 : 0 : return d_gcds[n];
202 : : }
203 [ - - ]: 0 : if (n.isConst())
204 : : {
205 : 0 : const Rational& q = n.getConst<Rational>();
206 [ - - ]: 0 : if(q.isIntegral()){
207 : 0 : d_gcds[n] = q.getNumerator();
208 : 0 : return d_gcds[n];
209 : : }else{
210 : 0 : return d_one;
211 : : }
212 : : }
213 : 0 : else if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
214 : : {
215 : 0 : const Integer& tgcd = gcdIte(n[1]);
216 [ - - ]: 0 : if(tgcd.isOne()){
217 : 0 : d_gcds[n] = d_one;
218 : 0 : return d_one;
219 : : }else{
220 : 0 : const Integer& egcd = gcdIte(n[2]);
221 : 0 : Integer ite_gcd = tgcd.gcd(egcd);
222 : 0 : d_gcds[n] = ite_gcd;
223 : 0 : return d_gcds[n];
224 : 0 : }
225 : : }
226 : 0 : return d_one;
227 : : }
228 : :
229 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q){
230 [ - - ]: 0 : if(n.isConst()){
231 : 0 : Assert(n.getType().isRealOrInt());
232 : 0 : return nodeManager()->mkConstRealOrInt(n.getType(),
233 : 0 : n.getConst<Rational>() * q);
234 : : }else{
235 : 0 : Assert(n.getKind() == Kind::ITE);
236 : 0 : Assert(n.getType().isInteger());
237 : 0 : Node rc = reduceConstantIteByGCD(n[0]);
238 : 0 : Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
239 : 0 : Node re = reduceIteConstantIteByGCD_rec(n[2], q);
240 : 0 : return rc.iteNode(rt, re);
241 : 0 : }
242 : : }
243 : :
244 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD(Node n){
245 : 0 : Assert(n.getKind() == Kind::ITE);
246 : 0 : Assert(n.getType().isRealOrInt());
247 : 0 : const Integer& gcd = gcdIte(n);
248 : 0 : NodeManager* nm = nodeManager();
249 [ - - ]: 0 : if(gcd.isOne()){
250 : 0 : Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1],n[2]);
251 : 0 : d_reduceGcd[n] = newIte;
252 : 0 : return newIte;
253 [ - - ]: 0 : }else if(gcd.isZero()){
254 : 0 : Node zeroNode = nm->mkConstRealOrInt(n.getType(), Rational(0));
255 : 0 : d_reduceGcd[n] = zeroNode;
256 : 0 : return zeroNode;
257 : 0 : }else{
258 : 0 : Rational divBy(Integer(1), gcd);
259 : 0 : Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
260 : 0 : Node gcdNode = nm->mkConstRealOrInt(n.getType(), Rational(gcd));
261 : 0 : Node multIte = nm->mkNode(Kind::MULT, gcdNode, redite);
262 : 0 : d_reduceGcd[n] = multIte;
263 : 0 : return multIte;
264 : 0 : }
265 : : }
266 : :
267 : 0 : Node ArithIteUtils::reduceConstantIteByGCD(Node n){
268 [ - - ]: 0 : if(d_reduceGcd.find(n) != d_reduceGcd.end()){
269 : 0 : return d_reduceGcd[n];
270 : : }
271 : 0 : if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
272 : : {
273 : 0 : return reduceIteConstantIteByGCD(n);
274 : : }
275 : :
276 [ - - ]: 0 : if(n.getNumChildren() > 0){
277 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
278 [ - - ]: 0 : if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
279 : 0 : nb << (n.getOperator());
280 : : }
281 : 0 : bool anychange = false;
282 [ - - ]: 0 : for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
283 : 0 : Node child = *it;
284 : 0 : Node redchild = reduceConstantIteByGCD(child);
285 : 0 : anychange = anychange || (child != redchild);
286 : 0 : nb << redchild;
287 : 0 : }
288 [ - - ]: 0 : if(anychange){
289 : 0 : Node res = nb;
290 : 0 : d_reduceGcd[n] = res;
291 : 0 : return res;
292 : 0 : }else{
293 : 0 : d_reduceGcd[n] = n;
294 : 0 : return n;
295 : : }
296 : 0 : }else{
297 : 0 : return n;
298 : : }
299 : : }
300 : :
301 : 8 : unsigned ArithIteUtils::getSubCount() const{
302 : 8 : return d_subcount;
303 : : }
304 : :
305 : 0 : void ArithIteUtils::addSubstitution(TNode f, TNode t){
306 [ - - ]: 0 : Trace("arith::ite") << "adding " << f << " -> " << t << endl;
307 : 0 : d_subcount = d_subcount + 1;
308 : 0 : d_subs.addSubstitution(f, t);
309 : 0 : }
310 : :
311 : 0 : Node ArithIteUtils::applySubstitutions(TNode f){
312 : 0 : AlwaysAssert(!options().base.incrementalSolving);
313 : 0 : return d_subs.apply(f);
314 : : }
315 : :
316 : 0 : Node ArithIteUtils::selectForCmp(Node n) const{
317 [ - - ]: 0 : if (n.getKind() == Kind::ITE)
318 : : {
319 [ - - ]: 0 : if(d_skolems.find(n[0]) != d_skolems.end()){
320 : 0 : return selectForCmp(n[1]);
321 : : }
322 : : }
323 : 0 : return n;
324 : : }
325 : :
326 : 4 : void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions){
327 [ - + ][ - + ]: 4 : AlwaysAssert(!options().base.incrementalSolving);
[ - - ]
328 [ + + ]: 12 : for(size_t i=0, N=assertions.size(); i < N; ++i){
329 : 8 : collectAssertions(assertions[i]);
330 : : }
331 : : bool solvedSomething;
332 [ - + ]: 4 : do{
333 : 4 : solvedSomething = false;
334 : 4 : size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
335 [ - + ]: 4 : for(; readPos < N; readPos++){
336 : 0 : Node curr = d_orBinEqs[readPos];
337 : 0 : bool solved = solveBinOr(curr);
338 [ - - ]: 0 : if(solved){
339 : 0 : solvedSomething = true;
340 : : }else{
341 : : // didn't solve, push back
342 : 0 : d_orBinEqs[writePos] = curr;
343 : 0 : writePos++;
344 : : }
345 : 0 : }
346 [ - + ][ - + ]: 4 : Assert(writePos <= N);
[ - - ]
347 : 4 : d_orBinEqs.resize(writePos);
348 : : }while(solvedSomething);
349 : :
350 : 4 : d_implies.clear();
351 : 4 : d_orBinEqs.clear();
352 : 4 : }
353 : :
354 : 0 : void ArithIteUtils::addImplications(Node x, Node y){
355 : : // (or x y)
356 : : // (=> (not x) y)
357 : : // (=> (not y) x)
358 : :
359 : 0 : Node xneg = x.negate();
360 : 0 : Node yneg = y.negate();
361 : 0 : d_implies[xneg].insert(y);
362 : 0 : d_implies[yneg].insert(x);
363 : 0 : }
364 : :
365 : 8 : void ArithIteUtils::collectAssertions(TNode assertion){
366 [ - + ]: 8 : if (assertion.getKind() == Kind::OR)
367 : : {
368 [ - - ]: 0 : if(assertion.getNumChildren() == 2){
369 : 0 : TNode left = assertion[0], right = assertion[1];
370 : 0 : addImplications(left, right);
371 : 0 : if (left.getKind() == Kind::EQUAL && right.getKind() == Kind::EQUAL)
372 : : {
373 : 0 : if(left[0].getType().isInteger() && right[0].getType().isInteger()){
374 : 0 : d_orBinEqs.push_back(assertion);
375 : : }
376 : : }
377 : 0 : }
378 : : }
379 [ - + ]: 8 : else if (assertion.getKind() == Kind::AND)
380 : : {
381 [ - - ]: 0 : for(unsigned i=0, N=assertion.getNumChildren(); i < N; ++i){
382 : 0 : collectAssertions(assertion[i]);
383 : : }
384 : : }
385 : 8 : }
386 : :
387 : 0 : Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const{
388 : 0 : Node negtb = tb.negate();
389 : 0 : Node negfb = fb.negate();
390 : 0 : ImpMap::const_iterator ti = d_implies.find(negtb);
391 : 0 : ImpMap::const_iterator fi = d_implies.find(negfb);
392 : :
393 [ - - ][ - - ]: 0 : if(ti != d_implies.end() && fi != d_implies.end()){
[ - - ]
394 : 0 : const std::set<Node>& negtimp = ti->second;
395 : 0 : const std::set<Node>& negfimp = fi->second;
396 : :
397 : : // (or (not x) y)
398 : : // (or x z)
399 : : // (or y z)
400 : : // ---
401 : : // (ite x y z) return x
402 : : // ---
403 : : // (not y) => (not x)
404 : : // (not z) => x
405 : 0 : std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
406 [ - - ]: 0 : for (; ci != cend; ++ci)
407 : : {
408 : 0 : Node impliedByNotTB = *ci;
409 : 0 : Node impliedByNotTBNeg = impliedByNotTB.negate();
410 [ - - ]: 0 : if(negfimp.find(impliedByNotTBNeg) != negfimp.end()){
411 : 0 : return impliedByNotTBNeg; // implies tb
412 : : }
413 [ - - ][ - - ]: 0 : }
414 : : }
415 : :
416 : 0 : return Node::null();
417 : 0 : }
418 : :
419 : 0 : bool ArithIteUtils::solveBinOr(TNode binor){
420 : 0 : Assert(binor.getKind() == Kind::OR);
421 : 0 : Assert(binor.getNumChildren() == 2);
422 : 0 : Assert(binor[0].getKind() == Kind::EQUAL);
423 : 0 : Assert(binor[1].getKind() == Kind::EQUAL);
424 : :
425 : : //Node n =
426 : 0 : Node n = applySubstitutions(binor);
427 [ - - ]: 0 : if(n != binor){
428 : 0 : n = rewrite(n);
429 : :
430 : 0 : if (!(n.getKind() == Kind::OR && n.getNumChildren() == 2
431 : 0 : && n[0].getKind() == Kind::EQUAL && n[1].getKind() == Kind::EQUAL))
432 : : {
433 : 0 : return false;
434 : : }
435 : : }
436 : :
437 : 0 : Assert(n.getKind() == Kind::OR);
438 : 0 : Assert(n.getNumChildren() == 2);
439 : 0 : TNode l = n[0];
440 : 0 : TNode r = n[1];
441 : :
442 : 0 : Assert(l.getKind() == Kind::EQUAL);
443 : 0 : Assert(r.getKind() == Kind::EQUAL);
444 : :
445 [ - - ]: 0 : Trace("arith::ite") << "bin or " << n << endl;
446 : :
447 : 0 : bool lArithEq = l.getKind() == Kind::EQUAL && l[0].getType().isInteger();
448 : 0 : bool rArithEq = r.getKind() == Kind::EQUAL && r[0].getType().isInteger();
449 : :
450 [ - - ][ - - ]: 0 : if(lArithEq && rArithEq){
451 : 0 : TNode sel = Node::null();
452 : 0 : TNode otherL = Node::null();
453 : 0 : TNode otherR = Node::null();
454 [ - - ]: 0 : if(l[0] == r[0]) {
455 : 0 : sel = l[0]; otherL = l[1]; otherR = r[1];
456 [ - - ]: 0 : }else if(l[0] == r[1]){
457 : 0 : sel = l[0]; otherL = l[1]; otherR = r[0];
458 [ - - ]: 0 : }else if(l[1] == r[0]){
459 : 0 : sel = l[1]; otherL = l[0]; otherR = r[1];
460 [ - - ]: 0 : }else if(l[1] == r[1]){
461 : 0 : sel = l[1]; otherL = l[0]; otherR = r[0];
462 : : }
463 [ - - ]: 0 : Trace("arith::ite") << "selected " << sel << endl;
464 : 0 : if (sel.isVar() && sel.getKind() != Kind::SKOLEM)
465 : : {
466 [ - - ]: 0 : Trace("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
467 : 0 : Node useForCmpL = selectForCmp(otherL);
468 : 0 : Node useForCmpR = selectForCmp(otherR);
469 : :
470 : 0 : Assert(linear::Polynomial::isMember(sel));
471 : 0 : Assert(linear::Polynomial::isMember(useForCmpL));
472 : 0 : Assert(linear::Polynomial::isMember(useForCmpR));
473 : 0 : linear::Polynomial lside = linear::Polynomial::parsePolynomial( useForCmpL );
474 : 0 : linear::Polynomial rside = linear::Polynomial::parsePolynomial( useForCmpR );
475 : 0 : linear::Polynomial diff = lside-rside;
476 : :
477 [ - - ]: 0 : Trace("arith::ite") << "diff: " << diff.getNode() << endl;
478 [ - - ]: 0 : if(diff.isConstant()){
479 : : // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
480 : :
481 : 0 : NodeManager* nm = nodeManager();
482 : 0 : SkolemManager* sm = nm->getSkolemManager();
483 : :
484 : 0 : Node cnd = findIteCnd(binor[0], binor[1]);
485 : :
486 : 0 : Node eq = sel.eqNode(otherL);
487 : 0 : Node sk = sm->mkPurifySkolem(eq);
488 : 0 : Node ite = sk.iteNode(otherL, otherR);
489 : 0 : d_skolems.insert(sk, cnd);
490 : : // Given (or (= x c) (= x d)), we replace x by (ite @purifyX c d),
491 : : // where @purifyX is the purification skolem for (= x c), where c and
492 : : // d are known to be distinct.
493 : 0 : addSubstitution(sel, ite);
494 : 0 : return true;
495 : 0 : }
496 [ - - ][ - - ]: 0 : }
[ - - ][ - - ]
[ - - ]
497 [ - - ][ - - ]: 0 : }
[ - - ]
498 : 0 : return false;
499 : 0 : }
500 : :
501 : : } // namespace arith
502 : : } // namespace theory
503 : : } // namespace cvc5::internal
|