Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Mudathir Mohamed, Aina Niemetz, Andrew Reynolds
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Utility functions for bags.
14 : : */
15 : : #include "bags_utils.h"
16 : :
17 : : #include "expr/dtype.h"
18 : : #include "expr/dtype_cons.h"
19 : : #include "expr/emptybag.h"
20 : : #include "smt/logic_exception.h"
21 : : #include "theory/bags/bag_reduction.h"
22 : : #include "theory/datatypes/project_op.h"
23 : : #include "theory/datatypes/tuple_utils.h"
24 : : #include "theory/rewriter.h"
25 : : #include "theory/sets/normal_form.h"
26 : : #include "theory/type_enumerator.h"
27 : : #include "theory/uf/equality_engine.h"
28 : : #include "util/rational.h"
29 : :
30 : : using namespace cvc5::internal::kind;
31 : : using namespace cvc5::internal::theory::datatypes;
32 : :
33 : : namespace cvc5::internal {
34 : : namespace theory {
35 : : namespace bags {
36 : :
37 : 59 : Node BagsUtils::computeDisjointUnion(TypeNode bagType,
38 : : const std::vector<Node>& bags)
39 : : {
40 : 59 : NodeManager* nm = NodeManager::currentNM();
41 [ + + ]: 59 : if (bags.empty())
42 : : {
43 : 6 : return nm->mkConst(EmptyBag(bagType));
44 : : }
45 [ + + ]: 56 : if (bags.size() == 1)
46 : : {
47 : 24 : return bags[0];
48 : : }
49 : 64 : Node unionDisjoint = bags[0];
50 [ + + ]: 100 : for (size_t i = 1; i < bags.size(); i++)
51 : : {
52 [ - + ]: 68 : if (bags[i].getKind() == Kind::BAG_EMPTY)
53 : : {
54 : 0 : continue;
55 : : }
56 : : unionDisjoint =
57 : 68 : nm->mkNode(Kind::BAG_UNION_DISJOINT, unionDisjoint, bags[i]);
58 : : }
59 : 32 : return unionDisjoint;
60 : : }
61 : :
62 : 643 : bool BagsUtils::isConstant(TNode n)
63 : : {
64 [ - + ]: 643 : if (n.getKind() == Kind::BAG_EMPTY)
65 : : {
66 : : // empty bags are already normalized
67 : 0 : return true;
68 : : }
69 [ - + ]: 643 : if (n.getKind() == Kind::BAG_MAKE)
70 : : {
71 : : // see the implementation in MkBagTypeRule::computeIsConst
72 : 0 : return n.isConst();
73 : : }
74 [ + - ]: 643 : if (n.getKind() == Kind::BAG_UNION_DISJOINT)
75 : : {
76 : 643 : if (!(n[0].getKind() == Kind::BAG_MAKE && n[0].isConst()))
77 : : {
78 : : // the first child is not a constant
79 : 291 : return false;
80 : : }
81 : : // store the previous element to check the ordering of elements
82 : 1056 : Node previousElement = n[0][0];
83 : 1056 : Node current = n[1];
84 [ + + ]: 717 : while (current.getKind() == Kind::BAG_UNION_DISJOINT)
85 : : {
86 : 368 : if (!(current[0].getKind() == Kind::BAG_MAKE && current[0].isConst()))
87 : : {
88 : : // the current element is not a constant
89 : 0 : return false;
90 : : }
91 [ + + ]: 368 : if (previousElement >= current[0][0])
92 : : {
93 : : // the ordering is violated
94 : 3 : return false;
95 : : }
96 : 365 : previousElement = current[0][0];
97 : 365 : current = current[1];
98 : : }
99 : : // check last element
100 [ + + ][ + + ]: 349 : if (!(current.getKind() == Kind::BAG_MAKE && current.isConst()))
[ + + ]
101 : : {
102 : : // the last element is not a constant
103 : 25 : return false;
104 : : }
105 [ + + ]: 324 : if (previousElement >= current[0])
106 : : {
107 : : // the ordering is violated
108 : 20 : return false;
109 : : }
110 : 304 : return true;
111 : : }
112 : :
113 : : // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
114 : : // constants
115 : 0 : return false;
116 : : }
117 : :
118 : 14207 : bool BagsUtils::areChildrenConstants(TNode n)
119 : : {
120 : 31415 : return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
121 : : }
122 : :
123 : 703 : Node BagsUtils::evaluate(Rewriter* rewriter, TNode n)
124 : : {
125 [ - + ][ - + ]: 703 : Assert(areChildrenConstants(n));
[ - - ]
126 [ - + ]: 703 : if (n.isConst())
127 : : {
128 : : // a constant node is already in a normal form
129 : 0 : return n;
130 : : }
131 [ + + ][ + + ]: 703 : switch (n.getKind())
[ + + ][ + + ]
[ + + ][ + - ]
[ + + ][ + + ]
[ - ]
132 : : {
133 : 51 : case Kind::BAG_MAKE: return evaluateMakeBag(n);
134 : 317 : case Kind::BAG_COUNT: return evaluateBagCount(n);
135 : 6 : case Kind::BAG_SETOF: return evaluateSetof(n);
136 : 153 : case Kind::BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
137 : 7 : case Kind::BAG_UNION_MAX: return evaluateUnionMax(n);
138 : 72 : case Kind::BAG_INTER_MIN: return evaluateIntersectionMin(n);
139 : 6 : case Kind::BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
140 : 10 : case Kind::BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
141 : 29 : case Kind::BAG_CARD: return evaluateCard(n);
142 : 11 : case Kind::BAG_MAP: return evaluateBagMap(n);
143 : 3 : case Kind::BAG_FILTER: return evaluateBagFilter(n);
144 : 0 : case Kind::BAG_FOLD: return evaluateBagFold(n);
145 : 6 : case Kind::TABLE_PRODUCT: return evaluateProduct(n);
146 : 4 : case Kind::TABLE_JOIN: return evaluateJoin(rewriter, n);
147 : 16 : case Kind::TABLE_GROUP: return evaluateGroup(n);
148 : 12 : case Kind::TABLE_PROJECT: return evaluateTableProject(n);
149 : 0 : default: break;
150 : : }
151 : 0 : Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
152 : 0 : << std::endl;
153 : : }
154 : :
155 : : template <typename T1, typename T2, typename T3, typename T4, typename T5>
156 : 248 : Node BagsUtils::evaluateBinaryOperation(const TNode& n,
157 : : T1&& equal,
158 : : T2&& less,
159 : : T3&& greaterOrEqual,
160 : : T4&& remainderOfA,
161 : : T5&& remainderOfB)
162 : : {
163 : 496 : std::map<Node, Rational> elementsA = getBagElements(n[0]);
164 : 496 : std::map<Node, Rational> elementsB = getBagElements(n[1]);
165 : 496 : std::map<Node, Rational> elements;
166 : :
167 : 248 : std::map<Node, Rational>::const_iterator itA = elementsA.begin();
168 : 248 : std::map<Node, Rational>::const_iterator itB = elementsB.begin();
169 : :
170 [ + - ]: 496 : Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
171 : 248 : << n.getKind() << "] " << std::endl
172 : 0 : << "elements A: " << elementsA << std::endl
173 : 0 : << "elements B: " << elementsB << std::endl;
174 : :
175 [ + + ][ + + ]: 566 : while (itA != elementsA.end() && itB != elementsB.end())
[ + + ]
176 : : {
177 [ + + ]: 318 : if (itA->first == itB->first)
178 : : {
179 : 73 : equal(elements, itA, itB);
180 : 73 : itA++;
181 : 73 : itB++;
182 : : }
183 [ + + ]: 245 : else if (itA->first < itB->first)
184 : : {
185 : 222 : less(elements, itA, itB);
186 : 222 : itA++;
187 : : }
188 : : else
189 : : {
190 : 23 : greaterOrEqual(elements, itA, itB);
191 : 23 : itB++;
192 : : }
193 : : }
194 : :
195 : : // handle the remaining elements from A
196 : 248 : remainderOfA(elements, elementsA, itA);
197 : : // handle the remaining elements from B
198 : 248 : remainderOfB(elements, elementsB, itB);
199 : :
200 [ + - ]: 248 : Trace("bags-evaluate") << "elements: " << elements << std::endl;
201 : 248 : Node bag = constructConstantBagFromElements(n.getType(), elements);
202 [ + - ]: 248 : Trace("bags-evaluate") << "bag: " << bag << std::endl;
203 : 496 : return bag;
204 : : }
205 : :
206 : 992 : std::map<Node, Rational> BagsUtils::getBagElements(TNode n)
207 : : {
208 : 992 : std::map<Node, Rational> elements;
209 [ + + ]: 992 : if (n.getKind() == Kind::BAG_EMPTY)
210 : : {
211 : 290 : return elements;
212 : : }
213 [ + + ]: 1115 : while (n.getKind() == Kind::BAG_UNION_DISJOINT)
214 : : {
215 [ - + ][ - + ]: 413 : Assert(n[0].getKind() == Kind::BAG_MAKE);
[ - - ]
216 : 1239 : Node element = n[0][0];
217 : 826 : Rational count = n[0][1].getConst<Rational>();
218 : 413 : elements[element] = count;
219 : 413 : n = n[1];
220 : : }
221 [ - + ][ - + ]: 702 : Assert(n.getKind() == Kind::BAG_MAKE);
[ - - ]
222 : 1404 : Node lastElement = n[0];
223 : 1404 : Rational lastCount = n[1].getConst<Rational>();
224 : 702 : elements[lastElement] = lastCount;
225 : 702 : return elements;
226 : : }
227 : :
228 : 372 : Node BagsUtils::constructConstantBagFromElements(
229 : : TypeNode t, const std::map<Node, Rational>& elements)
230 : : {
231 [ - + ][ - + ]: 372 : Assert(t.isBag());
[ - - ]
232 : 372 : NodeManager* nm = NodeManager::currentNM();
233 [ + + ]: 372 : if (elements.empty())
234 : : {
235 : 222 : return nm->mkConst(EmptyBag(t));
236 : : }
237 : 522 : TypeNode elementType = t.getBagElementType();
238 : 261 : std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
239 : 783 : Node bag = nm->mkNode(Kind::BAG_MAKE, it->first, nm->mkConstInt(it->second));
240 [ + + ]: 565 : while (++it != elements.rend())
241 : : {
242 : 608 : Node n = nm->mkNode(Kind::BAG_MAKE, it->first, nm->mkConstInt(it->second));
243 : 304 : bag = nm->mkNode(Kind::BAG_UNION_DISJOINT, n, bag);
244 : : }
245 : 261 : return bag;
246 : : }
247 : :
248 : 384 : Node BagsUtils::constructBagFromElements(TypeNode t,
249 : : const std::map<Node, Node>& elements)
250 : : {
251 [ - + ][ - + ]: 384 : Assert(t.isBag());
[ - - ]
252 : 384 : NodeManager* nm = NodeManager::currentNM();
253 [ + + ]: 384 : if (elements.empty())
254 : : {
255 : 158 : return nm->mkConst(EmptyBag(t));
256 : : }
257 : 610 : TypeNode elementType = t.getBagElementType();
258 : 305 : std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
259 : 915 : Node bag = nm->mkNode(Kind::BAG_MAKE, it->first, it->second);
260 [ + + ]: 450 : while (++it != elements.rend())
261 : : {
262 : 290 : Node n = nm->mkNode(Kind::BAG_MAKE, it->first, it->second);
263 : 145 : bag = nm->mkNode(Kind::BAG_UNION_DISJOINT, n, bag);
264 : : }
265 : 305 : return bag;
266 : : }
267 : :
268 : 51 : Node BagsUtils::evaluateMakeBag(TNode n)
269 : : {
270 : : // the case where n is const should be handled earlier.
271 : : // here we handle the case where the multiplicity is zero or negative
272 : 102 : Assert(n.getKind() == Kind::BAG_MAKE && !n.isConst()
273 : : && n[1].getConst<Rational>().sgn() < 1);
274 : 102 : Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
275 : 51 : return emptybag;
276 : : }
277 : :
278 : 317 : Node BagsUtils::evaluateBagCount(TNode n)
279 : : {
280 [ - + ][ - + ]: 317 : Assert(n.getKind() == Kind::BAG_COUNT);
[ - - ]
281 : : // Examples
282 : : // --------
283 : : // - (bag.count "x" (as bag.empty (Bag String))) = 0
284 : : // - (bag.count "x" (bag "y" 5)) = 0
285 : : // - (bag.count "x" (bag "x" 4)) = 4
286 : : // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
287 : : // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
288 : :
289 : 634 : std::map<Node, Rational> elements = getBagElements(n[1]);
290 : 317 : std::map<Node, Rational>::iterator it = elements.find(n[0]);
291 : :
292 : 317 : NodeManager* nm = NodeManager::currentNM();
293 [ + + ]: 317 : if (it != elements.end())
294 : : {
295 : 330 : Node count = nm->mkConstInt(it->second);
296 : 165 : return count;
297 : : }
298 : 304 : return nm->mkConstInt(Rational(0));
299 : : }
300 : :
301 : 6 : Node BagsUtils::evaluateSetof(TNode n)
302 : : {
303 [ - + ][ - + ]: 6 : Assert(n.getKind() == Kind::BAG_SETOF);
[ - - ]
304 : :
305 : : // Examples
306 : : // --------
307 : : // - (bag.setof (as bag.empty (Bag String))) = (as bag.empty (Bag
308 : : // String))
309 : : // - (bag.setof (bag "x" 4)) = (bag "x" 1)
310 : : // - (bag.setof (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
311 : : // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
312 : :
313 : 12 : std::map<Node, Rational> oldElements = getBagElements(n[0]);
314 : : // copy elements from the old bag
315 : 12 : std::map<Node, Rational> newElements(oldElements);
316 : 12 : Rational one = Rational(1);
317 : 6 : std::map<Node, Rational>::iterator it;
318 [ + + ]: 10 : for (it = newElements.begin(); it != newElements.end(); it++)
319 : : {
320 : 4 : it->second = one;
321 : : }
322 : 12 : Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
323 : 12 : return bag;
324 : : }
325 : :
326 : 153 : Node BagsUtils::evaluateUnionDisjoint(TNode n)
327 : : {
328 [ - + ][ - + ]: 153 : Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
[ - - ]
329 : : // Example
330 : : // -------
331 : : // input: (bag.union_disjoint A B)
332 : : // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
333 : : // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
334 : : // output:
335 : : // (bag.union_disjoint A B)
336 : : // where A = (bag "x" 7)
337 : : // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
338 : :
339 : 14 : auto equal = [](std::map<Node, Rational>& elements,
340 : : std::map<Node, Rational>::const_iterator& itA,
341 : : std::map<Node, Rational>::const_iterator& itB) {
342 : : // compute the sum of the multiplicities
343 : 14 : elements[itA->first] = itA->second + itB->second;
344 : 14 : };
345 : :
346 : 220 : auto less = [](std::map<Node, Rational>& elements,
347 : : std::map<Node, Rational>::const_iterator& itA,
348 : : std::map<Node, Rational>::const_iterator& itB) {
349 : : // add the element to the result
350 : 220 : elements[itA->first] = itA->second;
351 : 220 : };
352 : :
353 : 19 : auto greaterOrEqual = [](std::map<Node, Rational>& elements,
354 : : std::map<Node, Rational>::const_iterator& itA,
355 : : std::map<Node, Rational>::const_iterator& itB) {
356 : : // add the element to the result
357 : 19 : elements[itB->first] = itB->second;
358 : 19 : };
359 : :
360 : 185 : auto remainderOfA = [](std::map<Node, Rational>& elements,
361 : : std::map<Node, Rational>& elementsA,
362 : : std::map<Node, Rational>::const_iterator& itA) {
363 : : // append the remainder of A
364 [ + + ]: 185 : while (itA != elementsA.end())
365 : : {
366 : 32 : elements[itA->first] = itA->second;
367 : 32 : itA++;
368 : : }
369 : 153 : };
370 : :
371 : 226 : auto remainderOfB = [](std::map<Node, Rational>& elements,
372 : : std::map<Node, Rational>& elementsB,
373 : : std::map<Node, Rational>::const_iterator& itB) {
374 : : // append the remainder of B
375 [ + + ]: 226 : while (itB != elementsB.end())
376 : : {
377 : 73 : elements[itB->first] = itB->second;
378 : 73 : itB++;
379 : : }
380 : 153 : };
381 : :
382 : : return evaluateBinaryOperation(
383 : 306 : n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
384 : : }
385 : :
386 : 7 : Node BagsUtils::evaluateUnionMax(TNode n)
387 : : {
388 [ - + ][ - + ]: 7 : Assert(n.getKind() == Kind::BAG_UNION_MAX);
[ - - ]
389 : : // Example
390 : : // -------
391 : : // input: (bag.union_max A B)
392 : : // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
393 : : // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
394 : : // output:
395 : : // (bag.union_disjoint A B)
396 : : // where A = (bag "x" 4)
397 : : // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
398 : :
399 : 3 : auto equal = [](std::map<Node, Rational>& elements,
400 : : std::map<Node, Rational>::const_iterator& itA,
401 : : std::map<Node, Rational>::const_iterator& itB) {
402 : : // compute the maximum multiplicity
403 : 3 : elements[itA->first] = std::max(itA->second, itB->second);
404 : 3 : };
405 : :
406 : 2 : auto less = [](std::map<Node, Rational>& elements,
407 : : std::map<Node, Rational>::const_iterator& itA,
408 : : std::map<Node, Rational>::const_iterator& itB) {
409 : : // add to the result
410 : 2 : elements[itA->first] = itA->second;
411 : 2 : };
412 : :
413 : 1 : auto greaterOrEqual = [](std::map<Node, Rational>& elements,
414 : : std::map<Node, Rational>::const_iterator& itA,
415 : : std::map<Node, Rational>::const_iterator& itB) {
416 : : // add to the result
417 : 1 : elements[itB->first] = itB->second;
418 : 1 : };
419 : :
420 : 9 : auto remainderOfA = [](std::map<Node, Rational>& elements,
421 : : std::map<Node, Rational>& elementsA,
422 : : std::map<Node, Rational>::const_iterator& itA) {
423 : : // append the remainder of A
424 [ + + ]: 9 : while (itA != elementsA.end())
425 : : {
426 : 2 : elements[itA->first] = itA->second;
427 : 2 : itA++;
428 : : }
429 : 7 : };
430 : :
431 : 9 : auto remainderOfB = [](std::map<Node, Rational>& elements,
432 : : std::map<Node, Rational>& elementsB,
433 : : std::map<Node, Rational>::const_iterator& itB) {
434 : : // append the remainder of B
435 [ + + ]: 9 : while (itB != elementsB.end())
436 : : {
437 : 2 : elements[itB->first] = itB->second;
438 : 2 : itB++;
439 : : }
440 : 7 : };
441 : :
442 : : return evaluateBinaryOperation(
443 : 14 : n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
444 : : }
445 : :
446 : 72 : Node BagsUtils::evaluateIntersectionMin(TNode n)
447 : : {
448 [ - + ][ - + ]: 72 : Assert(n.getKind() == Kind::BAG_INTER_MIN);
[ - - ]
449 : : // Example
450 : : // -------
451 : : // input: (bag.inter_min A B)
452 : : // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
453 : : // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
454 : : // output:
455 : : // (bag "x" 3)
456 : :
457 : 47 : auto equal = [](std::map<Node, Rational>& elements,
458 : : std::map<Node, Rational>::const_iterator& itA,
459 : : std::map<Node, Rational>::const_iterator& itB) {
460 : : // compute the minimum multiplicity
461 : 47 : elements[itA->first] = std::min(itA->second, itB->second);
462 : 47 : };
463 : :
464 : 0 : auto less = [](std::map<Node, Rational>& elements,
465 : : std::map<Node, Rational>::const_iterator& itA,
466 : : std::map<Node, Rational>::const_iterator& itB) {
467 : : // do nothing
468 : 0 : };
469 : :
470 : 1 : auto greaterOrEqual = [](std::map<Node, Rational>& elements,
471 : : std::map<Node, Rational>::const_iterator& itA,
472 : : std::map<Node, Rational>::const_iterator& itB) {
473 : : // do nothing
474 : 1 : };
475 : :
476 : 72 : auto remainderOfA = [](std::map<Node, Rational>& elements,
477 : : std::map<Node, Rational>& elementsA,
478 : : std::map<Node, Rational>::const_iterator& itA) {
479 : : // do nothing
480 : 72 : };
481 : :
482 : 72 : auto remainderOfB = [](std::map<Node, Rational>& elements,
483 : : std::map<Node, Rational>& elementsB,
484 : : std::map<Node, Rational>::const_iterator& itB) {
485 : : // do nothing
486 : 72 : };
487 : :
488 : : return evaluateBinaryOperation(
489 : 144 : n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
490 : : }
491 : :
492 : 6 : Node BagsUtils::evaluateDifferenceSubtract(TNode n)
493 : : {
494 [ - + ][ - + ]: 6 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
[ - - ]
495 : : // Example
496 : : // -------
497 : : // input: (bag.difference_subtract A B)
498 : : // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
499 : : // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
500 : : // output:
501 : : // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
502 : :
503 : 4 : auto equal = [](std::map<Node, Rational>& elements,
504 : : std::map<Node, Rational>::const_iterator& itA,
505 : : std::map<Node, Rational>::const_iterator& itB) {
506 : : // subtract the multiplicities
507 : 4 : elements[itA->first] = itA->second - itB->second;
508 : 4 : };
509 : :
510 : 0 : auto less = [](std::map<Node, Rational>& elements,
511 : : std::map<Node, Rational>::const_iterator& itA,
512 : : std::map<Node, Rational>::const_iterator& itB) {
513 : : // itA->first is not in B, so we add it to the difference subtract
514 : 0 : elements[itA->first] = itA->second;
515 : 0 : };
516 : :
517 : 1 : auto greaterOrEqual = [](std::map<Node, Rational>& elements,
518 : : std::map<Node, Rational>::const_iterator& itA,
519 : : std::map<Node, Rational>::const_iterator& itB) {
520 : : // itB->first is not in A, so we just skip it
521 : 1 : };
522 : :
523 : 7 : auto remainderOfA = [](std::map<Node, Rational>& elements,
524 : : std::map<Node, Rational>& elementsA,
525 : : std::map<Node, Rational>::const_iterator& itA) {
526 : : // append the remainder of A
527 [ + + ]: 7 : while (itA != elementsA.end())
528 : : {
529 : 1 : elements[itA->first] = itA->second;
530 : 1 : itA++;
531 : : }
532 : 6 : };
533 : :
534 : 6 : auto remainderOfB = [](std::map<Node, Rational>& elements,
535 : : std::map<Node, Rational>& elementsB,
536 : : std::map<Node, Rational>::const_iterator& itB) {
537 : : // do nothing
538 : 6 : };
539 : :
540 : : return evaluateBinaryOperation(
541 : 12 : n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
542 : : }
543 : :
544 : 10 : Node BagsUtils::evaluateDifferenceRemove(TNode n)
545 : : {
546 [ - + ][ - + ]: 10 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
[ - - ]
547 : : // Example
548 : : // -------
549 : : // input: (bag.difference_remove A B)
550 : : // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
551 : : // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
552 : : // output:
553 : : // (bag "z" 2)
554 : :
555 : 5 : auto equal = [](std::map<Node, Rational>& elements,
556 : : std::map<Node, Rational>::const_iterator& itA,
557 : : std::map<Node, Rational>::const_iterator& itB) {
558 : : // skip the shared element by doing nothing
559 : 5 : };
560 : :
561 : 0 : auto less = [](std::map<Node, Rational>& elements,
562 : : std::map<Node, Rational>::const_iterator& itA,
563 : : std::map<Node, Rational>::const_iterator& itB) {
564 : : // itA->first is not in B, so we add it to the difference remove
565 : 0 : elements[itA->first] = itA->second;
566 : 0 : };
567 : :
568 : 1 : auto greaterOrEqual = [](std::map<Node, Rational>& elements,
569 : : std::map<Node, Rational>::const_iterator& itA,
570 : : std::map<Node, Rational>::const_iterator& itB) {
571 : : // itB->first is not in A, so we just skip it
572 : 1 : };
573 : :
574 : 11 : auto remainderOfA = [](std::map<Node, Rational>& elements,
575 : : std::map<Node, Rational>& elementsA,
576 : : std::map<Node, Rational>::const_iterator& itA) {
577 : : // append the remainder of A
578 [ + + ]: 11 : while (itA != elementsA.end())
579 : : {
580 : 1 : elements[itA->first] = itA->second;
581 : 1 : itA++;
582 : : }
583 : 10 : };
584 : :
585 : 10 : auto remainderOfB = [](std::map<Node, Rational>& elements,
586 : : std::map<Node, Rational>& elementsB,
587 : : std::map<Node, Rational>::const_iterator& itB) {
588 : : // do nothing
589 : 10 : };
590 : :
591 : : return evaluateBinaryOperation(
592 : 20 : n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
593 : : }
594 : :
595 : 0 : Node BagsUtils::evaluateChoose(TNode n)
596 : : {
597 : 0 : Assert(n.getKind() == Kind::BAG_CHOOSE);
598 : : // Examples
599 : : // --------
600 : : // - (bag.choose (bag "x" 4)) = "x"
601 : :
602 [ - - ]: 0 : if (n[0].getKind() == Kind::BAG_MAKE)
603 : : {
604 : 0 : return n[0][0];
605 : : }
606 : 0 : throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
607 : : }
608 : :
609 : 29 : Node BagsUtils::evaluateCard(TNode n)
610 : : {
611 [ - + ][ - + ]: 29 : Assert(n.getKind() == Kind::BAG_CARD);
[ - - ]
612 : : // Examples
613 : : // --------
614 : : // - (card (as bag.empty (Bag String))) = 0
615 : : // - (bag.choose (bag "x" 4)) = 4
616 : : // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
617 : :
618 : 58 : std::map<Node, Rational> elements = getBagElements(n[0]);
619 : 58 : Rational sum(0);
620 [ + + ]: 46 : for (std::pair<Node, Rational> element : elements)
621 : : {
622 : 17 : sum += element.second;
623 : : }
624 : :
625 : 29 : NodeManager* nm = NodeManager::currentNM();
626 : 29 : Node sumNode = nm->mkConstInt(sum);
627 : 58 : return sumNode;
628 : : }
629 : :
630 : 23 : Node BagsUtils::evaluateBagMap(TNode n)
631 : : {
632 [ - + ][ - + ]: 23 : Assert(n.getKind() == Kind::BAG_MAP);
[ - - ]
633 : :
634 : : // Examples
635 : : // --------
636 : : // - (bag.map ((lambda ((x String)) "z")
637 : : // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
638 : : // (bag.union_disjoint
639 : : // (bag ((lambda ((x String)) "z") "a") 2)
640 : : // (bag ((lambda ((x String)) "z") "b") 3)) =
641 : : // (bag "z" 5)
642 : :
643 : 46 : std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
644 : 46 : std::map<Node, Rational> mappedElements;
645 : 23 : std::map<Node, Rational>::iterator it = elements.begin();
646 : 23 : NodeManager* nm = NodeManager::currentNM();
647 [ + + ]: 48 : while (it != elements.end())
648 : : {
649 : 75 : Node mappedElement = nm->mkNode(Kind::APPLY_UF, n[0], it->first);
650 : 25 : mappedElements[mappedElement] = it->second;
651 : 25 : ++it;
652 : : }
653 : 69 : TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
654 : 23 : Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
655 : 46 : return ret;
656 : : }
657 : :
658 : 6 : Node BagsUtils::evaluateBagFilter(TNode n)
659 : : {
660 [ - + ][ - + ]: 6 : Assert(n.getKind() == Kind::BAG_FILTER);
[ - - ]
661 : :
662 : : // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
663 : : // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) =
664 : : // (bag.union_disjoint
665 : : // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
666 : : // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
667 : :
668 : 12 : Node P = n[0];
669 : 12 : Node A = n[1];
670 : 12 : TypeNode bagType = A.getType();
671 : 6 : NodeManager* nm = NodeManager::currentNM();
672 : 12 : Node empty = nm->mkConst(EmptyBag(bagType));
673 : :
674 : 12 : std::map<Node, Rational> elements = getBagElements(n[1]);
675 : 12 : std::vector<Node> bags;
676 : :
677 [ + + ]: 9 : for (const auto& [e, count] : elements)
678 : : {
679 : 6 : Node multiplicity = nm->mkConstInt(count);
680 : 9 : Node bag = nm->mkNode(Kind::BAG_MAKE, e, multiplicity);
681 : 9 : Node pOfe = nm->mkNode(Kind::APPLY_UF, P, e);
682 : 9 : Node ite = nm->mkNode(Kind::ITE, pOfe, bag, empty);
683 : 3 : bags.push_back(ite);
684 : : }
685 : 6 : Node ret = computeDisjointUnion(bagType, bags);
686 : 12 : return ret;
687 : : }
688 : :
689 : 10 : Node BagsUtils::evaluateBagFold(TNode n)
690 : : {
691 [ - + ][ - + ]: 10 : Assert(n.getKind() == Kind::BAG_FOLD);
[ - - ]
692 : :
693 : : // Examples
694 : : // --------
695 : : // minimum string
696 : : // - (bag.fold
697 : : // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
698 : : // ""
699 : : // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
700 : : // = "a"
701 : :
702 : 20 : Node f = n[0]; // combining function
703 : 10 : Node ret = n[1]; // initial value
704 : 20 : Node A = n[2]; // bag
705 : 20 : std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
706 : :
707 : 10 : std::map<Node, Rational>::iterator it = elements.begin();
708 [ + + ]: 24 : while (it != elements.end())
709 : : {
710 : : // apply the combination function n times, where n is the multiplicity
711 : 28 : Rational count = it->second;
712 [ - + ][ - + ]: 14 : Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
[ - - ]
713 [ + + ]: 37 : while (!count.isZero())
714 : : {
715 : 23 : ret = NodeManager::mkNode(Kind::APPLY_UF, f, it->first, ret);
716 : 23 : count = count - 1;
717 : : }
718 : 14 : ++it;
719 : : }
720 : 20 : return ret;
721 : : }
722 : :
723 : 4 : Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n)
724 : : {
725 [ - + ][ - + ]: 4 : Assert(n.getKind() == Kind::BAG_PARTITION);
[ - - ]
726 : 4 : NodeManager* nm = NodeManager::currentNM();
727 : :
728 : : // Examples
729 : : // --------
730 : : // minimum string
731 : : // - (bag.partition
732 : : // ((lambda ((x Int) (y Int)) (= 0 (+ x y)))
733 : : // (bag.union_disjoint
734 : : // (bag 1 20) (bag (- 1) 50)
735 : : // (bag 2 30) (bag (- 2) 60)
736 : : // (bag 3 40) (bag (- 3) 70)
737 : : // (bag 4 100)))
738 : : // = (bag.union_disjoint
739 : : // (bag (bag 4 100) 1)
740 : : // (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1)
741 : : // (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1)
742 : : // (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1)))
743 : :
744 : 8 : Node r = n[0]; // equivalence relation
745 : 8 : Node A = n[1]; // bag
746 : 8 : TypeNode bagType = A.getType();
747 : 8 : TypeNode partitionType = n.getType();
748 : 8 : std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
749 [ + - ]: 4 : Trace("bags-partition") << "elements: " << elements << std::endl;
750 : : // a simple map from elements to equivalent classes with this invariant:
751 : : // each key element must appear exactly once in one of the values.
752 : 8 : std::map<Node, std::set<Node>> sets;
753 : 8 : std::set<Node> emptyClass;
754 [ + + ]: 18 : for (const auto& pair : elements)
755 : : {
756 : : // initially each singleton element is an equivalence class
757 : 28 : sets[pair.first] = {pair.first};
758 : : }
759 : 18 : for (std::map<Node, Rational>::iterator i = elements.begin();
760 [ + + ]: 18 : i != elements.end();
761 : 14 : ++i)
762 : : {
763 [ + + ]: 14 : if (sets[i->first].empty())
764 : : {
765 : : // skip this element since its equivalent class has already been processed
766 : 6 : continue;
767 : : }
768 : 8 : std::map<Node, Rational>::iterator j = i;
769 : 8 : ++j;
770 [ + + ]: 38 : while (j != elements.end())
771 : : {
772 : : Node sameClass =
773 : 60 : NodeManager::mkNode(Kind::APPLY_UF, r, i->first, j->first);
774 : 30 : sameClass = rewriter->rewrite(sameClass);
775 [ - + ]: 30 : if (!sameClass.isConst())
776 : : {
777 : : // we can not pursue further, so we return n itself
778 : 0 : return n;
779 : : }
780 [ + + ]: 30 : if (sameClass.getConst<bool>())
781 : : {
782 : : // add element j to the equivalent class
783 : 6 : sets[i->first].insert(j->first);
784 : : // mark the equivalent class of j as processed
785 : 6 : sets[j->first] = emptyClass;
786 : : }
787 : 30 : ++j;
788 : : }
789 : : }
790 : :
791 : : // construct the partition parts
792 : 8 : std::map<Node, Rational> parts;
793 [ + + ]: 18 : for (std::pair<Node, std::set<Node>> pair : sets)
794 : : {
795 : 14 : const std::set<Node>& eqc = pair.second;
796 [ + + ]: 14 : if (eqc.empty())
797 : : {
798 : 6 : continue;
799 : : }
800 : 16 : std::vector<Node> bags;
801 [ + + ]: 22 : for (const Node& node : eqc)
802 : : {
803 : : Node bag =
804 : 42 : nm->mkNode(Kind::BAG_MAKE, node, nm->mkConstInt(elements[node]));
805 : 14 : bags.push_back(bag);
806 : : }
807 : 8 : Node part = computeDisjointUnion(bagType, bags);
808 : : // each part in the partitions has multiplicity one
809 : 8 : parts[part] = Rational(1);
810 : : }
811 : 8 : Node ret = constructConstantBagFromElements(partitionType, parts);
812 [ + - ]: 4 : Trace("bags-partition") << "ret: " << ret << std::endl;
813 : 4 : return ret;
814 : : }
815 : :
816 : 2 : Node BagsUtils::evaluateTableAggregate(Rewriter* rewriter, TNode n)
817 : : {
818 [ - + ][ - + ]: 2 : Assert(n.getKind() == Kind::TABLE_AGGREGATE);
[ - - ]
819 : 2 : if (!(n[1].isConst() && n[2].isConst()))
820 : : {
821 : : // we can't proceed further.
822 : 0 : return n;
823 : : }
824 : :
825 : 4 : Node reduction = BagReduction::reduceAggregateOperator(n);
826 : 2 : return reduction;
827 : : }
828 : :
829 : 300 : Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
830 : : {
831 [ + + ][ - + ]: 300 : Assert(n.getKind() == Kind::TABLE_PRODUCT || n.getKind() == Kind::TABLE_JOIN);
[ - + ][ - - ]
832 : 600 : Node A = n[0];
833 : 600 : Node B = n[1];
834 : 600 : TypeNode typeA = A.getType().getBagElementType();
835 : 600 : TypeNode typeB = B.getType().getBagElementType();
836 [ - + ][ - + ]: 300 : Assert(e1.getType() == typeA);
[ - - ]
837 [ - + ][ - + ]: 300 : Assert(e2.getType() == typeB);
[ - - ]
838 : :
839 : 600 : TypeNode productTupleType = n.getType().getBagElementType();
840 : 600 : Node tuple = TupleUtils::concatTuples(productTupleType, e1, e2);
841 : 600 : return tuple;
842 : : }
843 : :
844 : 6 : Node BagsUtils::evaluateProduct(TNode n)
845 : : {
846 [ - + ][ - + ]: 6 : Assert(n.getKind() == Kind::TABLE_PRODUCT);
[ - - ]
847 : :
848 : : // Examples
849 : : // --------
850 : : //
851 : : // - (table.product (bag (tuple "a") 4) (bag (tuple true) 5)) =
852 : : // (bag (tuple "a" true) 20
853 : :
854 : 12 : Node A = n[0];
855 : 12 : Node B = n[1];
856 : :
857 : 12 : std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
858 : 12 : std::map<Node, Rational> elementsB = BagsUtils::getBagElements(B);
859 : :
860 : 12 : std::map<Node, Rational> elements;
861 : :
862 [ + + ]: 10 : for (const auto& [a, countA] : elementsA)
863 : : {
864 [ + + ]: 10 : for (const auto& [b, countB] : elementsB)
865 : : {
866 : 12 : Node element = constructProductTuple(n, a, b);
867 : 6 : elements[element] = countA * countB;
868 : : }
869 : : }
870 : :
871 : 6 : Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
872 : 12 : return ret;
873 : : }
874 : :
875 : 4 : Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n)
876 : : {
877 [ - + ][ - + ]: 4 : Assert(n.getKind() == Kind::TABLE_JOIN);
[ - - ]
878 : :
879 : 8 : Node A = n[0];
880 : 8 : Node B = n[1];
881 : 12 : auto [aIndices, bIndices] = splitTableJoinIndices(n);
882 : :
883 : 8 : std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
884 : 8 : std::map<Node, Rational> elementsB = BagsUtils::getBagElements(B);
885 : :
886 : 8 : std::map<Node, Rational> elements;
887 : :
888 [ + + ]: 13 : for (const auto& [a, countA] : elementsA)
889 : : {
890 : 18 : Node aProjection = TupleUtils::getTupleProjection(aIndices, a);
891 : 9 : aProjection = rewriter->rewrite(aProjection);
892 [ - + ][ - + ]: 9 : Assert(aProjection.isConst());
[ - - ]
893 [ + + ]: 36 : for (const auto& [b, countB] : elementsB)
894 : : {
895 : 54 : Node bProjection = TupleUtils::getTupleProjection(bIndices, b);
896 : 27 : bProjection = rewriter->rewrite(bProjection);
897 [ - + ][ - + ]: 27 : Assert(bProjection.isConst());
[ - - ]
898 [ + + ]: 27 : if (aProjection == bProjection)
899 : : {
900 : 20 : Node element = constructProductTuple(n, a, b);
901 : 10 : elements[element] = countA * countB;
902 : : }
903 : : }
904 : : }
905 : :
906 : 4 : Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
907 : 8 : return ret;
908 : : }
909 : :
910 : 16 : Node BagsUtils::evaluateGroup(TNode n)
911 : : {
912 [ - + ][ - + ]: 16 : Assert(n.getKind() == Kind::TABLE_GROUP);
[ - - ]
913 : :
914 : 16 : NodeManager* nm = NodeManager::currentNM();
915 : :
916 : 32 : Node A = n[0];
917 : 32 : TypeNode bagType = A.getType();
918 : 32 : TypeNode partitionType = n.getType();
919 : :
920 : : std::vector<uint32_t> indices =
921 : 32 : n.getOperator().getConst<ProjectOp>().getIndices();
922 : :
923 : 32 : std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
924 [ + - ]: 16 : Trace("bags-group") << "elements: " << elements << std::endl;
925 : : // a simple map from elements to equivalent classes with this invariant:
926 : : // each key element must appear exactly once in one of the values.
927 : 32 : std::map<Node, std::set<Node>> sets;
928 : 32 : std::set<Node> emptyClass;
929 [ + + ]: 123 : for (const auto& pair : elements)
930 : : {
931 : : // initially each singleton element is an equivalence class
932 : 214 : sets[pair.first] = {pair.first};
933 : : }
934 : 123 : for (std::map<Node, Rational>::iterator i = elements.begin();
935 [ + + ]: 123 : i != elements.end();
936 : 107 : ++i)
937 : : {
938 [ + + ]: 107 : if (sets[i->first].empty())
939 : : {
940 : : // skip this element since its equivalent class has already been processed
941 : 62 : continue;
942 : : }
943 : 45 : std::map<Node, Rational>::iterator j = i;
944 : 45 : ++j;
945 [ + + ]: 226 : while (j != elements.end())
946 : : {
947 [ + + ]: 181 : if (TupleUtils::sameProjection(indices, i->first, j->first))
948 : : {
949 : : // add element j to the equivalent class
950 : 62 : sets[i->first].insert(j->first);
951 : : // mark the equivalent class of j as processed
952 : 62 : sets[j->first] = emptyClass;
953 : : }
954 : 181 : ++j;
955 : : }
956 : : }
957 : :
958 : : // construct the partition parts
959 : 32 : std::map<Node, Rational> parts;
960 [ + + ]: 123 : for (std::pair<Node, std::set<Node>> pair : sets)
961 : : {
962 : 107 : const std::set<Node>& eqc = pair.second;
963 [ + + ]: 107 : if (eqc.empty())
964 : : {
965 : 62 : continue;
966 : : }
967 : 90 : std::vector<Node> bags;
968 [ + + ]: 152 : for (const Node& node : eqc)
969 : : {
970 : : Node bag =
971 : 321 : nm->mkNode(Kind::BAG_MAKE, node, nm->mkConstInt(elements[node]));
972 : 107 : bags.push_back(bag);
973 : : }
974 : 45 : Node part = computeDisjointUnion(bagType, bags);
975 : : // each part in the partitions has multiplicity one
976 : 45 : parts[part] = Rational(1);
977 : : }
978 [ - + ]: 16 : if (parts.empty())
979 : : {
980 : : // add an empty part
981 : 0 : Node emptyPart = nm->mkConst(EmptyBag(bagType));
982 : 0 : parts[emptyPart] = Rational(1);
983 : : }
984 : 16 : Node ret = constructConstantBagFromElements(partitionType, parts);
985 [ + - ]: 16 : Trace("bags-group") << "ret: " << ret << std::endl;
986 : 32 : return ret;
987 : : }
988 : :
989 : 12 : Node BagsUtils::evaluateTableProject(TNode n)
990 : : {
991 [ - + ][ - + ]: 12 : Assert(n.getKind() == Kind::TABLE_PROJECT);
[ - - ]
992 : 24 : Node bagMap = BagReduction::reduceProjectOperator(n);
993 : 12 : Node ret = evaluateBagMap(bagMap);
994 : 24 : return ret;
995 : : }
996 : :
997 : : std::pair<std::vector<uint32_t>, std::vector<uint32_t>>
998 : 26 : BagsUtils::splitTableJoinIndices(Node n)
999 : : {
1000 : 52 : ProjectOp op = n.getOperator().getConst<ProjectOp>();
1001 : 26 : const std::vector<uint32_t>& indices = op.getIndices();
1002 : 26 : size_t joinSize = indices.size() / 2;
1003 : 78 : std::vector<uint32_t> indices1(joinSize), indices2(joinSize);
1004 : :
1005 [ + + ]: 52 : for (size_t i = 0, index = 0; i < joinSize; i += 2, ++index)
1006 : : {
1007 : 26 : indices1[index] = indices[i];
1008 : 26 : indices2[index] = indices[i + 1];
1009 : : }
1010 : 52 : return std::make_pair(indices1, indices2);
1011 : : }
1012 : :
1013 : : } // namespace bags
1014 : : } // namespace theory
1015 : : } // namespace cvc5::internal
|