Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds, Aina Niemetz, Mathias Preiner
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Implementation of substitution minimization.
14 : : */
15 : :
16 : : #include "theory/subs_minimize.h"
17 : :
18 : : #include "expr/node_algorithm.h"
19 : : #include "theory/bv/theory_bv_utils.h"
20 : : #include "theory/rewriter.h"
21 : : #include "theory/strings/word.h"
22 : : #include "util/rational.h"
23 : :
24 : : using namespace std;
25 : : using namespace cvc5::internal::kind;
26 : :
27 : : namespace cvc5::internal {
28 : : namespace theory {
29 : :
30 : 202 : SubstitutionMinimize::SubstitutionMinimize(Env& env) : EnvObj(env) {}
31 : :
32 : 199 : bool SubstitutionMinimize::find(Node t,
33 : : Node target,
34 : : const std::vector<Node>& vars,
35 : : const std::vector<Node>& subs,
36 : : std::vector<Node>& reqVars)
37 : : {
38 : 199 : return findInternal(t, target, vars, subs, reqVars);
39 : : }
40 : :
41 : 26 : void getConjuncts(Node n, std::vector<Node>& conj)
42 : : {
43 [ + + ]: 26 : if (n.getKind() == Kind::AND)
44 : : {
45 [ + + ]: 24 : for (const Node& nc : n)
46 : : {
47 : 19 : conj.push_back(nc);
48 : : }
49 : : }
50 : : else
51 : : {
52 : 21 : conj.push_back(n);
53 : : }
54 : 26 : }
55 : :
56 : 3 : bool SubstitutionMinimize::findWithImplied(Node t,
57 : : const std::vector<Node>& vars,
58 : : const std::vector<Node>& subs,
59 : : std::vector<Node>& reqVars,
60 : : std::vector<Node>& impliedVars)
61 : : {
62 : 3 : NodeManager* nm = nodeManager();
63 : 6 : Node truen = nm->mkConst(true);
64 [ - + ]: 3 : if (!findInternal(t, truen, vars, subs, reqVars))
65 : : {
66 : 0 : return false;
67 : : }
68 [ - + ]: 3 : if (reqVars.empty())
69 : : {
70 : 0 : return true;
71 : : }
72 : :
73 : : // map from conjuncts of t to whether they may be used to show an implied var
74 : 6 : std::vector<Node> tconj;
75 : 3 : getConjuncts(t, tconj);
76 : : // map from conjuncts to their free symbols
77 : 6 : std::map<Node, std::unordered_set<Node> > tcFv;
78 : :
79 : 6 : std::unordered_set<Node> reqSet;
80 : 6 : std::vector<Node> reqSubs;
81 : 6 : std::map<Node, unsigned> reqVarToIndex;
82 [ + + ]: 10 : for (const Node& v : reqVars)
83 : : {
84 : 7 : reqVarToIndex[v] = reqSubs.size();
85 : : const std::vector<Node>::const_iterator& it =
86 : 7 : std::find(vars.begin(), vars.end(), v);
87 [ - + ][ - + ]: 7 : Assert(it != vars.end());
[ - - ]
88 : 7 : ptrdiff_t pos = std::distance(vars.begin(), it);
89 : 7 : reqSubs.push_back(subs[pos]);
90 : : }
91 : 3 : std::vector<Node> finalReqVars;
92 [ + + ]: 10 : for (const Node& v : vars)
93 : : {
94 [ - + ]: 7 : if (reqVarToIndex.find(v) == reqVarToIndex.end())
95 : : {
96 : : // not a required variable, nothing to do
97 : 0 : continue;
98 : : }
99 : 7 : unsigned vindex = reqVarToIndex[v];
100 : 14 : Node prev = reqSubs[vindex];
101 : : // make identity substitution
102 : 7 : reqSubs[vindex] = v;
103 : 7 : bool madeImplied = false;
104 : : // it is a required variable, can we make an implied variable?
105 [ + + ]: 50 : for (const Node& tc : tconj)
106 : : {
107 : : // ensure we've computed its free symbols
108 : 43 : std::map<Node, std::unordered_set<Node> >::iterator itf = tcFv.find(tc);
109 [ + + ]: 43 : if (itf == tcFv.end())
110 : : {
111 : 11 : expr::getSymbols(tc, tcFv[tc]);
112 : 11 : itf = tcFv.find(tc);
113 : : }
114 : : // only have a chance if contains v
115 [ + + ]: 43 : if (itf->second.find(v) == itf->second.end())
116 : : {
117 : 20 : continue;
118 : : }
119 : : // try the current substitution
120 : : Node tcs = tc.substitute(
121 : 23 : reqVars.begin(), reqVars.end(), reqSubs.begin(), reqSubs.end());
122 : 23 : Node tcsr = rewrite(tcs);
123 : 23 : std::vector<Node> tcsrConj;
124 : 23 : getConjuncts(tcsr, tcsrConj);
125 [ + + ]: 48 : for (const Node& tcc : tcsrConj)
126 : : {
127 [ - + ]: 25 : if (tcc.getKind() == Kind::EQUAL)
128 : : {
129 [ - - ]: 0 : for (unsigned r = 0; r < 2; r++)
130 : : {
131 [ - - ]: 0 : if (tcc[r] == v)
132 : : {
133 : 0 : Node res = tcc[1 - r];
134 [ - - ]: 0 : if (res.isConst())
135 : : {
136 : 0 : Assert(res == prev);
137 : 0 : madeImplied = true;
138 : 0 : break;
139 : : }
140 : : }
141 : : }
142 : : }
143 [ - + ]: 25 : if (madeImplied)
144 : : {
145 : 0 : break;
146 : : }
147 : : }
148 [ - + ]: 23 : if (madeImplied)
149 : : {
150 : 0 : break;
151 : : }
152 : : }
153 [ + - ]: 7 : if (!madeImplied)
154 : : {
155 : : // revert the substitution
156 : 7 : reqSubs[vindex] = prev;
157 : 7 : finalReqVars.push_back(v);
158 : : }
159 : : else
160 : : {
161 : 0 : impliedVars.push_back(v);
162 : : }
163 : : }
164 : 3 : reqVars.clear();
165 : 3 : reqVars.insert(reqVars.end(), finalReqVars.begin(), finalReqVars.end());
166 : :
167 : 3 : return true;
168 : : }
169 : :
170 : 202 : bool SubstitutionMinimize::findInternal(Node n,
171 : : Node target,
172 : : const std::vector<Node>& vars,
173 : : const std::vector<Node>& subs,
174 : : std::vector<Node>& reqVars)
175 : : {
176 [ + - ]: 202 : Trace("subs-min") << "Substitution minimize : " << std::endl;
177 [ + - ]: 404 : Trace("subs-min") << " substitution : " << vars << " -> " << subs
178 : 202 : << std::endl;
179 [ + - ]: 202 : Trace("subs-min") << " node : " << n << std::endl;
180 [ + - ]: 202 : Trace("subs-min") << " target : " << target << std::endl;
181 : :
182 [ + - ]: 202 : Trace("subs-min") << "--- Compute values for subterms..." << std::endl;
183 : : // the value of each subterm in n under the substitution
184 : 404 : std::unordered_map<TNode, Node> value;
185 : 202 : std::unordered_map<TNode, Node>::iterator it;
186 : 404 : std::vector<TNode> visit;
187 : 404 : TNode cur;
188 : 202 : visit.push_back(n);
189 : 1186 : do
190 : : {
191 : 1388 : cur = visit.back();
192 : 1388 : visit.pop_back();
193 : 1388 : it = value.find(cur);
194 : :
195 [ + + ]: 1388 : if (it == value.end())
196 : : {
197 [ + + ]: 875 : if (cur.isVar())
198 : : {
199 : : const std::vector<Node>::const_iterator& iit =
200 : 407 : std::find(vars.begin(), vars.end(), cur);
201 [ - + ]: 407 : if (iit == vars.end())
202 : : {
203 : 0 : value[cur] = cur;
204 : : }
205 : : else
206 : : {
207 : 407 : ptrdiff_t pos = std::distance(vars.begin(), iit);
208 : 407 : value[cur] = subs[pos];
209 : : }
210 : : }
211 : : else
212 : : {
213 : 468 : value[cur] = Node::null();
214 : 468 : visit.push_back(cur);
215 [ + + ]: 468 : if (cur.getKind() == Kind::APPLY_UF)
216 : : {
217 : 2 : visit.push_back(cur.getOperator());
218 : : }
219 : 468 : visit.insert(visit.end(), cur.begin(), cur.end());
220 : : }
221 : : }
222 [ + + ]: 513 : else if (it->second.isNull())
223 : : {
224 : 936 : Node ret = cur;
225 [ + + ]: 468 : if (cur.getNumChildren() > 0)
226 : : {
227 : 914 : std::vector<Node> children;
228 : 457 : NodeBuilder nb(nodeManager(), cur.getKind());
229 [ + + ]: 457 : if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
230 : : {
231 [ + - ]: 2 : if (cur.getKind() == Kind::APPLY_UF)
232 : : {
233 : 2 : children.push_back(cur.getOperator());
234 : : }
235 : : else
236 : : {
237 : 0 : nb << cur.getOperator();
238 : : }
239 : : }
240 : 457 : children.insert(children.end(), cur.begin(), cur.end());
241 [ + + ]: 1175 : for (const Node& cn : children)
242 : : {
243 : 718 : it = value.find(cn);
244 [ - + ][ - + ]: 718 : Assert(it != value.end());
[ - - ]
245 [ - + ][ - + ]: 718 : Assert(!it->second.isNull());
[ - - ]
246 : 718 : nb << it->second;
247 : : }
248 : 457 : ret = nb.constructNode();
249 : 457 : ret = rewrite(ret);
250 : : }
251 : 468 : value[cur] = ret;
252 : : }
253 [ + + ]: 1388 : } while (!visit.empty());
254 [ - + ][ - + ]: 202 : Assert(value.find(n) != value.end());
[ - - ]
255 [ - + ][ - + ]: 202 : Assert(!value.find(n)->second.isNull());
[ - - ]
256 : :
257 [ + - ][ - + ]: 202 : Trace("subs-min") << "... got " << value[n] << std::endl;
[ - - ]
258 [ + + ]: 202 : if (value[n] != target)
259 : : {
260 [ + - ]: 2 : Trace("subs-min") << "... not equal to target " << target << std::endl;
261 : : // depends on all variables
262 [ + + ]: 14 : for (const std::pair<const TNode, Node>& v : value)
263 : : {
264 [ + + ]: 12 : if (v.first.isVar())
265 : : {
266 : 2 : reqVars.push_back(v.first);
267 : : }
268 : : }
269 : 2 : return false;
270 : : }
271 : :
272 [ + - ]: 200 : Trace("subs-min") << "--- Compute relevant variables..." << std::endl;
273 : 400 : std::unordered_set<Node> rlvFv;
274 : : // only variables that occur in assertions are relevant
275 : :
276 : 200 : visit.push_back(n);
277 : 200 : std::unordered_set<TNode> visited;
278 : 200 : std::unordered_set<TNode>::iterator itv;
279 : 679 : do
280 : : {
281 : 879 : cur = visit.back();
282 : 879 : visit.pop_back();
283 : 879 : itv = visited.find(cur);
284 [ + + ]: 879 : if (itv == visited.end())
285 : : {
286 : 848 : visited.insert(cur);
287 : 848 : it = value.find(cur);
288 [ + + ]: 848 : if (it->second == cur)
289 : : {
290 : : // if its value is the same as current, there is nothing to do
291 : : }
292 [ + + ]: 840 : else if (cur.isVar())
293 : : {
294 : : // must include
295 : 401 : rlvFv.insert(cur);
296 : : }
297 [ - + ]: 439 : else if (cur.getKind() == Kind::ITE)
298 : : {
299 : : // only recurse on relevant branch
300 : 0 : Node bval = value[cur[0]];
301 : 0 : if (!bval.isNull() && bval.isConst())
302 : : {
303 [ - - ]: 0 : unsigned cindex = bval.getConst<bool>() ? 1 : 2;
304 : 0 : visit.push_back(cur[0]);
305 : 0 : visit.push_back(cur[cindex]);
306 : 0 : continue;
307 : : }
308 : : // otherwise, we handle it normally below
309 : : }
310 [ + + ]: 848 : if (cur.getNumChildren() > 0)
311 : : {
312 : 439 : Kind ck = cur.getKind();
313 : 439 : bool alreadyJustified = false;
314 : :
315 : : // if the operator is an apply uf, check its value
316 [ + + ]: 439 : if (cur.getKind() == Kind::APPLY_UF)
317 : : {
318 : 4 : Node op = cur.getOperator();
319 : 2 : it = value.find(op);
320 [ - + ][ - + ]: 2 : Assert(it != value.end());
[ - - ]
321 : 4 : TNode vop = it->second;
322 [ + - ]: 2 : if (vop.getKind() == Kind::LAMBDA)
323 : : {
324 : 2 : visit.push_back(op);
325 : : // do iterative partial evaluation on the body of the lambda
326 : 2 : Node curr = vop[1];
327 [ + + ]: 4 : for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
328 : : {
329 : 2 : it = value.find(cur[i]);
330 [ - + ][ - + ]: 2 : Assert(it != value.end());
[ - - ]
331 : 6 : Node scurr = curr.substitute(vop[0][i], it->second);
332 : : // if the valuation of the i^th argument changes the
333 : : // interpretation of the body of the lambda, then the i^th
334 : : // argument is relevant to the substitution. Hence, we add
335 : : // i to visit, and update curr below.
336 [ - + ]: 2 : if (scurr != curr)
337 : : {
338 : 0 : curr = rewrite(scurr);
339 : 0 : visit.push_back(cur[i]);
340 : : }
341 : : }
342 : 2 : alreadyJustified = true;
343 : : }
344 : : }
345 [ + + ]: 439 : if (!alreadyJustified)
346 : : {
347 : : // a subset of the arguments of cur that fully justify the evaluation
348 : 874 : std::vector<unsigned> justifyArgs;
349 [ + + ]: 437 : if (cur.getNumChildren() > 1)
350 : : {
351 [ + + ]: 717 : for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
352 : : {
353 : 481 : Node cn = cur[i];
354 : 481 : it = value.find(cn);
355 [ - + ][ - + ]: 481 : Assert(it != value.end());
[ - - ]
356 [ - + ][ - + ]: 481 : Assert(!it->second.isNull());
[ - - ]
357 [ + + ]: 481 : if (isSingularArg(it->second, ck, i))
358 : : {
359 : : // have we seen this argument already? if so, we are done
360 [ - + ]: 6 : if (visited.find(cn) != visited.end())
361 : : {
362 : 0 : alreadyJustified = true;
363 : 0 : break;
364 : : }
365 : 6 : justifyArgs.push_back(i);
366 : : }
367 : : }
368 : : }
369 : : // we need to recurse on at most one child
370 [ + - ][ + + ]: 437 : if (!alreadyJustified && !justifyArgs.empty())
[ + + ]
371 : : {
372 : 5 : unsigned sindex = justifyArgs[0];
373 : : // could choose a best index, for now, we just take the first
374 : 5 : visit.push_back(cur[sindex]);
375 : 5 : alreadyJustified = true;
376 : : }
377 : : }
378 [ + + ]: 439 : if (!alreadyJustified)
379 : : {
380 : : // must recurse on all arguments, including operator
381 [ - + ]: 432 : if (cur.getKind() == Kind::APPLY_UF)
382 : : {
383 : 0 : visit.push_back(cur.getOperator());
384 : : }
385 [ + + ]: 1104 : for (const Node& cn : cur)
386 : : {
387 : 672 : visit.push_back(cn);
388 : : }
389 : : }
390 : : }
391 : : }
392 [ + + ]: 879 : } while (!visit.empty());
393 : :
394 [ + + ]: 601 : for (const Node& v : rlvFv)
395 : : {
396 [ - + ][ - + ]: 401 : Assert(std::find(vars.begin(), vars.end(), v) != vars.end());
[ - - ]
397 : 401 : reqVars.push_back(v);
398 : : }
399 : :
400 [ + - ]: 400 : Trace("subs-min") << "... requires " << reqVars.size() << "/" << vars.size()
401 : 200 : << " : " << reqVars << std::endl;
402 : :
403 : 200 : return true;
404 : : }
405 : :
406 : 481 : bool SubstitutionMinimize::isSingularArg(Node n, Kind k, unsigned arg)
407 : : {
408 : : // Notice that this function is hardcoded. We could compute this function
409 : : // in a theory-independent way using partial evaluation. However, we
410 : : // prefer performance to generality here.
411 : :
412 : : // TODO: a variant of this code is implemented in quantifiers::TermUtil.
413 : : // These implementations should be merged (see #1216).
414 [ - + ]: 481 : if (!n.isConst())
415 : : {
416 : 0 : return false;
417 : : }
418 [ + + ]: 481 : if (k == Kind::AND)
419 : : {
420 : 27 : return !n.getConst<bool>();
421 : : }
422 [ + + ]: 454 : else if (k == Kind::OR)
423 : : {
424 : 10 : return n.getConst<bool>();
425 : : }
426 [ - + ]: 444 : else if (k == Kind::IMPLIES)
427 : : {
428 [ - - ]: 0 : return arg == (n.getConst<bool>() ? 1 : 0);
429 : : }
430 [ + + ]: 444 : if (k == Kind::MULT
431 [ + + ]: 440 : || (arg == 0
432 [ + - ][ + - ]: 220 : && (k == Kind::DIVISION_TOTAL || k == Kind::INTS_DIVISION_TOTAL
433 [ + - ]: 220 : || k == Kind::INTS_MODULUS_TOTAL))
434 [ - + ][ - - ]: 440 : || (arg == 2 && k == Kind::STRING_SUBSTR))
435 : : {
436 : : // zero
437 [ - + ]: 4 : if (n.getConst<Rational>().sgn() == 0)
438 : : {
439 : 0 : return true;
440 : : }
441 : : }
442 [ + - ][ + - ]: 444 : if (k == Kind::BITVECTOR_AND || k == Kind::BITVECTOR_MULT
443 [ + - ][ + - ]: 444 : || k == Kind::BITVECTOR_UDIV || k == Kind::BITVECTOR_UREM
444 [ + + ]: 444 : || (arg == 0
445 [ + - ][ + - ]: 222 : && (k == Kind::BITVECTOR_SHL || k == Kind::BITVECTOR_LSHR
446 [ - + ]: 222 : || k == Kind::BITVECTOR_ASHR)))
447 : : {
448 [ - - ]: 0 : if (bv::utils::isZero(n))
449 : : {
450 : 0 : return true;
451 : : }
452 : : }
453 [ - + ]: 444 : if (k == Kind::BITVECTOR_OR)
454 : : {
455 : : // bit-vector ones
456 [ - - ]: 0 : if (bv::utils::isOnes(n))
457 : : {
458 : 0 : return true;
459 : : }
460 : : }
461 : :
462 [ + + ][ + - ]: 444 : if ((arg == 1 && k == Kind::STRING_CONTAINS)
463 [ + + ][ - + ]: 444 : || (arg == 0 && k == Kind::STRING_SUBSTR))
464 : : {
465 : : // empty string
466 [ - - ]: 0 : if (strings::Word::getLength(n) == 0)
467 : : {
468 : 0 : return true;
469 : : }
470 : : }
471 [ + + ][ + - ]: 444 : if ((arg != 0 && k == Kind::STRING_SUBSTR)
472 [ - + ][ - - ]: 444 : || (arg == 2 && k == Kind::STRING_INDEXOF))
473 : : {
474 : : // negative integer
475 [ - - ]: 0 : if (n.getConst<Rational>().sgn() < 0)
476 : : {
477 : 0 : return true;
478 : : }
479 : : }
480 : 444 : return false;
481 : : }
482 : :
483 : : } // namespace theory
484 : : } // namespace cvc5::internal
|