Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Mudathir Mohamed, Aina Niemetz, Andrew Reynolds
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Solver for the theory of bags.
14 : : */
15 : :
16 : : #include "theory/bags/bag_solver.h"
17 : :
18 : : #include "expr/emptybag.h"
19 : : #include "theory/bags/bags_utils.h"
20 : : #include "theory/bags/inference_generator.h"
21 : : #include "theory/bags/inference_manager.h"
22 : : #include "theory/bags/solver_state.h"
23 : : #include "theory/bags/term_registry.h"
24 : : #include "theory/uf/equality_engine_iterator.h"
25 : : #include "util/rational.h"
26 : :
27 : : using namespace std;
28 : : using namespace cvc5::context;
29 : : using namespace cvc5::internal::kind;
30 : :
31 : : namespace cvc5::internal {
32 : : namespace theory {
33 : : namespace bags {
34 : :
35 : 49907 : BagSolver::BagSolver(Env& env, SolverState& s, InferenceManager& im)
36 : : : EnvObj(env),
37 : 49907 : d_state(s),
38 : 49907 : d_ig(env.getNodeManager(), &s, &im),
39 : 49907 : d_im(im),
40 : 99814 : d_mapCache(userContext())
41 : : {
42 : 49907 : d_zero = nodeManager()->mkConstInt(Rational(0));
43 : 49907 : d_one = nodeManager()->mkConstInt(Rational(1));
44 : 49907 : d_true = nodeManager()->mkConst(true);
45 : 49907 : d_false = nodeManager()->mkConst(false);
46 : 49907 : }
47 : :
48 : 49562 : BagSolver::~BagSolver() {}
49 : :
50 : 44878 : void BagSolver::checkBasicOperations()
51 : : {
52 : 44878 : checkDisequalBagTerms();
53 : :
54 : : // At this point, all bag and count representatives should be in the solver
55 : : // state.
56 [ + + ]: 50583 : for (const Node& bag : d_state.getBags())
57 : : {
58 : : // iterate through all bags terms in each equivalent class
59 : : eq::EqClassIterator it =
60 : 5705 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
61 [ + + ]: 20317 : while (!it.isFinished())
62 : : {
63 : 14612 : Node n = (*it);
64 : 14612 : Kind k = n.getKind();
65 [ + + ][ + + ]: 14612 : switch (k)
[ + + ][ + + ]
[ + + ][ + + ]
[ + ]
66 : : {
67 : 735 : case Kind::BAG_EMPTY: checkEmpty(n); break;
68 : 1084 : case Kind::BAG_MAKE: checkBagMake(n); break;
69 : 651 : case Kind::BAG_UNION_DISJOINT: checkUnionDisjoint(n); break;
70 : 85 : case Kind::BAG_UNION_MAX: checkUnionMax(n); break;
71 : 49 : case Kind::BAG_INTER_MIN: checkIntersectionMin(n); break;
72 : 68 : case Kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
73 : 108 : case Kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
74 : 29 : case Kind::BAG_SETOF: checkSetof(n); break;
75 : 423 : case Kind::BAG_FILTER: checkFilter(n); break;
76 : 114 : case Kind::TABLE_PRODUCT: checkProduct(n); break;
77 : 32 : case Kind::TABLE_JOIN: checkJoin(n); break;
78 : 286 : case Kind::TABLE_GROUP: checkGroup(n); break;
79 : 10948 : default: break;
80 : : }
81 : 14612 : it++;
82 : 14612 : }
83 : : }
84 : :
85 : : // add non negative constraints for all multiplicities
86 [ + + ]: 50583 : for (const Node& n : d_state.getBags())
87 : : {
88 [ + + ]: 17257 : for (const Node& e : d_state.getElements(n))
89 : : {
90 : 11552 : checkNonNegativeCountTerms(n, d_state.getRepresentative(e));
91 : 5705 : }
92 : : }
93 : 44878 : }
94 : :
95 : 44194 : void BagSolver::checkQuantifiedOperations()
96 : : {
97 [ + + ]: 46232 : for (const Node& bag : d_state.getBags())
98 : : {
99 : : // iterate through all bags terms in each equivalent class
100 : : eq::EqClassIterator it =
101 : 2038 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
102 [ + + ]: 6753 : while (!it.isFinished())
103 : : {
104 : 4715 : Node n = (*it);
105 : 4715 : Kind k = n.getKind();
106 [ + + ]: 4715 : switch (k)
107 : : {
108 : 431 : case Kind::BAG_MAP: checkMap(n); break;
109 : 4284 : default: break;
110 : : }
111 : 4715 : it++;
112 : 4715 : }
113 : : }
114 : :
115 : : // add non negative constraints for all multiplicities
116 [ + + ]: 46232 : for (const Node& n : d_state.getBags())
117 : : {
118 [ + + ]: 5910 : for (const Node& e : d_state.getElements(n))
119 : : {
120 : 3872 : checkNonNegativeCountTerms(n, d_state.getRepresentative(e));
121 : 2038 : }
122 : : }
123 : 44194 : }
124 : :
125 : 961 : set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
126 : : {
127 : 961 : set<Node> elements;
128 : 961 : const set<Node>& downwards = d_state.getElements(n);
129 : 961 : const set<Node>& upwards0 = d_state.getElements(n[0]);
130 : 961 : const set<Node>& upwards1 = d_state.getElements(n[1]);
131 : :
132 : 961 : set_union(downwards.begin(),
133 : : downwards.end(),
134 : : upwards0.begin(),
135 : : upwards0.end(),
136 : : inserter(elements, elements.begin()));
137 : 961 : elements.insert(upwards1.begin(), upwards1.end());
138 : 1922 : return elements;
139 : 961 : }
140 : :
141 : 735 : void BagSolver::checkEmpty(const Node& n)
142 : : {
143 [ - + ][ - + ]: 735 : Assert(n.getKind() == Kind::BAG_EMPTY);
[ - - ]
144 [ + + ]: 1862 : for (const Node& e : d_state.getElements(n))
145 : : {
146 : 2254 : InferInfo i = d_ig.empty(n, d_state.getRepresentative(e));
147 : 1127 : d_im.lemmaTheoryInference(&i);
148 : 1862 : }
149 : 735 : }
150 : :
151 : 651 : void BagSolver::checkUnionDisjoint(const Node& n)
152 : : {
153 [ - + ][ - + ]: 651 : Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
[ - - ]
154 : 651 : std::set<Node> elements = getElementsForBinaryOperator(n);
155 [ + + ]: 2451 : for (const Node& e : elements)
156 : : {
157 : 3600 : InferInfo i = d_ig.unionDisjoint(n, d_state.getRepresentative(e));
158 : 1800 : d_im.lemmaTheoryInference(&i);
159 : 1800 : }
160 : 651 : }
161 : :
162 : 85 : void BagSolver::checkUnionMax(const Node& n)
163 : : {
164 [ - + ][ - + ]: 85 : Assert(n.getKind() == Kind::BAG_UNION_MAX);
[ - - ]
165 : 85 : std::set<Node> elements = getElementsForBinaryOperator(n);
166 [ + + ]: 210 : for (const Node& e : elements)
167 : : {
168 : 250 : InferInfo i = d_ig.unionMax(n, d_state.getRepresentative(e));
169 : 125 : d_im.lemmaTheoryInference(&i);
170 : 125 : }
171 : 85 : }
172 : :
173 : 49 : void BagSolver::checkIntersectionMin(const Node& n)
174 : : {
175 [ - + ][ - + ]: 49 : Assert(n.getKind() == Kind::BAG_INTER_MIN);
[ - - ]
176 : 49 : std::set<Node> elements = getElementsForBinaryOperator(n);
177 [ + + ]: 100 : for (const Node& e : elements)
178 : : {
179 : 102 : InferInfo i = d_ig.intersection(n, d_state.getRepresentative(e));
180 : 51 : d_im.lemmaTheoryInference(&i);
181 : 51 : }
182 : 49 : }
183 : :
184 : 68 : void BagSolver::checkDifferenceSubtract(const Node& n)
185 : : {
186 [ - + ][ - + ]: 68 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
[ - - ]
187 : 68 : std::set<Node> elements = getElementsForBinaryOperator(n);
188 [ + + ]: 168 : for (const Node& e : elements)
189 : : {
190 : 200 : InferInfo i = d_ig.differenceSubtract(n, d_state.getRepresentative(e));
191 : 100 : d_im.lemmaTheoryInference(&i);
192 : 100 : }
193 : 68 : }
194 : :
195 : 44936 : bool BagSolver::checkBagMake()
196 : : {
197 : 44936 : bool sentLemma = false;
198 [ + + ]: 50838 : for (const Node& bag : d_state.getBags())
199 : : {
200 : 5902 : TypeNode bagType = bag.getType();
201 : 5902 : NodeManager* nm = nodeManager();
202 : 5902 : Node empty = nm->mkConst(EmptyBag(bagType));
203 : 5902 : if (d_state.areEqual(empty, bag) || d_state.areDisequal(empty, bag))
204 : : {
205 : 2943 : continue;
206 : : }
207 : :
208 : : // look for BAG_MAKE terms in the equivalent class
209 : : eq::EqClassIterator it =
210 : 2959 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
211 [ + + ]: 7761 : while (!it.isFinished())
212 : : {
213 : 4889 : Node n = (*it);
214 [ + + ]: 4889 : if (n.getKind() == Kind::BAG_MAKE)
215 : : {
216 [ + - ]: 87 : Trace("bags-check") << "splitting on node " << std::endl;
217 : 87 : InferInfo i = d_ig.bagMake(n);
218 : 87 : sentLemma |= d_im.lemmaTheoryInference(&i);
219 : : // it is enough to split only once per equivalent class
220 : 87 : break;
221 : 87 : }
222 : 4802 : it++;
223 [ + + ]: 4889 : }
224 [ + + ][ + + ]: 8845 : }
225 : 44936 : return sentLemma;
226 : : }
227 : :
228 : 1084 : void BagSolver::checkBagMake(const Node& n)
229 : : {
230 [ - + ][ - + ]: 1084 : Assert(n.getKind() == Kind::BAG_MAKE);
[ - - ]
231 [ + - ]: 2168 : Trace("bags::BagSolver::postCheck")
232 : 0 : << "BagSolver::checkBagMake Elements of " << n
233 : 1084 : << " are: " << d_state.getElements(n) << std::endl;
234 [ + + ]: 3276 : for (const Node& e : d_state.getElements(n))
235 : : {
236 : 4384 : InferInfo i = d_ig.bagMake(n, d_state.getRepresentative(e));
237 : 2192 : d_im.lemmaTheoryInference(&i);
238 : 3276 : }
239 : 1084 : }
240 : 15424 : void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
241 : : {
242 : 30848 : InferInfo i = d_ig.nonNegativeCount(bag, element);
243 : 15424 : d_im.lemmaTheoryInference(&i);
244 : 15424 : }
245 : :
246 : 108 : void BagSolver::checkDifferenceRemove(const Node& n)
247 : : {
248 [ - + ][ - + ]: 108 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
[ - - ]
249 : 108 : std::set<Node> elements = getElementsForBinaryOperator(n);
250 [ + + ]: 248 : for (const Node& e : elements)
251 : : {
252 : 280 : InferInfo i = d_ig.differenceRemove(n, d_state.getRepresentative(e));
253 : 140 : d_im.lemmaTheoryInference(&i);
254 : 140 : }
255 : 108 : }
256 : :
257 : 29 : void BagSolver::checkSetof(Node n)
258 : : {
259 [ - + ][ - + ]: 29 : Assert(n.getKind() == Kind::BAG_SETOF);
[ - - ]
260 : 29 : set<Node> elements;
261 : 29 : const set<Node>& downwards = d_state.getElements(n);
262 : 29 : const set<Node>& upwards = d_state.getElements(n[0]);
263 : :
264 : 29 : elements.insert(downwards.begin(), downwards.end());
265 : 29 : elements.insert(upwards.begin(), upwards.end());
266 : :
267 [ + + ]: 54 : for (const Node& e : elements)
268 : : {
269 : 50 : InferInfo i = d_ig.setof(n, d_state.getRepresentative(e));
270 : 25 : d_im.lemmaTheoryInference(&i);
271 : 25 : }
272 : 29 : }
273 : :
274 : 44878 : void BagSolver::checkDisequalBagTerms()
275 : : {
276 [ + + ]: 47609 : for (const auto& [equality, witness] : d_state.getDisequalBagTerms())
277 : : {
278 : 5462 : InferInfo info = d_ig.bagDisequality(equality, witness);
279 : 2731 : d_im.lemmaTheoryInference(&info);
280 : 2731 : }
281 : 44878 : }
282 : :
283 : 431 : void BagSolver::checkMap(Node n)
284 : : {
285 [ - + ][ - + ]: 431 : Assert(n.getKind() == Kind::BAG_MAP);
[ - - ]
286 : 431 : const set<Node>& downwards = d_state.getElements(n);
287 : 431 : const set<Node>& upwards = d_state.getElements(n[1]);
288 [ + + ]: 997 : for (const Node& x : upwards)
289 : : {
290 : 1132 : InferInfo upInference = d_ig.mapUp1(n, x);
291 : 566 : d_im.lemmaTheoryInference(&upInference);
292 : 566 : }
293 : :
294 [ + + ]: 431 : if (d_state.isInjective(n[0]))
295 : : {
296 [ + + ]: 117 : for (const Node& z : downwards)
297 : : {
298 : 150 : InferInfo upInference = d_ig.mapDownInjective(n, z);
299 : 75 : d_im.lemmaTheoryInference(&upInference);
300 : 75 : }
301 : : }
302 : : else
303 : : {
304 [ + + ]: 840 : for (const Node& z : downwards)
305 : : {
306 : 902 : Node y = d_state.getRepresentative(z);
307 [ + + ]: 451 : if (!d_mapCache.count(n))
308 : : {
309 : : std::shared_ptr<context::CDHashMap<Node, std::pair<Node, Node>>> nMap =
310 : : std::make_shared<context::CDHashMap<Node, std::pair<Node, Node>>>(
311 : 19 : userContext());
312 : 19 : d_mapCache[n] = nMap;
313 : 19 : }
314 [ + + ]: 451 : if (!d_mapCache[n].get()->count(y))
315 : : {
316 : 96 : auto [downInference, uf, preImageSize] = d_ig.mapDown(n, y);
317 : 48 : d_im.lemmaTheoryInference(&downInference);
318 : 48 : std::pair<Node, Node> yPair = std::make_pair(uf, preImageSize);
319 : 48 : d_mapCache[n].get()->insert(y, yPair);
320 : 48 : }
321 : :
322 : : context::CDHashMap<Node, std::pair<Node, Node>>::iterator it =
323 : 451 : d_mapCache[n].get()->find(y);
324 : :
325 : 451 : auto [uf, preImageSize] = it->second;
326 : :
327 [ + + ]: 1309 : for (const Node& x : upwards)
328 : : {
329 : 1716 : InferInfo upInference = d_ig.mapUp2(n, uf, preImageSize, y, x);
330 : 858 : d_im.lemmaTheoryInference(&upInference);
331 : 858 : }
332 : 451 : }
333 : : }
334 : 431 : }
335 : :
336 : 423 : void BagSolver::checkFilter(Node n)
337 : : {
338 [ - + ][ - + ]: 423 : Assert(n.getKind() == Kind::BAG_FILTER);
[ - - ]
339 : :
340 : 423 : set<Node> elements;
341 : 423 : const set<Node>& downwards = d_state.getElements(n);
342 : 423 : const set<Node>& upwards = d_state.getElements(n[1]);
343 : 423 : elements.insert(downwards.begin(), downwards.end());
344 : 423 : elements.insert(upwards.begin(), upwards.end());
345 : :
346 [ + + ]: 1377 : for (const Node& e : elements)
347 : : {
348 : 1908 : InferInfo i = d_ig.filterDown(n, d_state.getRepresentative(e));
349 : 954 : d_im.lemmaTheoryInference(&i);
350 : 954 : }
351 [ + + ]: 1377 : for (const Node& e : elements)
352 : : {
353 : 1908 : InferInfo i = d_ig.filterUp(n, d_state.getRepresentative(e));
354 : 954 : d_im.lemmaTheoryInference(&i);
355 : 954 : }
356 : 423 : }
357 : :
358 : 114 : void BagSolver::checkProduct(Node n)
359 : : {
360 [ - + ][ - + ]: 114 : Assert(n.getKind() == Kind::TABLE_PRODUCT);
[ - - ]
361 : 114 : const set<Node>& elementsA = d_state.getElements(n[0]);
362 : 114 : const set<Node>& elementsB = d_state.getElements(n[1]);
363 : :
364 [ + + ]: 186 : for (const Node& e1 : elementsA)
365 : : {
366 [ + + ]: 192 : for (const Node& e2 : elementsB)
367 : : {
368 : : InferInfo i = d_ig.productUp(
369 : 240 : n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
370 : 120 : d_im.lemmaTheoryInference(&i);
371 : 120 : }
372 : : }
373 : :
374 : 114 : std::set<Node> elements = d_state.getElements(n);
375 [ + + ]: 230 : for (const Node& e : elements)
376 : : {
377 : 232 : InferInfo i = d_ig.productDown(n, d_state.getRepresentative(e));
378 : 116 : d_im.lemmaTheoryInference(&i);
379 : 116 : }
380 : 114 : }
381 : :
382 : 32 : void BagSolver::checkJoin(Node n)
383 : : {
384 [ - + ][ - + ]: 32 : Assert(n.getKind() == Kind::TABLE_JOIN);
[ - - ]
385 : 32 : const set<Node>& elementsA = d_state.getElements(n[0]);
386 : 32 : const set<Node>& elementsB = d_state.getElements(n[1]);
387 : :
388 [ + + ]: 98 : for (const Node& e1 : elementsA)
389 : : {
390 [ + + ]: 252 : for (const Node& e2 : elementsB)
391 : : {
392 : : InferInfo i = d_ig.joinUp(
393 : 372 : n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
394 : 186 : d_im.lemmaTheoryInference(&i);
395 : 186 : }
396 : : }
397 : :
398 : 32 : std::set<Node> elements = d_state.getElements(n);
399 [ + + ]: 180 : for (const Node& e : elements)
400 : : {
401 : 296 : InferInfo i = d_ig.joinDown(n, d_state.getRepresentative(e));
402 : 148 : d_im.lemmaTheoryInference(&i);
403 : 148 : }
404 : 32 : }
405 : :
406 : 286 : void BagSolver::checkGroup(Node n)
407 : : {
408 [ - + ][ - + ]: 286 : Assert(n.getKind() == Kind::TABLE_GROUP);
[ - - ]
409 : :
410 : 286 : InferInfo notEmpty = d_ig.groupNotEmpty(n);
411 : 286 : d_im.lemmaTheoryInference(¬Empty);
412 : :
413 : 286 : Node part = d_ig.defineSkolemPartFunction(n);
414 : :
415 : 286 : const set<Node>& elementsA = d_state.getElements(n[0]);
416 : : std::shared_ptr<context::CDHashSet<Node>> skolems =
417 : 286 : d_state.getPartElementSkolems(n);
418 [ + + ]: 1778 : for (const Node& a : elementsA)
419 : : {
420 [ - + ]: 1492 : if (skolems->contains(a))
421 : : {
422 : : // skip skolem elements that were introduced by groupPartCount below.
423 : 0 : continue;
424 : : }
425 : 2984 : Node aRep = d_state.getRepresentative(a);
426 : 2984 : InferInfo i = d_ig.groupUp1(n, aRep, part);
427 : 1492 : d_im.lemmaTheoryInference(&i);
428 : 1492 : i = d_ig.groupUp2(n, aRep, part);
429 : 1492 : d_im.lemmaTheoryInference(&i);
430 : 1492 : }
431 : :
432 : 286 : std::set<Node> parts = d_state.getElements(n);
433 [ + + ]: 1040 : for (std::set<Node>::iterator partIt1 = parts.begin(); partIt1 != parts.end();
434 : 754 : ++partIt1)
435 : : {
436 : 1508 : Node part1 = d_state.getRepresentative(*partIt1);
437 : 754 : std::vector<Node> partEqc;
438 : 754 : d_state.getEquivalenceClass(part1, partEqc);
439 : 754 : bool newPart = true;
440 [ + + ]: 8602 : for (Node p : partEqc)
441 : : {
442 [ + + ][ + + ]: 7848 : if (p.getKind() == Kind::APPLY_UF && p.getOperator() == part)
[ + + ][ + + ]
[ - - ]
443 : : {
444 : 2052 : newPart = false;
445 : : }
446 : 7848 : }
447 [ + + ]: 754 : if (newPart)
448 : : {
449 : : // only apply the groupPartCount rule for a part that does not have
450 : : // nodes of the form (part x) introduced by the group up rule above.
451 : 308 : InferInfo partCardinality = d_ig.groupPartCount(n, part1, part);
452 : 154 : d_im.lemmaTheoryInference(&partCardinality);
453 : 154 : }
454 : :
455 : 754 : std::set<Node> partElements = d_state.getElements(part1);
456 : 754 : for (std::set<Node>::iterator i = partElements.begin();
457 [ + + ]: 4912 : i != partElements.end();
458 : 4158 : ++i)
459 : : {
460 : 8316 : Node x = d_state.getRepresentative(*i);
461 [ + + ]: 4158 : if (!skolems->contains(x))
462 : : {
463 : : // only apply down rules for elements not generated by groupPartCount
464 : : // rule above
465 : 7872 : InferInfo down = d_ig.groupDown(n, part1, x, part);
466 : 3936 : d_im.lemmaTheoryInference(&down);
467 : 3936 : }
468 : :
469 : 4158 : std::set<Node>::iterator j = i;
470 : 4158 : ++j;
471 [ + + ]: 17898 : while (j != partElements.end())
472 : : {
473 : 27480 : Node y = d_state.getRepresentative(*j);
474 : : // x, y should have the same projection
475 : : InferInfo sameProjection =
476 : 27480 : d_ig.groupSameProjection(n, part1, x, y, part);
477 : 13740 : d_im.lemmaTheoryInference(&sameProjection);
478 : 13740 : ++j;
479 : 13740 : }
480 : :
481 [ + + ]: 36720 : for (const Node& a : elementsA)
482 : : {
483 : 65124 : Node y = d_state.getRepresentative(a);
484 [ + + ]: 32562 : if (x != y)
485 : : {
486 : : // x, y should have the same projection
487 : 57248 : InferInfo samePart = d_ig.groupSamePart(n, part1, x, y, part);
488 : 28624 : d_im.lemmaTheoryInference(&samePart);
489 : 28624 : }
490 : 32562 : }
491 : 4158 : }
492 : 754 : }
493 : 286 : }
494 : :
495 : : } // namespace bags
496 : : } // namespace theory
497 : : } // namespace cvc5::internal
|