Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Tim King, Aina Niemetz, Andrew Reynolds
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * [[ Add one-line brief description here ]]
14 : : *
15 : : * [[ Add lengthier description here ]]
16 : : * \todo document this file
17 : : */
18 : :
19 : : #include "theory/arith/arith_ite_utils.h"
20 : :
21 : : #include <ostream>
22 : :
23 : : #include "base/output.h"
24 : : #include "expr/skolem_manager.h"
25 : : #include "options/base_options.h"
26 : : #include "preprocessing/util/ite_utilities.h"
27 : : #include "smt/env.h"
28 : : #include "theory/arith/arith_utilities.h"
29 : : #include "theory/arith/linear/normal_form.h"
30 : : #include "theory/rewriter.h"
31 : : #include "theory/substitutions.h"
32 : : #include "theory/theory_model.h"
33 : :
34 : : using namespace std;
35 : :
36 : : namespace cvc5::internal {
37 : : namespace theory {
38 : : namespace arith {
39 : :
40 : 0 : Node ArithIteUtils::applyReduceVariablesInItes(Node n){
41 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
42 [ - - ]: 0 : if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
43 : 0 : nb << (n.getOperator());
44 : : }
45 [ - - ]: 0 : for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
46 : 0 : nb << reduceVariablesInItes(*it);
47 : : }
48 : 0 : Node res = nb;
49 : 0 : return res;
50 : : }
51 : :
52 : 0 : Node ArithIteUtils::reduceVariablesInItes(Node n){
53 : : using namespace cvc5::internal::kind;
54 [ - - ]: 0 : if(d_reduceVar.find(n) != d_reduceVar.end()){
55 : 0 : Node res = d_reduceVar[n];
56 [ - - ]: 0 : return res.isNull() ? n : res;
57 : : }
58 : :
59 [ - - ]: 0 : switch(n.getKind()){
60 : 0 : case Kind::ITE:
61 : : {
62 : 0 : Node c = n[0], t = n[1], e = n[2];
63 : 0 : TypeNode tn = n.getType();
64 [ - - ]: 0 : if (tn.isRealOrInt())
65 : : {
66 : 0 : Node rc = reduceVariablesInItes(c);
67 : 0 : Node rt = reduceVariablesInItes(t);
68 : 0 : Node re = reduceVariablesInItes(e);
69 : :
70 : 0 : Node vt = d_varParts[t];
71 : 0 : Node ve = d_varParts[e];
72 [ - - ]: 0 : Node vpite = (vt == ve) ? vt : Node::null();
73 : :
74 : 0 : NodeManager* nm = nodeManager();
75 [ - - ]: 0 : if (vpite.isNull())
76 : : {
77 : 0 : Node rite = rc.iteNode(rt, re);
78 : : // do not apply
79 : 0 : d_reduceVar[n] = rite;
80 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
81 : 0 : d_varParts[n] = rite; // treat the ite as a variable
82 : 0 : return rite;
83 : : }
84 : : else
85 : : {
86 : 0 : Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
87 : 0 : Node sum = nm->mkNode(Kind::ADD, vpite, constantite);
88 : 0 : d_reduceVar[n] = sum;
89 : 0 : d_constants[n] = constantite;
90 : 0 : d_varParts[n] = vpite;
91 : 0 : return sum;
92 : : }
93 : : }
94 : : else
95 : : { // non-arith ite
96 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
97 : : {
98 : : // don't bother adding to d_reduceVar
99 : 0 : return n;
100 : : }
101 : : else
102 : : {
103 : 0 : Node newIte = applyReduceVariablesInItes(n);
104 [ - - ]: 0 : d_reduceVar[n] = (n == newIte) ? Node::null() : newIte;
105 : 0 : return newIte;
106 : : }
107 : : }
108 : : }
109 : : break;
110 : 0 : default:
111 : : {
112 : 0 : TypeNode tn = n.getType();
113 : 0 : if (tn.isRealOrInt() && linear::Polynomial::isMember(n))
114 : : {
115 : 0 : Node newn = Node::null();
116 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
117 : : {
118 : 0 : newn = n;
119 : : }
120 [ - - ]: 0 : else if (n.getNumChildren() > 0)
121 : : {
122 : 0 : newn = applyReduceVariablesInItes(n);
123 : 0 : newn = rewrite(newn);
124 : 0 : Assert(linear::Polynomial::isMember(newn));
125 : : }
126 : : else
127 : : {
128 : 0 : newn = n;
129 : : }
130 : 0 : NodeManager* nm = nodeManager();
131 : 0 : linear::Polynomial p = linear::Polynomial::parsePolynomial(newn);
132 [ - - ]: 0 : if (p.isConstant())
133 : : {
134 : 0 : d_constants[n] = newn;
135 : 0 : d_varParts[n] = nm->mkConstRealOrInt(tn, Rational(0));
136 : : // don't bother adding to d_reduceVar
137 : 0 : return newn;
138 : : }
139 [ - - ]: 0 : else if (!p.containsConstant())
140 : : {
141 : 0 : d_constants[n] = nm->mkConstRealOrInt(tn, Rational(0));
142 : 0 : d_varParts[n] = newn;
143 : 0 : d_reduceVar[n] = p.getNode();
144 : 0 : return p.getNode();
145 : : }
146 : : else
147 : : {
148 : 0 : linear::Monomial mc = p.getHead();
149 : 0 : d_constants[n] = mc.getConstant().getNode();
150 : 0 : d_varParts[n] = p.getTail().getNode();
151 : 0 : d_reduceVar[n] = newn;
152 : 0 : return newn;
153 : : }
154 : : }
155 : : else
156 : : {
157 [ - - ]: 0 : if (!d_contains.containsTermITE(n))
158 : : {
159 : 0 : return n;
160 : : }
161 [ - - ]: 0 : if (n.getNumChildren() > 0)
162 : : {
163 : 0 : Node res = applyReduceVariablesInItes(n);
164 : 0 : d_reduceVar[n] = res;
165 : 0 : return res;
166 : : }
167 : : else
168 : : {
169 : 0 : return n;
170 : : }
171 : : }
172 : : }
173 : : break;
174 : : }
175 : : Unreachable();
176 : : }
177 : :
178 : 4 : ArithIteUtils::ArithIteUtils(
179 : : Env& env,
180 : : preprocessing::util::ContainsTermITEVisitor& contains,
181 : 4 : SubstitutionMap& subs)
182 : : : EnvObj(env),
183 : : d_contains(contains),
184 : : d_subs(subs),
185 : : d_one(1),
186 : 0 : d_subcount(userContext(), 0),
187 : 8 : d_skolems(userContext()),
188 : : d_implies(),
189 : 4 : d_orBinEqs()
190 : : {
191 : 4 : }
192 : :
193 : 4 : ArithIteUtils::~ArithIteUtils(){
194 : 4 : }
195 : :
196 : 0 : void ArithIteUtils::clear(){
197 : 0 : d_reduceVar.clear();
198 : 0 : d_constants.clear();
199 : 0 : d_varParts.clear();
200 : 0 : }
201 : :
202 : 0 : const Integer& ArithIteUtils::gcdIte(Node n){
203 [ - - ]: 0 : if(d_gcds.find(n) != d_gcds.end()){
204 : 0 : return d_gcds[n];
205 : : }
206 [ - - ]: 0 : if (n.isConst())
207 : : {
208 : 0 : const Rational& q = n.getConst<Rational>();
209 [ - - ]: 0 : if(q.isIntegral()){
210 : 0 : d_gcds[n] = q.getNumerator();
211 : 0 : return d_gcds[n];
212 : : }else{
213 : 0 : return d_one;
214 : : }
215 : : }
216 : 0 : else if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
217 : : {
218 : 0 : const Integer& tgcd = gcdIte(n[1]);
219 [ - - ]: 0 : if(tgcd.isOne()){
220 : 0 : d_gcds[n] = d_one;
221 : 0 : return d_one;
222 : : }else{
223 : 0 : const Integer& egcd = gcdIte(n[2]);
224 : 0 : Integer ite_gcd = tgcd.gcd(egcd);
225 : 0 : d_gcds[n] = ite_gcd;
226 : 0 : return d_gcds[n];
227 : : }
228 : : }
229 : 0 : return d_one;
230 : : }
231 : :
232 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q){
233 [ - - ]: 0 : if(n.isConst()){
234 : 0 : Assert(n.getType().isRealOrInt());
235 : 0 : return nodeManager()->mkConstRealOrInt(n.getType(),
236 : 0 : n.getConst<Rational>() * q);
237 : : }else{
238 : 0 : Assert(n.getKind() == Kind::ITE);
239 : 0 : Assert(n.getType().isInteger());
240 : 0 : Node rc = reduceConstantIteByGCD(n[0]);
241 : 0 : Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
242 : 0 : Node re = reduceIteConstantIteByGCD_rec(n[2], q);
243 : 0 : return rc.iteNode(rt, re);
244 : : }
245 : : }
246 : :
247 : 0 : Node ArithIteUtils::reduceIteConstantIteByGCD(Node n){
248 : 0 : Assert(n.getKind() == Kind::ITE);
249 : 0 : Assert(n.getType().isRealOrInt());
250 : 0 : const Integer& gcd = gcdIte(n);
251 : 0 : NodeManager* nm = nodeManager();
252 [ - - ]: 0 : if(gcd.isOne()){
253 : 0 : Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1],n[2]);
254 : 0 : d_reduceGcd[n] = newIte;
255 : 0 : return newIte;
256 [ - - ]: 0 : }else if(gcd.isZero()){
257 : 0 : Node zeroNode = nm->mkConstRealOrInt(n.getType(), Rational(0));
258 : 0 : d_reduceGcd[n] = zeroNode;
259 : 0 : return zeroNode;
260 : : }else{
261 : 0 : Rational divBy(Integer(1), gcd);
262 : 0 : Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
263 : 0 : Node gcdNode = nm->mkConstRealOrInt(n.getType(), Rational(gcd));
264 : 0 : Node multIte = nm->mkNode(Kind::MULT, gcdNode, redite);
265 : 0 : d_reduceGcd[n] = multIte;
266 : 0 : return multIte;
267 : : }
268 : : }
269 : :
270 : 0 : Node ArithIteUtils::reduceConstantIteByGCD(Node n){
271 [ - - ]: 0 : if(d_reduceGcd.find(n) != d_reduceGcd.end()){
272 : 0 : return d_reduceGcd[n];
273 : : }
274 : 0 : if (n.getKind() == Kind::ITE && n.getType().isRealOrInt())
275 : : {
276 : 0 : return reduceIteConstantIteByGCD(n);
277 : : }
278 : :
279 [ - - ]: 0 : if(n.getNumChildren() > 0){
280 : 0 : NodeBuilder nb(nodeManager(), n.getKind());
281 [ - - ]: 0 : if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
282 : 0 : nb << (n.getOperator());
283 : : }
284 : 0 : bool anychange = false;
285 [ - - ]: 0 : for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
286 : 0 : Node child = *it;
287 : 0 : Node redchild = reduceConstantIteByGCD(child);
288 : 0 : anychange = anychange || (child != redchild);
289 : 0 : nb << redchild;
290 : : }
291 [ - - ]: 0 : if(anychange){
292 : 0 : Node res = nb;
293 : 0 : d_reduceGcd[n] = res;
294 : 0 : return res;
295 : : }else{
296 : 0 : d_reduceGcd[n] = n;
297 : 0 : return n;
298 : : }
299 : : }else{
300 : 0 : return n;
301 : : }
302 : : }
303 : :
304 : 8 : unsigned ArithIteUtils::getSubCount() const{
305 : 8 : return d_subcount;
306 : : }
307 : :
308 : 0 : void ArithIteUtils::addSubstitution(TNode f, TNode t){
309 [ - - ]: 0 : Trace("arith::ite") << "adding " << f << " -> " << t << endl;
310 : 0 : d_subcount = d_subcount + 1;
311 : 0 : d_subs.addSubstitution(f, t);
312 : 0 : }
313 : :
314 : 0 : Node ArithIteUtils::applySubstitutions(TNode f){
315 : 0 : AlwaysAssert(!options().base.incrementalSolving);
316 : 0 : return d_subs.apply(f);
317 : : }
318 : :
319 : 0 : Node ArithIteUtils::selectForCmp(Node n) const{
320 [ - - ]: 0 : if (n.getKind() == Kind::ITE)
321 : : {
322 [ - - ]: 0 : if(d_skolems.find(n[0]) != d_skolems.end()){
323 : 0 : return selectForCmp(n[1]);
324 : : }
325 : : }
326 : 0 : return n;
327 : : }
328 : :
329 : 4 : void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions){
330 [ - + ][ - + ]: 4 : AlwaysAssert(!options().base.incrementalSolving);
[ - - ]
331 [ + + ]: 12 : for(size_t i=0, N=assertions.size(); i < N; ++i){
332 : 8 : collectAssertions(assertions[i]);
333 : : }
334 : : bool solvedSomething;
335 [ - + ]: 4 : do{
336 : 4 : solvedSomething = false;
337 : 4 : size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
338 [ - + ]: 4 : for(; readPos < N; readPos++){
339 : 0 : Node curr = d_orBinEqs[readPos];
340 : 0 : bool solved = solveBinOr(curr);
341 [ - - ]: 0 : if(solved){
342 : 0 : solvedSomething = true;
343 : : }else{
344 : : // didn't solve, push back
345 : 0 : d_orBinEqs[writePos] = curr;
346 : 0 : writePos++;
347 : : }
348 : : }
349 [ - + ][ - + ]: 4 : Assert(writePos <= N);
[ - - ]
350 : 4 : d_orBinEqs.resize(writePos);
351 : : }while(solvedSomething);
352 : :
353 : 4 : d_implies.clear();
354 : 4 : d_orBinEqs.clear();
355 : 4 : }
356 : :
357 : 0 : void ArithIteUtils::addImplications(Node x, Node y){
358 : : // (or x y)
359 : : // (=> (not x) y)
360 : : // (=> (not y) x)
361 : :
362 : 0 : Node xneg = x.negate();
363 : 0 : Node yneg = y.negate();
364 : 0 : d_implies[xneg].insert(y);
365 : 0 : d_implies[yneg].insert(x);
366 : 0 : }
367 : :
368 : 8 : void ArithIteUtils::collectAssertions(TNode assertion){
369 [ - + ]: 8 : if (assertion.getKind() == Kind::OR)
370 : : {
371 [ - - ]: 0 : if(assertion.getNumChildren() == 2){
372 : 0 : TNode left = assertion[0], right = assertion[1];
373 : 0 : addImplications(left, right);
374 : 0 : if (left.getKind() == Kind::EQUAL && right.getKind() == Kind::EQUAL)
375 : : {
376 : 0 : if(left[0].getType().isInteger() && right[0].getType().isInteger()){
377 : 0 : d_orBinEqs.push_back(assertion);
378 : : }
379 : : }
380 : : }
381 : : }
382 [ - + ]: 8 : else if (assertion.getKind() == Kind::AND)
383 : : {
384 [ - - ]: 0 : for(unsigned i=0, N=assertion.getNumChildren(); i < N; ++i){
385 : 0 : collectAssertions(assertion[i]);
386 : : }
387 : : }
388 : 8 : }
389 : :
390 : 0 : Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const{
391 : 0 : Node negtb = tb.negate();
392 : 0 : Node negfb = fb.negate();
393 : 0 : ImpMap::const_iterator ti = d_implies.find(negtb);
394 : 0 : ImpMap::const_iterator fi = d_implies.find(negfb);
395 : :
396 [ - - ][ - - ]: 0 : if(ti != d_implies.end() && fi != d_implies.end()){
[ - - ]
397 : 0 : const std::set<Node>& negtimp = ti->second;
398 : 0 : const std::set<Node>& negfimp = fi->second;
399 : :
400 : : // (or (not x) y)
401 : : // (or x z)
402 : : // (or y z)
403 : : // ---
404 : : // (ite x y z) return x
405 : : // ---
406 : : // (not y) => (not x)
407 : : // (not z) => x
408 : 0 : std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
409 [ - - ]: 0 : for (; ci != cend; ++ci)
410 : : {
411 : 0 : Node impliedByNotTB = *ci;
412 : 0 : Node impliedByNotTBNeg = impliedByNotTB.negate();
413 [ - - ]: 0 : if(negfimp.find(impliedByNotTBNeg) != negfimp.end()){
414 : 0 : return impliedByNotTBNeg; // implies tb
415 : : }
416 : : }
417 : : }
418 : :
419 : 0 : return Node::null();
420 : : }
421 : :
422 : 0 : bool ArithIteUtils::solveBinOr(TNode binor){
423 : 0 : Assert(binor.getKind() == Kind::OR);
424 : 0 : Assert(binor.getNumChildren() == 2);
425 : 0 : Assert(binor[0].getKind() == Kind::EQUAL);
426 : 0 : Assert(binor[1].getKind() == Kind::EQUAL);
427 : :
428 : : //Node n =
429 : 0 : Node n = applySubstitutions(binor);
430 [ - - ]: 0 : if(n != binor){
431 : 0 : n = rewrite(n);
432 : :
433 : 0 : if (!(n.getKind() == Kind::OR && n.getNumChildren() == 2
434 : 0 : && n[0].getKind() == Kind::EQUAL && n[1].getKind() == Kind::EQUAL))
435 : : {
436 : 0 : return false;
437 : : }
438 : : }
439 : :
440 : 0 : Assert(n.getKind() == Kind::OR);
441 : 0 : Assert(n.getNumChildren() == 2);
442 : 0 : TNode l = n[0];
443 : 0 : TNode r = n[1];
444 : :
445 : 0 : Assert(l.getKind() == Kind::EQUAL);
446 : 0 : Assert(r.getKind() == Kind::EQUAL);
447 : :
448 [ - - ]: 0 : Trace("arith::ite") << "bin or " << n << endl;
449 : :
450 : 0 : bool lArithEq = l.getKind() == Kind::EQUAL && l[0].getType().isInteger();
451 : 0 : bool rArithEq = r.getKind() == Kind::EQUAL && r[0].getType().isInteger();
452 : :
453 [ - - ][ - - ]: 0 : if(lArithEq && rArithEq){
454 : 0 : TNode sel = Node::null();
455 : 0 : TNode otherL = Node::null();
456 : 0 : TNode otherR = Node::null();
457 [ - - ]: 0 : if(l[0] == r[0]) {
458 : 0 : sel = l[0]; otherL = l[1]; otherR = r[1];
459 [ - - ]: 0 : }else if(l[0] == r[1]){
460 : 0 : sel = l[0]; otherL = l[1]; otherR = r[0];
461 [ - - ]: 0 : }else if(l[1] == r[0]){
462 : 0 : sel = l[1]; otherL = l[0]; otherR = r[1];
463 [ - - ]: 0 : }else if(l[1] == r[1]){
464 : 0 : sel = l[1]; otherL = l[0]; otherR = r[0];
465 : : }
466 [ - - ]: 0 : Trace("arith::ite") << "selected " << sel << endl;
467 : 0 : if (sel.isVar() && sel.getKind() != Kind::SKOLEM)
468 : : {
469 [ - - ]: 0 : Trace("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
470 : 0 : Node useForCmpL = selectForCmp(otherL);
471 : 0 : Node useForCmpR = selectForCmp(otherR);
472 : :
473 : 0 : Assert(linear::Polynomial::isMember(sel));
474 : 0 : Assert(linear::Polynomial::isMember(useForCmpL));
475 : 0 : Assert(linear::Polynomial::isMember(useForCmpR));
476 : 0 : linear::Polynomial lside = linear::Polynomial::parsePolynomial( useForCmpL );
477 : 0 : linear::Polynomial rside = linear::Polynomial::parsePolynomial( useForCmpR );
478 : 0 : linear::Polynomial diff = lside-rside;
479 : :
480 [ - - ]: 0 : Trace("arith::ite") << "diff: " << diff.getNode() << endl;
481 [ - - ]: 0 : if(diff.isConstant()){
482 : : // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
483 : :
484 : 0 : NodeManager* nm = nodeManager();
485 : 0 : SkolemManager* sm = nm->getSkolemManager();
486 : :
487 : 0 : Node cnd = findIteCnd(binor[0], binor[1]);
488 : :
489 : 0 : Node eq = sel.eqNode(otherL);
490 : 0 : Node sk = sm->mkPurifySkolem(eq);
491 : 0 : Node ite = sk.iteNode(otherL, otherR);
492 : 0 : d_skolems.insert(sk, cnd);
493 : : // Given (or (= x c) (= x d)), we replace x by (ite @purifyX c d),
494 : : // where @purifyX is the purification skolem for (= x c), where c and
495 : : // d are known to be distinct.
496 : 0 : addSubstitution(sel, ite);
497 : 0 : return true;
498 : : }
499 : : }
500 : : }
501 : 0 : return false;
502 : : }
503 : :
504 : : } // namespace arith
505 : : } // namespace theory
506 : : } // namespace cvc5::internal
|