Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of arithmetic proof checker.
11 : : */
12 : :
13 : : #include "theory/arith/proof_checker.h"
14 : :
15 : : #include <iostream>
16 : : #include <set>
17 : :
18 : : #include "expr/skolem_manager.h"
19 : : #include "theory/arith/arith_poly_norm.h"
20 : : #include "theory/arith/arith_utilities.h"
21 : : #include "theory/arith/linear/constraint.h"
22 : : #include "theory/arith/operator_elim.h"
23 : :
24 : : using namespace cvc5::internal::kind;
25 : :
26 : : namespace cvc5::internal {
27 : : namespace theory {
28 : : namespace arith {
29 : :
30 : 51048 : ArithProofRuleChecker::ArithProofRuleChecker(NodeManager* nm)
31 : : : ProofRuleChecker(nm),
32 : 51048 : d_extChecker(nm),
33 : 51048 : d_trChecker(nm)
34 : : #ifdef CVC5_POLY_IMP
35 : : ,
36 : 102096 : d_covChecker(nm)
37 : : #endif
38 : : {
39 : 51048 : }
40 : :
41 : 19795 : void ArithProofRuleChecker::registerTo(ProofChecker* pc)
42 : : {
43 : 19795 : pc->registerChecker(ProofRule::MACRO_ARITH_SCALE_SUM_UB, this);
44 : 19795 : pc->registerChecker(ProofRule::ARITH_SUM_UB, this);
45 : 19795 : pc->registerChecker(ProofRule::ARITH_TRICHOTOMY, this);
46 : 19795 : pc->registerChecker(ProofRule::INT_TIGHT_UB, this);
47 : 19795 : pc->registerChecker(ProofRule::INT_TIGHT_LB, this);
48 : 19795 : pc->registerChecker(ProofRule::ARITH_REDUCTION, this);
49 : 19795 : pc->registerChecker(ProofRule::ARITH_MULT_POS, this);
50 : 19795 : pc->registerChecker(ProofRule::ARITH_MULT_NEG, this);
51 : 19795 : pc->registerChecker(ProofRule::ARITH_POLY_NORM, this);
52 : 19795 : pc->registerChecker(ProofRule::ARITH_POLY_NORM_REL, this);
53 : : // register the extended proof checkers
54 : 19795 : d_extChecker.registerTo(pc);
55 : 19795 : d_trChecker.registerTo(pc);
56 : : #ifdef CVC5_POLY_IMP
57 : 19795 : d_covChecker.registerTo(pc);
58 : : #endif
59 : 19795 : }
60 : :
61 : 1049476 : Node ArithProofRuleChecker::checkInternal(ProofRule id,
62 : : const std::vector<Node>& children,
63 : : const std::vector<Node>& args)
64 : : {
65 : 1049476 : NodeManager* nm = nodeManager();
66 [ - + ]: 1049476 : if (TraceIsOn("arith::pf::check"))
67 : : {
68 [ - - ]: 0 : Trace("arith::pf::check") << "Arith ProofRule:" << id << std::endl;
69 [ - - ]: 0 : Trace("arith::pf::check") << " children: " << std::endl;
70 [ - - ]: 0 : for (const auto& c : children)
71 : : {
72 [ - - ]: 0 : Trace("arith::pf::check") << " * " << c << std::endl;
73 : : }
74 [ - - ]: 0 : Trace("arith::pf::check") << " args:" << std::endl;
75 [ - - ]: 0 : for (const auto& c : args)
76 : : {
77 [ - - ]: 0 : Trace("arith::pf::check") << " * " << c << std::endl;
78 : : }
79 : : }
80 [ + + ][ + + ]: 1049476 : switch (id)
[ + + ][ + + ]
[ + + ][ - ]
81 : : {
82 : 16811 : case ProofRule::ARITH_MULT_POS:
83 : : {
84 [ - + ][ - + ]: 16811 : Assert(children.empty());
[ - - ]
85 [ - + ][ - + ]: 16811 : Assert(args.size() == 2);
[ - - ]
86 : 16811 : Node mult = args[0];
87 : 16811 : Kind rel = args[1].getKind();
88 [ + + ][ - + ]: 16811 : Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + - ]
[ - + ][ - + ]
[ - - ]
89 : : || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
90 : 16811 : Node lhs = args[1][0];
91 : 16811 : Node rhs = args[1][1];
92 : 33622 : Node zero = nm->mkConstRealOrInt(mult.getType(), Rational(0));
93 : : return nm->mkNode(Kind::IMPLIES,
94 [ + + ][ - - ]: 100866 : nm->mkAnd(std::vector<Node>{
95 : 67244 : nm->mkNode(Kind::GT, mult, zero), args[1]}),
96 : 50433 : nm->mkNode(rel,
97 : 33622 : nm->mkNode(Kind::MULT, mult, lhs),
98 : 84055 : nm->mkNode(Kind::MULT, mult, rhs)));
99 : 16811 : }
100 : 121788 : case ProofRule::ARITH_MULT_NEG:
101 : : {
102 [ - + ][ - + ]: 121788 : Assert(children.empty());
[ - - ]
103 [ - + ][ - + ]: 121788 : Assert(args.size() == 2);
[ - - ]
104 : 121788 : Node mult = args[0];
105 : 121788 : Kind rel = args[1].getKind();
106 [ + + ][ - + ]: 121788 : Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + - ]
[ - + ][ - + ]
[ - - ]
107 : : || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
108 [ + - ]: 121788 : Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel));
109 : 121788 : Node lhs = args[1][0];
110 : 121788 : Node rhs = args[1][1];
111 : 243576 : Node zero = nm->mkConstRealOrInt(mult.getType(), Rational(0));
112 : : return nm->mkNode(Kind::IMPLIES,
113 [ + + ][ - - ]: 730728 : nm->mkAnd(std::vector<Node>{
114 : 487152 : nm->mkNode(Kind::LT, mult, zero), args[1]}),
115 : 365364 : nm->mkNode(rel_inv,
116 : 243576 : nm->mkNode(Kind::MULT, mult, lhs),
117 : 608940 : nm->mkNode(Kind::MULT, mult, rhs)));
118 : 121788 : }
119 : 93086 : case ProofRule::ARITH_SUM_UB:
120 : : {
121 [ - + ]: 93086 : if (children.size() < 2)
122 : : {
123 : 0 : return Node::null();
124 : : }
125 : :
126 : : // Whether a strict inequality is in the sum.
127 : 93086 : bool strict = false;
128 : 93086 : NodeBuilder leftSum(nm, Kind::ADD);
129 : 93086 : NodeBuilder rightSum(nm, Kind::ADD);
130 [ + + ]: 498234 : for (size_t i = 0; i < children.size(); ++i)
131 : : {
132 : : // Adjust strictness
133 [ + + ][ - ]: 405148 : switch (children[i].getKind())
134 : : {
135 : 76714 : case Kind::LT:
136 : : {
137 : 76714 : strict = true;
138 : 76714 : break;
139 : : }
140 : 328434 : case Kind::LEQ:
141 : : case Kind::EQUAL:
142 : : {
143 : 328434 : break;
144 : : }
145 : 0 : default:
146 : : {
147 [ - - ]: 0 : Trace("arith::pf::check")
148 : 0 : << "Bad kind: " << children[i].getKind() << std::endl;
149 : 0 : return Node::null();
150 : : }
151 : : }
152 : 405148 : leftSum << children[i][0];
153 : 405148 : rightSum << children[i][1];
154 : : }
155 : : Node r = nm->mkNode(strict ? Kind::LT : Kind::LEQ,
156 : 186172 : leftSum.constructNode(),
157 [ + + ]: 465430 : rightSum.constructNode());
158 : 93086 : return r;
159 : 93086 : }
160 : 716849 : case ProofRule::MACRO_ARITH_SCALE_SUM_UB:
161 : : {
162 : : //================================================= Arithmetic rules
163 : : // ======== Adding Inequalities
164 : : // Note: an ArithLiteral is a term of the form (>< poly const)
165 : : // where
166 : : // >< is >=, >, ==, <, <=, or not(== ...).
167 : : // poly is a polynomial
168 : : // const is a rational constant
169 : :
170 : : // Children: (P1:l1, ..., Pn:ln)
171 : : // where each li is an ArithLiteral
172 : : // not(= ...) is dis-allowed!
173 : : //
174 : : // Arguments: (k1, ..., kn), non-zero reals
175 : : // ---------------------
176 : : // Conclusion: (>< t1 t2)
177 : : // where >< is the fusion of the combination of the ><i, (flipping each
178 : : // it its ki is negative). >< is always one of <, <= NB: this implies
179 : : // that lower bounds must have negative ki,
180 : : // and upper bounds must have positive ki.
181 : : // t1 is the sum of the scaled polynomials (k_1 * poly_1 + ... + k_n *
182 : : // poly_n) t2 is the sum of the scaled constants (k_1 * const_1 + ... +
183 : : // k_n * const_n)
184 [ - + ][ - + ]: 716849 : Assert(children.size() == args.size());
[ - - ]
185 [ - + ]: 716849 : if (children.size() < 2)
186 : : {
187 : 0 : return Node::null();
188 : : }
189 : :
190 : : // Whether a strict inequality is in the sum.
191 : 716849 : bool strict = false;
192 : 716849 : NodeBuilder leftSum(nm, Kind::ADD);
193 : 716849 : NodeBuilder rightSum(nm, Kind::ADD);
194 [ + + ]: 2684514 : for (size_t i = 0; i < children.size(); ++i)
195 : : {
196 : 1967665 : Rational scalar = args[i].getConst<Rational>();
197 [ - + ]: 1967665 : if (scalar == 0)
198 : : {
199 [ - - ]: 0 : Trace("arith::pf::check") << "Error: zero scalar" << std::endl;
200 : 0 : return Node::null();
201 : : }
202 : :
203 : : // Adjust strictness
204 [ + + ][ - ]: 1967665 : switch (children[i].getKind())
205 : : {
206 : 369046 : case Kind::GT:
207 : : case Kind::LT:
208 : : {
209 : 369046 : strict = true;
210 : 369046 : break;
211 : : }
212 : 1598619 : case Kind::GEQ:
213 : : case Kind::LEQ:
214 : : case Kind::EQUAL:
215 : : {
216 : 1598619 : break;
217 : : }
218 : 0 : default:
219 : : {
220 [ - - ]: 0 : Trace("arith::pf::check")
221 : 0 : << "Bad kind: " << children[i].getKind() << std::endl;
222 : : }
223 : : }
224 : : // check for spurious mixed arithmetic
225 : 3935330 : if (children[i][0].getType().isReal()
226 : 3935330 : || children[i][1].getType().isReal())
227 : : {
228 [ - + ]: 702853 : if (args[i].getType().isInteger())
229 : : {
230 : : // Should use real for predicates over reals. This is only
231 : : // necessary for avoiding spurious usage of mixed arithmetic, but we
232 : : // check here to be pedantic.
233 : 0 : return Node::null();
234 : : }
235 : : }
236 [ + + ][ - + ]: 1264812 : else if (args[i].getType().isReal() && scalar.isIntegral())
[ + - ][ - + ]
[ - - ]
237 : : {
238 : : // conversely, don't use (integral) real for integer relation.
239 : 0 : return Node::null();
240 : : }
241 : : // Check sign
242 [ + + ][ + - ]: 1967665 : switch (children[i].getKind())
243 : : {
244 : 474768 : case Kind::GT:
245 : : case Kind::GEQ:
246 : : {
247 [ - + ]: 474768 : if (scalar > 0)
248 : : {
249 [ - - ]: 0 : Trace("arith::pf::check")
250 : 0 : << "Positive scalar for lower bound: " << scalar << " for "
251 : 0 : << children[i] << std::endl;
252 : 0 : return Node::null();
253 : : }
254 : 474768 : break;
255 : : }
256 : 595513 : case Kind::LEQ:
257 : : case Kind::LT:
258 : : {
259 [ - + ]: 595513 : if (scalar < 0)
260 : : {
261 [ - - ]: 0 : Trace("arith::pf::check")
262 : 0 : << "Negative scalar for upper bound: " << scalar << " for "
263 : 0 : << children[i] << std::endl;
264 : 0 : return Node::null();
265 : : }
266 : 595513 : break;
267 : : }
268 : 897384 : case Kind::EQUAL:
269 : : {
270 : 897384 : break;
271 : : }
272 : 0 : default:
273 : : {
274 [ - - ]: 0 : Trace("arith::pf::check")
275 : 0 : << "Bad kind: " << children[i].getKind() << std::endl;
276 : : }
277 : : }
278 : : // if multiplying by one, don't introduce MULT
279 [ + + ]: 1967665 : if (scalar == 1)
280 : : {
281 : 911092 : leftSum << children[i][0];
282 : 911092 : rightSum << children[i][1];
283 : : }
284 : : else
285 : : {
286 : 1056573 : leftSum << nm->mkNode(Kind::MULT, args[i], children[i][0]);
287 : 1056573 : rightSum << nm->mkNode(Kind::MULT, args[i], children[i][1]);
288 : : }
289 [ + - ]: 1967665 : }
290 : : Node r = nm->mkNode(strict ? Kind::LT : Kind::LEQ,
291 : 1433698 : leftSum.constructNode(),
292 [ + + ]: 3584245 : rightSum.constructNode());
293 : 716849 : return r;
294 : 716849 : }
295 : 783 : case ProofRule::INT_TIGHT_LB:
296 : : {
297 : : // Children: (P:(> i c))
298 : : // where i has integer type.
299 : : // Arguments: none
300 : : // ---------------------
301 : : // Conclusion: (>= i leastIntGreaterThan(c)})
302 : 1566 : if (children.size() != 1
303 [ - + ]: 783 : || (children[0].getKind() != Kind::GT
304 [ - - ]: 0 : && children[0].getKind() != Kind::GEQ)
305 : 1566 : || !children[0][0].getType().isInteger() || !children[0][1].isConst())
306 : : {
307 [ - - ]: 0 : Trace("arith::pf::check") << "Illformed input: " << children;
308 : 0 : return Node::null();
309 : : }
310 : : else
311 : : {
312 : 783 : Rational originalBound = children[0][1].getConst<Rational>();
313 : 783 : Rational newBound = leastIntGreaterThan(originalBound);
314 : 783 : Node rational = nm->mkConstInt(newBound);
315 : 783 : return nm->mkNode(Kind::GEQ, children[0][0], rational);
316 : 783 : }
317 : : }
318 : 10371 : case ProofRule::INT_TIGHT_UB:
319 : : {
320 : : // ======== Tightening Strict Integer Upper Bounds
321 : : // Children: (P:(< i c))
322 : : // where i has integer type.
323 : : // Arguments: none
324 : : // ---------------------
325 : : // Conclusion: (<= i greatestIntLessThan(c)})
326 : 20742 : if (children.size() != 1
327 [ - + ]: 10371 : || (children[0].getKind() != Kind::LT
328 [ - - ]: 0 : && children[0].getKind() != Kind::LEQ)
329 : 20742 : || !children[0][0].getType().isInteger() || !children[0][1].isConst())
330 : : {
331 [ - - ]: 0 : Trace("arith::pf::check") << "Illformed input: " << children;
332 : 0 : return Node::null();
333 : : }
334 : : else
335 : : {
336 : 10371 : Rational originalBound = children[0][1].getConst<Rational>();
337 : 10371 : Rational newBound = greatestIntLessThan(originalBound);
338 : 10371 : Node rational = nm->mkConstInt(newBound);
339 : 10371 : return nm->mkNode(Kind::LEQ, children[0][0], rational);
340 : 10371 : }
341 : : }
342 : 6235 : case ProofRule::ARITH_TRICHOTOMY:
343 : : {
344 : 6235 : Node a = negateProofLiteral(children[0]);
345 : 6235 : Node b = negateProofLiteral(children[1]);
346 : 6235 : if (a[0] == b[0] && a[1] == b[1])
347 : : {
348 : 6235 : std::set<Kind> cmps;
349 : 6235 : cmps.insert(a.getKind());
350 : 6235 : cmps.insert(b.getKind());
351 : 6235 : Kind retk = Kind::UNDEFINED_KIND;
352 [ + + ]: 6235 : if (cmps.count(Kind::EQUAL) == 0)
353 : : {
354 : 3599 : retk = Kind::EQUAL;
355 : : }
356 [ + + ]: 6235 : if (cmps.count(Kind::GT) == 0)
357 : : {
358 [ - + ]: 1371 : if (retk != Kind::UNDEFINED_KIND)
359 : : {
360 [ - - ]: 0 : Trace("arith::pf::check")
361 : 0 : << "Error: No GT and " << retk << std::endl;
362 : 0 : return Node::null();
363 : : }
364 : 1371 : retk = Kind::GT;
365 : : }
366 [ + + ]: 6235 : if (cmps.count(Kind::LT) == 0)
367 : : {
368 [ - + ]: 1265 : if (retk != Kind::UNDEFINED_KIND)
369 : : {
370 [ - - ]: 0 : Trace("arith::pf::check")
371 : 0 : << "Error: No LT and " << retk << std::endl;
372 : 0 : return Node::null();
373 : : }
374 : 1265 : retk = Kind::LT;
375 : : }
376 : 6235 : return nm->mkNode(retk, a[0], a[1]);
377 : 6235 : }
378 : : else
379 : : {
380 [ - - ]: 0 : Trace("arith::pf::check")
381 : 0 : << "Error: Different polynomials / values" << std::endl;
382 [ - - ]: 0 : Trace("arith::pf::check") << " a: " << a << std::endl;
383 [ - - ]: 0 : Trace("arith::pf::check") << " b: " << b << std::endl;
384 : 0 : return Node::null();
385 : : }
386 : : // Check that all have the same constant:
387 : 6235 : }
388 : 1457 : case ProofRule::ARITH_REDUCTION:
389 : : {
390 [ - + ][ - + ]: 1457 : Assert(children.empty());
[ - - ]
391 [ - + ][ - + ]: 1457 : Assert(args.size() == 1);
[ - - ]
392 : 1457 : return OperatorElim::getAxiomFor(nm, args[0]);
393 : : }
394 : 53320 : case ProofRule::ARITH_POLY_NORM:
395 : : {
396 [ - + ][ - + ]: 53320 : Assert(children.empty());
[ - - ]
397 [ - + ][ - + ]: 53320 : Assert(args.size() == 1);
[ - - ]
398 : 53320 : if (args[0].getKind() != Kind::EQUAL
399 : 106640 : || !args[0][0].getType().isRealOrInt())
400 : : {
401 : 0 : return Node::null();
402 : : }
403 [ - + ]: 53320 : if (!PolyNorm::isArithPolyNorm(args[0][0], args[0][1]))
404 : : {
405 : 0 : return Node::null();
406 : : }
407 : 53320 : return args[0];
408 : : }
409 : 28776 : case ProofRule::ARITH_POLY_NORM_REL:
410 : : {
411 [ - + ][ - + ]: 28776 : Assert(children.size() == 1);
[ - - ]
412 [ - + ][ - + ]: 28776 : Assert(args.size() == 1);
[ - - ]
413 [ - + ]: 28776 : if (args[0].getKind() != Kind::EQUAL)
414 : : {
415 : 0 : return Node::null();
416 : : }
417 : 28776 : Kind k = args[0][0].getKind();
418 [ + + ][ + + ]: 28776 : if (k != Kind::LT && k != Kind::LEQ && k != Kind::EQUAL && k != Kind::GT
[ + + ][ + + ]
419 [ - + ]: 11836 : && k != Kind::GEQ)
420 : : {
421 : 0 : return Node::null();
422 : : }
423 [ - + ]: 28776 : if (children[0].getKind() != Kind::EQUAL)
424 : : {
425 : 0 : return Node::null();
426 : : }
427 : 28776 : Node l = children[0][0];
428 : 28776 : Node r = children[0][1];
429 [ + - ][ - + ]: 28776 : if (l.getKind() != Kind::MULT || r.getKind() != Kind::MULT)
[ - + ]
430 : : {
431 : 0 : return Node::null();
432 : : }
433 : 28776 : Node lr = l[1];
434 [ + + ]: 28776 : lr = lr.getKind() == Kind::TO_REAL ? lr[0] : lr;
435 : 28776 : Node rr = r[1];
436 [ + + ]: 28776 : rr = rr.getKind() == Kind::TO_REAL ? rr[0] : rr;
437 [ + - ][ - + ]: 28776 : if (lr.getKind() != Kind::SUB || rr.getKind() != Kind::SUB)
[ - + ]
438 : : {
439 : 0 : return Node::null();
440 : : }
441 : 28776 : Node cx = l[0];
442 : 28776 : Node x1 = lr[0];
443 : 28776 : Node x2 = lr[1];
444 : 28776 : Node cy = r[0];
445 : 28776 : Node y1 = rr[0];
446 : 28776 : Node y2 = rr[1];
447 : 28776 : if ((cx.getKind() == Kind::CONST_INTEGER
448 [ + - ]: 11720 : || cx.getKind() == Kind::CONST_RATIONAL)
449 [ + + ][ + + ]: 52216 : && (cy.getKind() == Kind::CONST_INTEGER
[ + - ]
450 [ + - ]: 11720 : || cy.getKind() == Kind::CONST_RATIONAL))
451 : : {
452 : 28776 : Rational c1 = cx.getConst<Rational>();
453 : 28776 : Rational c2 = cy.getConst<Rational>();
454 [ + - ][ - + ]: 28776 : if (c1.sgn() == 0 || c2.sgn() == 0)
[ - + ]
455 : : {
456 : 0 : return Node::null();
457 : : }
458 [ + + ][ - + ]: 28776 : if (k != Kind::EQUAL && c1.sgn() != c2.sgn())
[ - + ]
459 : : {
460 : 0 : return Node::null();
461 : : }
462 [ + - ][ + - ]: 28776 : }
463 : 57552 : Node ret = nm->mkNode(k, x1, x2).eqNode(nm->mkNode(k, y1, y2));
464 [ - + ]: 28776 : if (ret != args[0])
465 : : {
466 : 0 : return Node::null();
467 : : }
468 : 28776 : return ret;
469 : 28776 : }
470 : 0 : default: return Node::null();
471 : : }
472 : : }
473 : : } // namespace arith
474 : : } // namespace theory
475 : : } // namespace cvc5::internal
|