Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Solver for the theory of bags.
11 : : */
12 : :
13 : : #include "theory/bags/bag_solver.h"
14 : :
15 : : #include "expr/emptybag.h"
16 : : #include "theory/bags/bags_utils.h"
17 : : #include "theory/bags/inference_generator.h"
18 : : #include "theory/bags/inference_manager.h"
19 : : #include "theory/bags/solver_state.h"
20 : : #include "theory/bags/term_registry.h"
21 : : #include "theory/uf/equality_engine_iterator.h"
22 : : #include "util/rational.h"
23 : :
24 : : using namespace std;
25 : : using namespace cvc5::context;
26 : : using namespace cvc5::internal::kind;
27 : :
28 : : namespace cvc5::internal {
29 : : namespace theory {
30 : : namespace bags {
31 : :
32 : 51329 : BagSolver::BagSolver(Env& env, SolverState& s, InferenceManager& im)
33 : : : EnvObj(env),
34 : 51329 : d_state(s),
35 : 51329 : d_ig(env.getNodeManager(), &s, &im),
36 : 51329 : d_im(im),
37 : 102658 : d_mapCache(userContext())
38 : : {
39 : 51329 : d_zero = nodeManager()->mkConstInt(Rational(0));
40 : 51329 : d_one = nodeManager()->mkConstInt(Rational(1));
41 : 51329 : d_true = nodeManager()->mkConst(true);
42 : 51329 : d_false = nodeManager()->mkConst(false);
43 : 51329 : }
44 : :
45 : 50981 : BagSolver::~BagSolver() {}
46 : :
47 : 49515 : void BagSolver::checkBasicOperations()
48 : : {
49 : 49515 : checkDisequalBagTerms();
50 : :
51 : : // At this point, all bag and count representatives should be in the solver
52 : : // state.
53 [ + + ]: 55265 : for (const Node& bag : d_state.getBags())
54 : : {
55 : : // iterate through all bags terms in each equivalent class
56 : : eq::EqClassIterator it =
57 : 5750 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
58 [ + + ]: 20431 : while (!it.isFinished())
59 : : {
60 : 14681 : Node n = (*it);
61 : 14681 : Kind k = n.getKind();
62 [ + + ][ + + ]: 14681 : switch (k)
[ + + ][ + + ]
[ + + ][ + + ]
[ + ]
63 : : {
64 : 731 : case Kind::BAG_EMPTY: checkEmpty(n); break;
65 : 1082 : case Kind::BAG_MAKE: checkBagMake(n); break;
66 : 647 : case Kind::BAG_UNION_DISJOINT: checkUnionDisjoint(n); break;
67 : 83 : case Kind::BAG_UNION_MAX: checkUnionMax(n); break;
68 : 49 : case Kind::BAG_INTER_MIN: checkIntersectionMin(n); break;
69 : 68 : case Kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
70 : 106 : case Kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
71 : 29 : case Kind::BAG_SETOF: checkSetof(n); break;
72 : 455 : case Kind::BAG_FILTER: checkFilter(n); break;
73 : 112 : case Kind::TABLE_PRODUCT: checkProduct(n); break;
74 : 30 : case Kind::TABLE_JOIN: checkJoin(n); break;
75 : 286 : case Kind::TABLE_GROUP: checkGroup(n); break;
76 : 11003 : default: break;
77 : : }
78 : 14681 : it++;
79 : 14681 : }
80 : : }
81 : :
82 : : // add non negative constraints for all multiplicities
83 [ + + ]: 55265 : for (const Node& n : d_state.getBags())
84 : : {
85 [ + + ]: 17676 : for (const Node& e : d_state.getElements(n))
86 : : {
87 : 11926 : checkNonNegativeCountTerms(n, d_state.getRepresentative(e));
88 : 5750 : }
89 : : }
90 : 49515 : }
91 : :
92 : 48820 : void BagSolver::checkQuantifiedOperations()
93 : : {
94 [ + + ]: 50848 : for (const Node& bag : d_state.getBags())
95 : : {
96 : : // iterate through all bags terms in each equivalent class
97 : : eq::EqClassIterator it =
98 : 2028 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
99 [ + + ]: 6713 : while (!it.isFinished())
100 : : {
101 : 4685 : Node n = (*it);
102 : 4685 : Kind k = n.getKind();
103 [ + + ]: 4685 : switch (k)
104 : : {
105 : 431 : case Kind::BAG_MAP: checkMap(n); break;
106 : 4254 : default: break;
107 : : }
108 : 4685 : it++;
109 : 4685 : }
110 : : }
111 : :
112 : : // add non negative constraints for all multiplicities
113 [ + + ]: 50848 : for (const Node& n : d_state.getBags())
114 : : {
115 [ + + ]: 6051 : for (const Node& e : d_state.getElements(n))
116 : : {
117 : 4023 : checkNonNegativeCountTerms(n, d_state.getRepresentative(e));
118 : 2028 : }
119 : : }
120 : 48820 : }
121 : :
122 : 953 : set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
123 : : {
124 : 953 : set<Node> elements;
125 : 953 : const set<Node>& downwards = d_state.getElements(n);
126 : 953 : const set<Node>& upwards0 = d_state.getElements(n[0]);
127 : 953 : const set<Node>& upwards1 = d_state.getElements(n[1]);
128 : :
129 : 953 : set_union(downwards.begin(),
130 : : downwards.end(),
131 : : upwards0.begin(),
132 : : upwards0.end(),
133 : : inserter(elements, elements.begin()));
134 : 953 : elements.insert(upwards1.begin(), upwards1.end());
135 : 1906 : return elements;
136 : 953 : }
137 : :
138 : 731 : void BagSolver::checkEmpty(const Node& n)
139 : : {
140 [ - + ][ - + ]: 731 : Assert(n.getKind() == Kind::BAG_EMPTY);
[ - - ]
141 [ + + ]: 1856 : for (const Node& e : d_state.getElements(n))
142 : : {
143 : 2250 : InferInfo i = d_ig.empty(n, d_state.getRepresentative(e));
144 : 1125 : d_im.lemmaTheoryInference(&i);
145 : 1856 : }
146 : 731 : }
147 : :
148 : 647 : void BagSolver::checkUnionDisjoint(const Node& n)
149 : : {
150 [ - + ][ - + ]: 647 : Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
[ - - ]
151 : 647 : std::set<Node> elements = getElementsForBinaryOperator(n);
152 [ + + ]: 2437 : for (const Node& e : elements)
153 : : {
154 : 3580 : InferInfo i = d_ig.unionDisjoint(n, d_state.getRepresentative(e));
155 : 1790 : d_im.lemmaTheoryInference(&i);
156 : 1790 : }
157 : 647 : }
158 : :
159 : 83 : void BagSolver::checkUnionMax(const Node& n)
160 : : {
161 [ - + ][ - + ]: 83 : Assert(n.getKind() == Kind::BAG_UNION_MAX);
[ - - ]
162 : 83 : std::set<Node> elements = getElementsForBinaryOperator(n);
163 [ + + ]: 206 : for (const Node& e : elements)
164 : : {
165 : 246 : InferInfo i = d_ig.unionMax(n, d_state.getRepresentative(e));
166 : 123 : d_im.lemmaTheoryInference(&i);
167 : 123 : }
168 : 83 : }
169 : :
170 : 49 : void BagSolver::checkIntersectionMin(const Node& n)
171 : : {
172 [ - + ][ - + ]: 49 : Assert(n.getKind() == Kind::BAG_INTER_MIN);
[ - - ]
173 : 49 : std::set<Node> elements = getElementsForBinaryOperator(n);
174 [ + + ]: 100 : for (const Node& e : elements)
175 : : {
176 : 102 : InferInfo i = d_ig.intersection(n, d_state.getRepresentative(e));
177 : 51 : d_im.lemmaTheoryInference(&i);
178 : 51 : }
179 : 49 : }
180 : :
181 : 68 : void BagSolver::checkDifferenceSubtract(const Node& n)
182 : : {
183 [ - + ][ - + ]: 68 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
[ - - ]
184 : 68 : std::set<Node> elements = getElementsForBinaryOperator(n);
185 [ + + ]: 168 : for (const Node& e : elements)
186 : : {
187 : 200 : InferInfo i = d_ig.differenceSubtract(n, d_state.getRepresentative(e));
188 : 100 : d_im.lemmaTheoryInference(&i);
189 : 100 : }
190 : 68 : }
191 : :
192 : 49573 : bool BagSolver::checkBagMake()
193 : : {
194 : 49573 : bool sentLemma = false;
195 [ + + ]: 55520 : for (const Node& bag : d_state.getBags())
196 : : {
197 : 5947 : TypeNode bagType = bag.getType();
198 : 5947 : NodeManager* nm = nodeManager();
199 : 5947 : Node empty = nm->mkConst(EmptyBag(bagType));
200 : 5947 : if (d_state.areEqual(empty, bag) || d_state.areDisequal(empty, bag))
201 : : {
202 : 2933 : continue;
203 : : }
204 : :
205 : : // look for BAG_MAKE terms in the equivalent class
206 : : eq::EqClassIterator it =
207 : 3014 : eq::EqClassIterator(bag, d_state.getEqualityEngine());
208 [ + + ]: 7919 : while (!it.isFinished())
209 : : {
210 : 4992 : Node n = (*it);
211 [ + + ]: 4992 : if (n.getKind() == Kind::BAG_MAKE)
212 : : {
213 [ + - ]: 87 : Trace("bags-check") << "splitting on node " << std::endl;
214 : 87 : InferInfo i = d_ig.bagMake(n);
215 : 87 : sentLemma |= d_im.lemmaTheoryInference(&i);
216 : : // it is enough to split only once per equivalent class
217 : 87 : break;
218 : 87 : }
219 : 4905 : it++;
220 [ + + ]: 4992 : }
221 [ + + ][ + + ]: 8880 : }
222 : 49573 : return sentLemma;
223 : : }
224 : :
225 : 1082 : void BagSolver::checkBagMake(const Node& n)
226 : : {
227 [ - + ][ - + ]: 1082 : Assert(n.getKind() == Kind::BAG_MAKE);
[ - - ]
228 [ + - ]: 2164 : Trace("bags::BagSolver::postCheck")
229 : 0 : << "BagSolver::checkBagMake Elements of " << n
230 : 1082 : << " are: " << d_state.getElements(n) << std::endl;
231 [ + + ]: 3276 : for (const Node& e : d_state.getElements(n))
232 : : {
233 : 4388 : InferInfo i = d_ig.bagMake(n, d_state.getRepresentative(e));
234 : 2194 : d_im.lemmaTheoryInference(&i);
235 : 3276 : }
236 : 1082 : }
237 : 15949 : void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
238 : : {
239 : 31898 : InferInfo i = d_ig.nonNegativeCount(bag, element);
240 : 15949 : d_im.lemmaTheoryInference(&i);
241 : 15949 : }
242 : :
243 : 106 : void BagSolver::checkDifferenceRemove(const Node& n)
244 : : {
245 [ - + ][ - + ]: 106 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
[ - - ]
246 : 106 : std::set<Node> elements = getElementsForBinaryOperator(n);
247 [ + + ]: 244 : for (const Node& e : elements)
248 : : {
249 : 276 : InferInfo i = d_ig.differenceRemove(n, d_state.getRepresentative(e));
250 : 138 : d_im.lemmaTheoryInference(&i);
251 : 138 : }
252 : 106 : }
253 : :
254 : 29 : void BagSolver::checkSetof(Node n)
255 : : {
256 [ - + ][ - + ]: 29 : Assert(n.getKind() == Kind::BAG_SETOF);
[ - - ]
257 : 29 : set<Node> elements;
258 : 29 : const set<Node>& downwards = d_state.getElements(n);
259 : 29 : const set<Node>& upwards = d_state.getElements(n[0]);
260 : :
261 : 29 : elements.insert(downwards.begin(), downwards.end());
262 : 29 : elements.insert(upwards.begin(), upwards.end());
263 : :
264 [ + + ]: 54 : for (const Node& e : elements)
265 : : {
266 : 50 : InferInfo i = d_ig.setof(n, d_state.getRepresentative(e));
267 : 25 : d_im.lemmaTheoryInference(&i);
268 : 25 : }
269 : 29 : }
270 : :
271 : 49515 : void BagSolver::checkDisequalBagTerms()
272 : : {
273 [ + + ]: 52261 : for (const auto& [equality, witness] : d_state.getDisequalBagTerms())
274 : : {
275 : 5492 : InferInfo info = d_ig.bagDisequality(equality, witness);
276 : 2746 : d_im.lemmaTheoryInference(&info);
277 : 2746 : }
278 : 49515 : }
279 : :
280 : 431 : void BagSolver::checkMap(Node n)
281 : : {
282 [ - + ][ - + ]: 431 : Assert(n.getKind() == Kind::BAG_MAP);
[ - - ]
283 : 431 : const set<Node>& downwards = d_state.getElements(n);
284 : 431 : const set<Node>& upwards = d_state.getElements(n[1]);
285 [ + + ]: 1039 : for (const Node& x : upwards)
286 : : {
287 : 1216 : InferInfo upInference = d_ig.mapUp1(n, x);
288 : 608 : d_im.lemmaTheoryInference(&upInference);
289 : 608 : }
290 : :
291 [ + + ]: 431 : if (d_state.isInjective(n[0]))
292 : : {
293 [ + + ]: 150 : for (const Node& z : downwards)
294 : : {
295 : 204 : InferInfo upInference = d_ig.mapDownInjective(n, z);
296 : 102 : d_im.lemmaTheoryInference(&upInference);
297 : 102 : }
298 : : }
299 : : else
300 : : {
301 [ + + ]: 855 : for (const Node& z : downwards)
302 : : {
303 : 944 : Node y = d_state.getRepresentative(z);
304 [ + + ]: 472 : if (!d_mapCache.count(n))
305 : : {
306 : : std::shared_ptr<context::CDHashMap<Node, std::pair<Node, Node>>> nMap =
307 : : std::make_shared<context::CDHashMap<Node, std::pair<Node, Node>>>(
308 : 19 : userContext());
309 : 19 : d_mapCache[n] = nMap;
310 : 19 : }
311 [ + + ]: 472 : if (!d_mapCache[n].get()->count(y))
312 : : {
313 : 104 : auto [downInference, uf, preImageSize] = d_ig.mapDown(n, y);
314 : 52 : d_im.lemmaTheoryInference(&downInference);
315 : 52 : std::pair<Node, Node> yPair = std::make_pair(uf, preImageSize);
316 : 52 : d_mapCache[n].get()->insert(y, yPair);
317 : 52 : }
318 : :
319 : : context::CDHashMap<Node, std::pair<Node, Node>>::iterator it =
320 : 472 : d_mapCache[n].get()->find(y);
321 : :
322 : 472 : auto [uf, preImageSize] = it->second;
323 : :
324 [ + + ]: 1474 : for (const Node& x : upwards)
325 : : {
326 : 2004 : InferInfo upInference = d_ig.mapUp2(n, uf, preImageSize, y, x);
327 : 1002 : d_im.lemmaTheoryInference(&upInference);
328 : 1002 : }
329 : 472 : }
330 : : }
331 : 431 : }
332 : :
333 : 455 : void BagSolver::checkFilter(Node n)
334 : : {
335 [ - + ][ - + ]: 455 : Assert(n.getKind() == Kind::BAG_FILTER);
[ - - ]
336 : :
337 : 455 : set<Node> elements;
338 : 455 : const set<Node>& downwards = d_state.getElements(n);
339 : 455 : const set<Node>& upwards = d_state.getElements(n[1]);
340 : 455 : elements.insert(downwards.begin(), downwards.end());
341 : 455 : elements.insert(upwards.begin(), upwards.end());
342 : :
343 [ + + ]: 1575 : for (const Node& e : elements)
344 : : {
345 : 2240 : InferInfo i = d_ig.filterDown(n, d_state.getRepresentative(e));
346 : 1120 : d_im.lemmaTheoryInference(&i);
347 : 1120 : }
348 [ + + ]: 1575 : for (const Node& e : elements)
349 : : {
350 : 2240 : InferInfo i = d_ig.filterUp(n, d_state.getRepresentative(e));
351 : 1120 : d_im.lemmaTheoryInference(&i);
352 : 1120 : }
353 : 455 : }
354 : :
355 : 112 : void BagSolver::checkProduct(Node n)
356 : : {
357 [ - + ][ - + ]: 112 : Assert(n.getKind() == Kind::TABLE_PRODUCT);
[ - - ]
358 : 112 : const set<Node>& elementsA = d_state.getElements(n[0]);
359 : 112 : const set<Node>& elementsB = d_state.getElements(n[1]);
360 : :
361 [ + + ]: 182 : for (const Node& e1 : elementsA)
362 : : {
363 [ + + ]: 188 : for (const Node& e2 : elementsB)
364 : : {
365 : : InferInfo i = d_ig.productUp(
366 : 236 : n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
367 : 118 : d_im.lemmaTheoryInference(&i);
368 : 118 : }
369 : : }
370 : :
371 : 112 : std::set<Node> elements = d_state.getElements(n);
372 [ + + ]: 226 : for (const Node& e : elements)
373 : : {
374 : 228 : InferInfo i = d_ig.productDown(n, d_state.getRepresentative(e));
375 : 114 : d_im.lemmaTheoryInference(&i);
376 : 114 : }
377 : 112 : }
378 : :
379 : 30 : void BagSolver::checkJoin(Node n)
380 : : {
381 [ - + ][ - + ]: 30 : Assert(n.getKind() == Kind::TABLE_JOIN);
[ - - ]
382 : 30 : const set<Node>& elementsA = d_state.getElements(n[0]);
383 : 30 : const set<Node>& elementsB = d_state.getElements(n[1]);
384 : :
385 [ + + ]: 92 : for (const Node& e1 : elementsA)
386 : : {
387 [ + + ]: 236 : for (const Node& e2 : elementsB)
388 : : {
389 : : InferInfo i = d_ig.joinUp(
390 : 348 : n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
391 : 174 : d_im.lemmaTheoryInference(&i);
392 : 174 : }
393 : : }
394 : :
395 : 30 : std::set<Node> elements = d_state.getElements(n);
396 [ + + ]: 172 : for (const Node& e : elements)
397 : : {
398 : 284 : InferInfo i = d_ig.joinDown(n, d_state.getRepresentative(e));
399 : 142 : d_im.lemmaTheoryInference(&i);
400 : 142 : }
401 : 30 : }
402 : :
403 : 286 : void BagSolver::checkGroup(Node n)
404 : : {
405 [ - + ][ - + ]: 286 : Assert(n.getKind() == Kind::TABLE_GROUP);
[ - - ]
406 : :
407 : 286 : InferInfo notEmpty = d_ig.groupNotEmpty(n);
408 : 286 : d_im.lemmaTheoryInference(¬Empty);
409 : :
410 : 286 : Node part = d_ig.defineSkolemPartFunction(n);
411 : :
412 : 286 : const set<Node>& elementsA = d_state.getElements(n[0]);
413 : : std::shared_ptr<context::CDHashSet<Node>> skolems =
414 : 286 : d_state.getPartElementSkolems(n);
415 [ + + ]: 1778 : for (const Node& a : elementsA)
416 : : {
417 [ - + ]: 1492 : if (skolems->contains(a))
418 : : {
419 : : // skip skolem elements that were introduced by groupPartCount below.
420 : 0 : continue;
421 : : }
422 : 2984 : Node aRep = d_state.getRepresentative(a);
423 : 2984 : InferInfo i = d_ig.groupUp1(n, aRep, part);
424 : 1492 : d_im.lemmaTheoryInference(&i);
425 : 1492 : i = d_ig.groupUp2(n, aRep, part);
426 : 1492 : d_im.lemmaTheoryInference(&i);
427 : 1492 : }
428 : :
429 : 286 : std::set<Node> parts = d_state.getElements(n);
430 [ + + ]: 1040 : for (std::set<Node>::iterator partIt1 = parts.begin(); partIt1 != parts.end();
431 : 754 : ++partIt1)
432 : : {
433 : 1508 : Node part1 = d_state.getRepresentative(*partIt1);
434 : 754 : std::vector<Node> partEqc;
435 : 754 : d_state.getEquivalenceClass(part1, partEqc);
436 : 754 : bool newPart = true;
437 [ + + ]: 8602 : for (Node p : partEqc)
438 : : {
439 [ + + ][ + + ]: 7848 : if (p.getKind() == Kind::APPLY_UF && p.getOperator() == part)
[ + + ][ + + ]
[ - - ]
440 : : {
441 : 2052 : newPart = false;
442 : : }
443 : 7848 : }
444 [ + + ]: 754 : if (newPart)
445 : : {
446 : : // only apply the groupPartCount rule for a part that does not have
447 : : // nodes of the form (part x) introduced by the group up rule above.
448 : 308 : InferInfo partCardinality = d_ig.groupPartCount(n, part1, part);
449 : 154 : d_im.lemmaTheoryInference(&partCardinality);
450 : 154 : }
451 : :
452 : 754 : std::set<Node> partElements = d_state.getElements(part1);
453 : 754 : for (std::set<Node>::iterator i = partElements.begin();
454 [ + + ]: 4912 : i != partElements.end();
455 : 4158 : ++i)
456 : : {
457 : 8316 : Node x = d_state.getRepresentative(*i);
458 [ + + ]: 4158 : if (!skolems->contains(x))
459 : : {
460 : : // only apply down rules for elements not generated by groupPartCount
461 : : // rule above
462 : 7872 : InferInfo down = d_ig.groupDown(n, part1, x, part);
463 : 3936 : d_im.lemmaTheoryInference(&down);
464 : 3936 : }
465 : :
466 : 4158 : std::set<Node>::iterator j = i;
467 : 4158 : ++j;
468 [ + + ]: 17898 : while (j != partElements.end())
469 : : {
470 : 27480 : Node y = d_state.getRepresentative(*j);
471 : : // x, y should have the same projection
472 : : InferInfo sameProjection =
473 : 27480 : d_ig.groupSameProjection(n, part1, x, y, part);
474 : 13740 : d_im.lemmaTheoryInference(&sameProjection);
475 : 13740 : ++j;
476 : 13740 : }
477 : :
478 [ + + ]: 36720 : for (const Node& a : elementsA)
479 : : {
480 : 65124 : Node y = d_state.getRepresentative(a);
481 [ + + ]: 32562 : if (x != y)
482 : : {
483 : : // x, y should have the same projection
484 : 57248 : InferInfo samePart = d_ig.groupSamePart(n, part1, x, y, part);
485 : 28624 : d_im.lemmaTheoryInference(&samePart);
486 : 28624 : }
487 : 32562 : }
488 : 4158 : }
489 : 754 : }
490 : 286 : }
491 : :
492 : : } // namespace bags
493 : : } // namespace theory
494 : : } // namespace cvc5::internal
|