Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * bag reduction.
11 : : */
12 : :
13 : : #include "theory/bags/bag_reduction.h"
14 : :
15 : : #include "expr/bound_var_manager.h"
16 : : #include "expr/emptybag.h"
17 : : #include "expr/skolem_manager.h"
18 : : #include "theory/datatypes/project_op.h"
19 : : #include "theory/datatypes/tuple_utils.h"
20 : : #include "theory/quantifiers/fmf/bounded_integers.h"
21 : : #include "util/rational.h"
22 : :
23 : : using namespace cvc5::internal;
24 : : using namespace cvc5::internal::kind;
25 : :
26 : : namespace cvc5::internal {
27 : : namespace theory {
28 : : namespace bags {
29 : :
30 : 0 : BagReduction::BagReduction() {}
31 : :
32 : 0 : BagReduction::~BagReduction() {}
33 : :
34 : 8 : Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
35 : : {
36 [ - + ][ - + ]: 8 : Assert(node.getKind() == Kind::BAG_FOLD);
[ - - ]
37 : 8 : NodeManager* nm = node.getNodeManager();
38 : 8 : SkolemManager* sm = nm->getSkolemManager();
39 : 8 : Node f = node[0];
40 : 8 : Node t = node[1];
41 : 8 : Node A = node[2];
42 : 8 : Node zero = nm->mkConstInt(Rational(0));
43 : 8 : Node one = nm->mkConstInt(Rational(1));
44 : : // skolem functions
45 : 8 : Node n = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_CARD, A);
46 : 8 : Node elements = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_ELEMENTS, A);
47 : : Node unionDisjoint =
48 : 8 : sm->mkSkolemFunction(SkolemId::BAGS_FOLD_UNION_DISJOINT, A);
49 : 40 : Node combine = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_COMBINE, {f, t, A});
50 : :
51 : 8 : BoundVarManager* bvm = nm->getBoundVarManager();
52 : : Node i = bvm->mkBoundVar(
53 : 16 : BoundVarId::BAGS_FIRST_INDEX, node, "i", nm->integerType());
54 : 8 : Node iList = nm->mkNode(Kind::BOUND_VAR_LIST, i);
55 : 16 : Node iMinusOne = nm->mkNode(Kind::SUB, i, one);
56 : 16 : Node elements_i = nm->mkNode(Kind::APPLY_UF, elements, i);
57 : 16 : Node combine_0 = nm->mkNode(Kind::APPLY_UF, combine, zero);
58 : 16 : Node combine_iMinusOne = nm->mkNode(Kind::APPLY_UF, combine, iMinusOne);
59 : 16 : Node combine_i = nm->mkNode(Kind::APPLY_UF, combine, i);
60 : 16 : Node combine_n = nm->mkNode(Kind::APPLY_UF, combine, n);
61 : 16 : Node unionDisjoint_0 = nm->mkNode(Kind::APPLY_UF, unionDisjoint, zero);
62 : : Node unionDisjoint_iMinusOne =
63 : 16 : nm->mkNode(Kind::APPLY_UF, unionDisjoint, iMinusOne);
64 : 16 : Node unionDisjoint_i = nm->mkNode(Kind::APPLY_UF, unionDisjoint, i);
65 : 16 : Node unionDisjoint_n = nm->mkNode(Kind::APPLY_UF, unionDisjoint, n);
66 : 8 : Node combine_0_equal = combine_0.eqNode(t);
67 : : Node combine_i_equal = combine_i.eqNode(
68 : 16 : nm->mkNode(Kind::APPLY_UF, f, elements_i, combine_iMinusOne));
69 : : Node unionDisjoint_0_equal =
70 : 16 : unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(A.getType())));
71 : 16 : Node singleton = nm->mkNode(Kind::BAG_MAKE, elements_i, one);
72 : :
73 : : Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
74 : 16 : nm->mkNode(Kind::BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
75 : : Node interval_i = nm->mkNode(
76 : 16 : Kind::AND, nm->mkNode(Kind::GEQ, i, one), nm->mkNode(Kind::LEQ, i, n));
77 : :
78 : : Node body_i =
79 : : nm->mkNode(Kind::IMPLIES,
80 : : interval_i,
81 : 16 : nm->mkNode(Kind::AND, combine_i_equal, unionDisjoint_i_equal));
82 : : Node forAll_i =
83 : 16 : quantifiers::BoundedIntegers::mkBoundedForall(nm, iList, body_i);
84 : 16 : Node nonNegative = nm->mkNode(Kind::GEQ, n, zero);
85 : 8 : Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
86 : 8 : asserts.push_back(forAll_i);
87 : 8 : asserts.push_back(combine_0_equal);
88 : 8 : asserts.push_back(unionDisjoint_0_equal);
89 : 8 : asserts.push_back(unionDisjoint_n_equal);
90 : 8 : asserts.push_back(nonNegative);
91 : 16 : return combine_n;
92 : 8 : }
93 : :
94 : 33 : Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
95 : : {
96 [ - + ][ - + ]: 33 : Assert(node.getKind() == Kind::BAG_CARD);
[ - - ]
97 : 33 : NodeManager* nm = node.getNodeManager();
98 : 33 : SkolemManager* sm = nm->getSkolemManager();
99 : 33 : Node A = node[0];
100 : 33 : Node zero = nm->mkConstInt(Rational(0));
101 : 33 : Node one = nm->mkConstInt(Rational(1));
102 : : // types
103 : 33 : TypeNode bagType = A.getType();
104 : : // skolem functions
105 : 33 : Node n = sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS_SIZE, A);
106 : 33 : Node elements = sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS, A);
107 : : Node unionDisjoint =
108 : 33 : sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS_UNION_DISJOINT, A);
109 : 33 : Node combine = sm->mkSkolemFunction(SkolemId::BAGS_CARD_COMBINE, A);
110 : :
111 : 33 : BoundVarManager* bvm = nm->getBoundVarManager();
112 : : Node i = bvm->mkBoundVar(
113 : 66 : BoundVarId::BAGS_FIRST_INDEX, node, "i", nm->integerType());
114 : : Node j = bvm->mkBoundVar(
115 : 66 : BoundVarId::BAGS_SECOND_INDEX, node, "j", nm->integerType());
116 : 33 : Node iList = nm->mkNode(Kind::BOUND_VAR_LIST, i);
117 : 33 : Node jList = nm->mkNode(Kind::BOUND_VAR_LIST, j);
118 : 66 : Node iMinusOne = nm->mkNode(Kind::SUB, i, one);
119 : 66 : Node elements_i = nm->mkNode(Kind::APPLY_UF, elements, i);
120 : 66 : Node elements_j = nm->mkNode(Kind::APPLY_UF, elements, j);
121 : 66 : Node combine_0 = nm->mkNode(Kind::APPLY_UF, combine, zero);
122 : 66 : Node combine_iMinusOne = nm->mkNode(Kind::APPLY_UF, combine, iMinusOne);
123 : 66 : Node combine_i = nm->mkNode(Kind::APPLY_UF, combine, i);
124 : 66 : Node combine_n = nm->mkNode(Kind::APPLY_UF, combine, n);
125 : 66 : Node unionDisjoint_0 = nm->mkNode(Kind::APPLY_UF, unionDisjoint, zero);
126 : : Node unionDisjoint_iMinusOne =
127 : 66 : nm->mkNode(Kind::APPLY_UF, unionDisjoint, iMinusOne);
128 : 66 : Node unionDisjoint_i = nm->mkNode(Kind::APPLY_UF, unionDisjoint, i);
129 : 66 : Node unionDisjoint_n = nm->mkNode(Kind::APPLY_UF, unionDisjoint, n);
130 : 33 : Node combine_0_equal = combine_0.eqNode(zero);
131 : 66 : Node elements_i_multiplicity = nm->mkNode(Kind::BAG_COUNT, elements_i, A);
132 : : Node combine_i_equal = combine_i.eqNode(
133 : 66 : nm->mkNode(Kind::ADD, elements_i_multiplicity, combine_iMinusOne));
134 : : Node unionDisjoint_0_equal =
135 : 66 : unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
136 : 66 : Node bag = nm->mkNode(Kind::BAG_MAKE, elements_i, elements_i_multiplicity);
137 : :
138 : : Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
139 : 66 : nm->mkNode(Kind::BAG_UNION_DISJOINT, bag, unionDisjoint_iMinusOne));
140 : : // 1 <= i <= n
141 : : Node interval_i = nm->mkNode(
142 : 66 : Kind::AND, nm->mkNode(Kind::GEQ, i, one), nm->mkNode(Kind::LEQ, i, n));
143 : :
144 : : // i < j <= n
145 : : Node interval_j = nm->mkNode(
146 : 66 : Kind::AND, nm->mkNode(Kind::LT, i, j), nm->mkNode(Kind::LEQ, j, n));
147 : : // elements(i) != elements(j)
148 : : Node elements_i_equals_elements_j =
149 : 66 : nm->mkNode(Kind::EQUAL, elements_i, elements_j);
150 : 66 : Node notEqual = nm->mkNode(Kind::EQUAL, elements_i, elements_j).negate();
151 : 66 : Node body_j = nm->mkNode(Kind::OR, interval_j.negate(), notEqual);
152 : : Node forAll_j =
153 : 66 : quantifiers::BoundedIntegers::mkBoundedForall(nm, jList, body_j);
154 : : Node body_i = nm->mkNode(
155 : : Kind::IMPLIES,
156 : : interval_i,
157 : 66 : nm->mkNode(Kind::AND, combine_i_equal, unionDisjoint_i_equal, forAll_j));
158 : : Node forAll_i =
159 : 66 : quantifiers::BoundedIntegers::mkBoundedForall(nm, iList, body_i);
160 : 66 : Node nonNegative = nm->mkNode(Kind::GEQ, n, zero);
161 : 33 : Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
162 : 33 : asserts.push_back(forAll_i);
163 : 33 : asserts.push_back(combine_0_equal);
164 : 33 : asserts.push_back(unionDisjoint_0_equal);
165 : 33 : asserts.push_back(unionDisjoint_n_equal);
166 : 33 : asserts.push_back(nonNegative);
167 : 66 : return combine_n;
168 : 33 : }
169 : :
170 : 2 : Node BagReduction::reduceAggregateOperator(Node node)
171 : : {
172 [ - + ][ - + ]: 2 : Assert(node.getKind() == Kind::TABLE_AGGREGATE);
[ - - ]
173 : 2 : NodeManager* nm = node.getNodeManager();
174 : 2 : BoundVarManager* bvm = nm->getBoundVarManager();
175 : 2 : Node function = node[0];
176 : 4 : TypeNode elementType = function.getType().getArgTypes()[0];
177 : 2 : Node initialValue = node[1];
178 : 2 : Node A = node[2];
179 : 2 : ProjectOp op = node.getOperator().getConst<ProjectOp>();
180 : :
181 : 2 : Node groupOp = nm->mkConst(Kind::TABLE_GROUP_OP, op);
182 : 6 : Node group = nm->mkNode(Kind::TABLE_GROUP, {groupOp, A});
183 : :
184 : : Node bag = bvm->mkBoundVar(
185 : 4 : BoundVarId::BAGS_FIRST_INDEX, group, "bag", nm->mkBagType(elementType));
186 : 2 : Node foldList = nm->mkNode(Kind::BOUND_VAR_LIST, bag);
187 : 4 : Node foldBody = nm->mkNode(Kind::BAG_FOLD, function, initialValue, bag);
188 : :
189 : 4 : Node fold = nm->mkNode(Kind::LAMBDA, foldList, foldBody);
190 : 4 : Node map = nm->mkNode(Kind::BAG_MAP, fold, group);
191 : 4 : return map;
192 : 2 : }
193 : :
194 : 25 : Node BagReduction::reduceProjectOperator(Node n)
195 : : {
196 [ - + ][ - + ]: 25 : Assert(n.getKind() == Kind::TABLE_PROJECT);
[ - - ]
197 : 25 : NodeManager* nm = n.getNodeManager();
198 : 25 : Node A = n[0];
199 : 25 : TypeNode elementType = A.getType().getBagElementType();
200 : 25 : ProjectOp projectOp = n.getOperator().getConst<ProjectOp>();
201 : 25 : Node op = nm->mkConst(Kind::TUPLE_PROJECT_OP, projectOp);
202 : 50 : Node t = NodeManager::mkBoundVar("t", elementType);
203 : 50 : Node projection = nm->mkNode(Kind::TUPLE_PROJECT, op, t);
204 : : Node lambda =
205 : 50 : nm->mkNode(Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, t), projection);
206 : 50 : Node setMap = nm->mkNode(Kind::BAG_MAP, lambda, A);
207 : 50 : return setMap;
208 : 25 : }
209 : :
210 : : } // namespace bags
211 : : } // namespace theory
212 : : } // namespace cvc5::internal
|