Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds, Aina Niemetz, Andres Noetzli
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Arithmetic utilities regarding monomial sums.
14 : : */
15 : :
16 : : #include "theory/arith/arith_msum.h"
17 : :
18 : : #include "theory/rewriter.h"
19 : : #include "util/rational.h"
20 : :
21 : : using namespace cvc5::internal::kind;
22 : :
23 : : namespace cvc5::internal {
24 : : namespace theory {
25 : :
26 : 978 : bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
27 : : {
28 [ + - ][ + - ]: 978 : if (n.getKind() == Kind::MULT && n.getNumChildren() == 2 && n[0].isConst())
[ + - ][ + - ]
[ + - ][ - - ]
29 : : {
30 : 978 : c = n[0];
31 : 978 : v = n[1];
32 : 978 : return true;
33 : : }
34 : 0 : return false;
35 : : }
36 : :
37 : 2001730 : bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
38 : : {
39 [ + + ]: 2001730 : if (n.isConst())
40 : : {
41 [ + - ]: 533831 : if (msum.find(Node::null()) == msum.end())
42 : : {
43 : 533831 : msum[Node::null()] = n;
44 : 533831 : return true;
45 : : }
46 : : }
47 [ + - ]: 559003 : else if (n.getKind() == Kind::MULT && n.getNumChildren() == 2
48 [ + + ][ + - ]: 2026900 : && n[0].isConst())
[ + + ][ + + ]
[ - - ]
49 : : {
50 [ + - ]: 559003 : if (msum.find(n[1]) == msum.end())
51 : : {
52 : 559003 : msum[n[1]] = n[0];
53 : 559003 : return true;
54 : : }
55 : : }
56 : : else
57 : : {
58 [ + - ]: 908899 : if (msum.find(n) == msum.end())
59 : : {
60 : 908899 : msum[n] = Node::null();
61 : 908899 : return true;
62 : : }
63 : : }
64 : 0 : return false;
65 : : }
66 : :
67 : 1219610 : bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
68 : : {
69 [ + + ]: 1219610 : if (n.getKind() == Kind::ADD)
70 : : {
71 [ + + ]: 1839630 : for (Node nc : n)
72 : : {
73 [ - + ]: 1310880 : if (!getMonomial(nc, msum))
74 : : {
75 : 0 : return false;
76 : : }
77 : : }
78 : 528757 : return true;
79 : : }
80 : 690856 : return getMonomial(n, msum);
81 : : }
82 : :
83 : 443029 : bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
84 : : {
85 : 1329090 : if (lit.getKind() == Kind::GEQ
86 : 443029 : || (lit.getKind() == Kind::EQUAL && lit[0].getType().isRealOrInt()))
87 : : {
88 [ + - ]: 364755 : if (getMonomialSum(lit[0], msum))
89 : : {
90 : 364755 : if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
91 : : {
92 : 75844 : return true;
93 : : }
94 : : else
95 : : {
96 : : // subtract the other side
97 : 288911 : std::map<Node, Node> msum2;
98 : 288911 : NodeManager* nm = NodeManager::currentNM();
99 [ + - ]: 288911 : if (getMonomialSum(lit[1], msum2))
100 : : {
101 : 608881 : for (std::map<Node, Node>::iterator it = msum2.begin();
102 [ + + ]: 608881 : it != msum2.end();
103 : 319970 : ++it)
104 : : {
105 : 319970 : std::map<Node, Node>::iterator it2 = msum.find(it->first);
106 [ + + ]: 319970 : if (it2 != msum.end())
107 : : {
108 : 130 : Rational r1 = it2->second.isNull()
109 : : ? Rational(1)
110 [ + + ]: 260 : : it2->second.getConst<Rational>();
111 : 130 : Rational r2 = it->second.isNull()
112 : : ? Rational(1)
113 [ + + ]: 130 : : it->second.getConst<Rational>();
114 : 130 : msum[it->first] = nm->mkConstRealOrInt(r1 - r2);
115 : : }
116 : : else
117 : : {
118 : 639680 : msum[it->first] = it->second.isNull()
119 [ + + ][ - - ]: 1113710 : ? nm->mkConstInt(Rational(-1))
120 : : : nm->mkConstRealOrInt(
121 [ + + ][ + + ]: 793868 : -it->second.getConst<Rational>());
[ - - ]
122 : : }
123 : : }
124 : 288911 : return true;
125 : : }
126 : : }
127 : : }
128 : : }
129 : 78274 : return false;
130 : : }
131 : :
132 : 1779 : Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
133 : : {
134 : 1779 : NodeManager* nm = NodeManager::currentNM();
135 : 1779 : std::vector<Node> children;
136 [ + + ]: 3558 : for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
137 : 1779 : ++it)
138 : : {
139 : 3558 : Node m;
140 [ + + ]: 1779 : if (!it->first.isNull())
141 : : {
142 : 921 : m = mkCoeffTerm(it->second, it->first);
143 : : }
144 : : else
145 : : {
146 [ - + ][ - + ]: 858 : Assert(!it->second.isNull());
[ - - ]
147 : 858 : m = it->second;
148 : : }
149 : 1779 : children.push_back(m);
150 : : }
151 : 1779 : return children.size() > 1
152 : : ? nm->mkNode(Kind::ADD, children)
153 : 3558 : : (children.size() == 1 ? children[0]
154 [ - + ][ + - ]: 8895 : : nm->mkConstInt(Rational(0)));
[ - + ][ - - ]
155 : : }
156 : :
157 : 245998 : int ArithMSum::isolate(
158 : : Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
159 : : {
160 [ - + ][ - + ]: 245998 : Assert(veq_c.isNull());
[ - - ]
161 : 245998 : std::map<Node, Node>::const_iterator itv = msum.find(v);
162 [ + + ]: 245998 : if (itv != msum.end())
163 : : {
164 : 239495 : NodeManager* nm = NodeManager::currentNM();
165 : 239495 : std::vector<Node> children;
166 : : Rational r =
167 [ + + ]: 239495 : itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
168 [ + + ]: 239495 : if (r.sgn() != 0)
169 : : {
170 : 239493 : TypeNode vtn = v.getType();
171 : 814878 : for (std::map<Node, Node>::const_iterator it = msum.begin();
172 [ + + ]: 814878 : it != msum.end();
173 : 575385 : ++it)
174 : : {
175 [ + + ]: 575385 : if (it->first != v)
176 : : {
177 : 671784 : Node m;
178 [ + + ]: 335892 : if (!it->first.isNull())
179 : : {
180 : 243265 : m = mkCoeffTerm(it->second, it->first);
181 : : }
182 : : else
183 : : {
184 : 92627 : m = it->second;
185 : : }
186 : 335892 : children.push_back(m);
187 : : }
188 : : }
189 : 239493 : val = children.size() > 1
190 [ + + ][ + + ]: 822756 : ? nm->mkNode(Kind::ADD, children)
191 : 318663 : : (children.size() == 1 ? children[0]
192 [ + + ][ - - ]: 504093 : : nm->mkConstInt(Rational(0)));
193 [ + + ][ + + ]: 239493 : if (!r.isOne() && !r.isNegativeOne())
[ + + ]
194 : : {
195 [ + + ]: 12677 : if (vtn.isInteger())
196 : : {
197 : 6460 : veq_c = nm->mkConstRealOrInt(r.abs());
198 : : }
199 : : else
200 : : {
201 : 12434 : val = nm->mkNode(
202 : 18651 : Kind::MULT, val, nm->mkConstReal(Rational(1) / r.abs()));
203 : : }
204 : : }
205 : 810448 : val = r.sgn() == 1 ? nm->mkNode(
206 : 405224 : Kind::MULT, nm->mkConstRealOrInt(Rational(-1)), val)
207 : 239493 : : val;
208 [ + + ][ + + ]: 239493 : return (r.sgn() == 1 || k == Kind::EQUAL) ? 1 : -1;
209 : : }
210 : : }
211 : 6505 : return 0;
212 : : }
213 : :
214 : 14011 : int ArithMSum::isolate(
215 : : Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
216 : : {
217 : 28022 : Node veq_c;
218 : 28022 : Node val;
219 : : // isolate v in the (in)equality
220 : 14011 : int ires = isolate(v, msum, veq_c, val, k);
221 [ + + ]: 14011 : if (ires != 0)
222 : : {
223 : 13972 : NodeManager* nm = NodeManager::currentNM();
224 : 13972 : Node vc = v;
225 [ + + ]: 13972 : if (!veq_c.isNull())
226 : : {
227 [ + + ]: 147 : if (doCoeff)
228 : : {
229 : 119 : vc = nm->mkNode(Kind::MULT, veq_c, vc);
230 : : }
231 : : else
232 : : {
233 : 28 : return 0;
234 : : }
235 : : }
236 : 13944 : bool inOrder = ires == 1;
237 : : // ensure type is correct for equality
238 [ + + ]: 13944 : if (k == Kind::EQUAL)
239 : : {
240 : 5827 : bool vci = vc.getType().isInteger();
241 : 5827 : bool vi = val.getType().isInteger();
242 [ + + ][ + + ]: 5827 : if (!vci && vi)
243 : : {
244 : 12 : val = nm->mkNode(Kind::TO_REAL, val);
245 : : }
246 [ + + ][ + + ]: 5815 : else if (vci && !vi)
247 : : {
248 : 4 : val = nm->mkNode(Kind::TO_INTEGER, val);
249 : : }
250 [ - + ][ - - ]: 11654 : Assert(val.getType() == vc.getType())
251 : 5827 : << val << " " << vc << " " << val.getType() << " " << vc.getType();
252 : : }
253 [ + + ][ + + ]: 13944 : veq = nm->mkNode(k, inOrder ? vc : val, inOrder ? val : vc);
254 : : }
255 : 13983 : return ires;
256 : : }
257 : :
258 : 140 : Node ArithMSum::solveEqualityFor(Node lit, Node v)
259 : : {
260 [ - + ][ - + ]: 140 : Assert(lit.getKind() == Kind::EQUAL);
[ - - ]
261 : : // first look directly at sides
262 : 280 : TypeNode tn = lit[0].getType();
263 [ + + ]: 256 : for (unsigned r = 0; r < 2; r++)
264 : : {
265 [ + + ]: 198 : if (lit[r] == v)
266 : : {
267 : 82 : return lit[1 - r];
268 : : }
269 : : }
270 [ + - ]: 58 : if (tn.isRealOrInt())
271 : : {
272 : 58 : std::map<Node, Node> msum;
273 [ + - ]: 58 : if (ArithMSum::getMonomialSumLit(lit, msum))
274 : : {
275 : 58 : Node val, veqc;
276 [ + + ]: 58 : if (ArithMSum::isolate(v, msum, veqc, val, Kind::EQUAL) != 0)
277 : : {
278 [ + - ]: 50 : if (veqc.isNull())
279 : : {
280 : : // in this case, we have an integer equality with a coefficient
281 : : // on the variable we solved for that could not be eliminated,
282 : : // hence we fail.
283 : 50 : return val;
284 : : }
285 : : }
286 : : }
287 : : }
288 : 8 : return Node::null();
289 : : }
290 : :
291 : 0 : bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
292 : : {
293 : 0 : std::map<Node, Node> msum;
294 [ - - ]: 0 : if (getMonomialSum(n, msum))
295 : : {
296 : 0 : std::map<Node, Node>::iterator it = msum.find(v);
297 [ - - ]: 0 : if (it == msum.end())
298 : : {
299 : 0 : return false;
300 : : }
301 : : else
302 : : {
303 : 0 : coeff = it->second;
304 : 0 : msum.erase(v);
305 : 0 : rem = mkNode(msum);
306 : 0 : return true;
307 : : }
308 : : }
309 : 0 : return false;
310 : : }
311 : :
312 : 5420 : void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
313 : : {
314 [ + + ]: 16474 : for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
315 : : {
316 [ + - ]: 11054 : Trace(c) << " ";
317 [ + + ]: 11054 : if (!it->second.isNull())
318 : : {
319 [ + - ]: 4977 : Trace(c) << it->second;
320 [ + + ]: 4977 : if (!it->first.isNull())
321 : : {
322 [ + - ]: 2778 : Trace(c) << " * ";
323 : : }
324 : : }
325 [ + + ]: 11054 : if (!it->first.isNull())
326 : : {
327 [ + - ]: 8855 : Trace(c) << it->first;
328 : : }
329 [ + - ]: 11054 : Trace(c) << std::endl;
330 : : }
331 [ + - ]: 5420 : Trace(c) << std::endl;
332 : 5420 : }
333 : :
334 : : } // namespace theory
335 : : } // namespace cvc5::internal
|