Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of parametric integer and (PIAND) solver.
11 : : */
12 : :
13 : : #include "theory/arith/nl/piand_solver.h"
14 : :
15 : : #include "options/arith_options.h"
16 : : #include "options/smt_options.h"
17 : : #include "preprocessing/passes/bv_to_int.h"
18 : : #include "theory/arith/arith_msum.h"
19 : : #include "theory/arith/arith_utilities.h"
20 : : #include "theory/arith/inference_manager.h"
21 : : #include "theory/arith/nl/nl_model.h"
22 : : #include "theory/rewriter.h"
23 : : #include "util/bitvector.h"
24 : : #include "util/iand.h"
25 : :
26 : : using namespace cvc5::internal::kind;
27 : :
28 : : namespace cvc5::internal {
29 : : namespace theory {
30 : : namespace arith {
31 : : namespace nl {
32 : :
33 : 33012 : PIAndSolver::PIAndSolver(Env& env, InferenceManager& im, NlModel& model)
34 : : : EnvObj(env),
35 : 33012 : d_im(im),
36 : 33012 : d_model(model),
37 : 33012 : d_iandUtils(env.getNodeManager()),
38 : 66024 : d_initRefine(userContext())
39 : : {
40 : 33012 : NodeManager* nm = nodeManager();
41 : 33012 : d_false = nm->mkConst(false);
42 : 33012 : d_true = nm->mkConst(true);
43 : 33012 : d_zero = nm->mkConstInt(Rational(0));
44 : 33012 : d_one = nm->mkConstInt(Rational(1));
45 : 33012 : d_two = nm->mkConstInt(Rational(2));
46 : 33012 : }
47 : :
48 : 32728 : PIAndSolver::~PIAndSolver() {}
49 : :
50 : 11182 : void PIAndSolver::initLastCall(const std::vector<Node>& xts)
51 : : {
52 : 11182 : d_piands.clear();
53 : :
54 [ + - ]: 11182 : Trace("piand-mv") << "PIAND terms : " << std::endl;
55 [ + + ]: 72195 : for (const Node& a : xts)
56 : : {
57 : 61013 : Kind ak = a.getKind();
58 [ + + ]: 61013 : if (ak != Kind::PIAND)
59 : : {
60 : : // don't care about other terms
61 : 60936 : continue;
62 : : }
63 : 77 : d_piands[a[0]].push_back(a);
64 : : }
65 [ + - ]: 22364 : Trace("piand") << "We have " << d_piands.size() << " PIAND bit-width."
66 : 11182 : << std::endl;
67 : 11182 : }
68 : :
69 : 11182 : void PIAndSolver::checkInitialRefine()
70 : : {
71 [ + - ]: 11182 : Trace("piand-check") << "PIAndSolver::checkInitialRefine" << std::endl;
72 : 11182 : NodeManager* nm = nodeManager();
73 [ + + ]: 11252 : for (const std::pair<const Node, std::vector<Node> >& is : d_piands)
74 : : {
75 : : // the reference bitwidth
76 : 70 : Node k = is.first;
77 [ + + ]: 147 : for (const Node& i : is.second)
78 : : {
79 : 77 : Node x = i[1];
80 : 77 : Node y = i[2];
81 [ + + ]: 77 : if (d_initRefine.find(i) != d_initRefine.end())
82 : : {
83 : : // already sent initial axioms for i in this user context
84 : 21 : continue;
85 : : }
86 : 56 : d_initRefine.insert(i);
87 : 56 : Node twok = nm->mkNode(Kind::POW2, k);
88 : 112 : Node arg0Mod = nm->mkNode(Kind::INTS_MODULUS, x, twok);
89 : 112 : Node arg1Mod = nm->mkNode(Kind::INTS_MODULUS, y, twok);
90 : 112 : Node arg0Mod2 = nm->mkNode(Kind::INTS_MODULUS, x, d_two);
91 : 112 : Node arg1Mod2 = nm->mkNode(Kind::INTS_MODULUS, y, d_two);
92 : 112 : Node plus = nm->mkNode(Kind::ADD, x, y);
93 : 112 : Node twok_minus_one = nm->mkNode(Kind::SUB, twok, d_one);
94 : 112 : Node k_gt_0 = nm->mkNode(Kind::GT, k, d_zero);
95 : 112 : Node x_geq_zero = nm->mkNode(Kind::GEQ, x, d_zero);
96 : 112 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, twok);
97 : 112 : Node x_range = nm->mkNode(Kind::AND, x_geq_zero, x_lt_pow2);
98 : 112 : Node y_geq_zero = nm->mkNode(Kind::GEQ, y, d_zero);
99 : 112 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, twok);
100 : 112 : Node y_range = nm->mkNode(Kind::AND, y_geq_zero, y_lt_pow2);
101 : :
102 : : // initial refinement lemmas
103 : 56 : std::vector<Node> conj;
104 : :
105 : : // x is an upper bound: x > 0 && x < 2^k && y = 2^k -1 -> piand(k,x,y) = x
106 : 112 : Node y_modpow2_eq_max = nm->mkNode(Kind::EQUAL, y, twok_minus_one);
107 : 112 : Node assum_max = nm->mkNode(Kind::AND, k_gt_0, y_modpow2_eq_max, x_range);
108 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_max, i.eqNode(x)));
109 : :
110 : : // y is an upper bound: y > 0 && y < 2^k && x = 2^k -1 -> piand(k,x,y) = y
111 : 112 : Node x_modpow2_eq_max = nm->mkNode(Kind::EQUAL, x, twok_minus_one);
112 : : Node assum_max_x =
113 : 112 : nm->mkNode(Kind::AND, k_gt_0, x_modpow2_eq_max, y_range);
114 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_max_x, i.eqNode(y)));
115 : :
116 : : // min-y: y = 0 -> piand(k,x,y) = 0
117 : 112 : Node eq_y_zero = nm->mkNode(Kind::EQUAL, y, d_zero);
118 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, eq_y_zero, i.eqNode(d_zero)));
119 : :
120 : : // min-x: x = 0 -> piand(k,x,y) = 0
121 : 112 : Node eq_x_zero = nm->mkNode(Kind::EQUAL, x, d_zero);
122 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, eq_x_zero, i.eqNode(d_zero)));
123 : :
124 : : // idempotence: k > 0 && x > 0 && x < 2^k && x = y -> piand(k,x,y) = x
125 : 112 : Node eq_y_x = nm->mkNode(Kind::EQUAL, y, x);
126 : 112 : Node assum_idempotence = nm->mkNode(Kind::AND, k_gt_0, eq_y_x, x_range);
127 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, assum_idempotence, i.eqNode(x)));
128 : :
129 : : // symmetry: piand(k,x,y) = piand(k,y,x)
130 : 112 : Node piand_y_x = nm->mkNode(Kind::PIAND, k, y, x);
131 : 56 : conj.push_back(nm->mkNode(Kind::EQUAL, i, piand_y_x));
132 : :
133 : : // range1: 0 <= piand(k,x,y)
134 : 56 : conj.push_back(nm->mkNode(Kind::LEQ, d_zero, i));
135 : :
136 : : // range 2: 0 <= x -> piand(k,x,y) <= x
137 : 112 : Node i_leq_x = nm->mkNode(Kind::LEQ, i, x);
138 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, x_geq_zero, i_leq_x));
139 : :
140 : : // range 3: 0 <= y -> piand(k,x,y) <= y
141 : 112 : Node i_leq_y = nm->mkNode(Kind::LEQ, i, y);
142 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, y_geq_zero, i_leq_y));
143 : :
144 : : // non-positive bitwidth: k <= 0 -> piand(k,x, y) = 0
145 : 112 : Node k_le_0 = nm->mkNode(Kind::LEQ, k, d_zero);
146 : 56 : conj.push_back(nm->mkNode(Kind::IMPLIES, k_le_0, i.eqNode(d_zero)));
147 : :
148 : : // lsb lemma for x: x mod 2 = 0 => piand(k,x,y) % 2 = 0
149 : 112 : Node piand_mod_two = nm->mkNode(Kind::INTS_MODULUS, i, d_two);
150 : 112 : Node arg1Mod2_eq_zero = nm->mkNode(Kind::EQUAL, arg1Mod2, d_zero);
151 : 56 : conj.push_back(nm->mkNode(
152 : 112 : Kind::IMPLIES, arg1Mod2_eq_zero, piand_mod_two.eqNode(d_zero)));
153 : :
154 : : // lsb lemma for y: y mod 2 = 0 => piand(k,x,y) % 2 = 0
155 : 112 : Node arg0Mod2_eq_zero = nm->mkNode(Kind::EQUAL, arg0Mod2, d_zero);
156 : 56 : conj.push_back(nm->mkNode(
157 : 112 : Kind::IMPLIES, arg0Mod2_eq_zero, piand_mod_two.eqNode(d_zero)));
158 : :
159 : : // insert lemmas
160 [ - + ]: 56 : Node lem = conj.size() == 1 ? conj[0] : nm->mkNode(Kind::AND, conj);
161 [ + - ]: 112 : Trace("piand-lemma") << "PIAndSolver::Lemma: " << lem << " ; INIT_REFINE"
162 : 56 : << std::endl;
163 : 56 : d_im.addPendingLemma(lem, InferenceId::ARITH_NL_PIAND_INIT_REFINE);
164 [ + + ][ + + ]: 98 : }
165 : 70 : }
166 : 11182 : }
167 : :
168 : 2189 : void PIAndSolver::checkFullRefine()
169 : : {
170 : 2189 : NodeManager* nm = nodeManager();
171 [ + - ]: 2189 : Trace("piand-check") << "PIAndSolver::checkFullRefine";
172 [ + + ]: 2203 : for (const std::pair<const Node, std::vector<Node> >& is : d_piands)
173 : : {
174 : 14 : int index = 0;
175 [ + + ]: 28 : for (const Node& i : is.second)
176 : : {
177 : 14 : index++;
178 : 14 : Node valAndXY = d_model.computeAbstractModelValue(i);
179 : 14 : Node valAndXYC = d_model.computeConcreteModelValue(i);
180 : 14 : valAndXYC = rewrite(valAndXYC);
181 : :
182 : 14 : Node k = i[0];
183 : 14 : Node x = i[1];
184 : 14 : Node y = i[2];
185 : 14 : Node valK = d_model.computeConcreteModelValue(k);
186 : 14 : Node valX = d_model.computeConcreteModelValue(x);
187 : 14 : Node valY = d_model.computeConcreteModelValue(y);
188 : :
189 : 14 : Integer model_piand = valAndXYC.getConst<Rational>().getNumerator();
190 : 14 : Integer model_k = valK.getConst<Rational>().getNumerator();
191 : 14 : Integer model_x = valX.getConst<Rational>().getNumerator();
192 : 14 : Integer model_y = valY.getConst<Rational>().getNumerator();
193 : :
194 [ - + ]: 14 : if (TraceIsOn("piand-check"))
195 : : {
196 [ - - ]: 0 : Trace("piand-check")
197 : 0 : << "* " << i << ", value = " << valAndXY << std::endl;
198 [ - - ]: 0 : Trace("piand-check") << " actual (" << valX << ", " << valY
199 : 0 : << ") = " << valAndXYC << std::endl;
200 : : }
201 [ + - ]: 14 : if (valAndXY == valAndXYC)
202 : : {
203 [ + - ]: 14 : Trace("piand-check") << "...already correct" << std::endl;
204 : 14 : continue;
205 : : }
206 : :
207 : 0 : Integer ione = 1;
208 : 0 : Integer itwo = 2;
209 : 0 : Integer ipow2 = itwo.pow(model_k.getLong());
210 : 0 : Integer max_int = ipow2 - 1;
211 : 0 : Node k_gt_0 = nm->mkNode(Kind::GT, k, d_zero);
212 : 0 : Node twok = nm->mkNode(Kind::POW2, k);
213 : 0 : Node arg0Mod = nm->mkNode(Kind::INTS_MODULUS, x, twok);
214 : 0 : Node arg1Mod = nm->mkNode(Kind::INTS_MODULUS, y, twok);
215 : 0 : Node arg0Mod2 = nm->mkNode(Kind::INTS_MODULUS, x, d_two);
216 : 0 : Node arg1Mod2 = nm->mkNode(Kind::INTS_MODULUS, y, d_two);
217 : :
218 : : // base case: piand(k,1,1) = 1
219 : 0 : if (model_k > 0 && model_x == 1 && model_y == 1 && model_piand != 1)
220 : : {
221 : 0 : Node x_equal_one = nm->mkNode(Kind::EQUAL, x, d_one);
222 : 0 : Node y_equal_one = nm->mkNode(Kind::EQUAL, y, d_one);
223 : 0 : Node assum = nm->mkNode(Kind::AND, k_gt_0, x_equal_one, y_equal_one);
224 : 0 : Node piand_one = nm->mkNode(Kind::EQUAL, i, d_one);
225 : 0 : Node xy_one_lem = nm->mkNode(Kind::IMPLIES, assum, piand_one);
226 : 0 : d_im.addPendingLemma(xy_one_lem,
227 : : InferenceId::ARITH_NL_PIAND_BASE_CASE_REFINE,
228 : : nullptr,
229 : : true);
230 : 0 : }
231 : :
232 : 0 : Node x_geq_zero = nm->mkNode(Kind::GEQ, x, d_zero);
233 : 0 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, twok);
234 : 0 : Node x_range = nm->mkNode(Kind::AND, x_geq_zero, x_lt_pow2);
235 : 0 : Node y_geq_zero = nm->mkNode(Kind::GEQ, y, d_zero);
236 : 0 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, twok);
237 : 0 : Node y_range = nm->mkNode(Kind::AND, y_geq_zero, y_lt_pow2);
238 : 0 : int j = -1;
239 [ - - ]: 0 : for (const Node& n : is.second)
240 : : {
241 : 0 : j++;
242 [ - - ]: 0 : if (j > index)
243 : : {
244 : 0 : Node k2 = n[0];
245 : 0 : Node x2 = n[1];
246 : 0 : Node y2 = n[2];
247 : 0 : Node valK2 = d_model.computeConcreteModelValue(k2);
248 : 0 : Node valX2 = d_model.computeConcreteModelValue(x2);
249 : 0 : Node valY2 = d_model.computeConcreteModelValue(y2);
250 : 0 : Node valAndXYC2 = d_model.computeConcreteModelValue(n);
251 : 0 : Integer model_piand2 = valAndXYC2.getConst<Rational>().getNumerator();
252 : 0 : Integer model_k2 = valK2.getConst<Rational>().getNumerator();
253 : 0 : Integer model_x2 = valX2.getConst<Rational>().getNumerator();
254 : 0 : Integer model_y2 = valY2.getConst<Rational>().getNumerator();
255 : :
256 : 0 : Node arg20Mod = nm->mkNode(Kind::INTS_MODULUS, x2, twok);
257 : 0 : Node arg21Mod = nm->mkNode(Kind::INTS_MODULUS, y2, twok);
258 : :
259 : 0 : Node x2_geq_zero = nm->mkNode(Kind::GEQ, x2, d_zero);
260 : 0 : Node x2_lt_pow2 = nm->mkNode(Kind::LT, x2, twok);
261 : 0 : Node x2_range = nm->mkNode(Kind::AND, x2_geq_zero, x2_lt_pow2);
262 : :
263 : : // difference: x != x2 /\ y = y2 => piand(k,x,y) != x2 \/
264 : : // piand(k,x2,y2) != x
265 : 0 : if (model_k > 0 && model_k == model_k2 && model_x != model_x2
266 : 0 : && model_y == model_y2 && model_piand == model_x2
267 : 0 : && model_piand2 == model_x)
268 : : {
269 : : Node noneqx = nm->mkNode(Kind::AND,
270 : 0 : k.eqNode(k2),
271 : 0 : (x.eqNode(x2)).notNode(),
272 : 0 : y.eqNode(y2));
273 : : Node ranges_assum =
274 : 0 : nm->mkNode(Kind::AND, x_range, x2_range, y_range);
275 : : Node assum_difference =
276 : 0 : nm->mkNode(Kind::AND, k_gt_0, noneqx, ranges_assum);
277 : : Node difference = nm->mkNode(
278 : 0 : Kind::OR, i.eqNode(x2).notNode(), n.eqNode(x).notNode());
279 : : Node diff_lemm =
280 : 0 : nm->mkNode(Kind::IMPLIES, assum_difference, difference);
281 : 0 : d_im.addPendingLemma(diff_lemm,
282 : : InferenceId::ARITH_NL_PIAND_DIFFERENCE_REFINE,
283 : : nullptr,
284 : : true);
285 : 0 : }
286 : :
287 : : // symmetry: piand(k,x,y) = piand(k,y,x)
288 : 0 : if (model_k == model_k2 && model_x == model_y2 && model_x2 == model_y
289 : 0 : && model_piand != model_piand2)
290 : : {
291 : : Node assum_sym = nm->mkNode(
292 : 0 : Kind::AND, k.eqNode(k2), (x.eqNode(y2)), y.eqNode(x2));
293 : 0 : Node sym_lemm = nm->mkNode(Kind::IMPLIES, assum_sym, i.eqNode(n));
294 : 0 : d_im.addPendingLemma(sym_lemm,
295 : : InferenceId::ARITH_NL_PIAND_SYMETRY_REFINE,
296 : : nullptr,
297 : : true);
298 : 0 : }
299 : 0 : }
300 : : }
301 : :
302 : : // contradition: x+y mod 2^k = 2^k-1 => piand(k,x,y) = 0
303 : 0 : if (model_x + model_y == max_int && model_piand != 0)
304 : : {
305 : 0 : Node x_plus_y = nm->mkNode(Kind::ADD, x, y);
306 : 0 : Node x_plus_y_mod = nm->mkNode(Kind::INTS_MODULUS, x_plus_y, twok);
307 : 0 : Node twok_minus_one = nm->mkNode(Kind::SUB, twok, d_one);
308 : 0 : Node assum = nm->mkNode(Kind::EQUAL, x_plus_y_mod, twok_minus_one);
309 : 0 : Node piand_zero = nm->mkNode(Kind::EQUAL, i, d_zero);
310 : 0 : Node neg_lem = nm->mkNode(Kind::IMPLIES, assum, piand_zero);
311 : 0 : d_im.addPendingLemma(neg_lem,
312 : : InferenceId::ARITH_NL_PIAND_CONTRADITION_REFINE,
313 : : nullptr,
314 : : true);
315 : 0 : }
316 : :
317 : : // one: k > 0 && y = 1 -> piand(k,x,y) = x mod 2
318 : 0 : if (model_k > 0 && model_y == 1 && model_piand != model_x.modByPow2(1))
319 : : {
320 : 0 : Node y_equal_one = nm->mkNode(Kind::EQUAL, y, d_one);
321 : 0 : Node asum_lsb = nm->mkNode(Kind::AND, k_gt_0, y_equal_one);
322 : 0 : Node lsb = nm->mkNode(Kind::EQUAL, i, arg0Mod2);
323 : 0 : Node y_one_lem = nm->mkNode(Kind::IMPLIES, asum_lsb, lsb);
324 : 0 : d_im.addPendingLemma(
325 : : y_one_lem, InferenceId::ARITH_NL_PIAND_ONE_REFINE, nullptr, true);
326 : 0 : }
327 : :
328 : : // one: k > 0 && x = 1 -> piand(k,x,y) = y mod 2
329 : 0 : if (model_k > 0 && model_x == 1 && model_piand != model_y.modByPow2(1))
330 : : {
331 : 0 : Node x_equal_one = nm->mkNode(Kind::EQUAL, x, d_one);
332 : 0 : Node asum_lsb2 = nm->mkNode(Kind::AND, k_gt_0, x_equal_one);
333 : 0 : Node lsb2 = nm->mkNode(Kind::EQUAL, i, arg1Mod2);
334 : 0 : Node x_one_lem = nm->mkNode(Kind::IMPLIES, asum_lsb2, lsb2);
335 : 0 : d_im.addPendingLemma(
336 : : x_one_lem, InferenceId::ARITH_NL_PIAND_ONE_REFINE, nullptr, true);
337 : 0 : }
338 : :
339 : 0 : Node lem_sum = sumBasedLemma(i, Kind::GEQ);
340 : 0 : d_im.addPendingLemma(
341 : : lem_sum, InferenceId::ARITH_NL_PIAND_SUM_REFINE, nullptr, true);
342 [ - + ][ - + ]: 168 : }
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
[ - + ][ - + ]
343 : : }
344 : 2189 : }
345 : :
346 : 0 : Node PIAndSolver::sumBasedLemma(Node i, Kind kind)
347 : : {
348 : 0 : Assert(i.getKind() == Kind::PIAND);
349 : 0 : Node k = d_model.computeConcreteModelValue(i[0]);
350 : 0 : Node x = i[1];
351 : 0 : Node y = i[2];
352 : 0 : uint64_t granularity = options().smt.BVAndIntegerGranularity;
353 : 0 : uint64_t int_k = k.getConst<Rational>().getNumerator().toUnsignedInt();
354 : 0 : NodeManager* nm = nodeManager();
355 : : // (i[0] >= k /\ 0 <= x < 2^k /\ 0 <= y < 2^k) => i = sum
356 : 0 : Node width = nm->mkNode(kind, i[0], k);
357 : 0 : Node condition;
358 : 0 : Node pow2_k = nm->mkConstInt(Integer(2).pow(int_k));
359 : 0 : Node zero = nm->mkConstInt(Rational(0));
360 : 0 : Node x_pos = nm->mkNode(Kind::GEQ, x, zero);
361 : 0 : Node y_pos = nm->mkNode(Kind::GEQ, y, zero);
362 : 0 : Node x_lt_pow2 = nm->mkNode(Kind::LT, x, pow2_k);
363 : 0 : Node y_lt_pow2 = nm->mkNode(Kind::LT, y, pow2_k);
364 : 0 : Node bound_x = nm->mkNode(Kind::AND, x_lt_pow2, x_pos);
365 : 0 : Node bound_y = nm->mkNode(Kind::AND, y_lt_pow2, y_pos);
366 : 0 : condition = nm->mkNode(Kind::AND, bound_x, bound_y, width);
367 : : Node then = nm->mkNode(
368 : 0 : Kind::EQUAL, i, d_iandUtils.createSumNode(x, y, int_k, granularity));
369 : 0 : Node lem = nm->mkNode(Kind::IMPLIES, condition, then);
370 : 0 : return lem;
371 : 0 : }
372 : :
373 : : } // namespace nl
374 : : } // namespace arith
375 : : } // namespace theory
376 : : } // namespace cvc5::internal
|