Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of the higher-order extension of TheoryUF.
11 : : */
12 : :
13 : : #include "theory/uf/ho_extension.h"
14 : :
15 : : #include "expr/node_algorithm.h"
16 : : #include "expr/skolem_manager.h"
17 : : #include "options/uf_options.h"
18 : : #include "theory/smt_engine_subsolver.h"
19 : : #include "theory/theory_model.h"
20 : : #include "theory/uf/function_const.h"
21 : : #include "theory/uf/lambda_lift.h"
22 : : #include "theory/uf/theory_uf_rewriter.h"
23 : : #include "util/rational.h"
24 : :
25 : : using namespace std;
26 : : using namespace cvc5::internal::kind;
27 : :
28 : : namespace cvc5::internal {
29 : : namespace theory {
30 : : namespace uf {
31 : :
32 : 1963 : HoExtension::HoExtension(Env& env,
33 : : TheoryState& state,
34 : : TheoryInferenceManager& im,
35 : 1963 : LambdaLift& ll)
36 : : : EnvObj(env),
37 : 1963 : d_state(state),
38 : 1963 : d_im(im),
39 : 1963 : d_ll(ll),
40 : 1963 : d_extensionality(userContext()),
41 : 1963 : d_cachedLemmas(userContext()),
42 : 1963 : d_uf_std_skolem(userContext()),
43 : 5889 : d_lamEqProcessed(userContext())
44 : : {
45 : 1963 : d_true = nodeManager()->mkConst(true);
46 : : // don't send true lemma
47 : 1963 : d_cachedLemmas.insert(d_true);
48 : 1963 : }
49 : :
50 : 58169 : TrustNode HoExtension::ppRewrite(Node node, std::vector<SkolemLemma>& lems)
51 : : {
52 : 58169 : Kind k = node.getKind();
53 [ + + ]: 58169 : if (k == Kind::HO_APPLY)
54 : : {
55 : : // convert HO_APPLY to APPLY_UF if fully applied
56 [ + + ]: 4904 : if (node[0].getType().getNumChildren() == 2)
57 : : {
58 [ + - ]: 3705 : Trace("uf-ho") << "uf-ho : expanding definition : " << node << std::endl;
59 : 3705 : Node ret = getApplyUfForHoApply(node);
60 [ + - ]: 7410 : Trace("uf-ho") << "uf-ho : ppRewrite : " << node << " to " << ret
61 : 3705 : << std::endl;
62 : 3705 : return TrustNode::mkTrustRewrite(node, ret);
63 : 3705 : }
64 : : // partial beta reduction
65 : : // f ---> (lambda ((x Int) (y Int)) s[x, y]) then (@ f t) is preprocessed
66 : : // to (lambda ((y Int)) s[t, y]).
67 [ + - ]: 1199 : if (options().uf.ufHoLazyLambdaLift)
68 : : {
69 : 1199 : Node op = node[0];
70 : 1199 : Node opl = d_ll.getLambdaFor(op);
71 [ - + ][ - - ]: 1199 : if (!opl.isNull() && !d_ll.isLifted(opl))
[ - + ]
72 : : {
73 : 0 : NodeManager* nm = nodeManager();
74 : 0 : Node app = nm->mkNode(Kind::HO_APPLY, opl, node[1]);
75 : 0 : app = rewrite(app);
76 [ - - ]: 0 : Trace("uf-lazy-ll")
77 : 0 : << "Partial beta reduce: " << node << " -> " << app << std::endl;
78 : 0 : return TrustNode::mkTrustRewrite(node, app, nullptr);
79 : 0 : }
80 [ + - ][ + - ]: 1199 : }
81 : : }
82 [ + + ]: 53265 : else if (k == Kind::APPLY_UF)
83 : : {
84 : : // Say (lambda ((x Int)) t[x]) occurs in the input. We replace this
85 : : // by k during ppRewrite. In the following, if we see (k s), we replace
86 : : // it by t[s]. This maintains the invariant that the *only* occurrences
87 : : // of k are as arguments to other functions; k is not applied
88 : : // in any preprocessed constraints.
89 [ + + ]: 32971 : if (options().uf.ufHoLazyLambdaLift)
90 : : {
91 : : // if an application of the lambda lifted function, do beta reduction
92 : : // immediately
93 : 32855 : Node op = node.getOperator();
94 : 32855 : Node opl = d_ll.getLambdaFor(op);
95 [ + + ][ + + ]: 32855 : if (!opl.isNull() && !d_ll.isLifted(opl))
[ + + ]
96 : : {
97 [ - + ][ - + ]: 878 : Assert(opl.getKind() == Kind::LAMBDA);
[ - - ]
98 : 878 : std::vector<Node> args(node.begin(), node.end());
99 : 878 : Node app = d_ll.betaReduce(opl, args);
100 [ + - ]: 1756 : Trace("uf-lazy-ll")
101 : 878 : << "Beta reduce: " << node << " -> " << app << std::endl;
102 : 878 : return TrustNode::mkTrustRewrite(node, app, nullptr);
103 : 878 : }
104 : : // If an unlifted lambda occurs in an argument to APPLY_UF, it must be
105 : : // lifted. We do this only if the lambda needs lifting, i.e. it is one
106 : : // that may induce circular model dependencies.
107 [ + + ]: 81296 : for (const Node& nc : node)
108 : : {
109 [ + + ]: 49319 : if (nc.getType().isFunction())
110 : : {
111 : 2878 : Node lam = d_ll.getLambdaFor(nc);
112 [ + + ][ + + ]: 2878 : if (!lam.isNull() && d_ll.needsLift(lam))
[ + + ]
113 : : {
114 : 791 : TrustNode trn = d_ll.lift(lam);
115 [ + + ]: 791 : if (!trn.isNull())
116 : : {
117 : 157 : lems.push_back(SkolemLemma(trn, nc));
118 : : }
119 : 791 : }
120 : 2878 : }
121 : 49319 : }
122 [ + + ][ + + ]: 33733 : }
123 : : }
124 [ + + ][ + + ]: 20294 : else if (k == Kind::LAMBDA || k == Kind::FUNCTION_ARRAY_CONST)
125 : : {
126 [ + - ]: 624 : Trace("uf-lazy-ll") << "Preprocess lambda: " << node << std::endl;
127 [ + + ]: 624 : if (k == Kind::LAMBDA)
128 : : {
129 : 532 : Node elimLam = TheoryUfRewriter::canEliminateLambda(nodeManager(), node);
130 [ - + ]: 532 : if (!elimLam.isNull())
131 : : {
132 [ - - ]: 0 : Trace("uf-lazy-ll") << "...eliminates to " << elimLam << std::endl;
133 : 0 : return TrustNode::mkTrustRewrite(node, elimLam, nullptr);
134 : : }
135 [ + - ]: 532 : }
136 : 624 : TrustNode skTrn = d_ll.ppRewrite(node, lems);
137 [ + - ][ - + ]: 624 : Trace("uf-lazy-ll") << "...return " << skTrn.getNode() << std::endl;
[ - - ]
138 : 624 : return skTrn;
139 : 624 : }
140 : 52962 : return TrustNode::null();
141 : : }
142 : :
143 : 53031 : Node HoExtension::getExtensionalityDeq(TNode deq, bool isCached)
144 : : {
145 : 53031 : Assert(deq.getKind() == Kind::NOT && deq[0].getKind() == Kind::EQUAL);
146 [ - + ][ - + ]: 53031 : Assert(deq[0][0].getType().isFunction());
[ - - ]
147 [ + + ]: 53031 : if (isCached)
148 : : {
149 : 3917 : std::map<Node, Node>::iterator it = d_extensionality_deq.find(deq);
150 [ - + ]: 3917 : if (it != d_extensionality_deq.end())
151 : : {
152 : 0 : return it->second;
153 : : }
154 : : }
155 : 106062 : TypeNode tn = deq[0][0].getType();
156 : 53031 : std::vector<TypeNode> argTypes = tn.getArgTypes();
157 : 53031 : std::vector<Node> skolems;
158 : 53031 : NodeManager* nm = nodeManager();
159 : 53031 : SkolemManager* sm = nm->getSkolemManager();
160 : 53031 : std::vector<Node> cacheVals;
161 : 53031 : cacheVals.push_back(deq[0][0]);
162 : 53031 : cacheVals.push_back(deq[0][1]);
163 : 53031 : cacheVals.push_back(Node::null());
164 [ + + ]: 106163 : for (unsigned i = 0, nargs = argTypes.size(); i < nargs; i++)
165 : : {
166 : 53132 : cacheVals[2] = nm->mkConstInt(Rational(i));
167 : 53132 : Node k = sm->mkSkolemFunction(SkolemId::HO_DEQ_DIFF, cacheVals);
168 : 53132 : skolems.push_back(k);
169 : 53132 : }
170 [ + + ]: 318186 : Node t[2];
171 [ + + ]: 159093 : for (unsigned i = 0; i < 2; i++)
172 : : {
173 : 106062 : std::vector<Node> children;
174 : 212124 : Node curr = deq[0][i];
175 [ + + ]: 202893 : while (curr.getKind() == Kind::HO_APPLY)
176 : : {
177 : 96831 : children.push_back(curr[1]);
178 : 96831 : curr = curr[0];
179 : : }
180 : 106062 : children.push_back(curr);
181 : 106062 : std::reverse(children.begin(), children.end());
182 : 106062 : children.insert(children.end(), skolems.begin(), skolems.end());
183 : 106062 : t[i] = nm->mkNode(Kind::APPLY_UF, children);
184 : 106062 : }
185 : 53031 : Node conc = t[0].eqNode(t[1]).negate();
186 [ + + ]: 53031 : if (isCached)
187 : : {
188 : 3917 : d_extensionality_deq[deq] = conc;
189 : : }
190 : 53031 : return conc;
191 [ + + ][ - - ]: 265155 : }
192 : :
193 : 35683 : unsigned HoExtension::applyExtensionality(TNode deq)
194 : : {
195 : 35683 : Assert(deq.getKind() == Kind::NOT && deq[0].getKind() == Kind::EQUAL);
196 [ - + ][ - + ]: 35683 : Assert(deq[0][0].getType().isFunction());
[ - - ]
197 : : // apply extensionality
198 [ + + ]: 35683 : if (d_extensionality.find(deq) == d_extensionality.end())
199 : : {
200 : 3917 : d_extensionality.insert(deq);
201 : 3917 : Node conc = getExtensionalityDeq(deq);
202 : 7834 : Node lem = nodeManager()->mkNode(Kind::OR, deq[0], conc);
203 [ + - ]: 7834 : Trace("uf-ho-lemma") << "uf-ho-lemma : extensionality : " << lem
204 : 3917 : << std::endl;
205 : 3917 : d_im.lemma(lem, InferenceId::UF_HO_EXTENSIONALITY);
206 : 3917 : return 1;
207 : 3917 : }
208 : 31766 : return 0;
209 : : }
210 : :
211 : 3705 : Node HoExtension::getApplyUfForHoApply(Node node)
212 : : {
213 [ - + ][ - + ]: 3705 : Assert(node[0].getType().getNumChildren() == 2);
[ - - ]
214 : 3705 : std::vector<TNode> args;
215 : 3705 : Node f = TheoryUfRewriter::decomposeHoApply(node, args, true);
216 : 3705 : Node new_f = f;
217 : 3705 : NodeManager* nm = nodeManager();
218 [ - + ]: 3705 : if (!TheoryUfRewriter::canUseAsApplyUfOperator(f))
219 : : {
220 : 0 : NodeNodeMap::const_iterator itus = d_uf_std_skolem.find(f);
221 [ - - ]: 0 : if (itus == d_uf_std_skolem.end())
222 : : {
223 : 0 : std::unordered_set<Node> fvs;
224 : 0 : expr::getFreeVariables(f, fvs);
225 : 0 : Node lem;
226 [ - - ]: 0 : if (!fvs.empty())
227 : : {
228 : 0 : std::vector<TypeNode> newTypes;
229 : 0 : std::vector<Node> vs;
230 : 0 : std::vector<Node> nvs;
231 [ - - ]: 0 : for (const Node& v : fvs)
232 : : {
233 : 0 : TypeNode vt = v.getType();
234 : 0 : newTypes.push_back(vt);
235 : 0 : Node nv = NodeManager::mkBoundVar(vt);
236 : 0 : vs.push_back(v);
237 : 0 : nvs.push_back(nv);
238 : 0 : }
239 : 0 : TypeNode ft = f.getType();
240 : 0 : std::vector<TypeNode> argTypes = ft.getArgTypes();
241 : 0 : TypeNode rangeType = ft.getRangeType();
242 : :
243 : 0 : newTypes.insert(newTypes.end(), argTypes.begin(), argTypes.end());
244 : 0 : TypeNode nft = nm->mkFunctionType(newTypes, rangeType);
245 : 0 : new_f = NodeManager::mkDummySkolem("app_uf", nft);
246 [ - - ]: 0 : for (const Node& v : vs)
247 : : {
248 : 0 : new_f = nm->mkNode(Kind::HO_APPLY, new_f, v);
249 : : }
250 : 0 : AssertEqual(new_f.getType(), f.getType());
251 : 0 : Node eq = new_f.eqNode(f);
252 : 0 : Node seq = eq.substitute(vs.begin(), vs.end(), nvs.begin(), nvs.end());
253 : 0 : lem = nm->mkNode(
254 : 0 : Kind::FORALL, nm->mkNode(Kind::BOUND_VAR_LIST, nvs), seq);
255 : 0 : }
256 : : else
257 : : {
258 : : // introduce skolem to make a standard APPLY_UF
259 : 0 : new_f = NodeManager::mkDummySkolem("app_uf", f.getType());
260 : 0 : lem = new_f.eqNode(f);
261 : : }
262 [ - - ]: 0 : Trace("uf-ho-lemma")
263 : 0 : << "uf-ho-lemma : Skolem definition for apply-conversion : " << lem
264 : 0 : << std::endl;
265 : 0 : d_im.lemma(lem, InferenceId::UF_HO_APP_CONV_SKOLEM);
266 : 0 : d_uf_std_skolem[f] = new_f;
267 : 0 : }
268 : : else
269 : : {
270 : 0 : new_f = (*itus).second;
271 : : }
272 : : // unroll the HO_APPLY, adding to the first argument position
273 : : // Note arguments in the vector args begin at position 1.
274 [ - - ]: 0 : while (new_f.getKind() == Kind::HO_APPLY)
275 : : {
276 : 0 : args.insert(args.begin() + 1, new_f[1]);
277 : 0 : new_f = new_f[0];
278 : : }
279 : : }
280 [ - + ][ - + ]: 3705 : Assert(TheoryUfRewriter::canUseAsApplyUfOperator(new_f));
[ - - ]
281 : 3705 : args[0] = new_f;
282 : 3705 : Node ret = nm->mkNode(Kind::APPLY_UF, args);
283 [ - + ][ - + ]: 11115 : AssertEqual(ret.getType(), node.getType());
[ - - ]
284 : 7410 : return ret;
285 : 3705 : }
286 : :
287 : 2459 : void HoExtension::computeRelevantTerms(std::set<Node>& termSet)
288 : : {
289 [ + + ]: 72178 : for (const Node& t : termSet)
290 : : {
291 [ + + ]: 69719 : if (t.getKind() == Kind::APPLY_UF)
292 : : {
293 : 20097 : Node ht = TheoryUfRewriter::getHoApplyForApplyUf(t);
294 : : // also add all subterms
295 [ + + ]: 49672 : while (ht.getKind() == Kind::HO_APPLY)
296 : : {
297 : 29575 : termSet.insert(ht);
298 : 29575 : termSet.insert(ht[1]);
299 : 29575 : ht = ht[0];
300 : : }
301 : 20097 : }
302 : : }
303 : 2459 : }
304 : :
305 : 9538 : unsigned HoExtension::checkExtensionality(TheoryModel* m)
306 : : {
307 : : // if we are in collect model info, we require looking at the model's
308 : : // equality engine, so that we only consider "relevant" (see
309 : : // Theory::computeRelevantTerms) function terms.
310 : : eq::EqualityEngine* ee =
311 [ + + ]: 9538 : m != nullptr ? m->getEqualityEngine() : d_state.getEqualityEngine();
312 : 9538 : NodeManager* nm = nodeManager();
313 : 9538 : unsigned num_lemmas = 0;
314 : 9538 : bool isCollectModel = (m != nullptr);
315 [ + - ]: 19076 : Trace("uf-ho") << "HoExtension::checkExtensionality, collectModel="
316 : 9538 : << isCollectModel << "..." << std::endl;
317 : 9538 : std::map<TypeNode, std::vector<Node> > func_eqcs;
318 : 9538 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
319 : 9538 : bool hasFunctions = false;
320 [ + + ]: 198671 : while (!eqcs_i.isFinished())
321 : : {
322 : 189133 : Node eqc = (*eqcs_i);
323 : 189133 : TypeNode tn = eqc.getType();
324 [ + + ][ + + ]: 189133 : if (tn.isFunction() && d_lambdaEqc.find(eqc) == d_lambdaEqc.end())
[ + + ]
325 : : {
326 : 24949 : hasFunctions = true;
327 : 24949 : std::vector<TypeNode> argTypes = tn.getArgTypes();
328 : : // We classify a function here to determine whether we need to apply
329 : : // extensionality eagerly during solving. We apply extensionality
330 : : // eagerly during solving if
331 : : // (A) The function type has finite cardinality,
332 : : // (B) All of its arguments have finite cardinality, or
333 : : // (C) It has a function as an argument.
334 : : // The latter is required so that we recursively consider extensionality
335 : : // between function constants introduced for extensionality lemmas.
336 : 24949 : bool eagerExtType = true;
337 [ + + ]: 24949 : if (!d_env.isFiniteType(tn))
338 : : {
339 [ + + ]: 47490 : for (const TypeNode& tna : argTypes)
340 : : {
341 [ + + ]: 27420 : if (!d_env.isFiniteType(tna))
342 : : {
343 : 26788 : eagerExtType = false;
344 : : }
345 [ + + ]: 27420 : if (tna.isFunction())
346 : : {
347 : 2742 : eagerExtType = true;
348 : 2742 : break;
349 : : }
350 : : }
351 : : }
352 : : // Based on the above classification of finite vs infinite.
353 : : // If during collect model, must have an infinite function type, since
354 : : // such function are not necessary to be handled during solving.
355 : : // If not during collect model, must have a finite function type, since
356 : : // such function symbols must be handled during solving.
357 [ + + ]: 24949 : if (eagerExtType != isCollectModel)
358 : : {
359 : 8670 : func_eqcs[tn].push_back(eqc);
360 [ + - ]: 17340 : Trace("uf-ho-debug")
361 : 8670 : << " func eqc : " << tn << " : " << eqc << std::endl;
362 : : }
363 : 24949 : }
364 : 189133 : ++eqcs_i;
365 : 189133 : }
366 [ - + ]: 9538 : if (!options().uf.ufHoExt)
367 : : {
368 : : // we are not applying extensionality, thus we are incomplete if functions
369 : : // are present
370 [ - - ]: 0 : if (hasFunctions)
371 : : {
372 : 0 : d_im.setModelUnsound(IncompleteId::UF_HO_EXT_DISABLED);
373 : : }
374 : 0 : return 0;
375 : : }
376 : :
377 : 9538 : for (std::map<TypeNode, std::vector<Node> >::iterator itf = func_eqcs.begin();
378 [ + + ]: 15355 : itf != func_eqcs.end();
379 : 5817 : ++itf)
380 : : {
381 [ + + ]: 14485 : for (unsigned j = 0, sizej = itf->second.size(); j < sizej; j++)
382 : : {
383 [ + + ]: 59413 : for (unsigned k = (j + 1), sizek = itf->second.size(); k < sizek; k++)
384 : : {
385 : : // if these equivalence classes are not explicitly disequal, do
386 : : // extensionality to ensure distinctness. Notice that we always use
387 : : // the (local) equality engine for this check via the state, since the
388 : : // model's equality engine does not store any disequalities. This is
389 : : // an optimization to introduce fewer equalities during model
390 : : // construction, since we know such disequalities have already been
391 : : // witness via assertions.
392 [ + + ]: 50747 : if (!d_state.areDisequal(itf->second[j], itf->second[k]))
393 : : {
394 : 98742 : Node deq = rewrite(itf->second[j].eqNode(itf->second[k]).negate());
395 : : // either add to model, or add lemma
396 [ + + ]: 49371 : if (isCollectModel)
397 : : {
398 : : // Add extentionality disequality to the model.
399 : : // It is important that we construct new (unconstrained) variables
400 : : // k here, so that we do not generate any inconsistencies.
401 : 49114 : Node edeq = getExtensionalityDeq(deq, false);
402 : 49114 : Assert(edeq.getKind() == Kind::NOT
403 : : && edeq[0].getKind() == Kind::EQUAL);
404 : : // introducing terms, must add required constraints, e.g. to
405 : : // force equalities between APPLY_UF and HO_APPLY terms
406 : 49114 : bool success = true;
407 [ + + ]: 147340 : for (unsigned r = 0; r < 2; r++)
408 : : {
409 [ - + ]: 98228 : if (!collectModelInfoHoTerm(edeq[0][r], m))
410 : : {
411 : 0 : return 1;
412 : : }
413 : : // Ensure finite skolems are set to arbitrary values eagerly.
414 : : // This ensures that partial function applications are identified
415 : : // with one another based on this assignment.
416 [ + + ]: 292383 : for (const Node& hk : edeq[0][r])
417 : : {
418 : 194157 : TypeNode tnk = hk.getType();
419 [ + + ]: 194157 : if (d_env.isFiniteType(tnk))
420 : : {
421 : 2 : TypeEnumerator te(tnk);
422 : 2 : Node v = *te;
423 [ + - ]: 2 : if (!m->assertEquality(hk, v, true))
424 : : {
425 : 2 : success = false;
426 : 2 : break;
427 : : }
428 [ - + ][ - + ]: 4 : }
429 [ + + ][ + + ]: 292387 : }
430 [ + + ]: 98228 : if (!success)
431 : : {
432 : 2 : break;
433 : : }
434 : : }
435 [ + + ]: 49114 : if (success)
436 : : {
437 : 98224 : TypeNode tn = edeq[0][0].getType();
438 [ + - ]: 98224 : Trace("uf-ho-debug")
439 : 0 : << "Add extensionality deq to model for : " << edeq
440 : 49112 : << std::endl;
441 [ + + ]: 49112 : if (d_env.isFiniteType(tn))
442 : : {
443 : : // We are an infinite function type with a finite range sort.
444 : : // Model construction assigns the first value for all
445 : : // unconstrained variables for such sorts, which does not
446 : : // suffice in this context since we are trying to make the
447 : : // functions disequal. Thus, for such case we enumerate the
448 : : // first two values for this sort and set the extensionality
449 : : // index to be equal to these two distinct values. There must
450 : : // be at least two values since this is an infinite function
451 : : // sort.
452 : 42050 : TypeEnumerator te(tn);
453 : 42050 : Node v1 = *te;
454 : 42050 : te++;
455 : 42050 : Node v2 = *te;
456 [ + - ][ + - ]: 42050 : Assert(!v2.isNull() && v2 != v1);
[ - + ][ - + ]
[ - - ]
457 : 84100 : Trace("uf-ho-debug") << "Finite witness: " << edeq[0][0]
458 : 42050 : << " == " << v1 << std::endl;
459 : 84100 : Trace("uf-ho-debug") << "Finite witness: " << edeq[0][1]
460 : 42050 : << " == " << v2 << std::endl;
461 : 42050 : success = m->assertEquality(edeq[0][0], v1, true);
462 [ + - ]: 42050 : if (success)
463 : : {
464 : 42050 : success = m->assertEquality(edeq[0][1], v2, true);
465 : : }
466 : 42050 : }
467 : 49112 : }
468 [ + + ]: 49114 : if (!success)
469 : : {
470 : 4 : Node eq = edeq[0][0].eqNode(edeq[0][1]);
471 : 8 : Node lem = nm->mkNode(Kind::OR, {deq.negate(), eq.negate()});
472 [ + - ]: 4 : Trace("uf-ho") << "HoExtension: cmi extensionality lemma " << lem
473 : 2 : << std::endl;
474 : 2 : d_im.lemma(lem, InferenceId::UF_HO_MODEL_EXTENSIONALITY);
475 : 2 : return 1;
476 : 2 : }
477 [ + + ]: 49114 : }
478 : : else
479 : : {
480 : : // apply extensionality lemma
481 : 257 : num_lemmas += applyExtensionality(deq);
482 : : }
483 [ + + ]: 49371 : }
484 : : }
485 : : }
486 : : }
487 : 9536 : return num_lemmas;
488 : 9538 : }
489 : :
490 : 10721023 : unsigned HoExtension::applyAppCompletion(TNode n)
491 : : {
492 [ - + ][ - + ]: 10721023 : Assert(n.getKind() == Kind::APPLY_UF);
[ - - ]
493 : :
494 : 10721023 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
495 : : // must expand into APPLY_HO version if not there already
496 : 10721023 : Node ret = TheoryUfRewriter::getHoApplyForApplyUf(n);
497 : 10721023 : if (!ee->hasTerm(ret) || !ee->areEqual(ret, n))
498 : : {
499 : 49051 : Node eq = n.eqNode(ret);
500 [ + - ]: 98102 : Trace("uf-ho-lemma") << "uf-ho-lemma : infer, by apply-expand : " << eq
501 : 49051 : << std::endl;
502 : 98102 : d_im.assertInternalFact(eq,
503 : : true,
504 : : InferenceId::UF_HO_APP_ENCODE,
505 : : ProofRule::HO_APP_ENCODE,
506 : : {},
507 : : {n});
508 : 49051 : return 1;
509 : 49051 : }
510 [ + - ]: 21343944 : Trace("uf-ho-debug") << " ...already have " << ret << " == " << n << "."
511 : 10671972 : << std::endl;
512 : 10671972 : return 0;
513 : 10721023 : }
514 : :
515 : 56130 : unsigned HoExtension::checkAppCompletion()
516 : : {
517 [ + - ]: 56130 : Trace("uf-ho") << "HoExtension::checkApplyCompletion..." << std::endl;
518 : : // compute the operators that are relevant (those for which an HO_APPLY exist)
519 : 56130 : std::set<TNode> rlvOp;
520 : 56130 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
521 : 56130 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
522 : 56130 : std::map<TNode, std::vector<Node> > apply_uf;
523 [ + + ]: 613456 : while (!eqcs_i.isFinished())
524 : : {
525 : 606377 : Node eqc = (*eqcs_i);
526 [ + - ]: 1212754 : Trace("uf-ho-debug") << " apply completion : visit eqc " << eqc
527 : 606377 : << std::endl;
528 : 606377 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
529 [ + + ]: 43280613 : while (!eqc_i.isFinished())
530 : : {
531 : 42723287 : Node n = *eqc_i;
532 [ + + ][ + + ]: 42723287 : if (n.getKind() == Kind::APPLY_UF || n.getKind() == Kind::HO_APPLY)
[ + + ]
533 : : {
534 : 36563354 : int curr_sum = 0;
535 : 36563354 : std::map<TNode, bool> curr_rops;
536 [ + + ]: 36563354 : if (n.getKind() == Kind::APPLY_UF)
537 : : {
538 : 45855066 : TNode rop = ee->getRepresentative(n.getOperator());
539 [ + + ]: 22927533 : if (rlvOp.find(rop) != rlvOp.end())
540 : : {
541 : : // try if its operator is relevant
542 : 10518740 : curr_sum = applyAppCompletion(n);
543 [ + + ]: 10518740 : if (curr_sum > 0)
544 : : {
545 : 43913 : return curr_sum;
546 : : }
547 : : }
548 : : else
549 : : {
550 : : // add to pending list
551 : 12408793 : apply_uf[rop].push_back(n);
552 : : }
553 : : // Arguments are also relevant operators.
554 : : // It might be possible include fewer terms here, see #1115.
555 [ + + ]: 52526833 : for (unsigned k = 0; k < n.getNumChildren(); k++)
556 : : {
557 [ + + ]: 29643213 : if (n[k].getType().isFunction())
558 : : {
559 : 1088078 : TNode rop2 = ee->getRepresentative(n[k]);
560 : 544039 : curr_rops[rop2] = true;
561 : 544039 : }
562 : : }
563 [ + + ]: 22927533 : }
564 : : else
565 : : {
566 [ - + ][ - + ]: 13635821 : Assert(n.getKind() == Kind::HO_APPLY);
[ - - ]
567 : 27271642 : TNode rop = ee->getRepresentative(n[0]);
568 : 13635821 : curr_rops[rop] = true;
569 : 13635821 : }
570 : 36519441 : for (std::map<TNode, bool>::iterator itc = curr_rops.begin();
571 [ + + ]: 50694136 : itc != curr_rops.end();
572 : 14174695 : ++itc)
573 : : {
574 : 14179833 : TNode rop = itc->first;
575 [ + + ]: 14179833 : if (rlvOp.find(rop) == rlvOp.end())
576 : : {
577 : 353754 : rlvOp.insert(rop);
578 : : // now, try each pending APPLY_UF for this operator
579 : : std::map<TNode, std::vector<Node> >::iterator itu =
580 : 353754 : apply_uf.find(rop);
581 [ + + ]: 353754 : if (itu != apply_uf.end())
582 : : {
583 [ + + ]: 213983 : for (unsigned j = 0, size = itu->second.size(); j < size; j++)
584 : : {
585 : 202283 : curr_sum = applyAppCompletion(itu->second[j]);
586 [ + + ]: 202283 : if (curr_sum > 0)
587 : : {
588 : 5138 : return curr_sum;
589 : : }
590 : : }
591 : : }
592 : : }
593 [ + + ]: 14179833 : }
594 [ + + ]: 36563354 : }
595 : 42674236 : ++eqc_i;
596 [ + + ]: 42723287 : }
597 : 557326 : ++eqcs_i;
598 [ + + ]: 606377 : }
599 : 7079 : return 0;
600 : 56130 : }
601 : :
602 : 6991 : unsigned HoExtension::checkLazyLambda()
603 : : {
604 [ + + ]: 6991 : if (!options().uf.ufHoLazyLambdaLift)
605 : : {
606 : : // no lambdas are lazily lifted
607 : 438 : return 0;
608 : : }
609 [ + - ]: 6553 : Trace("uf-ho") << "HoExtension::checkLazyLambda..." << std::endl;
610 : 6553 : NodeManager* nm = nodeManager();
611 : 6553 : unsigned numLemmas = 0;
612 : 6553 : d_lambdaEqc.clear();
613 : 6553 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
614 : 6553 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
615 : : // normal functions equated to lambda functions
616 : 6553 : std::unordered_set<Node> normalEqFuns;
617 : : // mapping from functions to terms
618 [ + + ]: 158883 : while (!eqcs_i.isFinished())
619 : : {
620 : 152330 : Node eqc = (*eqcs_i);
621 : 152330 : ++eqcs_i;
622 [ + + ]: 152330 : if (!eqc.getType().isFunction())
623 : : {
624 : 127512 : continue;
625 : : }
626 : 24818 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
627 : 24818 : Node lamRep; // the first lambda function we encounter in the equivalence
628 : : // class
629 : 24818 : bool needsLift = false;
630 : 24818 : bool doLift = false;
631 : 24818 : Node lamRepLam;
632 : 24818 : std::unordered_set<Node> normalEqFunWait;
633 [ + + ]: 57094 : while (!eqc_i.isFinished())
634 : : {
635 : 32276 : Node n = *eqc_i;
636 : 32276 : ++eqc_i;
637 : 32276 : Node lam = d_ll.getLambdaFor(n);
638 [ + + ][ + + ]: 32276 : if (lam.isNull() || d_ll.isLifted(lam))
[ + + ]
639 : : {
640 [ + + ]: 24377 : if (!lamRep.isNull())
641 : : {
642 : : // if we are equal to a lambda function, we must beta-reduce
643 : : // applications of this
644 : 54 : normalEqFuns.insert(n);
645 : 54 : doLift = needsLift;
646 : : }
647 : : else
648 : : {
649 : : // waiting to see if there is a lambda function in this equivalence
650 : : // class
651 : 24323 : normalEqFunWait.insert(n);
652 : : }
653 : : }
654 [ + + ]: 7899 : else if (lamRep.isNull())
655 : : {
656 : : // there is a lambda function in this equivalence class
657 : 7832 : lamRep = n;
658 : 7832 : lamRepLam = lam;
659 [ + + ][ + - ]: 7832 : needsLift = d_ll.needsLift(lam) && !d_ll.isLifted(lam);
660 [ + + ][ + + ]: 7832 : doLift = needsLift && !normalEqFunWait.empty();
661 : : // must consider all normal functions we've seen so far
662 : 7832 : normalEqFuns.insert(normalEqFunWait.begin(), normalEqFunWait.end());
663 : 7832 : normalEqFunWait.clear();
664 : : }
665 : : else
666 : : {
667 : : // two lambda functions are in same equivalence class
668 [ + - ]: 67 : Node f = lamRep < n ? lamRep : n;
669 [ + - ]: 67 : Node g = lamRep < n ? n : lamRep;
670 : : // swap based on order
671 [ - + ]: 67 : if (g < f)
672 : : {
673 : 0 : Node tmp = f;
674 : 0 : f = g;
675 : 0 : g = tmp;
676 : 0 : }
677 : 67 : Node fgEq = f.eqNode(g);
678 [ + + ]: 67 : if (d_lamEqProcessed.find(fgEq) != d_lamEqProcessed.end())
679 : : {
680 : 23 : continue;
681 : : }
682 : 44 : d_lamEqProcessed.insert(fgEq);
683 : :
684 [ + - ]: 88 : Trace("uf-ho-debug") << " found equivalent lambda functions " << f
685 : 44 : << " and " << g << std::endl;
686 [ + - ]: 44 : Node flam = lamRep < n ? lamRepLam : lam;
687 [ + - ][ + - ]: 44 : Assert(!flam.isNull() && flam.getKind() == Kind::LAMBDA);
[ - + ][ - + ]
[ - - ]
688 : 44 : Node lhs = flam[1];
689 [ + - ]: 44 : Node glam = lamRep < n ? lam : lamRepLam;
690 [ + - ]: 88 : Trace("uf-ho-debug")
691 : 44 : << " lambda are " << flam << " and " << glam << std::endl;
692 : 88 : std::vector<Node> args(flam[0].begin(), flam[0].end());
693 : 44 : Node rhs = d_ll.betaReduce(glam, args);
694 : 88 : Node univ = nm->mkNode(Kind::FORALL, flam[0], lhs.eqNode(rhs));
695 : : // do quantifier elimination if the option is set
696 : : // For example, say (= f1 f2) where f1 is (lambda ((x Int)) (< x a))
697 : : // and f2 is (lambda ((x Int)) (< x b)).
698 : : // By default, we would generate the inference
699 : : // (=> (= f1 f2) (forall ((x Int)) (= (< x a) (< x b))),
700 : : // where quantified reasoning is introduced into the main solving
701 : : // procedure.
702 : : // With --uf-lambda-qe, we use a subsolver to compute the quantifier
703 : : // elimination of:
704 : : // (forall ((x Int)) (= (< x a) (< x b)),
705 : : // which is (and (<= a b) (<= b a)). We instead generate the lemma
706 : : // (=> (= f1 f2) (and (<= a b) (<= b a)).
707 : : // The motivation for this is to reduce the complexity of constraints
708 : : // in the main solver. This is motivated by usages of set.filter where
709 : : // the lambdas are over a decidable theory that admits quantifier
710 : : // elimination, e.g. LIA or BV.
711 [ + + ]: 44 : if (options().uf.ufHoLambdaQe)
712 : : {
713 [ + - ]: 4 : Trace("uf-lambda-qe")
714 : 2 : << "Given " << flam << " == " << glam << std::endl;
715 [ + - ]: 2 : Trace("uf-lambda-qe") << "Run QE on " << univ << std::endl;
716 : 2 : std::unique_ptr<SolverEngine> lqe;
717 : : // initialize the subsolver using the standard method
718 : 2 : initializeSubsolver(lqe, d_env);
719 : 2 : Node univQe = lqe->getQuantifierElimination(univ, true);
720 [ + - ]: 2 : Trace("uf-lambda-qe") << "QE is " << univQe << std::endl;
721 [ - + ][ - + ]: 2 : Assert(!univQe.isNull());
[ - - ]
722 : : // Note that if quantifier elimination failed, then univQe will
723 : : // be equal to univ, in which case this above code has no effect.
724 : 2 : univ = univQe;
725 : 2 : }
726 : : // f = g => forall x. reduce(lambda(f)(x)) = reduce(lambda(g)(x))
727 : : //
728 : : // For example, if f -> lambda z. z+1, g -> lambda y. y+3, this
729 : : // will infer: f = g => forall x. x+1 = x+3, which simplifies to
730 : : // f != g.
731 : 88 : Node lem = nm->mkNode(Kind::IMPLIES, fgEq, univ);
732 [ + - ]: 44 : if (cacheLemma(lem))
733 : : {
734 : 44 : d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_UNIV_EQ);
735 : 44 : numLemmas++;
736 : : }
737 [ + + ][ + + ]: 113 : }
[ + + ]
738 [ + + ][ + + ]: 32299 : }
739 [ + + ]: 24818 : if (!lamRep.isNull())
740 : : {
741 : 7832 : d_lambdaEqc[eqc] = lamRep;
742 : : // Do the lambda lifting lemma if needed. This happens if a lambda
743 : : // needs lifting based on the symbols in its body and is equated to an
744 : : // ordinary function symbol. For example, this is what ensures we
745 : : // handle conflicts like f = (lambda ((x Int)) (+ 1 (f x))).
746 [ + + ]: 7832 : if (doLift)
747 : : {
748 : 39 : TrustNode tlift = d_ll.lift(lamRepLam);
749 [ - + ][ - + ]: 39 : Assert(!tlift.isNull());
[ - - ]
750 : 39 : d_im.trustedLemma(tlift, InferenceId::UF_HO_LAMBDA_LAZY_LIFT);
751 : 39 : }
752 : : }
753 [ + + ]: 152330 : }
754 [ + - ]: 13106 : Trace("uf-ho-debug")
755 : 6553 : << " found " << normalEqFuns.size()
756 : 6553 : << " ordinary functions that are equal to lambda functions" << std::endl;
757 [ + + ]: 6553 : if (normalEqFuns.empty())
758 : : {
759 : 6448 : return numLemmas;
760 : : }
761 : : // if we have normal functions that are equal to lambda functions, go back
762 : : // and ensure they are mapped properly
763 : : // mapping from functions to terms
764 : 105 : eq::EqClassesIterator eqcs_i2 = eq::EqClassesIterator(ee);
765 [ + + ]: 23356 : while (!eqcs_i2.isFinished())
766 : : {
767 : 23251 : Node eqc = (*eqcs_i2);
768 : 23251 : ++eqcs_i2;
769 [ + - ]: 23251 : Trace("uf-ho-debug") << "Check equivalence class " << eqc << std::endl;
770 : 23251 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
771 [ + + ]: 122706 : while (!eqc_i.isFinished())
772 : : {
773 : 99455 : Node n = *eqc_i;
774 : 99455 : ++eqc_i;
775 [ + - ]: 99455 : Trace("uf-ho-debug") << "Check term " << n << std::endl;
776 : 99455 : Node op;
777 : 99455 : Kind k = n.getKind();
778 : 99455 : std::vector<Node> args;
779 [ + + ]: 99455 : if (k == Kind::APPLY_UF)
780 : : {
781 : 49996 : op = n.getOperator();
782 : 49996 : args.insert(args.end(), n.begin(), n.end());
783 : : }
784 [ + + ]: 49459 : else if (k == Kind::HO_APPLY)
785 : : {
786 : 26627 : op = n[0];
787 : 26627 : args.push_back(n[1]);
788 : : }
789 : : else
790 : : {
791 : 22832 : continue;
792 : : }
793 [ + + ]: 76623 : if (normalEqFuns.find(op) == normalEqFuns.end())
794 : : {
795 : 69756 : continue;
796 : : }
797 [ + - ]: 13734 : Trace("uf-ho-debug") << " found relevant ordinary application " << n
798 : 6867 : << std::endl;
799 [ - + ][ - + ]: 6867 : Assert(ee->hasTerm(op));
[ - - ]
800 : 13734 : Node r = ee->getRepresentative(op);
801 [ - + ][ - + ]: 6867 : Assert(d_lambdaEqc.find(r) != d_lambdaEqc.end());
[ - - ]
802 : 6867 : Node lf = d_lambdaEqc[r];
803 : 6867 : Node lam = d_ll.getLambdaFor(lf);
804 [ + - ][ + - ]: 6867 : Assert(!lam.isNull() && lam.getKind() == Kind::LAMBDA);
[ - + ][ - + ]
[ - - ]
805 : : // a normal function g equal to a lambda, say f --> lambda(f)
806 : : // need to infer f = g => g(t) = f(t) for all terms g(t)
807 : : // that occur in the equality engine.
808 : 6867 : Node premise = op.eqNode(lf);
809 : 6867 : args.insert(args.begin(), lam);
810 : 6867 : Node rhs = nm->mkNode(n.getKind(), args);
811 : 6867 : rhs = rewrite(rhs);
812 : 6867 : Node conc = n.eqNode(rhs);
813 : 13734 : Node lem = nm->mkNode(Kind::IMPLIES, premise, conc);
814 [ + + ]: 6867 : if (cacheLemma(lem))
815 : : {
816 : 6071 : d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_APP_REDUCE);
817 : 6071 : numLemmas++;
818 : : }
819 [ + + ][ + + ]: 284631 : }
[ + + ]
820 : 23251 : }
821 : 105 : return numLemmas;
822 : 6553 : }
823 : :
824 : 7199 : unsigned HoExtension::check()
825 : : {
826 [ + - ]: 7199 : Trace("uf-ho") << "HoExtension::checkHigherOrder..." << std::endl;
827 : :
828 : : // infer new facts based on apply completion until fixed point
829 : : unsigned num_facts;
830 : : do
831 : : {
832 : 56130 : num_facts = checkAppCompletion();
833 [ + + ]: 56130 : if (d_state.isInConflict())
834 : : {
835 [ + - ]: 120 : Trace("uf-ho") << "...conflict during app-completion." << std::endl;
836 : 120 : return 1;
837 : : }
838 [ + + ]: 56010 : } while (num_facts > 0);
839 : :
840 : : // Apply extensionality, lazy lambda schemas in order. We make lazy lambda
841 : : // handling come last as it may introduce quantifiers.
842 [ + + ]: 20955 : for (size_t i = 0; i < 2; i++)
843 : : {
844 : 14070 : unsigned num_lemmas = 0;
845 : : // apply the schema
846 [ + + ][ - ]: 14070 : switch (i)
847 : : {
848 : 7079 : case 0: num_lemmas = checkExtensionality(); break;
849 : 6991 : case 1: num_lemmas = checkLazyLambda(); break;
850 : 0 : default: break;
851 : : }
852 : : // finish if we added lemmas
853 [ + + ]: 14070 : if (num_lemmas > 0)
854 : : {
855 [ + - ]: 194 : Trace("uf-ho") << "...returned " << num_lemmas << " lemmas." << std::endl;
856 : 194 : return num_lemmas;
857 : : }
858 : : }
859 : :
860 [ + - ]: 6885 : Trace("uf-ho") << "...finished check higher order." << std::endl;
861 : :
862 : 6885 : return 0;
863 : : }
864 : :
865 : 2459 : bool HoExtension::collectModelInfoHo(TheoryModel* m,
866 : : const std::set<Node>& termSet)
867 : : {
868 [ + + ]: 72325 : for (std::set<Node>::iterator it = termSet.begin(); it != termSet.end(); ++it)
869 : : {
870 : 69866 : Node n = *it;
871 : : // For model-building with higher-order, we require that APPLY_UF is always
872 : : // expanded to HO_APPLY. That is, we always expand to a fully applicative
873 : : // encoding during model construction.
874 [ - + ]: 69866 : if (!collectModelInfoHoTerm(n, m))
875 : : {
876 : 0 : return false;
877 : : }
878 [ + - ]: 69866 : }
879 : : // We apply an explicit extensionality technique for asserting
880 : : // disequalities to the model to ensure that function values are distinct
881 : : // in the curried HO_APPLY version of model construction. This is a
882 : : // non-standard alternative to using a type enumerator over function
883 : : // values to assign unique values.
884 : 2459 : int addedLemmas = checkExtensionality(m);
885 : : // for equivalence classes that we know to assign a lambda directly
886 [ + + ]: 2965 : for (const std::pair<const Node, Node>& p : d_lambdaEqc)
887 : : {
888 : 506 : Node lam = d_ll.getLambdaFor(p.second);
889 : 506 : lam = rewrite(lam);
890 [ - + ][ - + ]: 506 : Assert(!lam.isNull());
[ - - ]
891 : 506 : m->assertEquality(p.second, lam, true);
892 : 506 : m->assertSkeleton(lam);
893 : : // we don't assign the function definition here, which is handled internally
894 : : // in the model builder.
895 : 506 : }
896 : 2459 : return addedLemmas == 0;
897 : : }
898 : :
899 : 168094 : bool HoExtension::collectModelInfoHoTerm(Node n, TheoryModel* m)
900 : : {
901 [ + + ]: 168094 : if (n.getKind() == Kind::APPLY_UF)
902 : : {
903 : 118325 : Node hn = TheoryUfRewriter::getHoApplyForApplyUf(n);
904 [ - + ]: 118325 : if (!m->assertEquality(n, hn, true))
905 : : {
906 : 0 : Node eq = n.eqNode(hn);
907 [ - - ]: 0 : Trace("uf-ho") << "HoExtension: cmi app completion lemma " << eq
908 : 0 : << std::endl;
909 : 0 : d_im.lemma(eq, InferenceId::UF_HO_MODEL_APP_ENCODE);
910 : 0 : return false;
911 : 0 : }
912 : : // also add all subterms
913 : 118325 : eq::EqualityEngine* ee = m->getEqualityEngine();
914 [ + + ]: 342059 : while (hn.getKind() == Kind::HO_APPLY)
915 : : {
916 : 223734 : ee->addTerm(hn);
917 : 223734 : ee->addTerm(hn[1]);
918 : 223734 : hn = hn[0];
919 : : }
920 [ + - ]: 118325 : }
921 : 168094 : return true;
922 : : }
923 : :
924 : 6911 : bool HoExtension::cacheLemma(TNode lem)
925 : : {
926 : 6911 : Node rewritten = rewrite(lem);
927 [ + + ]: 6911 : if (d_cachedLemmas.find(rewritten) != d_cachedLemmas.end())
928 : : {
929 : 796 : return false;
930 : : }
931 : 6115 : d_cachedLemmas.insert(rewritten);
932 : 6115 : return true;
933 : 6911 : }
934 : :
935 : : } // namespace uf
936 : : } // namespace theory
937 : : } // namespace cvc5::internal
|