Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Mudathir Mohamed, Aina Niemetz, Andrew Reynolds
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Bags theory rewriter.
14 : : */
15 : :
16 : : #include "theory/bags/bags_rewriter.h"
17 : :
18 : : #include "expr/emptybag.h"
19 : : #include "theory/bags/bags_utils.h"
20 : : #include "theory/rewriter.h"
21 : : #include "util/rational.h"
22 : : #include "util/statistics_registry.h"
23 : :
24 : : using namespace cvc5::internal::kind;
25 : :
26 : : namespace cvc5::internal {
27 : : namespace theory {
28 : : namespace bags {
29 : :
30 : 63272 : BagsRewriteResponse::BagsRewriteResponse()
31 : 63272 : : d_node(Node::null()), d_rewrite(Rewrite::NONE)
32 : : {
33 : 63272 : }
34 : :
35 : 63272 : BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
36 : 63272 : : d_node(n), d_rewrite(rewrite)
37 : : {
38 : 63272 : }
39 : :
40 : 0 : BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
41 : 0 : : d_node(r.d_node), d_rewrite(r.d_rewrite)
42 : : {
43 : 0 : }
44 : :
45 : 51373 : BagsRewriter::BagsRewriter(NodeManager* nm,
46 : : Rewriter* r,
47 : 51373 : HistogramStat<Rewrite>* statistics)
48 : 51373 : : TheoryRewriter(nm), d_rewriter(r), d_statistics(statistics)
49 : : {
50 : 51373 : d_zero = d_nm->mkConstInt(Rational(0));
51 : 51373 : d_one = d_nm->mkConstInt(Rational(1));
52 : 51373 : }
53 : :
54 : 29054 : RewriteResponse BagsRewriter::postRewrite(TNode n)
55 : : {
56 : 58108 : BagsRewriteResponse response;
57 [ + + ]: 29054 : if (n.isConst())
58 : : {
59 : : // no need to rewrite n if it is already in a normal form
60 : 1730 : response = BagsRewriteResponse(n, Rewrite::NONE);
61 : : }
62 [ + + ]: 27324 : else if (n.getKind() == Kind::EQUAL)
63 : : {
64 : 13327 : response = postRewriteEqual(n);
65 : : }
66 [ + + ]: 13997 : else if (n.getKind() == Kind::BAG_CHOOSE)
67 : : {
68 : 29 : response = rewriteChoose(n);
69 : : }
70 [ + + ]: 13968 : else if (BagsUtils::areChildrenConstants(n))
71 : : {
72 : 706 : Node value = BagsUtils::evaluate(d_rewriter, n);
73 : 706 : response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
74 : : }
75 : : else
76 : : {
77 : 13262 : Kind k = n.getKind();
78 [ + + ][ + + ]: 13262 : switch (k)
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + + ]
79 : : {
80 : 598 : case Kind::BAG_MAKE: response = rewriteMakeBag(n); break;
81 : 10607 : case Kind::BAG_COUNT: response = rewriteBagCount(n); break;
82 : 33 : case Kind::BAG_SETOF: response = rewriteSetof(n); break;
83 : 117 : case Kind::BAG_UNION_MAX: response = rewriteUnionMax(n); break;
84 : 515 : case Kind::BAG_UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
85 : 84 : case Kind::BAG_INTER_MIN: response = rewriteIntersectionMin(n); break;
86 : 103 : case Kind::BAG_DIFFERENCE_SUBTRACT:
87 : 103 : response = rewriteDifferenceSubtract(n);
88 : 103 : break;
89 : 119 : case Kind::BAG_DIFFERENCE_REMOVE:
90 : 119 : response = rewriteDifferenceRemove(n);
91 : 119 : break;
92 : 95 : case Kind::BAG_CARD: response = rewriteCard(n); break;
93 : 403 : case Kind::BAG_MAP: response = postRewriteMap(n); break;
94 : 329 : case Kind::BAG_FILTER: response = postRewriteFilter(n); break;
95 : 39 : case Kind::BAG_FOLD: response = postRewriteFold(n); break;
96 : 8 : case Kind::BAG_PARTITION: response = postRewritePartition(n); break;
97 : 42 : case Kind::TABLE_PRODUCT: response = postRewriteProduct(n); break;
98 : 2 : case Kind::TABLE_AGGREGATE: response = postRewriteAggregate(n); break;
99 : 168 : default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
100 : : }
101 : : }
102 : :
103 [ + - ]: 58108 : Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
104 : 29054 : << " by " << response.d_rewrite << "." << std::endl;
105 : :
106 [ + + ]: 29054 : if (d_statistics != nullptr)
107 : : {
108 : 28963 : (*d_statistics) << response.d_rewrite;
109 : : }
110 [ + + ]: 29054 : if (response.d_node != n)
111 : : {
112 : 4250 : return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
113 : : }
114 : 24804 : return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
115 : : }
116 : :
117 : 34218 : RewriteResponse BagsRewriter::preRewrite(TNode n)
118 : : {
119 : 68436 : BagsRewriteResponse response;
120 : 34218 : Kind k = n.getKind();
121 [ + + ][ + + ]: 34218 : switch (k)
122 : : {
123 : 17099 : case Kind::EQUAL: response = preRewriteEqual(n); break;
124 : 32 : case Kind::BAG_SUBBAG: response = rewriteSubBag(n); break;
125 : 172 : case Kind::BAG_MEMBER: response = rewriteMember(n); break;
126 : 16915 : default: response = BagsRewriteResponse(n, Rewrite::NONE);
127 : : }
128 : :
129 [ + - ]: 68436 : Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
130 : 34218 : << " by " << response.d_rewrite << "." << std::endl;
131 : :
132 [ + + ]: 34218 : if (d_statistics != nullptr)
133 : : {
134 : 34217 : (*d_statistics) << response.d_rewrite;
135 : : }
136 [ + + ]: 34218 : if (response.d_node != n)
137 : : {
138 : 787 : return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
139 : : }
140 : 33431 : return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
141 : : }
142 : :
143 : 17099 : BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
144 : : {
145 [ - + ][ - + ]: 17099 : Assert(n.getKind() == Kind::EQUAL);
[ - - ]
146 [ + + ]: 17099 : if (n[0] == n[1])
147 : : {
148 : : // (= A A) = true where A is a bag
149 : 1166 : return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
150 : : }
151 : 16516 : return BagsRewriteResponse(n, Rewrite::NONE);
152 : : }
153 : :
154 : 32 : BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
155 : : {
156 [ - + ][ - + ]: 32 : Assert(n.getKind() == Kind::BAG_SUBBAG);
[ - - ]
157 : :
158 : : // (bag.subbag A B) = ((bag.difference_subtract A B) == bag.empty)
159 : 96 : Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
160 : 96 : Node subtract = d_nm->mkNode(Kind::BAG_DIFFERENCE_SUBTRACT, n[0], n[1]);
161 : 32 : Node equal = subtract.eqNode(emptybag);
162 : 64 : return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
163 : : }
164 : :
165 : 172 : BagsRewriteResponse BagsRewriter::rewriteMember(const TNode& n) const
166 : : {
167 [ - + ][ - + ]: 172 : Assert(n.getKind() == Kind::BAG_MEMBER);
[ - - ]
168 : :
169 : : // - (bag.member x A) = (>= (bag.count x A) 1)
170 : 516 : Node count = d_nm->mkNode(Kind::BAG_COUNT, n[0], n[1]);
171 : 344 : Node geq = d_nm->mkNode(Kind::GEQ, count, d_one);
172 : 344 : return BagsRewriteResponse(geq, Rewrite::MEMBER);
173 : : }
174 : :
175 : 598 : BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
176 : : {
177 [ - + ][ - + ]: 598 : Assert(n.getKind() == Kind::BAG_MAKE);
[ - - ]
178 : : // return bag.empty for negative or zero multiplicity
179 : 598 : if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
180 : : {
181 : : // (bag x c) = bag.empty where c <= 0
182 : 128 : Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
183 : 64 : return BagsRewriteResponse(emptybag, Rewrite::BAG_MAKE_COUNT_NEGATIVE);
184 : : }
185 : 534 : return BagsRewriteResponse(n, Rewrite::NONE);
186 : : }
187 : :
188 : 10607 : BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
189 : : {
190 [ - + ][ - + ]: 10607 : Assert(n.getKind() == Kind::BAG_COUNT);
[ - - ]
191 : 10607 : if (n[1].isConst() && n[1].getKind() == Kind::BAG_EMPTY)
192 : : {
193 : : // (bag.count x bag.empty) = 0
194 : 369 : return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
195 : : }
196 : 11152 : if (n[1].getKind() == Kind::BAG_MAKE && n[0] == n[1][0] && n[1][1].isConst()
197 : 11152 : && n[1][1].getConst<Rational>() > Rational(0))
198 : : {
199 : : // (bag.count x (bag x c)) = c, c > 0 is a constant
200 : 124 : Node c = n[1][1];
201 : 62 : return BagsRewriteResponse(c, Rewrite::COUNT_BAG_MAKE);
202 : : }
203 : 10176 : return BagsRewriteResponse(n, Rewrite::NONE);
204 : : }
205 : :
206 : 33 : BagsRewriteResponse BagsRewriter::rewriteSetof(const TNode& n) const
207 : : {
208 [ - + ][ - + ]: 33 : Assert(n.getKind() == Kind::BAG_SETOF);
[ - - ]
209 : 34 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
210 : 34 : && n[0][1].getConst<Rational>().sgn() == 1)
211 : : {
212 : : // (bag.setof (bag x n)) = (bag x 1)
213 : : // where n is a positive constant
214 : 2 : Node bag = d_nm->mkNode(Kind::BAG_MAKE, n[0][0], d_one);
215 : 1 : return BagsRewriteResponse(bag, Rewrite::SETOF_BAG_MAKE);
216 : : }
217 : 32 : return BagsRewriteResponse(n, Rewrite::NONE);
218 : : }
219 : :
220 : 117 : BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
221 : : {
222 [ - + ][ - + ]: 117 : Assert(n.getKind() == Kind::BAG_UNION_MAX);
[ - - ]
223 : 117 : if (n[1].getKind() == Kind::BAG_EMPTY || n[0] == n[1])
224 : : {
225 : : // (bag.union_max A A) = A
226 : : // (bag.union_max A bag.empty) = A
227 : 2 : return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
228 : : }
229 [ + + ]: 115 : if (n[0].getKind() == Kind::BAG_EMPTY)
230 : : {
231 : : // (bag.union_max bag.empty A) = A
232 : 1 : return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
233 : : }
234 : :
235 [ + + ][ - - ]: 114 : if ((n[1].getKind() == Kind::BAG_UNION_MAX
236 [ + + ][ + - ]: 226 : || n[1].getKind() == Kind::BAG_UNION_DISJOINT)
[ - - ]
237 : 226 : && (n[0] == n[1][0] || n[0] == n[1][1]))
238 : : {
239 : : // (bag.union_max A (bag.union_max A B)) = (bag.union_max A B)
240 : : // (bag.union_max A (bag.union_max B A)) = (bag.union_max B A)
241 : : // (bag.union_max A (bag.union_disjoint A B)) = (bag.union_disjoint A B)
242 : : // (bag.union_max A (bag.union_disjoint B A)) = (bag.union_disjoint B A)
243 : 4 : return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
244 : : }
245 : :
246 [ + + ][ - - ]: 110 : if ((n[0].getKind() == Kind::BAG_UNION_MAX
247 [ + + ][ + - ]: 218 : || n[0].getKind() == Kind::BAG_UNION_DISJOINT)
[ - - ]
248 : 218 : && (n[0][0] == n[1] || n[0][1] == n[1]))
249 : : {
250 : : // (bag.union_max (bag.union_max A B) A)) = (bag.union_max A B)
251 : : // (bag.union_max (bag.union_max B A) A)) = (bag.union_max B A)
252 : : // (bag.union_max (bag.union_disjoint A B) A)) = (bag.union_disjoint A B)
253 : : // (bag.union_max (bag.union_disjoint B A) A)) = (bag.union_disjoint B A)
254 : 4 : return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
255 : : }
256 : 106 : return BagsRewriteResponse(n, Rewrite::NONE);
257 : : }
258 : :
259 : 515 : BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
260 : : {
261 [ - + ][ - + ]: 515 : Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
[ - - ]
262 [ + + ]: 515 : if (n[1].getKind() == Kind::BAG_EMPTY)
263 : : {
264 : : // (bag.union_disjoint A bag.empty) = A
265 : 17 : return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
266 : : }
267 [ + + ]: 498 : if (n[0].getKind() == Kind::BAG_EMPTY)
268 : : {
269 : : // (bag.union_disjoint bag.empty A) = A
270 : 21 : return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
271 : : }
272 [ + + ][ - - ]: 477 : if ((n[0].getKind() == Kind::BAG_UNION_MAX
273 [ - + ][ + - ]: 480 : && n[1].getKind() == Kind::BAG_INTER_MIN)
[ - - ]
274 [ + + ][ - + ]: 957 : || (n[1].getKind() == Kind::BAG_UNION_MAX
[ + + ][ - - ]
275 [ - - ][ - + ]: 477 : && n[0].getKind() == Kind::BAG_INTER_MIN))
[ + + ][ - - ]
276 : :
277 : : {
278 : : // (bag.union_disjoint (bag.union_max A B) (bag.inter_min A B)) =
279 : : // (bag.union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
280 : : // check if the operands of bag.union_max and bag.inter_min are the
281 : : // same
282 : 6 : std::set<Node> left(n[0].begin(), n[0].end());
283 : 6 : std::set<Node> right(n[1].begin(), n[1].end());
284 [ + + ]: 3 : if (left == right)
285 : : {
286 : 4 : Node rewritten = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, n[0][0], n[0][1]);
287 : 2 : return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
288 : : }
289 : : }
290 : 475 : return BagsRewriteResponse(n, Rewrite::NONE);
291 : : }
292 : :
293 : 84 : BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
294 : : {
295 [ - + ][ - + ]: 84 : Assert(n.getKind() == Kind::BAG_INTER_MIN);
[ - - ]
296 [ + + ]: 84 : if (n[0].getKind() == Kind::BAG_EMPTY)
297 : : {
298 : : // (bag.inter_min bag.empty A) = bag.empty
299 : 1 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
300 : : }
301 [ + + ]: 83 : if (n[1].getKind() == Kind::BAG_EMPTY)
302 : : {
303 : : // (bag.inter_min A bag.empty) = bag.empty
304 : 4 : return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
305 : : }
306 [ + + ]: 79 : if (n[0] == n[1])
307 : : {
308 : : // (bag.inter_min A A) = A
309 : 3 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
310 : : }
311 [ + + ][ - - ]: 76 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
312 [ + + ][ + + ]: 76 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
313 : : {
314 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
315 : : {
316 : : // (bag.inter_min A (bag.union_disjoint A B)) = A
317 : : // (bag.inter_min A (bag.union_disjoint B A)) = A
318 : : // (bag.inter_min A (bag.union_max A B)) = A
319 : : // (bag.inter_min A (bag.union_max B A)) = A
320 : 4 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
321 : : }
322 : : }
323 : :
324 [ + + ][ - - ]: 72 : if (n[0].getKind() == Kind::BAG_UNION_DISJOINT
325 [ + + ][ + + ]: 72 : || n[0].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
326 : : {
327 : 4 : if (n[1] == n[0][0] || n[1] == n[0][1])
328 : : {
329 : : // (bag.inter_min (bag.union_disjoint A B) A) = A
330 : : // (bag.inter_min (bag.union_disjoint B A) A) = A
331 : : // (bag.inter_min (bag.union_max A B) A) = A
332 : : // (bag.inter_min (bag.union_max B A) A) = A
333 : 4 : return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
334 : : }
335 : : }
336 : :
337 : 68 : return BagsRewriteResponse(n, Rewrite::NONE);
338 : : }
339 : :
340 : 103 : BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
341 : : const TNode& n) const
342 : : {
343 [ - + ][ - + ]: 103 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
[ - - ]
344 : 103 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
345 : : {
346 : : // (bag.difference_subtract A bag.empty) = A
347 : : // (bag.difference_subtract bag.empty A) = bag.empty
348 : 2 : return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
349 : : }
350 [ + + ]: 101 : if (n[0] == n[1])
351 : : {
352 : : // (bag.difference_subtract A A) = bag.empty
353 : 2 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
354 : 1 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
355 : : }
356 : :
357 [ + + ]: 100 : if (n[0].getKind() == Kind::BAG_UNION_DISJOINT)
358 : : {
359 [ + + ]: 2 : if (n[1] == n[0][0])
360 : : {
361 : : // (bag.difference_subtract (bag.union_disjoint A B) A) = B
362 : : return BagsRewriteResponse(n[0][1],
363 : 1 : Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
364 : : }
365 [ + - ]: 1 : if (n[1] == n[0][1])
366 : : {
367 : : // (bag.difference_subtract (bag.union_disjoint B A) A) = B
368 : : return BagsRewriteResponse(n[0][0],
369 : 1 : Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
370 : : }
371 : : }
372 : :
373 [ + + ][ - - ]: 98 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
374 [ + + ][ + + ]: 98 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
375 : : {
376 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
377 : : {
378 : : // (bag.difference_subtract A (bag.union_disjoint A B)) = bag.empty
379 : : // (bag.difference_subtract A (bag.union_disjoint B A)) = bag.empty
380 : : // (bag.difference_subtract A (bag.union_max A B)) = bag.empty
381 : : // (bag.difference_subtract A (bag.union_max B A)) = bag.empty
382 : 8 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
383 : 4 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
384 : : }
385 : : }
386 : :
387 [ + + ]: 94 : if (n[0].getKind() == Kind::BAG_INTER_MIN)
388 : : {
389 : 2 : if (n[1] == n[0][0] || n[1] == n[0][1])
390 : : {
391 : : // (bag.difference_subtract (bag.inter_min A B) A) = bag.empty
392 : : // (bag.difference_subtract (bag.inter_min B A) A) = bag.empty
393 : 4 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
394 : 2 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
395 : : }
396 : : }
397 : :
398 : 92 : return BagsRewriteResponse(n, Rewrite::NONE);
399 : : }
400 : :
401 : 119 : BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
402 : : {
403 [ - + ][ - + ]: 119 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
[ - - ]
404 : :
405 : 119 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
406 : : {
407 : : // (bag.difference_remove A bag.empty) = A
408 : : // (bag.difference_remove bag.empty B) = bag.empty
409 : 2 : return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
410 : : }
411 : :
412 [ + + ]: 117 : if (n[0] == n[1])
413 : : {
414 : : // (bag.difference_remove A A) = bag.empty
415 : 6 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
416 : 3 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
417 : : }
418 : :
419 [ + + ][ - - ]: 114 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
420 [ + + ][ + + ]: 114 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
421 : : {
422 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
423 : : {
424 : : // (bag.difference_remove A (bag.union_disjoint A B)) = bag.empty
425 : : // (bag.difference_remove A (bag.union_disjoint B A)) = bag.empty
426 : : // (bag.difference_remove A (bag.union_max A B)) = bag.empty
427 : : // (bag.difference_remove A (bag.union_max B A)) = bag.empty
428 : 8 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
429 : 4 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
430 : : }
431 : : }
432 : :
433 [ + + ]: 110 : if (n[0].getKind() == Kind::BAG_INTER_MIN)
434 : : {
435 : 2 : if (n[1] == n[0][0] || n[1] == n[0][1])
436 : : {
437 : : // (bag.difference_remove (bag.inter_min A B) A) = bag.empty
438 : : // (bag.difference_remove (bag.inter_min B A) A) = bag.empty
439 : 4 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
440 : 2 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
441 : : }
442 : : }
443 : :
444 : 108 : return BagsRewriteResponse(n, Rewrite::NONE);
445 : : }
446 : :
447 : 29 : BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
448 : : {
449 [ - + ][ - + ]: 29 : Assert(n.getKind() == Kind::BAG_CHOOSE);
[ - - ]
450 : 30 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
451 : 30 : && n[0][1].getConst<Rational>() > 0)
452 : : {
453 : : // (bag.choose (bag x c)) = x where c is a constant > 0
454 : 1 : return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_BAG_MAKE);
455 : : }
456 : 28 : return BagsRewriteResponse(n, Rewrite::NONE);
457 : : }
458 : :
459 : 95 : BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
460 : : {
461 [ - + ][ - + ]: 95 : Assert(n.getKind() == Kind::BAG_CARD);
[ - - ]
462 : 95 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst())
463 : : {
464 : : // (bag.card (bag x c)) = c where c is a constant > 0
465 : 1 : return BagsRewriteResponse(n[0][1], Rewrite::CARD_BAG_MAKE);
466 : : }
467 : :
468 : 94 : return BagsRewriteResponse(n, Rewrite::NONE);
469 : : }
470 : :
471 : 13327 : BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
472 : : {
473 [ - + ][ - + ]: 13327 : Assert(n.getKind() == Kind::EQUAL);
[ - - ]
474 [ + + ]: 13327 : if (n[0] == n[1])
475 : : {
476 : 33 : Node ret = d_nm->mkConst(true);
477 : 33 : return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
478 : : }
479 : :
480 : 13294 : if (n[0].isConst() && n[1].isConst())
481 : : {
482 : 88 : Node ret = d_nm->mkConst(false);
483 : 88 : return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
484 : : }
485 : :
486 : : // standard ordering
487 [ + + ]: 13206 : if (n[0] > n[1])
488 : : {
489 : 5570 : Node ret = d_nm->mkNode(Kind::EQUAL, n[1], n[0]);
490 : 2785 : return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
491 : : }
492 : 10421 : return BagsRewriteResponse(n, Rewrite::NONE);
493 : : }
494 : :
495 : 403 : BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
496 : : {
497 [ - + ][ - + ]: 403 : Assert(n.getKind() == Kind::BAG_MAP);
[ - - ]
498 [ + + ]: 403 : if (n[1].isConst())
499 : : {
500 : : // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
501 : : // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
502 : 56 : std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
503 : 56 : std::map<Node, Rational> mappedElements;
504 : 28 : std::map<Node, Rational>::iterator it = elements.begin();
505 [ + + ]: 37 : while (it != elements.end())
506 : : {
507 : 27 : Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], it->first);
508 : 9 : mappedElements[mappedElement] = it->second;
509 : 9 : ++it;
510 : : }
511 : 84 : TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
512 : 28 : Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
513 : 28 : return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
514 : : }
515 : 375 : Kind k = n[1].getKind();
516 [ - + ][ + ]: 375 : switch (k)
517 : : {
518 : 0 : case Kind::BAG_MAKE:
519 : : {
520 : : // (bag.map f (bag x y)) = (bag (apply f x) y)
521 : 0 : Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], n[1][0]);
522 : 0 : Node ret = d_nm->mkNode(Kind::BAG_MAKE, mappedElement, n[1][1]);
523 : 0 : return BagsRewriteResponse(ret, Rewrite::MAP_BAG_MAKE);
524 : : }
525 : :
526 : 1 : case Kind::BAG_UNION_DISJOINT:
527 : : {
528 : : // (bag.map f (bag.union_disjoint A B)) =
529 : : // (bag.union_disjoint (bag.map f A) (bag.map f B))
530 : 3 : Node a = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][0]);
531 : 3 : Node b = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][1]);
532 : 2 : Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
533 : 1 : return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
534 : : }
535 : :
536 : 374 : default: return BagsRewriteResponse(n, Rewrite::NONE);
537 : : }
538 : : }
539 : :
540 : 329 : BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const
541 : : {
542 [ - + ][ - + ]: 329 : Assert(n.getKind() == Kind::BAG_FILTER);
[ - - ]
543 : 658 : Node P = n[0];
544 : 658 : Node A = n[1];
545 : 658 : TypeNode t = A.getType();
546 [ + + ]: 329 : if (A.isConst())
547 : : {
548 : : // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
549 : : // (bag.filter p (bag "a" 3) ((bag "b" 2))) =
550 : : // (bag.union_disjoint
551 : : // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
552 : : // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
553 : :
554 : 3 : Node ret = BagsUtils::evaluateBagFilter(n);
555 : 3 : return BagsRewriteResponse(ret, Rewrite::FILTER_CONST);
556 : : }
557 : 326 : Kind k = A.getKind();
558 [ - - ][ + ]: 326 : switch (k)
559 : : {
560 : 0 : case Kind::BAG_MAKE:
561 : : {
562 : : // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
563 : 0 : Node empty = d_nm->mkConst(EmptyBag(t));
564 : 0 : Node pOfe = d_nm->mkNode(Kind::APPLY_UF, P, A[0]);
565 : 0 : Node ret = d_nm->mkNode(Kind::ITE, pOfe, A, empty);
566 : 0 : return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE);
567 : : }
568 : :
569 : 0 : case Kind::BAG_UNION_DISJOINT:
570 : : {
571 : : // (bag.filter p (bag.union_disjoint A B)) =
572 : : // (bag.union_disjoint (bag.filter p A) (bag.filter p B))
573 : 0 : Node a = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][0]);
574 : 0 : Node b = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][1]);
575 : 0 : Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
576 : 0 : return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT);
577 : : }
578 : :
579 : 326 : default: return BagsRewriteResponse(n, Rewrite::NONE);
580 : : }
581 : : }
582 : :
583 : 39 : BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
584 : : {
585 [ - + ][ - + ]: 39 : Assert(n.getKind() == Kind::BAG_FOLD);
[ - - ]
586 : 78 : Node f = n[0];
587 : 78 : Node t = n[1];
588 : 78 : Node bag = n[2];
589 [ + + ]: 39 : if (bag.isConst())
590 : : {
591 : 9 : Node value = BagsUtils::evaluateBagFold(n);
592 : 9 : return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
593 : : }
594 : 30 : Kind k = bag.getKind();
595 [ + + ][ + ]: 30 : switch (k)
596 : : {
597 : 1 : case Kind::BAG_MAKE:
598 : : {
599 : 1 : if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
600 : : {
601 : : // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
602 : 1 : Node value = BagsUtils::evaluateBagFold(n);
603 : 1 : return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
604 : : }
605 : 0 : break;
606 : : }
607 : 1 : case Kind::BAG_UNION_DISJOINT:
608 : : {
609 : : // (bag.fold f t (bag.union_disjoint A B)) =
610 : : // (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
611 [ + - ]: 3 : Node A = bag[0] < bag[1] ? bag[0] : bag[1];
612 [ + - ]: 3 : Node B = bag[0] < bag[1] ? bag[1] : bag[0];
613 : 3 : Node foldA = d_nm->mkNode(Kind::BAG_FOLD, f, t, A);
614 : 2 : Node fold = d_nm->mkNode(Kind::BAG_FOLD, f, foldA, B);
615 : 1 : return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT);
616 : : }
617 : 28 : default: return BagsRewriteResponse(n, Rewrite::NONE);
618 : : }
619 : 0 : return BagsRewriteResponse(n, Rewrite::NONE);
620 : : }
621 : :
622 : 8 : BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
623 : : {
624 [ - + ][ - + ]: 8 : Assert(n.getKind() == Kind::BAG_PARTITION);
[ - - ]
625 [ + + ]: 8 : if (n[1].isConst())
626 : : {
627 : 4 : Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n);
628 [ + - ]: 4 : if (ret != n)
629 : : {
630 : 4 : return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST);
631 : : }
632 : : }
633 : :
634 : 4 : return BagsRewriteResponse(n, Rewrite::NONE);
635 : : }
636 : :
637 : 2 : BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const
638 : : {
639 [ - + ][ - + ]: 2 : Assert(n.getKind() == Kind::TABLE_AGGREGATE);
[ - - ]
640 : 2 : if (n[1].isConst() && n[2].isConst())
641 : : {
642 : 2 : Node ret = BagsUtils::evaluateTableAggregate(d_rewriter, n);
643 [ + - ]: 2 : if (ret != n)
644 : : {
645 : 2 : return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST);
646 : : }
647 : : }
648 : :
649 : 0 : return BagsRewriteResponse(n, Rewrite::NONE);
650 : : }
651 : :
652 : 42 : BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
653 : : {
654 [ - + ][ - + ]: 42 : Assert(n.getKind() == Kind::TABLE_PRODUCT);
[ - - ]
655 : 84 : TypeNode tableType = n.getType();
656 : 84 : Node empty = d_nm->mkConst(EmptyBag(tableType));
657 : 42 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
658 : : {
659 : 2 : return BagsRewriteResponse(empty, Rewrite::PRODUCT_EMPTY);
660 : : }
661 : :
662 : 40 : return BagsRewriteResponse(n, Rewrite::NONE);
663 : : }
664 : :
665 : : } // namespace bags
666 : : } // namespace theory
667 : : } // namespace cvc5::internal
|