Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of the higher-order extension of TheoryUF.
11 : : */
12 : :
13 : : #include "theory/uf/ho_extension.h"
14 : :
15 : : #include "expr/node_algorithm.h"
16 : : #include "expr/skolem_manager.h"
17 : : #include "options/uf_options.h"
18 : : #include "theory/smt_engine_subsolver.h"
19 : : #include "theory/theory_model.h"
20 : : #include "theory/uf/function_const.h"
21 : : #include "theory/uf/lambda_lift.h"
22 : : #include "theory/uf/theory_uf_rewriter.h"
23 : : #include "util/rational.h"
24 : :
25 : : using namespace std;
26 : : using namespace cvc5::internal::kind;
27 : :
28 : : namespace cvc5::internal {
29 : : namespace theory {
30 : : namespace uf {
31 : :
32 : 1201 : HoExtension::HoExtension(Env& env,
33 : : TheoryState& state,
34 : : TheoryInferenceManager& im,
35 : 1201 : LambdaLift& ll)
36 : : : EnvObj(env),
37 : 1201 : d_state(state),
38 : 1201 : d_im(im),
39 : 1201 : d_ll(ll),
40 : 1201 : d_extensionality(userContext()),
41 : 1201 : d_cachedLemmas(userContext()),
42 : 1201 : d_uf_std_skolem(userContext()),
43 : 3603 : d_lamEqProcessed(userContext())
44 : : {
45 : 1201 : d_true = nodeManager()->mkConst(true);
46 : : // don't send true lemma
47 : 1201 : d_cachedLemmas.insert(d_true);
48 : 1201 : }
49 : :
50 : 47916 : TrustNode HoExtension::ppRewrite(Node node, std::vector<SkolemLemma>& lems)
51 : : {
52 : 47916 : Kind k = node.getKind();
53 [ + + ]: 47916 : if (k == Kind::HO_APPLY)
54 : : {
55 : : // convert HO_APPLY to APPLY_UF if fully applied
56 [ + + ]: 4701 : if (node[0].getType().getNumChildren() == 2)
57 : : {
58 [ + - ]: 3628 : Trace("uf-ho") << "uf-ho : expanding definition : " << node << std::endl;
59 : 3628 : Node ret = getApplyUfForHoApply(node);
60 [ + - ]: 7256 : Trace("uf-ho") << "uf-ho : ppRewrite : " << node << " to " << ret
61 : 3628 : << std::endl;
62 : 3628 : return TrustNode::mkTrustRewrite(node, ret);
63 : 3628 : }
64 : : // partial beta reduction
65 : : // f ---> (lambda ((x Int) (y Int)) s[x, y]) then (@ f t) is preprocessed
66 : : // to (lambda ((y Int)) s[t, y]).
67 [ + - ]: 1073 : if (options().uf.ufHoLazyLambdaLift)
68 : : {
69 : 1073 : Node op = node[0];
70 : 1073 : Node opl = d_ll.getLambdaFor(op);
71 [ - + ][ - - ]: 1073 : if (!opl.isNull() && !d_ll.isLifted(opl))
[ - + ]
72 : : {
73 : 0 : NodeManager* nm = nodeManager();
74 : 0 : Node app = nm->mkNode(Kind::HO_APPLY, opl, node[1]);
75 : 0 : app = rewrite(app);
76 [ - - ]: 0 : Trace("uf-lazy-ll")
77 : 0 : << "Partial beta reduce: " << node << " -> " << app << std::endl;
78 : 0 : return TrustNode::mkTrustRewrite(node, app, nullptr);
79 : 0 : }
80 [ + - ][ + - ]: 1073 : }
81 : : }
82 [ + + ]: 43215 : else if (k == Kind::APPLY_UF)
83 : : {
84 : : // Say (lambda ((x Int)) t[x]) occurs in the input. We replace this
85 : : // by k during ppRewrite. In the following, if we see (k s), we replace
86 : : // it by t[s]. This maintains the invariant that the *only* occurrences
87 : : // of k are as arguments to other functions; k is not applied
88 : : // in any preprocessed constraints.
89 [ + + ]: 31657 : if (options().uf.ufHoLazyLambdaLift)
90 : : {
91 : : // if an application of the lambda lifted function, do beta reduction
92 : : // immediately
93 : 31541 : Node op = node.getOperator();
94 : 31541 : Node opl = d_ll.getLambdaFor(op);
95 [ + + ][ + + ]: 31541 : if (!opl.isNull() && !d_ll.isLifted(opl))
[ + + ]
96 : : {
97 [ - + ][ - + ]: 798 : Assert(opl.getKind() == Kind::LAMBDA);
[ - - ]
98 : 798 : std::vector<Node> args(node.begin(), node.end());
99 : 798 : Node app = d_ll.betaReduce(opl, args);
100 [ + - ]: 1596 : Trace("uf-lazy-ll")
101 : 798 : << "Beta reduce: " << node << " -> " << app << std::endl;
102 : 798 : return TrustNode::mkTrustRewrite(node, app, nullptr);
103 : 798 : }
104 : : // If an unlifted lambda occurs in an argument to APPLY_UF, it must be
105 : : // lifted. We do this only if the lambda needs lifting, i.e. it is one
106 : : // that may induce circular model dependencies.
107 [ + + ]: 78455 : for (const Node& nc : node)
108 : : {
109 [ + + ]: 47712 : if (nc.getType().isFunction())
110 : : {
111 : 2695 : Node lam = d_ll.getLambdaFor(nc);
112 [ + + ][ + + ]: 2695 : if (!lam.isNull() && d_ll.needsLift(lam))
[ + + ]
113 : : {
114 : 789 : TrustNode trn = d_ll.lift(lam);
115 [ + + ]: 789 : if (!trn.isNull())
116 : : {
117 : 155 : lems.push_back(SkolemLemma(trn, nc));
118 : : }
119 : 789 : }
120 : 2695 : }
121 : 47712 : }
122 [ + + ][ + + ]: 32339 : }
123 : : }
124 [ + + ][ + + ]: 11558 : else if (k == Kind::LAMBDA || k == Kind::FUNCTION_ARRAY_CONST)
125 : : {
126 [ + - ]: 546 : Trace("uf-lazy-ll") << "Preprocess lambda: " << node << std::endl;
127 [ + + ]: 546 : if (k == Kind::LAMBDA)
128 : : {
129 : 480 : Node elimLam = TheoryUfRewriter::canEliminateLambda(nodeManager(), node);
130 [ - + ]: 480 : if (!elimLam.isNull())
131 : : {
132 [ - - ]: 0 : Trace("uf-lazy-ll") << "...eliminates to " << elimLam << std::endl;
133 : 0 : return TrustNode::mkTrustRewrite(node, elimLam, nullptr);
134 : : }
135 [ + - ]: 480 : }
136 : 546 : TrustNode skTrn = d_ll.ppRewrite(node, lems);
137 [ + - ][ - + ]: 546 : Trace("uf-lazy-ll") << "...return " << skTrn.getNode() << std::endl;
[ - - ]
138 : 546 : return skTrn;
139 : 546 : }
140 : 42944 : return TrustNode::null();
141 : : }
142 : :
143 : 52772 : Node HoExtension::getExtensionalityDeq(TNode deq, bool isCached)
144 : : {
145 : 52772 : Assert(deq.getKind() == Kind::NOT && deq[0].getKind() == Kind::EQUAL);
146 [ - + ][ - + ]: 52772 : Assert(deq[0][0].getType().isFunction());
[ - - ]
147 [ + + ]: 52772 : if (isCached)
148 : : {
149 : 3821 : std::map<Node, Node>::iterator it = d_extensionality_deq.find(deq);
150 [ - + ]: 3821 : if (it != d_extensionality_deq.end())
151 : : {
152 : 0 : return it->second;
153 : : }
154 : : }
155 : 105544 : TypeNode tn = deq[0][0].getType();
156 : 52772 : std::vector<TypeNode> argTypes = tn.getArgTypes();
157 : 52772 : std::vector<Node> skolems;
158 : 52772 : NodeManager* nm = nodeManager();
159 : 52772 : SkolemManager* sm = nm->getSkolemManager();
160 : 52772 : std::vector<Node> cacheVals;
161 : 52772 : cacheVals.push_back(deq[0][0]);
162 : 52772 : cacheVals.push_back(deq[0][1]);
163 : 52772 : cacheVals.push_back(Node::null());
164 [ + + ]: 105645 : for (unsigned i = 0, nargs = argTypes.size(); i < nargs; i++)
165 : : {
166 : 52873 : cacheVals[2] = nm->mkConstInt(Rational(i));
167 : 52873 : Node k = sm->mkSkolemFunction(SkolemId::HO_DEQ_DIFF, cacheVals);
168 : 52873 : skolems.push_back(k);
169 : 52873 : }
170 [ + + ]: 316632 : Node t[2];
171 [ + + ]: 158316 : for (unsigned i = 0; i < 2; i++)
172 : : {
173 : 105544 : std::vector<Node> children;
174 : 211088 : Node curr = deq[0][i];
175 [ + + ]: 202401 : while (curr.getKind() == Kind::HO_APPLY)
176 : : {
177 : 96857 : children.push_back(curr[1]);
178 : 96857 : curr = curr[0];
179 : : }
180 : 105544 : children.push_back(curr);
181 : 105544 : std::reverse(children.begin(), children.end());
182 : 105544 : children.insert(children.end(), skolems.begin(), skolems.end());
183 : 105544 : t[i] = nm->mkNode(Kind::APPLY_UF, children);
184 : 105544 : }
185 : 52772 : Node conc = t[0].eqNode(t[1]).negate();
186 [ + + ]: 52772 : if (isCached)
187 : : {
188 : 3821 : d_extensionality_deq[deq] = conc;
189 : : }
190 : 52772 : return conc;
191 [ + + ][ - - ]: 263860 : }
192 : :
193 : 35098 : unsigned HoExtension::applyExtensionality(TNode deq)
194 : : {
195 : 35098 : Assert(deq.getKind() == Kind::NOT && deq[0].getKind() == Kind::EQUAL);
196 [ - + ][ - + ]: 35098 : Assert(deq[0][0].getType().isFunction());
[ - - ]
197 : : // apply extensionality
198 [ + + ]: 35098 : if (d_extensionality.find(deq) == d_extensionality.end())
199 : : {
200 : 3821 : d_extensionality.insert(deq);
201 : 3821 : Node conc = getExtensionalityDeq(deq);
202 : 7642 : Node lem = nodeManager()->mkNode(Kind::OR, deq[0], conc);
203 [ + - ]: 7642 : Trace("uf-ho-lemma") << "uf-ho-lemma : extensionality : " << lem
204 : 3821 : << std::endl;
205 : 3821 : d_im.lemma(lem, InferenceId::UF_HO_EXTENSIONALITY);
206 : 3821 : return 1;
207 : 3821 : }
208 : 31277 : return 0;
209 : : }
210 : :
211 : 3628 : Node HoExtension::getApplyUfForHoApply(Node node)
212 : : {
213 [ - + ][ - + ]: 3628 : Assert(node[0].getType().getNumChildren() == 2);
[ - - ]
214 : 3628 : std::vector<TNode> args;
215 : 3628 : Node f = TheoryUfRewriter::decomposeHoApply(node, args, true);
216 : 3628 : Node new_f = f;
217 : 3628 : NodeManager* nm = nodeManager();
218 [ - + ]: 3628 : if (!TheoryUfRewriter::canUseAsApplyUfOperator(f))
219 : : {
220 : 0 : NodeNodeMap::const_iterator itus = d_uf_std_skolem.find(f);
221 [ - - ]: 0 : if (itus == d_uf_std_skolem.end())
222 : : {
223 : 0 : std::unordered_set<Node> fvs;
224 : 0 : expr::getFreeVariables(f, fvs);
225 : 0 : Node lem;
226 [ - - ]: 0 : if (!fvs.empty())
227 : : {
228 : 0 : std::vector<TypeNode> newTypes;
229 : 0 : std::vector<Node> vs;
230 : 0 : std::vector<Node> nvs;
231 [ - - ]: 0 : for (const Node& v : fvs)
232 : : {
233 : 0 : TypeNode vt = v.getType();
234 : 0 : newTypes.push_back(vt);
235 : 0 : Node nv = NodeManager::mkBoundVar(vt);
236 : 0 : vs.push_back(v);
237 : 0 : nvs.push_back(nv);
238 : 0 : }
239 : 0 : TypeNode ft = f.getType();
240 : 0 : std::vector<TypeNode> argTypes = ft.getArgTypes();
241 : 0 : TypeNode rangeType = ft.getRangeType();
242 : :
243 : 0 : newTypes.insert(newTypes.end(), argTypes.begin(), argTypes.end());
244 : 0 : TypeNode nft = nm->mkFunctionType(newTypes, rangeType);
245 : 0 : new_f = NodeManager::mkDummySkolem("app_uf", nft);
246 [ - - ]: 0 : for (const Node& v : vs)
247 : : {
248 : 0 : new_f = nm->mkNode(Kind::HO_APPLY, new_f, v);
249 : : }
250 : 0 : Assert(new_f.getType() == f.getType());
251 : 0 : Node eq = new_f.eqNode(f);
252 : 0 : Node seq = eq.substitute(vs.begin(), vs.end(), nvs.begin(), nvs.end());
253 : 0 : lem = nm->mkNode(
254 : 0 : Kind::FORALL, nm->mkNode(Kind::BOUND_VAR_LIST, nvs), seq);
255 : 0 : }
256 : : else
257 : : {
258 : : // introduce skolem to make a standard APPLY_UF
259 : 0 : new_f = NodeManager::mkDummySkolem("app_uf", f.getType());
260 : 0 : lem = new_f.eqNode(f);
261 : : }
262 [ - - ]: 0 : Trace("uf-ho-lemma")
263 : 0 : << "uf-ho-lemma : Skolem definition for apply-conversion : " << lem
264 : 0 : << std::endl;
265 : 0 : d_im.lemma(lem, InferenceId::UF_HO_APP_CONV_SKOLEM);
266 : 0 : d_uf_std_skolem[f] = new_f;
267 : 0 : }
268 : : else
269 : : {
270 : 0 : new_f = (*itus).second;
271 : : }
272 : : // unroll the HO_APPLY, adding to the first argument position
273 : : // Note arguments in the vector args begin at position 1.
274 [ - - ]: 0 : while (new_f.getKind() == Kind::HO_APPLY)
275 : : {
276 : 0 : args.insert(args.begin() + 1, new_f[1]);
277 : 0 : new_f = new_f[0];
278 : : }
279 : : }
280 [ - + ][ - + ]: 3628 : Assert(TheoryUfRewriter::canUseAsApplyUfOperator(new_f));
[ - - ]
281 : 3628 : args[0] = new_f;
282 : 3628 : Node ret = nm->mkNode(Kind::APPLY_UF, args);
283 [ - + ][ - + ]: 3628 : Assert(ret.getType() == node.getType());
[ - - ]
284 : 7256 : return ret;
285 : 3628 : }
286 : :
287 : 1274 : void HoExtension::computeRelevantTerms(std::set<Node>& termSet)
288 : : {
289 [ + + ]: 47426 : for (const Node& t : termSet)
290 : : {
291 [ + + ]: 46152 : if (t.getKind() == Kind::APPLY_UF)
292 : : {
293 : 15405 : Node ht = TheoryUfRewriter::getHoApplyForApplyUf(t);
294 : : // also add all subterms
295 [ + + ]: 38863 : while (ht.getKind()==Kind::HO_APPLY)
296 : : {
297 : 23458 : termSet.insert(ht);
298 : 23458 : termSet.insert(ht[1]);
299 : 23458 : ht = ht[0];
300 : : }
301 : 15405 : }
302 : : }
303 : 1274 : }
304 : :
305 : 6621 : unsigned HoExtension::checkExtensionality(TheoryModel* m)
306 : : {
307 : : // if we are in collect model info, we require looking at the model's
308 : : // equality engine, so that we only consider "relevant" (see
309 : : // Theory::computeRelevantTerms) function terms.
310 : : eq::EqualityEngine* ee =
311 [ + + ]: 6621 : m != nullptr ? m->getEqualityEngine() : d_state.getEqualityEngine();
312 : 6621 : NodeManager* nm = nodeManager();
313 : 6621 : unsigned num_lemmas = 0;
314 : 6621 : bool isCollectModel = (m != nullptr);
315 [ + - ]: 13242 : Trace("uf-ho") << "HoExtension::checkExtensionality, collectModel="
316 : 6621 : << isCollectModel << "..." << std::endl;
317 : 6621 : std::map<TypeNode, std::vector<Node> > func_eqcs;
318 : 6621 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
319 : 6621 : bool hasFunctions = false;
320 [ + + ]: 137735 : while (!eqcs_i.isFinished())
321 : : {
322 : 131114 : Node eqc = (*eqcs_i);
323 : 131114 : TypeNode tn = eqc.getType();
324 [ + + ][ + + ]: 131114 : if (tn.isFunction() && d_lambdaEqc.find(eqc) == d_lambdaEqc.end())
[ + + ]
325 : : {
326 : 21140 : hasFunctions = true;
327 : 21140 : std::vector<TypeNode> argTypes = tn.getArgTypes();
328 : : // We classify a function here to determine whether we need to apply
329 : : // extensionality eagerly during solving. We apply extensionality
330 : : // eagerly during solving if
331 : : // (A) The function type has finite cardinality,
332 : : // (B) All of its arguments have finite cardinality, or
333 : : // (C) It has a function as an argument.
334 : : // The latter is required so that we recursively consider extensionality
335 : : // between function constants introduced for extensionality lemmas.
336 : 21140 : bool eagerExtType = true;
337 [ + + ]: 21140 : if (!d_env.isFiniteType(tn))
338 : : {
339 [ + + ]: 41485 : for (const TypeNode& tna : argTypes)
340 : : {
341 [ + + ]: 23938 : if (!d_env.isFiniteType(tna))
342 : : {
343 : 23306 : eagerExtType = false;
344 : : }
345 [ + + ]: 23938 : if (tna.isFunction())
346 : : {
347 : 2100 : eagerExtType = true;
348 : 2100 : break;
349 : : }
350 : : }
351 : : }
352 : : // Based on the above classification of finite vs infinite.
353 : : // If during collect model, must have an infinite function type, since
354 : : // such function are not necessary to be handled during solving.
355 : : // If not during collect model, must have a finite function type, since
356 : : // such function symbols must be handled during solving.
357 [ + + ]: 21140 : if (eagerExtType != isCollectModel)
358 : : {
359 : 7105 : func_eqcs[tn].push_back(eqc);
360 [ + - ]: 14210 : Trace("uf-ho-debug")
361 : 7105 : << " func eqc : " << tn << " : " << eqc << std::endl;
362 : : }
363 : 21140 : }
364 : 131114 : ++eqcs_i;
365 : 131114 : }
366 [ - + ]: 6621 : if (!options().uf.ufHoExt)
367 : : {
368 : : // we are not applying extensionality, thus we are incomplete if functions
369 : : // are present
370 [ - - ]: 0 : if (hasFunctions)
371 : : {
372 : 0 : d_im.setModelUnsound(IncompleteId::UF_HO_EXT_DISABLED);
373 : : }
374 : 0 : return 0;
375 : : }
376 : :
377 : 6621 : for (std::map<TypeNode, std::vector<Node> >::iterator itf = func_eqcs.begin();
378 [ + + ]: 11033 : itf != func_eqcs.end();
379 : 4412 : ++itf)
380 : : {
381 [ + + ]: 11494 : for (unsigned j = 0, sizej = itf->second.size(); j < sizej; j++)
382 : : {
383 [ + + ]: 57580 : for (unsigned k = (j + 1), sizek = itf->second.size(); k < sizek; k++)
384 : : {
385 : : // if these equivalence classes are not explicitly disequal, do
386 : : // extensionality to ensure distinctness. Notice that we always use
387 : : // the (local) equality engine for this check via the state, since the
388 : : // model's equality engine does not store any disequalities. This is
389 : : // an optimization to introduce fewer equalities during model
390 : : // construction, since we know such disequalities have already been
391 : : // witness via assertions.
392 [ + + ]: 50507 : if (!d_state.areDisequal(itf->second[j], itf->second[k]))
393 : : {
394 : 98402 : Node deq = rewrite(itf->second[j].eqNode(itf->second[k]).negate());
395 : : // either add to model, or add lemma
396 [ + + ]: 49201 : if (isCollectModel)
397 : : {
398 : : // Add extentionality disequality to the model.
399 : : // It is important that we construct new (unconstrained) variables
400 : : // k here, so that we do not generate any inconsistencies.
401 : 48951 : Node edeq = getExtensionalityDeq(deq, false);
402 : 48951 : Assert(edeq.getKind() == Kind::NOT
403 : : && edeq[0].getKind() == Kind::EQUAL);
404 : : // introducing terms, must add required constraints, e.g. to
405 : : // force equalities between APPLY_UF and HO_APPLY terms
406 : 48951 : bool success = true;
407 [ + + ]: 146851 : for (unsigned r = 0; r < 2; r++)
408 : : {
409 [ - + ]: 97902 : if (!collectModelInfoHoTerm(edeq[0][r], m))
410 : : {
411 : 0 : return 1;
412 : : }
413 : : // Ensure finite skolems are set to arbitrary values eagerly.
414 : : // This ensures that partial function applications are identified
415 : : // with one another based on this assignment.
416 [ + + ]: 291743 : for (const Node& hk : edeq[0][r])
417 : : {
418 : 193843 : TypeNode tnk = hk.getType();
419 [ + + ]: 193843 : if (d_env.isFiniteType(tnk))
420 : : {
421 : 2 : TypeEnumerator te(tnk);
422 : 2 : Node v = *te;
423 [ + - ]: 2 : if (!m->assertEquality(hk, v, true))
424 : : {
425 : 2 : success = false;
426 : 2 : break;
427 : : }
428 [ - + ][ - + ]: 4 : }
429 [ + + ][ + + ]: 291747 : }
430 [ + + ]: 97902 : if (!success)
431 : : {
432 : 2 : break;
433 : : }
434 : : }
435 [ + + ]: 48951 : if (success)
436 : : {
437 : 97898 : TypeNode tn = edeq[0][0].getType();
438 [ + - ]: 97898 : Trace("uf-ho-debug")
439 : 0 : << "Add extensionality deq to model for : " << edeq
440 : 48949 : << std::endl;
441 [ + + ]: 48949 : if (d_env.isFiniteType(tn))
442 : : {
443 : : // We are an infinite function type with a finite range sort.
444 : : // Model construction assigns the first value for all
445 : : // unconstrained variables for such sorts, which does not
446 : : // suffice in this context since we are trying to make the
447 : : // functions disequal. Thus, for such case we enumerate the first
448 : : // two values for this sort and set the extensionality index to
449 : : // be equal to these two distinct values. There must be at least
450 : : // two values since this is an infinite function sort.
451 : 41945 : TypeEnumerator te(tn);
452 : 41945 : Node v1 = *te;
453 : 41945 : te++;
454 : 41945 : Node v2 = *te;
455 [ + - ][ + - ]: 41945 : Assert(!v2.isNull() && v2 != v1);
[ - + ][ - + ]
[ - - ]
456 : 41945 : Trace("uf-ho-debug") << "Finite witness: " << edeq[0][0] << " == " << v1 << std::endl;
457 : 41945 : Trace("uf-ho-debug") << "Finite witness: " << edeq[0][1] << " == " << v2 << std::endl;
458 : 41945 : success = m->assertEquality(edeq[0][0], v1, true);
459 [ + - ]: 41945 : if (success)
460 : : {
461 : 41945 : success = m->assertEquality(edeq[0][1], v2, true);
462 : : }
463 : 41945 : }
464 : 48949 : }
465 [ + + ]: 48951 : if (!success)
466 : : {
467 : 18 : Node eq = edeq[0][0].eqNode(edeq[0][1]);
468 : 18 : Node lem = nm->mkNode(Kind::OR, deq.negate(), eq.negate());
469 [ + - ]: 18 : Trace("uf-ho") << "HoExtension: cmi extensionality lemma " << lem
470 : 9 : << std::endl;
471 : 9 : d_im.lemma(lem, InferenceId::UF_HO_MODEL_EXTENSIONALITY);
472 : 9 : return 1;
473 : 9 : }
474 [ + + ]: 48951 : }
475 : : else
476 : : {
477 : : // apply extensionality lemma
478 : 250 : num_lemmas += applyExtensionality(deq);
479 : : }
480 [ + + ]: 49201 : }
481 : : }
482 : : }
483 : : }
484 : 6612 : return num_lemmas;
485 : 6621 : }
486 : :
487 : 10669777 : unsigned HoExtension::applyAppCompletion(TNode n)
488 : : {
489 [ - + ][ - + ]: 10669777 : Assert(n.getKind() == Kind::APPLY_UF);
[ - - ]
490 : :
491 : 10669777 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
492 : : // must expand into APPLY_HO version if not there already
493 : 10669777 : Node ret = TheoryUfRewriter::getHoApplyForApplyUf(n);
494 : 10669777 : if (!ee->hasTerm(ret) || !ee->areEqual(ret, n))
495 : : {
496 : 45912 : Node eq = n.eqNode(ret);
497 [ + - ]: 91824 : Trace("uf-ho-lemma") << "uf-ho-lemma : infer, by apply-expand : " << eq
498 : 45912 : << std::endl;
499 : 91824 : d_im.assertInternalFact(eq,
500 : : true,
501 : : InferenceId::UF_HO_APP_ENCODE,
502 : : ProofRule::HO_APP_ENCODE,
503 : : {},
504 : : {n});
505 : 45912 : return 1;
506 : 45912 : }
507 [ + - ]: 21247730 : Trace("uf-ho-debug") << " ...already have " << ret << " == " << n << "."
508 : 10623865 : << std::endl;
509 : 10623865 : return 0;
510 : 10669777 : }
511 : :
512 : 51259 : unsigned HoExtension::checkAppCompletion()
513 : : {
514 [ + - ]: 51259 : Trace("uf-ho") << "HoExtension::checkApplyCompletion..." << std::endl;
515 : : // compute the operators that are relevant (those for which an HO_APPLY exist)
516 : 51259 : std::set<TNode> rlvOp;
517 : 51259 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
518 : 51259 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
519 : 51259 : std::map<TNode, std::vector<Node> > apply_uf;
520 [ + + ]: 562782 : while (!eqcs_i.isFinished())
521 : : {
522 : 557435 : Node eqc = (*eqcs_i);
523 [ + - ]: 1114870 : Trace("uf-ho-debug") << " apply completion : visit eqc " << eqc
524 : 557435 : << std::endl;
525 : 557435 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
526 [ + + ]: 42913049 : while (!eqc_i.isFinished())
527 : : {
528 : 42401526 : Node n = *eqc_i;
529 [ + + ][ + + ]: 42401526 : if (n.getKind() == Kind::APPLY_UF || n.getKind() == Kind::HO_APPLY)
[ + + ]
530 : : {
531 : 36419083 : int curr_sum = 0;
532 : 36419083 : std::map<TNode, bool> curr_rops;
533 [ + + ]: 36419083 : if (n.getKind() == Kind::APPLY_UF)
534 : : {
535 : 45668194 : TNode rop = ee->getRepresentative(n.getOperator());
536 [ + + ]: 22834097 : if (rlvOp.find(rop) != rlvOp.end())
537 : : {
538 : : // try if its operator is relevant
539 : 10470932 : curr_sum = applyAppCompletion(n);
540 [ + + ]: 10470932 : if (curr_sum > 0)
541 : : {
542 : 41177 : return curr_sum;
543 : : }
544 : : }
545 : : else
546 : : {
547 : : // add to pending list
548 : 12363165 : apply_uf[rop].push_back(n);
549 : : }
550 : : // Arguments are also relevant operators.
551 : : // It might be possible include fewer terms here, see #1115.
552 [ + + ]: 52324792 : for (unsigned k = 0; k < n.getNumChildren(); k++)
553 : : {
554 [ + + ]: 29531872 : if (n[k].getType().isFunction())
555 : : {
556 : 1046038 : TNode rop2 = ee->getRepresentative(n[k]);
557 : 523019 : curr_rops[rop2] = true;
558 : 523019 : }
559 : : }
560 [ + + ]: 22834097 : }
561 : : else
562 : : {
563 [ - + ][ - + ]: 13584986 : Assert(n.getKind() == Kind::HO_APPLY);
[ - - ]
564 : 27169972 : TNode rop = ee->getRepresentative(n[0]);
565 : 13584986 : curr_rops[rop] = true;
566 : 13584986 : }
567 : 36377906 : for (std::map<TNode, bool>::iterator itc = curr_rops.begin();
568 [ + + ]: 50481149 : itc != curr_rops.end();
569 : 14103243 : ++itc)
570 : : {
571 : 14107978 : TNode rop = itc->first;
572 [ + + ]: 14107978 : if (rlvOp.find(rop) == rlvOp.end())
573 : : {
574 : 341843 : rlvOp.insert(rop);
575 : : // now, try each pending APPLY_UF for this operator
576 : : std::map<TNode, std::vector<Node> >::iterator itu =
577 : 341843 : apply_uf.find(rop);
578 [ + + ]: 341843 : if (itu != apply_uf.end())
579 : : {
580 [ + + ]: 209008 : for (unsigned j = 0, size = itu->second.size(); j < size; j++)
581 : : {
582 : 198845 : curr_sum = applyAppCompletion(itu->second[j]);
583 [ + + ]: 198845 : if (curr_sum > 0)
584 : : {
585 : 4735 : return curr_sum;
586 : : }
587 : : }
588 : : }
589 : : }
590 [ + + ]: 14107978 : }
591 [ + + ]: 36419083 : }
592 : 42355614 : ++eqc_i;
593 [ + + ]: 42401526 : }
594 : 511523 : ++eqcs_i;
595 [ + + ]: 557435 : }
596 : 5347 : return 0;
597 : 51259 : }
598 : :
599 : 5266 : unsigned HoExtension::checkLazyLambda()
600 : : {
601 [ + + ]: 5266 : if (!options().uf.ufHoLazyLambdaLift)
602 : : {
603 : : // no lambdas are lazily lifted
604 : 438 : return 0;
605 : : }
606 [ + - ]: 4828 : Trace("uf-ho") << "HoExtension::checkLazyLambda..." << std::endl;
607 : 4828 : NodeManager* nm = nodeManager();
608 : 4828 : unsigned numLemmas = 0;
609 : 4828 : d_lambdaEqc.clear();
610 : 4828 : eq::EqualityEngine* ee = d_state.getEqualityEngine();
611 : 4828 : eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
612 : : // normal functions equated to lambda functions
613 : 4828 : std::unordered_set<Node> normalEqFuns;
614 : : // mapping from functions to terms
615 [ + + ]: 113552 : while (!eqcs_i.isFinished())
616 : : {
617 : 108724 : Node eqc = (*eqcs_i);
618 : 108724 : ++eqcs_i;
619 [ + + ]: 108724 : if (!eqc.getType().isFunction())
620 : : {
621 : 86763 : continue;
622 : : }
623 : 21961 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
624 : 21961 : Node lamRep; // the first lambda function we encounter in the equivalence
625 : : // class
626 : 21961 : bool needsLift = false;
627 : 21961 : bool doLift = false;
628 : 21961 : Node lamRepLam;
629 : 21961 : std::unordered_set<Node> normalEqFunWait;
630 [ + + ]: 49709 : while (!eqc_i.isFinished())
631 : : {
632 : 27748 : Node n = *eqc_i;
633 : 27748 : ++eqc_i;
634 : 27748 : Node lam = d_ll.getLambdaFor(n);
635 [ + + ][ + + ]: 27748 : if (lam.isNull() || d_ll.isLifted(lam))
[ + + ]
636 : : {
637 [ + + ]: 20039 : if (!lamRep.isNull())
638 : : {
639 : : // if we are equal to a lambda function, we must beta-reduce
640 : : // applications of this
641 : 53 : normalEqFuns.insert(n);
642 : 53 : doLift = needsLift;
643 : : }
644 : : else
645 : : {
646 : : // waiting to see if there is a lambda function in this equivalence
647 : : // class
648 : 19986 : normalEqFunWait.insert(n);
649 : : }
650 : : }
651 [ + + ]: 7709 : else if (lamRep.isNull())
652 : : {
653 : : // there is a lambda function in this equivalence class
654 : 7669 : lamRep = n;
655 : 7669 : lamRepLam = lam;
656 [ + + ][ + - ]: 7669 : needsLift = d_ll.needsLift(lam) && !d_ll.isLifted(lam);
657 [ + + ][ + + ]: 7669 : doLift = needsLift && !normalEqFunWait.empty();
658 : : // must consider all normal functions we've seen so far
659 : 7669 : normalEqFuns.insert(normalEqFunWait.begin(), normalEqFunWait.end());
660 : 7669 : normalEqFunWait.clear();
661 : : }
662 : : else
663 : : {
664 : : // two lambda functions are in same equivalence class
665 [ + - ]: 40 : Node f = lamRep < n ? lamRep : n;
666 [ + - ]: 40 : Node g = lamRep < n ? n : lamRep;
667 : : // swap based on order
668 [ - + ]: 40 : if (g<f)
669 : : {
670 : 0 : Node tmp = f;
671 : 0 : f = g;
672 : 0 : g = tmp;
673 : 0 : }
674 : 40 : Node fgEq = f.eqNode(g);
675 [ + + ]: 40 : if (d_lamEqProcessed.find(fgEq)!=d_lamEqProcessed.end())
676 : : {
677 : 18 : continue;
678 : : }
679 : 22 : d_lamEqProcessed.insert(fgEq);
680 : :
681 [ + - ]: 44 : Trace("uf-ho-debug") << " found equivalent lambda functions " << f
682 : 22 : << " and " << g << std::endl;
683 [ + - ]: 22 : Node flam = lamRep < n ? lamRepLam : lam;
684 [ + - ][ + - ]: 22 : Assert(!flam.isNull() && flam.getKind() == Kind::LAMBDA);
[ - + ][ - + ]
[ - - ]
685 : 22 : Node lhs = flam[1];
686 [ + - ]: 22 : Node glam = lamRep < n ? lam : lamRepLam;
687 [ + - ]: 44 : Trace("uf-ho-debug")
688 : 22 : << " lambda are " << flam << " and " << glam << std::endl;
689 : 44 : std::vector<Node> args(flam[0].begin(), flam[0].end());
690 : 22 : Node rhs = d_ll.betaReduce(glam, args);
691 : 44 : Node univ = nm->mkNode(Kind::FORALL, flam[0], lhs.eqNode(rhs));
692 : : // do quantifier elimination if the option is set
693 : : // For example, say (= f1 f2) where f1 is (lambda ((x Int)) (< x a))
694 : : // and f2 is (lambda ((x Int)) (< x b)).
695 : : // By default, we would generate the inference
696 : : // (=> (= f1 f2) (forall ((x Int)) (= (< x a) (< x b))),
697 : : // where quantified reasoning is introduced into the main solving
698 : : // procedure.
699 : : // With --uf-lambda-qe, we use a subsolver to compute the quantifier
700 : : // elimination of:
701 : : // (forall ((x Int)) (= (< x a) (< x b)),
702 : : // which is (and (<= a b) (<= b a)). We instead generate the lemma
703 : : // (=> (= f1 f2) (and (<= a b) (<= b a)).
704 : : // The motivation for this is to reduce the complexity of constraints
705 : : // in the main solver. This is motivated by usages of set.filter where
706 : : // the lambdas are over a decidable theory that admits quantifier
707 : : // elimination, e.g. LIA or BV.
708 [ + + ]: 22 : if (options().uf.ufHoLambdaQe)
709 : : {
710 [ + - ]: 2 : Trace("uf-lambda-qe") << "Given " << flam << " == " << glam << std::endl;
711 [ + - ]: 2 : Trace("uf-lambda-qe") << "Run QE on " << univ << std::endl;
712 : 2 : std::unique_ptr<SolverEngine> lqe;
713 : : // initialize the subsolver using the standard method
714 : 2 : initializeSubsolver(lqe, d_env);
715 : 2 : Node univQe = lqe->getQuantifierElimination(univ, true);
716 [ + - ]: 2 : Trace("uf-lambda-qe") << "QE is " << univQe << std::endl;
717 [ - + ][ - + ]: 2 : Assert (!univQe.isNull());
[ - - ]
718 : : // Note that if quantifier elimination failed, then univQe will
719 : : // be equal to univ, in which case this above code has no effect.
720 : 2 : univ = univQe;
721 : 2 : }
722 : : // f = g => forall x. reduce(lambda(f)(x)) = reduce(lambda(g)(x))
723 : : //
724 : : // For example, if f -> lambda z. z+1, g -> lambda y. y+3, this
725 : : // will infer: f = g => forall x. x+1 = x+3, which simplifies to
726 : : // f != g.
727 : 44 : Node lem = nm->mkNode(Kind::IMPLIES, fgEq, univ);
728 [ + - ]: 22 : if (cacheLemma(lem))
729 : : {
730 : 22 : d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_UNIV_EQ);
731 : 22 : numLemmas++;
732 : : }
733 [ + + ][ + + ]: 76 : }
[ + + ]
734 [ + + ][ + + ]: 27766 : }
735 [ + + ]: 21961 : if (!lamRep.isNull())
736 : : {
737 : 7669 : d_lambdaEqc[eqc] = lamRep;
738 : : // Do the lambda lifting lemma if needed. This happens if a lambda
739 : : // needs lifting based on the symbols in its body and is equated to an
740 : : // ordinary function symbol. For example, this is what ensures we
741 : : // handle conflicts like f = (lambda ((x Int)) (+ 1 (f x))).
742 [ + + ]: 7669 : if (doLift)
743 : : {
744 : 39 : TrustNode tlift = d_ll.lift(lamRepLam);
745 [ - + ][ - + ]: 39 : Assert(!tlift.isNull());
[ - - ]
746 : 39 : d_im.trustedLemma(tlift, InferenceId::UF_HO_LAMBDA_LAZY_LIFT);
747 : 39 : }
748 : : }
749 [ + + ]: 108724 : }
750 [ + - ]: 9656 : Trace("uf-ho-debug")
751 : 4828 : << " found " << normalEqFuns.size()
752 : 4828 : << " ordinary functions that are equal to lambda functions" << std::endl;
753 [ + + ]: 4828 : if (normalEqFuns.empty())
754 : : {
755 : 4730 : return numLemmas;
756 : : }
757 : : // if we have normal functions that are equal to lambda functions, go back
758 : : // and ensure they are mapped properly
759 : : // mapping from functions to terms
760 : 98 : eq::EqClassesIterator eqcs_i2 = eq::EqClassesIterator(ee);
761 [ + + ]: 23278 : while (!eqcs_i2.isFinished())
762 : : {
763 : 23180 : Node eqc = (*eqcs_i2);
764 : 23180 : ++eqcs_i2;
765 [ + - ]: 23180 : Trace("uf-ho-debug") << "Check equivalence class " << eqc << std::endl;
766 : 23180 : eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
767 [ + + ]: 122482 : while (!eqc_i.isFinished())
768 : : {
769 : 99302 : Node n = *eqc_i;
770 : 99302 : ++eqc_i;
771 [ + - ]: 99302 : Trace("uf-ho-debug") << "Check term " << n << std::endl;
772 : 99302 : Node op;
773 : 99302 : Kind k = n.getKind();
774 : 99302 : std::vector<Node> args;
775 [ + + ]: 99302 : if (k == Kind::APPLY_UF)
776 : : {
777 : 49968 : op = n.getOperator();
778 : 49968 : args.insert(args.end(), n.begin(), n.end());
779 : : }
780 [ + + ]: 49334 : else if (k == Kind::HO_APPLY)
781 : : {
782 : 26627 : op = n[0];
783 : 26627 : args.push_back(n[1]);
784 : : }
785 : : else
786 : : {
787 : 22707 : continue;
788 : : }
789 [ + + ]: 76595 : if (normalEqFuns.find(op) == normalEqFuns.end())
790 : : {
791 : 69742 : continue;
792 : : }
793 [ + - ]: 13706 : Trace("uf-ho-debug") << " found relevant ordinary application " << n
794 : 6853 : << std::endl;
795 [ - + ][ - + ]: 6853 : Assert(ee->hasTerm(op));
[ - - ]
796 : 13706 : Node r = ee->getRepresentative(op);
797 [ - + ][ - + ]: 6853 : Assert(d_lambdaEqc.find(r) != d_lambdaEqc.end());
[ - - ]
798 : 6853 : Node lf = d_lambdaEqc[r];
799 : 6853 : Node lam = d_ll.getLambdaFor(lf);
800 [ + - ][ + - ]: 6853 : Assert(!lam.isNull() && lam.getKind() == Kind::LAMBDA);
[ - + ][ - + ]
[ - - ]
801 : : // a normal function g equal to a lambda, say f --> lambda(f)
802 : : // need to infer f = g => g(t) = f(t) for all terms g(t)
803 : : // that occur in the equality engine.
804 : 6853 : Node premise = op.eqNode(lf);
805 : 6853 : args.insert(args.begin(), lam);
806 : 6853 : Node rhs = nm->mkNode(n.getKind(), args);
807 : 6853 : rhs = rewrite(rhs);
808 : 6853 : Node conc = n.eqNode(rhs);
809 : 13706 : Node lem = nm->mkNode(Kind::IMPLIES, premise, conc);
810 [ + + ]: 6853 : if (cacheLemma(lem))
811 : : {
812 : 6057 : d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_APP_REDUCE);
813 : 6057 : numLemmas++;
814 : : }
815 [ + + ][ + + ]: 284200 : }
[ + + ]
816 : 23180 : }
817 : 98 : return numLemmas;
818 : 4828 : }
819 : :
820 : 5474 : unsigned HoExtension::check()
821 : : {
822 [ + - ]: 5474 : Trace("uf-ho") << "HoExtension::checkHigherOrder..." << std::endl;
823 : :
824 : : // infer new facts based on apply completion until fixed point
825 : : unsigned num_facts;
826 : : do
827 : : {
828 : 51259 : num_facts = checkAppCompletion();
829 [ + + ]: 51259 : if (d_state.isInConflict())
830 : : {
831 [ + - ]: 127 : Trace("uf-ho") << "...conflict during app-completion." << std::endl;
832 : 127 : return 1;
833 : : }
834 [ + + ]: 51132 : } while (num_facts > 0);
835 : :
836 : : // Apply extensionality, lazy lambda schemas in order. We make lazy lambda
837 : : // handling come last as it may introduce quantifiers.
838 [ + + ]: 15797 : for (size_t i = 0; i < 2; i++)
839 : : {
840 : 10613 : unsigned num_lemmas = 0;
841 : : // apply the schema
842 [ + + ][ - ]: 10613 : switch (i)
843 : : {
844 : 5347 : case 0: num_lemmas = checkExtensionality(); break;
845 : 5266 : case 1: num_lemmas = checkLazyLambda(); break;
846 : 0 : default: break;
847 : : }
848 : : // finish if we added lemmas
849 [ + + ]: 10613 : if (num_lemmas > 0)
850 : : {
851 [ + - ]: 163 : Trace("uf-ho") << "...returned " << num_lemmas << " lemmas." << std::endl;
852 : 163 : return num_lemmas;
853 : : }
854 : : }
855 : :
856 [ + - ]: 5184 : Trace("uf-ho") << "...finished check higher order." << std::endl;
857 : :
858 : 5184 : return 0;
859 : : }
860 : :
861 : 1274 : bool HoExtension::collectModelInfoHo(TheoryModel* m,
862 : : const std::set<Node>& termSet)
863 : : {
864 [ + + ]: 47573 : for (std::set<Node>::iterator it = termSet.begin(); it != termSet.end(); ++it)
865 : : {
866 : 46299 : Node n = *it;
867 : : // For model-building with higher-order, we require that APPLY_UF is always
868 : : // expanded to HO_APPLY. That is, we always expand to a fully applicative
869 : : // encoding during model construction.
870 [ - + ]: 46299 : if (!collectModelInfoHoTerm(n, m))
871 : : {
872 : 0 : return false;
873 : : }
874 [ + - ]: 46299 : }
875 : : // We apply an explicit extensionality technique for asserting
876 : : // disequalities to the model to ensure that function values are distinct
877 : : // in the curried HO_APPLY version of model construction. This is a
878 : : // non-standard alternative to using a type enumerator over function
879 : : // values to assign unique values.
880 : 1274 : int addedLemmas = checkExtensionality(m);
881 : : // for equivalence classes that we know to assign a lambda directly
882 [ + + ]: 1734 : for (const std::pair<const Node, Node>& p : d_lambdaEqc)
883 : : {
884 : 460 : Node lam = d_ll.getLambdaFor(p.second);
885 : 460 : lam = rewrite(lam);
886 [ - + ][ - + ]: 460 : Assert(!lam.isNull());
[ - - ]
887 : 460 : m->assertEquality(p.second, lam, true);
888 : 460 : m->assertSkeleton(lam);
889 : : // we don't assign the function definition here, which is handled internally
890 : : // in the model builder.
891 : 460 : }
892 : 1274 : return addedLemmas == 0;
893 : : }
894 : :
895 : 144201 : bool HoExtension::collectModelInfoHoTerm(Node n, TheoryModel* m)
896 : : {
897 [ + + ]: 144201 : if (n.getKind() == Kind::APPLY_UF)
898 : : {
899 : 113307 : Node hn = TheoryUfRewriter::getHoApplyForApplyUf(n);
900 [ - + ]: 113307 : if (!m->assertEquality(n, hn, true))
901 : : {
902 : 0 : Node eq = n.eqNode(hn);
903 [ - - ]: 0 : Trace("uf-ho") << "HoExtension: cmi app completion lemma " << eq
904 : 0 : << std::endl;
905 : 0 : d_im.lemma(eq, InferenceId::UF_HO_MODEL_APP_ENCODE);
906 : 0 : return false;
907 : 0 : }
908 : : // also add all subterms
909 : 113307 : eq::EqualityEngine* ee = m->getEqualityEngine();
910 [ + + ]: 330610 : while (hn.getKind()==Kind::HO_APPLY)
911 : : {
912 : 217303 : ee->addTerm(hn);
913 : 217303 : ee->addTerm(hn[1]);
914 : 217303 : hn = hn[0];
915 : : }
916 [ + - ]: 113307 : }
917 : 144201 : return true;
918 : : }
919 : :
920 : 6875 : bool HoExtension::cacheLemma(TNode lem)
921 : : {
922 : 6875 : Node rewritten = rewrite(lem);
923 [ + + ]: 6875 : if (d_cachedLemmas.find(rewritten) != d_cachedLemmas.end())
924 : : {
925 : 796 : return false;
926 : : }
927 : 6079 : d_cachedLemmas.insert(rewritten);
928 : 6079 : return true;
929 : 6875 : }
930 : :
931 : : } // namespace uf
932 : : } // namespace theory
933 : : } // namespace cvc5::internal
|