Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Bags theory rewriter.
11 : : */
12 : :
13 : : #include "theory/bags/bags_rewriter.h"
14 : :
15 : : #include "expr/emptybag.h"
16 : : #include "theory/bags/bags_utils.h"
17 : : #include "theory/rewriter.h"
18 : : #include "util/rational.h"
19 : : #include "util/statistics_registry.h"
20 : :
21 : : using namespace cvc5::internal::kind;
22 : :
23 : : namespace cvc5::internal {
24 : : namespace theory {
25 : : namespace bags {
26 : :
27 : 61369 : BagsRewriteResponse::BagsRewriteResponse()
28 : 61369 : : d_node(Node::null()), d_rewrite(Rewrite::NONE)
29 : : {
30 : 61369 : }
31 : :
32 : 61369 : BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
33 : 61369 : : d_node(n), d_rewrite(rewrite)
34 : : {
35 : 61369 : }
36 : :
37 : 50934 : BagsRewriter::BagsRewriter(NodeManager* nm,
38 : : Rewriter* r,
39 : 50934 : HistogramStat<Rewrite>* statistics)
40 : 50934 : : TheoryRewriter(nm), d_rewriter(r), d_statistics(statistics)
41 : : {
42 : 50934 : d_zero = d_nm->mkConstInt(Rational(0));
43 : 50934 : d_one = d_nm->mkConstInt(Rational(1));
44 : 50934 : }
45 : :
46 : 28253 : RewriteResponse BagsRewriter::postRewrite(TNode n)
47 : : {
48 : 28253 : BagsRewriteResponse response;
49 [ + + ]: 28253 : if (n.isConst())
50 : : {
51 : : // no need to rewrite n if it is already in a normal form
52 : 1956 : response = BagsRewriteResponse(n, Rewrite::NONE);
53 : : }
54 [ + + ]: 26297 : else if (n.getKind() == Kind::EQUAL)
55 : : {
56 : 11872 : response = postRewriteEqual(n);
57 : : }
58 [ + + ]: 14425 : else if (n.getKind() == Kind::BAG_CHOOSE)
59 : : {
60 : 25 : response = rewriteChoose(n);
61 : : }
62 [ + + ]: 14400 : else if (BagsUtils::areChildrenConstants(n))
63 : : {
64 : 897 : Node value = BagsUtils::evaluate(d_rewriter, n);
65 : 897 : response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
66 : 897 : }
67 : : else
68 : : {
69 : 13503 : Kind k = n.getKind();
70 [ + + ][ + + ]: 13503 : switch (k)
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ][ + + ]
[ + + ]
71 : : {
72 : 611 : case Kind::BAG_MAKE: response = rewriteMakeBag(n); break;
73 : 10947 : case Kind::BAG_COUNT: response = rewriteBagCount(n); break;
74 : 29 : case Kind::BAG_SETOF: response = rewriteSetof(n); break;
75 : 107 : case Kind::BAG_UNION_MAX: response = rewriteUnionMax(n); break;
76 : 597 : case Kind::BAG_UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
77 : 80 : case Kind::BAG_INTER_MIN: response = rewriteIntersectionMin(n); break;
78 : 95 : case Kind::BAG_DIFFERENCE_SUBTRACT:
79 : 95 : response = rewriteDifferenceSubtract(n);
80 : 95 : break;
81 : 59 : case Kind::BAG_DIFFERENCE_REMOVE:
82 : 59 : response = rewriteDifferenceRemove(n);
83 : 59 : break;
84 : 83 : case Kind::BAG_CARD: response = rewriteCard(n); break;
85 : 303 : case Kind::BAG_MAP: response = postRewriteMap(n); break;
86 : 349 : case Kind::BAG_FILTER: response = postRewriteFilter(n); break;
87 : 23 : case Kind::BAG_ALL: response = postRewriteAll(n); break;
88 : 5 : case Kind::BAG_SOME: response = postRewriteSome(n); break;
89 : 39 : case Kind::BAG_FOLD: response = postRewriteFold(n); break;
90 : 8 : case Kind::BAG_PARTITION: response = postRewritePartition(n); break;
91 : 38 : case Kind::TABLE_PRODUCT: response = postRewriteProduct(n); break;
92 : 2 : case Kind::TABLE_AGGREGATE: response = postRewriteAggregate(n); break;
93 : 128 : default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
94 : : }
95 : : }
96 : :
97 [ + - ]: 56506 : Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
98 : 28253 : << " by " << response.d_rewrite << "." << std::endl;
99 : :
100 [ + + ]: 28253 : if (d_statistics != nullptr)
101 : : {
102 : 28162 : (*d_statistics) << response.d_rewrite;
103 : : }
104 [ + + ]: 28253 : if (response.d_node != n)
105 : : {
106 : 4089 : return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
107 : : }
108 : 24164 : return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
109 : 28253 : }
110 : :
111 : 33116 : RewriteResponse BagsRewriter::preRewrite(TNode n)
112 : : {
113 : 33116 : BagsRewriteResponse response;
114 : 33116 : Kind k = n.getKind();
115 [ + + ][ + + ]: 33116 : switch (k)
116 : : {
117 : 15143 : case Kind::EQUAL: response = preRewriteEqual(n); break;
118 : 28 : case Kind::BAG_SUBBAG: response = rewriteSubBag(n); break;
119 : 147 : case Kind::BAG_MEMBER: response = rewriteMember(n); break;
120 : 17798 : default: response = BagsRewriteResponse(n, Rewrite::NONE);
121 : : }
122 : :
123 [ + - ]: 66232 : Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
124 : 33116 : << " by " << response.d_rewrite << "." << std::endl;
125 : :
126 [ + + ]: 33116 : if (d_statistics != nullptr)
127 : : {
128 : 33115 : (*d_statistics) << response.d_rewrite;
129 : : }
130 [ + + ]: 33116 : if (response.d_node != n)
131 : : {
132 : 726 : return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
133 : : }
134 : 32390 : return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
135 : 33116 : }
136 : :
137 : 15143 : BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
138 : : {
139 [ - + ][ - + ]: 15143 : Assert(n.getKind() == Kind::EQUAL);
[ - - ]
140 [ + + ]: 15143 : if (n[0] == n[1])
141 : : {
142 : : // (= A A) = true where A is a bag
143 : 1102 : return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
144 : : }
145 : 14592 : return BagsRewriteResponse(n, Rewrite::NONE);
146 : : }
147 : :
148 : 28 : BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
149 : : {
150 [ - + ][ - + ]: 28 : Assert(n.getKind() == Kind::BAG_SUBBAG);
[ - - ]
151 : :
152 : : // (bag.subbag A B) = ((bag.difference_subtract A B) == bag.empty)
153 : 56 : Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
154 : 56 : Node subtract = d_nm->mkNode(Kind::BAG_DIFFERENCE_SUBTRACT, n[0], n[1]);
155 : 28 : Node equal = subtract.eqNode(emptybag);
156 : 56 : return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
157 : 28 : }
158 : :
159 : 147 : BagsRewriteResponse BagsRewriter::rewriteMember(const TNode& n) const
160 : : {
161 [ - + ][ - + ]: 147 : Assert(n.getKind() == Kind::BAG_MEMBER);
[ - - ]
162 : :
163 : : // - (bag.member x A) = (>= (bag.count x A) 1)
164 : 294 : Node count = d_nm->mkNode(Kind::BAG_COUNT, n[0], n[1]);
165 : 294 : Node geq = d_nm->mkNode(Kind::GEQ, count, d_one);
166 : 294 : return BagsRewriteResponse(geq, Rewrite::MEMBER);
167 : 147 : }
168 : :
169 : 611 : BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
170 : : {
171 [ - + ][ - + ]: 611 : Assert(n.getKind() == Kind::BAG_MAKE);
[ - - ]
172 : : // return bag.empty for negative or zero multiplicity
173 : 611 : if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
174 : : {
175 : : // (bag x c) = bag.empty where c <= 0
176 : 130 : Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
177 : 65 : return BagsRewriteResponse(emptybag, Rewrite::BAG_MAKE_COUNT_NEGATIVE);
178 : 65 : }
179 : 546 : return BagsRewriteResponse(n, Rewrite::NONE);
180 : : }
181 : :
182 : 10947 : BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
183 : : {
184 [ - + ][ - + ]: 10947 : Assert(n.getKind() == Kind::BAG_COUNT);
[ - - ]
185 : 10947 : if (n[1].isConst() && n[1].getKind() == Kind::BAG_EMPTY)
186 : : {
187 : : // (bag.count x bag.empty) = 0
188 : 372 : return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
189 : : }
190 : 11534 : if (n[1].getKind() == Kind::BAG_MAKE && n[0] == n[1][0] && n[1][1].isConst()
191 : 11534 : && n[1][1].getConst<Rational>() > Rational(0))
192 : : {
193 : : // (bag.count x (bag x c)) = c, c > 0 is a constant
194 : 114 : Node c = n[1][1];
195 : 57 : return BagsRewriteResponse(c, Rewrite::COUNT_BAG_MAKE);
196 : 57 : }
197 : 10518 : return BagsRewriteResponse(n, Rewrite::NONE);
198 : : }
199 : :
200 : 29 : BagsRewriteResponse BagsRewriter::rewriteSetof(const TNode& n) const
201 : : {
202 [ - + ][ - + ]: 29 : Assert(n.getKind() == Kind::BAG_SETOF);
[ - - ]
203 : 30 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
204 : 30 : && n[0][1].getConst<Rational>().sgn() == 1)
205 : : {
206 : : // (bag.setof (bag x n)) = (bag x 1)
207 : : // where n is a positive constant
208 : 2 : Node bag = d_nm->mkNode(Kind::BAG_MAKE, n[0][0], d_one);
209 : 1 : return BagsRewriteResponse(bag, Rewrite::SETOF_BAG_MAKE);
210 : 1 : }
211 : 28 : return BagsRewriteResponse(n, Rewrite::NONE);
212 : : }
213 : :
214 : 107 : BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
215 : : {
216 [ - + ][ - + ]: 107 : Assert(n.getKind() == Kind::BAG_UNION_MAX);
[ - - ]
217 : 107 : if (n[1].getKind() == Kind::BAG_EMPTY || n[0] == n[1])
218 : : {
219 : : // (bag.union_max A A) = A
220 : : // (bag.union_max A bag.empty) = A
221 : 2 : return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
222 : : }
223 [ + + ]: 105 : if (n[0].getKind() == Kind::BAG_EMPTY)
224 : : {
225 : : // (bag.union_max bag.empty A) = A
226 : 1 : return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
227 : : }
228 : :
229 [ + + ][ - - ]: 104 : if ((n[1].getKind() == Kind::BAG_UNION_MAX
230 [ + + ][ + - ]: 206 : || n[1].getKind() == Kind::BAG_UNION_DISJOINT)
[ - - ]
231 : 206 : && (n[0] == n[1][0] || n[0] == n[1][1]))
232 : : {
233 : : // (bag.union_max A (bag.union_max A B)) = (bag.union_max A B)
234 : : // (bag.union_max A (bag.union_max B A)) = (bag.union_max B A)
235 : : // (bag.union_max A (bag.union_disjoint A B)) = (bag.union_disjoint A B)
236 : : // (bag.union_max A (bag.union_disjoint B A)) = (bag.union_disjoint B A)
237 : 4 : return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
238 : : }
239 : :
240 [ + + ][ - - ]: 100 : if ((n[0].getKind() == Kind::BAG_UNION_MAX
241 [ + + ][ + - ]: 198 : || n[0].getKind() == Kind::BAG_UNION_DISJOINT)
[ - - ]
242 : 198 : && (n[0][0] == n[1] || n[0][1] == n[1]))
243 : : {
244 : : // (bag.union_max (bag.union_max A B) A)) = (bag.union_max A B)
245 : : // (bag.union_max (bag.union_max B A) A)) = (bag.union_max B A)
246 : : // (bag.union_max (bag.union_disjoint A B) A)) = (bag.union_disjoint A B)
247 : : // (bag.union_max (bag.union_disjoint B A) A)) = (bag.union_disjoint B A)
248 : 4 : return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
249 : : }
250 : 96 : return BagsRewriteResponse(n, Rewrite::NONE);
251 : : }
252 : :
253 : 597 : BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
254 : : {
255 [ - + ][ - + ]: 597 : Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
[ - - ]
256 [ + + ]: 597 : if (n[1].getKind() == Kind::BAG_EMPTY)
257 : : {
258 : : // (bag.union_disjoint A bag.empty) = A
259 : 9 : return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
260 : : }
261 [ + + ]: 588 : if (n[0].getKind() == Kind::BAG_EMPTY)
262 : : {
263 : : // (bag.union_disjoint bag.empty A) = A
264 : 37 : return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
265 : : }
266 [ + + ][ - - ]: 551 : if ((n[0].getKind() == Kind::BAG_UNION_MAX
267 [ - + ][ + - ]: 554 : && n[1].getKind() == Kind::BAG_INTER_MIN)
[ - - ]
268 [ + + ][ - + ]: 1105 : || (n[1].getKind() == Kind::BAG_UNION_MAX
[ + + ][ - - ]
269 [ - - ][ - + ]: 551 : && n[0].getKind() == Kind::BAG_INTER_MIN))
[ + + ][ - - ]
270 : :
271 : : {
272 : : // (bag.union_disjoint (bag.union_max A B) (bag.inter_min A B)) =
273 : : // (bag.union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
274 : : // check if the operands of bag.union_max and bag.inter_min are the
275 : : // same
276 : 6 : std::set<Node> left(n[0].begin(), n[0].end());
277 : 6 : std::set<Node> right(n[1].begin(), n[1].end());
278 [ + + ]: 3 : if (left == right)
279 : : {
280 : 4 : Node rewritten = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, n[0][0], n[0][1]);
281 : 2 : return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
282 : 2 : }
283 [ + + ][ + + ]: 5 : }
284 : 549 : return BagsRewriteResponse(n, Rewrite::NONE);
285 : : }
286 : :
287 : 80 : BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
288 : : {
289 [ - + ][ - + ]: 80 : Assert(n.getKind() == Kind::BAG_INTER_MIN);
[ - - ]
290 [ + + ]: 80 : if (n[0].getKind() == Kind::BAG_EMPTY)
291 : : {
292 : : // (bag.inter_min bag.empty A) = bag.empty
293 : 1 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
294 : : }
295 [ + + ]: 79 : if (n[1].getKind() == Kind::BAG_EMPTY)
296 : : {
297 : : // (bag.inter_min A bag.empty) = bag.empty
298 : 4 : return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
299 : : }
300 [ + + ]: 75 : if (n[0] == n[1])
301 : : {
302 : : // (bag.inter_min A A) = A
303 : 3 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
304 : : }
305 [ + + ][ - - ]: 72 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
306 [ + + ][ + + ]: 72 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
307 : : {
308 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
309 : : {
310 : : // (bag.inter_min A (bag.union_disjoint A B)) = A
311 : : // (bag.inter_min A (bag.union_disjoint B A)) = A
312 : : // (bag.inter_min A (bag.union_max A B)) = A
313 : : // (bag.inter_min A (bag.union_max B A)) = A
314 : 4 : return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
315 : : }
316 : : }
317 : :
318 [ + + ][ - - ]: 68 : if (n[0].getKind() == Kind::BAG_UNION_DISJOINT
319 [ + + ][ + + ]: 68 : || n[0].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
320 : : {
321 : 4 : if (n[1] == n[0][0] || n[1] == n[0][1])
322 : : {
323 : : // (bag.inter_min (bag.union_disjoint A B) A) = A
324 : : // (bag.inter_min (bag.union_disjoint B A) A) = A
325 : : // (bag.inter_min (bag.union_max A B) A) = A
326 : : // (bag.inter_min (bag.union_max B A) A) = A
327 : 4 : return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
328 : : }
329 : : }
330 : :
331 : 64 : return BagsRewriteResponse(n, Rewrite::NONE);
332 : : }
333 : :
334 : 95 : BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
335 : : const TNode& n) const
336 : : {
337 [ - + ][ - + ]: 95 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
[ - - ]
338 : 95 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
339 : : {
340 : : // (bag.difference_subtract A bag.empty) = A
341 : : // (bag.difference_subtract bag.empty A) = bag.empty
342 : 2 : return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
343 : : }
344 [ + + ]: 93 : if (n[0] == n[1])
345 : : {
346 : : // (bag.difference_subtract A A) = bag.empty
347 : 2 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
348 : 1 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
349 : 1 : }
350 : :
351 [ + + ]: 92 : if (n[0].getKind() == Kind::BAG_UNION_DISJOINT)
352 : : {
353 [ + + ]: 2 : if (n[1] == n[0][0])
354 : : {
355 : : // (bag.difference_subtract (bag.union_disjoint A B) A) = B
356 : : return BagsRewriteResponse(n[0][1],
357 : 1 : Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
358 : : }
359 [ + - ]: 1 : if (n[1] == n[0][1])
360 : : {
361 : : // (bag.difference_subtract (bag.union_disjoint B A) A) = B
362 : : return BagsRewriteResponse(n[0][0],
363 : 1 : Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
364 : : }
365 : : }
366 : :
367 [ + + ][ - - ]: 90 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
368 [ + + ][ + + ]: 90 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
369 : : {
370 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
371 : : {
372 : : // (bag.difference_subtract A (bag.union_disjoint A B)) = bag.empty
373 : : // (bag.difference_subtract A (bag.union_disjoint B A)) = bag.empty
374 : : // (bag.difference_subtract A (bag.union_max A B)) = bag.empty
375 : : // (bag.difference_subtract A (bag.union_max B A)) = bag.empty
376 : 8 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
377 : 4 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
378 : 4 : }
379 : : }
380 : :
381 [ + + ]: 86 : if (n[0].getKind() == Kind::BAG_INTER_MIN)
382 : : {
383 : 2 : if (n[1] == n[0][0] || n[1] == n[0][1])
384 : : {
385 : : // (bag.difference_subtract (bag.inter_min A B) A) = bag.empty
386 : : // (bag.difference_subtract (bag.inter_min B A) A) = bag.empty
387 : 4 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
388 : 2 : return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
389 : 2 : }
390 : : }
391 : :
392 : 84 : return BagsRewriteResponse(n, Rewrite::NONE);
393 : : }
394 : :
395 : 59 : BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
396 : : {
397 [ - + ][ - + ]: 59 : Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
[ - - ]
398 : :
399 : 59 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
400 : : {
401 : : // (bag.difference_remove A bag.empty) = A
402 : : // (bag.difference_remove bag.empty B) = bag.empty
403 : 2 : return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
404 : : }
405 : :
406 [ + + ]: 57 : if (n[0] == n[1])
407 : : {
408 : : // (bag.difference_remove A A) = bag.empty
409 : 6 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
410 : 3 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
411 : 3 : }
412 : :
413 [ + + ][ - - ]: 54 : if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
414 [ + + ][ + + ]: 54 : || n[1].getKind() == Kind::BAG_UNION_MAX)
[ + + ][ + - ]
[ - - ]
415 : : {
416 : 4 : if (n[0] == n[1][0] || n[0] == n[1][1])
417 : : {
418 : : // (bag.difference_remove A (bag.union_disjoint A B)) = bag.empty
419 : : // (bag.difference_remove A (bag.union_disjoint B A)) = bag.empty
420 : : // (bag.difference_remove A (bag.union_max A B)) = bag.empty
421 : : // (bag.difference_remove A (bag.union_max B A)) = bag.empty
422 : 8 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
423 : 4 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
424 : 4 : }
425 : : }
426 : :
427 [ + + ]: 50 : if (n[0].getKind() == Kind::BAG_INTER_MIN)
428 : : {
429 : 2 : if (n[1] == n[0][0] || n[1] == n[0][1])
430 : : {
431 : : // (bag.difference_remove (bag.inter_min A B) A) = bag.empty
432 : : // (bag.difference_remove (bag.inter_min B A) A) = bag.empty
433 : 4 : Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
434 : 2 : return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
435 : 2 : }
436 : : }
437 : :
438 : 48 : return BagsRewriteResponse(n, Rewrite::NONE);
439 : : }
440 : :
441 : 25 : BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
442 : : {
443 [ - + ][ - + ]: 25 : Assert(n.getKind() == Kind::BAG_CHOOSE);
[ - - ]
444 : 26 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
445 : 26 : && n[0][1].getConst<Rational>() > 0)
446 : : {
447 : : // (bag.choose (bag x c)) = x where c is a constant > 0
448 : 1 : return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_BAG_MAKE);
449 : : }
450 : 24 : return BagsRewriteResponse(n, Rewrite::NONE);
451 : : }
452 : :
453 : 83 : BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
454 : : {
455 [ - + ][ - + ]: 83 : Assert(n.getKind() == Kind::BAG_CARD);
[ - - ]
456 : 83 : if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst())
457 : : {
458 : : // (bag.card (bag x c)) = c where c is a constant > 0
459 : 1 : return BagsRewriteResponse(n[0][1], Rewrite::CARD_BAG_MAKE);
460 : : }
461 : :
462 : 82 : return BagsRewriteResponse(n, Rewrite::NONE);
463 : : }
464 : :
465 : 11872 : BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
466 : : {
467 [ - + ][ - + ]: 11872 : Assert(n.getKind() == Kind::EQUAL);
[ - - ]
468 [ + + ]: 11872 : if (n[0] == n[1])
469 : : {
470 : 34 : Node ret = d_nm->mkConst(true);
471 : 34 : return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
472 : 34 : }
473 : :
474 : 11838 : if (n[0].isConst() && n[1].isConst())
475 : : {
476 : 108 : Node ret = d_nm->mkConst(false);
477 : 108 : return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
478 : 108 : }
479 : :
480 : : // standard ordering
481 [ + + ]: 11730 : if (n[0] > n[1])
482 : : {
483 : 4726 : Node ret = d_nm->mkNode(Kind::EQUAL, n[1], n[0]);
484 : 2363 : return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
485 : 2363 : }
486 : 9367 : return BagsRewriteResponse(n, Rewrite::NONE);
487 : : }
488 : :
489 : 303 : BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
490 : : {
491 [ - + ][ - + ]: 303 : Assert(n.getKind() == Kind::BAG_MAP);
[ - - ]
492 [ + + ]: 303 : if (n[1].isConst())
493 : : {
494 : : // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
495 : : // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
496 : 30 : std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
497 : 30 : std::map<Node, Rational> mappedElements;
498 : 30 : std::map<Node, Rational>::iterator it = elements.begin();
499 [ + + ]: 48 : while (it != elements.end())
500 : : {
501 : 36 : Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], it->first);
502 : 18 : mappedElements[mappedElement] = it->second;
503 : 18 : ++it;
504 : 18 : }
505 : 60 : TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
506 : 30 : Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
507 : 30 : return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
508 : 30 : }
509 : 273 : Kind k = n[1].getKind();
510 [ - + ][ + ]: 273 : switch (k)
511 : : {
512 : 0 : case Kind::BAG_MAKE:
513 : : {
514 : : // (bag.map f (bag x y)) = (bag (apply f x) y)
515 : 0 : Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], n[1][0]);
516 : 0 : Node ret = d_nm->mkNode(Kind::BAG_MAKE, mappedElement, n[1][1]);
517 : 0 : return BagsRewriteResponse(ret, Rewrite::MAP_BAG_MAKE);
518 : 0 : }
519 : :
520 : 1 : case Kind::BAG_UNION_DISJOINT:
521 : : {
522 : : // (bag.map f (bag.union_disjoint A B)) =
523 : : // (bag.union_disjoint (bag.map f A) (bag.map f B))
524 : 2 : Node a = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][0]);
525 : 2 : Node b = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][1]);
526 : 2 : Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
527 : 1 : return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
528 : 1 : }
529 : :
530 : 272 : default: return BagsRewriteResponse(n, Rewrite::NONE);
531 : : }
532 : : }
533 : :
534 : 349 : BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const
535 : : {
536 [ - + ][ - + ]: 349 : Assert(n.getKind() == Kind::BAG_FILTER);
[ - - ]
537 : 349 : Node P = n[0];
538 : 349 : Node A = n[1];
539 : 349 : TypeNode t = A.getType();
540 [ + + ]: 349 : if (A.isConst())
541 : : {
542 : : // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
543 : : // (bag.filter p (bag "a" 3) ((bag "b" 2))) =
544 : : // (bag.union_disjoint
545 : : // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
546 : : // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
547 : :
548 : 15 : Node ret = BagsUtils::evaluateBagFilter(n);
549 : 15 : return BagsRewriteResponse(ret, Rewrite::FILTER_CONST);
550 : 15 : }
551 : 334 : Kind k = A.getKind();
552 [ - - ][ + ]: 334 : switch (k)
553 : : {
554 : 0 : case Kind::BAG_MAKE:
555 : : {
556 : : // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
557 : 0 : Node empty = d_nm->mkConst(EmptyBag(t));
558 : 0 : Node pOfe = d_nm->mkNode(Kind::APPLY_UF, P, A[0]);
559 : 0 : Node ret = d_nm->mkNode(Kind::ITE, pOfe, A, empty);
560 : 0 : return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE);
561 : 0 : }
562 : :
563 : 0 : case Kind::BAG_UNION_DISJOINT:
564 : : {
565 : : // (bag.filter p (bag.union_disjoint A B)) =
566 : : // (bag.union_disjoint (bag.filter p A) (bag.filter p B))
567 : 0 : Node a = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][0]);
568 : 0 : Node b = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][1]);
569 : 0 : Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
570 : 0 : return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT);
571 : 0 : }
572 : :
573 : 334 : default: return BagsRewriteResponse(n, Rewrite::NONE);
574 : : }
575 : 349 : }
576 : :
577 : 23 : BagsRewriteResponse BagsRewriter::postRewriteAll(TNode n)
578 : : {
579 [ - + ][ - + ]: 23 : Assert(n.getKind() == Kind::BAG_ALL);
[ - - ]
580 : 23 : NodeManager* nm = nodeManager();
581 : 23 : Kind k = n[1].getKind();
582 [ - + ][ + + ]: 23 : switch (k)
583 : : {
584 : 0 : case Kind::BAG_EMPTY:
585 : : {
586 : : // (bag.all p (as bag.empty (Bag T)) = true)
587 : 0 : return BagsRewriteResponse(nm->mkConst(true), Rewrite::ALL_EMPTY);
588 : : }
589 : 4 : case Kind::BAG_MAKE:
590 : : {
591 : : // (bag.all p (bag x n)) = (or (p x) (<= n 0)
592 : 8 : Node px = nm->mkNode(Kind::APPLY_UF, n[0], n[1][0]);
593 : 8 : Node leq = nm->mkNode(Kind::LEQ, n[1][1], d_zero);
594 : 4 : Node ret = px.orNode(leq);
595 : 4 : return BagsRewriteResponse(ret, Rewrite::ALL_BAG_MAKE);
596 : 4 : }
597 : 3 : case Kind::BAG_UNION_DISJOINT:
598 : : {
599 : : // (bag.all p (bag.union_disjoint A B)) =
600 : : // (and (bag.all p A) (bag.all p B))
601 : 6 : Node a = nm->mkNode(Kind::BAG_ALL, n[0], n[1][0]);
602 : 6 : Node b = nm->mkNode(Kind::BAG_ALL, n[0], n[1][1]);
603 : 3 : Node ret = a.andNode(b);
604 : 3 : return BagsRewriteResponse(ret, Rewrite::ALL_UNION_DISJOINT);
605 : 3 : }
606 : 16 : default:
607 : : {
608 : : // (bag.all p A) is rewritten as (bag.filter p A) = A
609 : 32 : Node filter = nm->mkNode(Kind::BAG_FILTER, n[0], n[1]);
610 : 16 : Node all = filter.eqNode(n[1]);
611 : 16 : return BagsRewriteResponse(all, Rewrite::ALL_FILTER);
612 : 16 : }
613 : : }
614 : : }
615 : :
616 : 5 : BagsRewriteResponse BagsRewriter::postRewriteSome(TNode n)
617 : : {
618 [ - + ][ - + ]: 5 : Assert(n.getKind() == Kind::BAG_SOME);
[ - - ]
619 : 5 : NodeManager* nm = nodeManager();
620 : 5 : Kind k = n[1].getKind();
621 [ + + ][ - + ]: 5 : switch (k)
622 : : {
623 : 1 : case Kind::BAG_EMPTY:
624 : : {
625 : : // (bag.some p (as bag.empty (Set T)) = false)
626 : 2 : return BagsRewriteResponse(nm->mkConst(false), Rewrite::SOME_EMPTY);
627 : : }
628 : 2 : case Kind::BAG_MAKE:
629 : : {
630 : : // (bag.some p (bag x n)) = (and (> n 0) (p x))
631 : 4 : Node px = nm->mkNode(Kind::APPLY_UF, n[0], n[1][0]);
632 : 4 : Node leq = nm->mkNode(Kind::GT, n[1][1], d_zero);
633 : 2 : Node ret = px.andNode(leq);
634 : 2 : return BagsRewriteResponse(ret, Rewrite::SOME_BAG_MAKE);
635 : 2 : }
636 : 0 : case Kind::BAG_UNION_DISJOINT:
637 : : {
638 : : // (bag.some p (bag.union_disjoint A B)) =
639 : : // (or (bag.some p A) (bag.union_disjoint p B))
640 : 0 : Node a = nm->mkNode(Kind::BAG_SOME, n[0], n[1][0]);
641 : 0 : Node b = nm->mkNode(Kind::BAG_SOME, n[0], n[1][1]);
642 : 0 : Node ret = a.orNode(b);
643 : 0 : return BagsRewriteResponse(ret, Rewrite::SOME_UNION_DISJOINT);
644 : 0 : }
645 : 2 : default:
646 : : {
647 : : // (bag.some p A) is rewritten as (distinct (bag.filter p A) bag.empty))
648 : 4 : Node filter = nm->mkNode(Kind::BAG_FILTER, n[0], n[1]);
649 : 4 : Node empty = nm->mkConst(EmptyBag(n[1].getType()));
650 : 2 : Node some = filter.eqNode(empty).notNode();
651 : 2 : return BagsRewriteResponse(some, Rewrite::SOME_FILTER);
652 : 2 : }
653 : : }
654 : : }
655 : :
656 : 39 : BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
657 : : {
658 [ - + ][ - + ]: 39 : Assert(n.getKind() == Kind::BAG_FOLD);
[ - - ]
659 : 39 : Node f = n[0];
660 : 39 : Node t = n[1];
661 : 39 : Node bag = n[2];
662 [ + + ]: 39 : if (bag.isConst())
663 : : {
664 : 9 : Node value = BagsUtils::evaluateBagFold(n);
665 : 9 : return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
666 : 9 : }
667 : 30 : Kind k = bag.getKind();
668 [ + + ][ + ]: 30 : switch (k)
669 : : {
670 : 1 : case Kind::BAG_MAKE:
671 : : {
672 : 1 : if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
673 : : {
674 : : // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
675 : 1 : Node value = BagsUtils::evaluateBagFold(n);
676 : 1 : return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
677 : 1 : }
678 : 0 : break;
679 : : }
680 : 1 : case Kind::BAG_UNION_DISJOINT:
681 : : {
682 : : // (bag.fold f t (bag.union_disjoint A B)) =
683 : : // (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
684 [ + - ]: 2 : Node A = bag[0] < bag[1] ? bag[0] : bag[1];
685 [ + - ]: 2 : Node B = bag[0] < bag[1] ? bag[1] : bag[0];
686 : 2 : Node foldA = d_nm->mkNode(Kind::BAG_FOLD, f, t, A);
687 : 2 : Node fold = d_nm->mkNode(Kind::BAG_FOLD, f, foldA, B);
688 : 1 : return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT);
689 : 1 : }
690 : 28 : default: return BagsRewriteResponse(n, Rewrite::NONE);
691 : : }
692 : 0 : return BagsRewriteResponse(n, Rewrite::NONE);
693 : 39 : }
694 : :
695 : 8 : BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
696 : : {
697 [ - + ][ - + ]: 8 : Assert(n.getKind() == Kind::BAG_PARTITION);
[ - - ]
698 [ + + ]: 8 : if (n[1].isConst())
699 : : {
700 : 4 : Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n);
701 [ + - ]: 4 : if (ret != n)
702 : : {
703 : 4 : return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST);
704 : : }
705 [ - + ]: 4 : }
706 : :
707 : 4 : return BagsRewriteResponse(n, Rewrite::NONE);
708 : : }
709 : :
710 : 2 : BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const
711 : : {
712 [ - + ][ - + ]: 2 : Assert(n.getKind() == Kind::TABLE_AGGREGATE);
[ - - ]
713 : 2 : if (n[1].isConst() && n[2].isConst())
714 : : {
715 : 2 : Node ret = BagsUtils::evaluateTableAggregate(n);
716 [ + - ]: 2 : if (ret != n)
717 : : {
718 : 2 : return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST);
719 : : }
720 [ - + ]: 2 : }
721 : :
722 : 0 : return BagsRewriteResponse(n, Rewrite::NONE);
723 : : }
724 : :
725 : 38 : BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
726 : : {
727 [ - + ][ - + ]: 38 : Assert(n.getKind() == Kind::TABLE_PRODUCT);
[ - - ]
728 : 38 : TypeNode tableType = n.getType();
729 : 38 : Node empty = d_nm->mkConst(EmptyBag(tableType));
730 : 38 : if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
731 : : {
732 : 2 : return BagsRewriteResponse(empty, Rewrite::PRODUCT_EMPTY);
733 : : }
734 : :
735 : 36 : return BagsRewriteResponse(n, Rewrite::NONE);
736 : 38 : }
737 : :
738 : : } // namespace bags
739 : : } // namespace theory
740 : : } // namespace cvc5::internal
|