Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of parametric integer and (PIAND) solver.
11 : : */
12 : :
13 : : #include "theory/arith/nl/piand_solver.h"
14 : :
15 : : #include "options/arith_options.h"
16 : : #include "options/smt_options.h"
17 : : #include "preprocessing/passes/bv_to_int.h"
18 : : #include "theory/arith/arith_msum.h"
19 : : #include "theory/arith/arith_utilities.h"
20 : : #include "theory/arith/inference_manager.h"
21 : : #include "theory/arith/nl/nl_model.h"
22 : : #include "theory/rewriter.h"
23 : : #include "util/bitvector.h"
24 : : #include "util/iand.h"
25 : :
26 : : using namespace cvc5::internal::kind;
27 : :
28 : : namespace cvc5::internal {
29 : : namespace theory {
30 : : namespace arith {
31 : : namespace nl {
32 : :
33 : 33863 : PIAndSolver::PIAndSolver(Env& env, InferenceManager& im, NlModel& model)
34 : : : EnvObj(env),
35 : 33863 : d_im(im),
36 : 33863 : d_model(model),
37 : 33863 : d_iandUtils(env.getNodeManager()),
38 : 67726 : d_initRefine(userContext())
39 : : {
40 : 33863 : NodeManager* nm = nodeManager();
41 : 33863 : d_false = nm->mkConst(false);
42 : 33863 : d_true = nm->mkConst(true);
43 : 33863 : d_zero = nm->mkConstInt(Rational(0));
44 : 33863 : d_one = nm->mkConstInt(Rational(1));
45 : 33863 : d_two = nm->mkConstInt(Rational(2));
46 : 33863 : }
47 : :
48 : 33576 : PIAndSolver::~PIAndSolver() {}
49 : :
50 : 11256 : void PIAndSolver::initLastCall(const std::vector<Node>& xts)
51 : : {
52 : 11256 : d_piands.clear();
53 : :
54 [ + - ]: 11256 : Trace("piand-mv") << "PIAND terms : " << std::endl;
55 [ + + ]: 71943 : for (const Node& a : xts)
56 : : {
57 : 60687 : Kind ak = a.getKind();
58 [ + + ]: 60687 : if (ak != Kind::PIAND)
59 : : {
60 : : // don't care about other terms
61 : 60610 : continue;
62 : : }
63 : 77 : d_piands[a[0]].push_back(a);
64 : : }
65 [ + - ]: 22512 : Trace("piand") << "We have " << d_piands.size() << " PIAND bit-width."
66 : 11256 : << std::endl;
67 : 11256 : }
68 : :
69 : 11256 : void PIAndSolver::checkInitialRefine()
70 : : {
71 [ + - ]: 11256 : Trace("piand-check") << "PIAndSolver::checkInitialRefine" << std::endl;
72 : 11256 : NodeManager* nm = nodeManager();
73 [ + + ]: 11326 : for (const std::pair<const Node, std::vector<Node> >& is : d_piands)
74 : : {
75 : : // the reference bitwidth
76 : 70 : Node k = is.first;
77 [ + + ]: 147 : for (const Node& i : is.second)
78 : : {
79 : 77 : Node x = i[1];
80 : 77 : Node y = i[2];
81 [ + + ]: 77 : if (d_initRefine.find(i) != d_initRefine.end())
82 : : {
83 : : // already sent initial axioms for i in this user context
84 : 21 : continue;
85 : : }
86 : 56 : d_initRefine.insert(i);
87 : 56 : Node twok = nm->mkNode(Kind::POW2, k);
88 : 112 : Node arg0Mod = nm->mkNode(Kind::INTS_MODULUS, x, twok);
89 : 112 : Node arg1Mod = nm->mkNode(Kind::INTS_MODULUS, y, twok);
90 : 112 : Node arg0Mod2 = nm->mkNode(Kind::INTS_MODULUS, x, d_two);
91 : 112 : Node arg1Mod2 = nm->mkNode(Kind::INTS_MODULUS, y, d_two);
92 : 112 : Node plus = nm->mkNode(Kind::ADD, x, y);
93 : 112 : Node twok_minus_one = nm->mkNode(Kind::SUB, twok, d_one);
94 : 112 : Node k_gt_0 = nm->mkNode(Kind::GT, k, d_zero);
95 : 112 : Node x_geq_zero = nm->mkNode(Kind::GEQ, x, d_zero);
96 : 112 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, twok);
97 : 112 : Node x_range = nm->mkNode(Kind::AND, x_geq_zero, x_lt_pow2);
98 : 112 : Node y_geq_zero = nm->mkNode(Kind::GEQ, y, d_zero);
99 : 112 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, twok);
100 : 112 : Node y_range = nm->mkNode(Kind::AND, y_geq_zero, y_lt_pow2);
101 : :
102 : : // initial refinement lemmas
103 : 56 : std::vector<Node> conj;
104 : :
105 : : // x is an upper bound: x > 0 && x < 2^k && y = 2^k -1 -> piand(k,x,y) = x
106 : 112 : Node y_modpow2_eq_max = nm->mkNode(Kind::EQUAL, y, twok_minus_one);
107 : 112 : Node assum_max = nm->mkNode(Kind::AND, k_gt_0, y_modpow2_eq_max, x_range);
108 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_max, i.eqNode(x)));
109 : :
110 : : // y is an upper bound: y > 0 && y < 2^k && x = 2^k -1 -> piand(k,x,y) = y
111 : 112 : Node x_modpow2_eq_max = nm->mkNode(Kind::EQUAL, x, twok_minus_one);
112 : : Node assum_max_x =
113 : 112 : nm->mkNode(Kind::AND, k_gt_0, x_modpow2_eq_max, y_range);
114 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_max_x, i.eqNode(y)));
115 : :
116 : : // min-y: y = 0 -> piand(k,x,y) = 0
117 : 112 : Node eq_y_zero = nm->mkNode(Kind::EQUAL, y, d_zero);
118 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, eq_y_zero, i.eqNode(d_zero)));
119 : :
120 : : // min-x: x = 0 -> piand(k,x,y) = 0
121 : 112 : Node eq_x_zero = nm->mkNode(Kind::EQUAL, x, d_zero);
122 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, eq_x_zero, i.eqNode(d_zero)));
123 : :
124 : : // idempotence: k > 0 && x > 0 && x < 2^k && x = y -> piand(k,x,y) = x
125 : 112 : Node eq_y_x = nm->mkNode(Kind::EQUAL, y, x);
126 : 112 : Node assum_idempotence = nm->mkNode(Kind::AND, k_gt_0, eq_y_x, x_range);
127 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_idempotence, i.eqNode(x)));
128 : :
129 : : // symmetry: piand(k,x,y) = piand(k,y,x)
130 : 112 : Node piand_y_x = nm->mkNode(Kind::PIAND, k, y, x);
131 : 56 : conj.push_back(nm->mkNode(Kind::EQUAL, i, piand_y_x));
132 : :
133 : : // range1: 0 <= piand(k,x,y)
134 : 56 : conj.push_back(nm->mkNode(Kind::LEQ, d_zero, i));
135 : :
136 : : // range 2: 0 <= x -> piand(k,x,y) <= x
137 : 112 : Node i_leq_x = nm->mkNode(Kind::LEQ, i, x);
138 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, x_geq_zero, i_leq_x));
139 : :
140 : : // range 3: 0 <= y -> piand(k,x,y) <= y
141 : 112 : Node i_leq_y = nm->mkNode(Kind::LEQ, i, y);
142 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, y_geq_zero, i_leq_y));
143 : :
144 : : // non-positive bitwidth: k <= 0 -> piand(k,x, y) = 0
145 : 112 : Node k_le_0 = nm->mkNode(Kind::LEQ, k, d_zero);
146 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, k_le_0, i.eqNode(d_zero)));
147 : :
148 : : // lsb lemma for x: x mod 2 = 0 => piand(k,x,y) % 2 = 0
149 : 112 : Node piand_mod_two = nm->mkNode(Kind::INTS_MODULUS, i, d_two);
150 : 112 : Node arg1Mod2_eq_zero = nm->mkNode(Kind::EQUAL, arg1Mod2, d_zero);
151 : 56 : conj.push_back(nm->mkNode(
152 : 112 : Kind::IMPLIES, arg1Mod2_eq_zero, piand_mod_two.eqNode(d_zero)));
153 : :
154 : : // lsb lemma for y: y mod 2 = 0 => piand(k,x,y) % 2 = 0
155 : 112 : Node arg0Mod2_eq_zero = nm->mkNode(Kind::EQUAL, arg0Mod2, d_zero);
156 : 56 : conj.push_back(nm->mkNode(
157 : 112 : Kind::IMPLIES, arg0Mod2_eq_zero, piand_mod_two.eqNode(d_zero)));
158 : :
159 : : // insert lemmas
160 [ - + ]: 56 : Node lem = conj.size() == 1 ? conj[0] : nm->mkNode(Kind::AND, conj);
161 [ + - ]: 112 : Trace("piand-lemma") << "PIAndSolver::Lemma: " << lem << " ; INIT_REFINE"
162 : 56 : << std::endl;
163 : 56 : d_im.addPendingLemma(lem, InferenceId::ARITH_NL_PIAND_INIT_REFINE);
164 [ + + ][ + + ]: 98 : }
165 : 70 : }
166 : 11256 : }
167 : :
168 : 2182 : void PIAndSolver::checkFullRefine()
169 : : {
170 : 2182 : NodeManager* nm = nodeManager();
171 [ + - ]: 2182 : Trace("piand-check") << "PIAndSolver::checkFullRefine";
172 [ + + ]: 2196 : for (const std::pair<const Node, std::vector<Node> >& is : d_piands)
173 : : {
174 : 14 : int index = 0;
175 [ + + ]: 28 : for (const Node& i : is.second)
176 : : {
177 : 14 : index++;
178 : 14 : Node valAndXY = d_model.computeAbstractModelValue(i);
179 : 14 : Node valAndXYC = d_model.computeConcreteModelValue(i);
180 : 14 : valAndXYC = rewrite(valAndXYC);
181 : :
182 : 14 : Node k = i[0];
183 : 14 : Node x = i[1];
184 : 14 : Node y = i[2];
185 : 14 : Node valK = d_model.computeConcreteModelValue(k);
186 : 14 : Node valX = d_model.computeConcreteModelValue(x);
187 : 14 : Node valY = d_model.computeConcreteModelValue(y);
188 : :
189 : 14 : Integer model_piand = valAndXYC.getConst<Rational>().getNumerator();
190 : 14 : Integer model_k = valK.getConst<Rational>().getNumerator();
191 : 14 : Integer model_x = valX.getConst<Rational>().getNumerator();
192 : 14 : Integer model_y = valY.getConst<Rational>().getNumerator();
193 : :
194 [ - + ]: 14 : if (TraceIsOn("piand-check"))
195 : : {
196 [ - - ]: 0 : Trace("piand-check")
197 : 0 : << "* " << i << ", value = " << valAndXY << std::endl;
198 [ - - ]: 0 : Trace("piand-check") << " actual (" << valX << ", " << valY
199 : 0 : << ") = " << valAndXYC << std::endl;
200 : : }
201 [ + - ]: 14 : if (valAndXY == valAndXYC)
202 : : {
203 [ + - ]: 14 : Trace("piand-check") << "...already correct" << std::endl;
204 : 14 : continue;
205 : : }
206 : :
207 : 0 : Integer ione = 1;
208 : 0 : Integer itwo = 2;
209 : 0 : Integer ipow2 = itwo.pow(model_k.getLong());
210 : 0 : Integer max_int = ipow2 - 1;
211 : 0 : Node k_gt_0 = nm->mkNode(Kind::GT, k, d_zero);
212 : 0 : Node twok = nm->mkNode(Kind::POW2, k);
213 : 0 : Node arg0Mod = nm->mkNode(Kind::INTS_MODULUS, x, twok);
214 : 0 : Node arg1Mod = nm->mkNode(Kind::INTS_MODULUS, y, twok);
215 : 0 : Node arg0Mod2 = nm->mkNode(Kind::INTS_MODULUS, x, d_two);
216 : 0 : Node arg1Mod2 = nm->mkNode(Kind::INTS_MODULUS, y, d_two);
217 : :
218 : : // base case: piand(k,1,1) = 1
219 : 0 : if (model_k > 0 && model_x == 1 && model_y == 1 && model_piand != 1)
220 : : {
221 : 0 : Node x_equal_one = nm->mkNode(Kind::EQUAL, x, d_one);
222 : 0 : Node y_equal_one = nm->mkNode(Kind::EQUAL, y, d_one);
223 : 0 : Node assum = nm->mkNode(Kind::AND, k_gt_0, x_equal_one, y_equal_one);
224 : 0 : Node piand_one = nm->mkNode(Kind::EQUAL, i, d_one);
225 : 0 : Node xy_one_lem = nm->mkNode(Kind::IMPLIES, assum, piand_one);
226 : 0 : d_im.addPendingLemma(xy_one_lem,
227 : : InferenceId::ARITH_NL_PIAND_BASE_CASE_REFINE,
228 : : nullptr,
229 : : true);
230 : 0 : }
231 : :
232 : 0 : Node x_geq_zero = nm->mkNode(Kind::GEQ, x, d_zero);
233 : 0 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, twok);
234 : 0 : Node x_range = nm->mkNode(Kind::AND, x_geq_zero, x_lt_pow2);
235 : 0 : Node y_geq_zero = nm->mkNode(Kind::GEQ, y, d_zero);
236 : 0 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, twok);
237 : 0 : Node y_range = nm->mkNode(Kind::AND, y_geq_zero, y_lt_pow2);
238 : 0 : int j = -1;
239 [ - - ]: 0 : for (const Node& n : is.second)
240 : : {
241 : 0 : j++;
242 [ - - ]: 0 : if (j > index)
243 : : {
244 : 0 : Node k2 = n[0];
245 : 0 : Node x2 = n[1];
246 : 0 : Node y2 = n[2];
247 : 0 : Node valK2 = d_model.computeConcreteModelValue(k2);
248 : 0 : Node valX2 = d_model.computeConcreteModelValue(x2);
249 : 0 : Node valY2 = d_model.computeConcreteModelValue(y2);
250 : 0 : Node valAndXYC2 = d_model.computeConcreteModelValue(n);
251 : 0 : Integer model_piand2 = valAndXYC2.getConst<Rational>().getNumerator();
252 : 0 : Integer model_k2 = valK2.getConst<Rational>().getNumerator();
253 : 0 : Integer model_x2 = valX2.getConst<Rational>().getNumerator();
254 : 0 : Integer model_y2 = valY2.getConst<Rational>().getNumerator();
255 : :
256 : 0 : Node arg20Mod = nm->mkNode(Kind::INTS_MODULUS, x2, twok);
257 : 0 : Node arg21Mod = nm->mkNode(Kind::INTS_MODULUS, y2, twok);
258 : :
259 : 0 : Node x2_geq_zero = nm->mkNode(Kind::GEQ, x2, d_zero);
260 : 0 : Node x2_lt_pow2 = nm->mkNode(Kind::LT, x2, twok);
261 : 0 : Node x2_range = nm->mkNode(Kind::AND, x2_geq_zero, x2_lt_pow2);
262 : :
263 : : // difference: x != x2 /\ y = y2 => piand(k,x,y) != x2 \/
264 : : // piand(k,x2,y2) != x
265 : 0 : if (model_k > 0 && model_k == model_k2 && model_x != model_x2
266 : 0 : && model_y == model_y2 && model_piand == model_x2
267 : 0 : && model_piand2 == model_x)
268 : : {
269 : 0 : Node noneqx = nm->mkNode(
270 : : Kind::AND,
271 : 0 : {k.eqNode(k2), (x.eqNode(x2)).notNode(), y.eqNode(y2)});
272 : : Node ranges_assum =
273 : 0 : nm->mkNode(Kind::AND, x_range, x2_range, y_range);
274 : : Node assum_difference =
275 : 0 : nm->mkNode(Kind::AND, k_gt_0, noneqx, ranges_assum);
276 : 0 : Node difference = nm->mkNode(
277 : 0 : Kind::OR, {i.eqNode(x2).notNode(), n.eqNode(x).notNode()});
278 : : Node diff_lemm =
279 : 0 : nm->mkNode(Kind::IMPLIES, assum_difference, difference);
280 : 0 : d_im.addPendingLemma(diff_lemm,
281 : : InferenceId::ARITH_NL_PIAND_DIFFERENCE_REFINE,
282 : : nullptr,
283 : : true);
284 : 0 : }
285 : :
286 : : // symmetry: piand(k,x,y) = piand(k,y,x)
287 : 0 : if (model_k == model_k2 && model_x == model_y2 && model_x2 == model_y
288 : 0 : && model_piand != model_piand2)
289 : : {
290 : 0 : Node assum_sym = nm->mkNode(
291 : 0 : Kind::AND, {k.eqNode(k2), (x.eqNode(y2)), y.eqNode(x2)});
292 : 0 : Node sym_lemm = nm->mkNode(Kind::IMPLIES, assum_sym, i.eqNode(n));
293 : 0 : d_im.addPendingLemma(sym_lemm,
294 : : InferenceId::ARITH_NL_PIAND_SYMETRY_REFINE,
295 : : nullptr,
296 : : true);
297 : 0 : }
298 : 0 : }
299 : : }
300 : :
301 : : // contradition: x+y mod 2^k = 2^k-1 => piand(k,x,y) = 0
302 : 0 : if (model_x + model_y == max_int && model_piand != 0)
303 : : {
304 : 0 : Node x_plus_y = nm->mkNode(Kind::ADD, x, y);
305 : 0 : Node x_plus_y_mod = nm->mkNode(Kind::INTS_MODULUS, x_plus_y, twok);
306 : 0 : Node twok_minus_one = nm->mkNode(Kind::SUB, twok, d_one);
307 : 0 : Node assum = nm->mkNode(Kind::EQUAL, x_plus_y_mod, twok_minus_one);
308 : 0 : Node piand_zero = nm->mkNode(Kind::EQUAL, i, d_zero);
309 : 0 : Node neg_lem = nm->mkNode(Kind::IMPLIES, assum, piand_zero);
310 : 0 : d_im.addPendingLemma(neg_lem,
311 : : InferenceId::ARITH_NL_PIAND_CONTRADITION_REFINE,
312 : : nullptr,
313 : : true);
314 : 0 : }
315 : :
316 : : // one: k > 0 && y = 1 -> piand(k,x,y) = x mod 2
317 : 0 : if (model_k > 0 && model_y == 1 && model_piand != model_x.modByPow2(1))
318 : : {
319 : 0 : Node y_equal_one = nm->mkNode(Kind::EQUAL, y, d_one);
320 : 0 : Node asum_lsb = nm->mkNode(Kind::AND, k_gt_0, y_equal_one);
321 : 0 : Node lsb = nm->mkNode(Kind::EQUAL, i, arg0Mod2);
322 : 0 : Node y_one_lem = nm->mkNode(Kind::IMPLIES, asum_lsb, lsb);
323 : 0 : d_im.addPendingLemma(
324 : : y_one_lem, InferenceId::ARITH_NL_PIAND_ONE_REFINE, nullptr, true);
325 : 0 : }
326 : :
327 : : // one: k > 0 && x = 1 -> piand(k,x,y) = y mod 2
328 : 0 : if (model_k > 0 && model_x == 1 && model_piand != model_y.modByPow2(1))
329 : : {
330 : 0 : Node x_equal_one = nm->mkNode(Kind::EQUAL, x, d_one);
331 : 0 : Node asum_lsb2 = nm->mkNode(Kind::AND, k_gt_0, x_equal_one);
332 : 0 : Node lsb2 = nm->mkNode(Kind::EQUAL, i, arg1Mod2);
333 : 0 : Node x_one_lem = nm->mkNode(Kind::IMPLIES, asum_lsb2, lsb2);
334 : 0 : d_im.addPendingLemma(
335 : : x_one_lem, InferenceId::ARITH_NL_PIAND_ONE_REFINE, nullptr, true);
336 : 0 : }
337 : :
338 : 0 : Node lem_sum = sumBasedLemma(i, Kind::GEQ);
339 : 0 : d_im.addPendingLemma(
340 : : lem_sum, InferenceId::ARITH_NL_PIAND_SUM_REFINE, nullptr, true);
341 [ - + ][ - + ]: 168 : }
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
342 : : }
343 : 2182 : }
344 : :
345 : 0 : Node PIAndSolver::sumBasedLemma(Node i, Kind kind)
346 : : {
347 : 0 : Assert(i.getKind() == Kind::PIAND);
348 : 0 : Node k = d_model.computeConcreteModelValue(i[0]);
349 : 0 : Node x = i[1];
350 : 0 : Node y = i[2];
351 : 0 : uint64_t granularity = options().smt.BVAndIntegerGranularity;
352 : 0 : uint64_t int_k = k.getConst<Rational>().getNumerator().toUnsignedInt();
353 : 0 : NodeManager* nm = nodeManager();
354 : : // (i[0] >= k /\ 0 <= x < 2^k /\ 0 <= y < 2^k) => i = sum
355 : 0 : Node width = nm->mkNode(kind, i[0], k);
356 : 0 : Node condition;
357 : 0 : Node pow2_k = nm->mkConstInt(Integer(2).pow(int_k));
358 : 0 : Node zero = nm->mkConstInt(Rational(0));
359 : 0 : Node x_pos = nm->mkNode(Kind::GEQ, x, zero);
360 : 0 : Node y_pos = nm->mkNode(Kind::GEQ, y, zero);
361 : 0 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, pow2_k);
362 : 0 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, pow2_k);
363 : 0 : Node bound_x = nm->mkNode(Kind::AND, x_lt_pow2, x_pos);
364 : 0 : Node bound_y = nm->mkNode(Kind::AND, y_lt_pow2, y_pos);
365 : 0 : condition = nm->mkNode(Kind::AND, bound_x, bound_y, width);
366 : : Node then = nm->mkNode(
367 : 0 : Kind::EQUAL, i, d_iandUtils.createSumNode(x, y, int_k, granularity));
368 : 0 : Node lem = nm->mkNode(Kind::IMPLIES, condition, then);
369 : 0 : return lem;
370 : 0 : }
371 : :
372 : : } // namespace nl
373 : : } // namespace arith
374 : : } // namespace theory
375 : : } // namespace cvc5::internal
|