Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds, Aina Niemetz, Mathias Preiner
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Function definition processor for finite model finding.
14 : : */
15 : :
16 : : #include "preprocessing/passes/fun_def_fmf.h"
17 : :
18 : : #include <sstream>
19 : :
20 : : #include "expr/skolem_manager.h"
21 : : #include "options/smt_options.h"
22 : : #include "preprocessing/assertion_pipeline.h"
23 : : #include "preprocessing/preprocessing_pass_context.h"
24 : : #include "theory/quantifiers/quantifiers_attributes.h"
25 : : #include "theory/quantifiers/term_database.h"
26 : : #include "theory/quantifiers/term_util.h"
27 : : #include "theory/rewriter.h"
28 : :
29 : : using namespace std;
30 : : using namespace cvc5::internal::kind;
31 : : using namespace cvc5::internal::theory;
32 : : using namespace cvc5::internal::theory::quantifiers;
33 : :
34 : : namespace cvc5::internal {
35 : : namespace preprocessing {
36 : : namespace passes {
37 : :
38 : 50591 : FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
39 : : : PreprocessingPass(preprocContext, "fun-def-fmf"),
40 : 50591 : d_fmfRecFunctionsDefined(nullptr)
41 : : {
42 : 50591 : d_fmfRecFunctionsDefined = new (true) NodeList(userContext());
43 : 50591 : d_fmfFunSc = nodeManager()->mkSortConstructor("@fmf-fun-sort", 1);
44 : 50591 : }
45 : :
46 : 100670 : FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
47 : :
48 : 139 : PreprocessingPassResult FunDefFmf::applyInternal(
49 : : AssertionPipeline* assertionsToPreprocess)
50 : : {
51 [ - + ][ - + ]: 139 : Assert(d_fmfRecFunctionsDefined != nullptr);
[ - - ]
52 : : // reset
53 : 139 : d_sorts.clear();
54 : 139 : d_input_arg_inj.clear();
55 : 139 : d_funcs.clear();
56 : : // must carry over current definitions (in case of incremental)
57 : 127 : for (context::CDList<Node>::const_iterator fit =
58 : 139 : d_fmfRecFunctionsDefined->begin();
59 [ + + ]: 266 : fit != d_fmfRecFunctionsDefined->end();
60 : 127 : ++fit)
61 : : {
62 : 254 : Node f = (*fit);
63 [ - + ][ - + ]: 127 : Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
[ - - ]
64 : 254 : TypeNode ft = d_fmfRecFunctionsAbs[f];
65 : 127 : d_sorts[f] = ft;
66 : : std::map<Node, std::vector<Node>>::iterator fcit =
67 : 127 : d_fmfRecFunctionsConcrete.find(f);
68 [ - + ][ - + ]: 127 : Assert(fcit != d_fmfRecFunctionsConcrete.end());
[ - - ]
69 [ + + ]: 795 : for (const Node& fcc : fcit->second)
70 : : {
71 : 668 : d_input_arg_inj[f].push_back(fcc);
72 : : }
73 : : }
74 : 139 : process(assertionsToPreprocess);
75 : : // must store new definitions (in case of incremental)
76 [ + + ]: 252 : for (const Node& f : d_funcs)
77 : : {
78 : 113 : d_fmfRecFunctionsAbs[f] = d_sorts[f];
79 : 113 : d_fmfRecFunctionsConcrete[f].clear();
80 [ + + ]: 351 : for (const Node& fcc : d_input_arg_inj[f])
81 : : {
82 : 238 : d_fmfRecFunctionsConcrete[f].push_back(fcc);
83 : : }
84 : 113 : d_fmfRecFunctionsDefined->push_back(f);
85 : : }
86 : 139 : return PreprocessingPassResult::NO_CONFLICT;
87 : : }
88 : :
89 : 139 : void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
90 : : {
91 : 139 : const std::vector<Node>& assertions = assertionsToPreprocess->ref();
92 : 278 : std::vector<int> fd_assertions;
93 : 278 : std::map<int, Node> subs_head;
94 : : // first pass : find defined functions, transform quantifiers
95 : 139 : NodeManager* nm = nodeManager();
96 [ + + ]: 774 : for (size_t i = 0, asize = assertions.size(); i < asize; i++)
97 : : {
98 : 1270 : Node n = QuantAttributes::getFunDefHead(assertions[i]);
99 [ + + ]: 635 : if (!n.isNull())
100 : : {
101 [ - + ][ - + ]: 113 : Assert(n.getKind() == Kind::APPLY_UF);
[ - - ]
102 : 226 : Node f = n.getOperator();
103 : :
104 : : // check if already defined, if so, throw error
105 [ - + ]: 113 : if (d_sorts.find(f) != d_sorts.end())
106 : : {
107 : 0 : Unhandled() << "Cannot define function " << f << " more than once.";
108 : : }
109 : :
110 : 226 : Node bd = QuantAttributes::getFunDefBody(assertions[i]);
111 [ + - ]: 226 : Trace("fmf-fun-def-debug")
112 : 113 : << "Process function " << n << ", body = " << bd << std::endl;
113 [ + - ]: 113 : if (!bd.isNull())
114 : : {
115 : 113 : d_funcs.push_back(f);
116 : 113 : bd = nm->mkNode(Kind::EQUAL, n, bd);
117 : :
118 : : // create a sort S that represents the inputs of the function
119 : 226 : std::stringstream ss;
120 : 113 : ss << f;
121 : : // We make an uninterpreted sort whose name is the same as the
122 : : // function.
123 : 226 : TypeNode iType = nm->mkSort(ss.str());
124 : : // We then make the sort constructor applied to that type. For example,
125 : : // this is (@fmf-fun-sort f), where here f is an uninterpreted sort.
126 : : // This is done to have a clear name for this sort, and to support
127 : : // proof printing in ALF where @fmf-fun-sort is a type constructor
128 : : // parameterized by a function.
129 : 226 : iType = nm->mkSort(d_fmfFunSc, {iType});
130 : : AbsTypeFunDefAttribute atfda;
131 : 113 : iType.setAttribute(atfda, true);
132 : 113 : d_sorts[f] = iType;
133 : :
134 : : // create functions f1...fn mapping from this sort to concrete elements
135 : 113 : size_t nchildn = n.getNumChildren();
136 [ + + ]: 351 : for (size_t j = 0; j < nchildn; j++)
137 : : {
138 : 714 : TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
139 : 238 : std::stringstream ssf;
140 : 238 : ssf << f << "_arg_" << j;
141 : 714 : d_input_arg_inj[f].push_back(NodeManager::mkDummySkolem(
142 : 476 : ssf.str(), typ, "op created during fun def fmf"));
143 : : }
144 : :
145 : : // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
146 : 226 : std::vector<Node> children;
147 : 339 : Node bv = NodeManager::mkBoundVar("?i", iType);
148 : 226 : Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, bv);
149 : 226 : std::vector<Node> subs;
150 : 226 : std::vector<Node> vars;
151 [ + + ]: 351 : for (size_t j = 0; j < nchildn; j++)
152 : : {
153 : 238 : vars.push_back(n[j]);
154 : 238 : subs.push_back(nm->mkNode(Kind::APPLY_UF, d_input_arg_inj[f][j], bv));
155 : : }
156 : 113 : bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
157 : 113 : subs_head[i] =
158 : 226 : n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
159 : :
160 [ + - ]: 226 : Trace("fmf-fun-def")
161 : 113 : << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
162 [ + - ]: 113 : Trace("fmf-fun-def") << " to " << std::endl;
163 : 226 : Node new_q = nm->mkNode(Kind::FORALL, bvl, bd);
164 : 113 : assertionsToPreprocess->replace(
165 : : i, new_q, nullptr, TrustId::PREPROCESS_FUN_DEF_FMF);
166 : 113 : assertionsToPreprocess->ensureRewritten(i);
167 [ + - ]: 113 : Trace("fmf-fun-def") << " " << assertions[i] << std::endl;
168 : 113 : fd_assertions.push_back(i);
169 : : }
170 : : else
171 : : {
172 : : // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
173 : : // f(x)=f(x)+1
174 : : }
175 : : }
176 : : }
177 : : // second pass : rewrite assertions
178 : 278 : std::map<int, std::map<Node, Node>> visited;
179 : 278 : std::map<int, std::map<Node, Node>> visited_cons;
180 [ + + ]: 774 : for (size_t i = 0, asize = assertions.size(); i < asize; i++)
181 : : {
182 : 635 : bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
183 : 1270 : != fd_assertions.end();
184 : 1270 : std::vector<Node> constraints;
185 [ + - ]: 1270 : Trace("fmf-fun-def-rewrite")
186 : 0 : << "Rewriting " << assertions[i]
187 : 635 : << ", is function definition = " << is_fd << std::endl;
188 : 635 : Node n = simplifyFormula(assertions[i],
189 : : true,
190 : : true,
191 : : constraints,
192 : 1270 : is_fd ? subs_head[i] : Node::null(),
193 : : is_fd,
194 : : visited,
195 [ + + ]: 2540 : visited_cons);
196 [ - + ][ - + ]: 635 : Assert(constraints.empty());
[ - - ]
197 [ + + ]: 635 : if (n != assertions[i])
198 : : {
199 : 201 : n = rewrite(n);
200 [ + - ]: 402 : Trace("fmf-fun-def-rewrite")
201 : 201 : << "FMF fun def : rewrite " << assertions[i] << std::endl;
202 [ + - ]: 201 : Trace("fmf-fun-def-rewrite") << " to " << std::endl;
203 [ + - ]: 201 : Trace("fmf-fun-def-rewrite") << " " << n << std::endl;
204 : 201 : assertionsToPreprocess->replace(
205 : : i, n, nullptr, TrustId::PREPROCESS_FUN_DEF_FMF);
206 : : }
207 : : }
208 : 139 : }
209 : :
210 : 1817 : Node FunDefFmf::simplifyFormula(
211 : : Node n,
212 : : bool pol,
213 : : bool hasPol,
214 : : std::vector<Node>& constraints,
215 : : Node hd,
216 : : bool is_fun_def,
217 : : std::map<int, std::map<Node, Node>>& visited,
218 : : std::map<int, std::map<Node, Node>>& visited_cons)
219 : : {
220 [ - + ][ - + ]: 1817 : Assert(constraints.empty());
[ - - ]
221 [ + + ][ + + ]: 1817 : int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
[ + + ]
222 : 1817 : std::map<Node, Node>::iterator itv = visited[index].find(n);
223 [ + + ]: 1817 : if (itv != visited[index].end())
224 : : {
225 : : // constraints.insert( visited_cons[index]
226 : 231 : std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
227 [ - + ]: 231 : if (itvc != visited_cons[index].end())
228 : : {
229 : 0 : constraints.push_back(itvc->second);
230 : : }
231 : 231 : return itv->second;
232 : : }
233 : 1586 : NodeManager* nm = nodeManager();
234 : 3172 : Node ret;
235 [ + - ]: 3172 : Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
236 : 1586 : << " " << is_fun_def << std::endl;
237 [ + + ]: 1586 : if (n.getKind() == Kind::FORALL)
238 : : {
239 : : Node c = simplifyFormula(
240 : 447 : n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
241 : : // append prenex to constraints
242 [ - + ]: 149 : for (unsigned i = 0; i < constraints.size(); i++)
243 : : {
244 : 0 : constraints[i] = nm->mkNode(Kind::FORALL, n[0], constraints[i]);
245 : 0 : constraints[i] = rewrite(constraints[i]);
246 : : }
247 [ + + ]: 149 : if (c != n[1])
248 : : {
249 : 77 : ret = nm->mkNode(Kind::FORALL, n[0], c);
250 : : }
251 : : else
252 : : {
253 : 72 : ret = n;
254 : : }
255 : : }
256 : : else
257 : : {
258 : 2874 : Node nn = n;
259 : 1437 : bool isBool = n.getType().isBoolean();
260 [ + + ][ + + ]: 1437 : if (isBool && n.getKind() != Kind::APPLY_UF)
[ + + ]
261 : : {
262 : 1518 : std::vector<Node> children;
263 : 759 : bool childChanged = false;
264 : : // are we at a branch position (not all children are necessarily
265 : : // relevant)?
266 [ + + ]: 1518 : bool branch_pos = (n.getKind() == Kind::ITE || n.getKind() == Kind::OR
267 [ + - ][ + + ]: 1518 : || n.getKind() == Kind::AND);
268 : 1518 : std::vector<Node> branch_constraints;
269 [ + + ]: 1904 : for (unsigned i = 0; i < n.getNumChildren(); i++)
270 : : {
271 : 1145 : Node c = n[i];
272 : : // do not process LHS of definition
273 [ + + ][ + + ]: 1145 : if (!is_fun_def || c != hd)
[ + + ]
274 : : {
275 : : bool newHasPol;
276 : : bool newPol;
277 : 1033 : QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
278 : : // get child constraints
279 : 1033 : std::vector<Node> cconstraints;
280 : 2066 : c = simplifyFormula(n[i],
281 : : newPol,
282 : : newHasPol,
283 : : cconstraints,
284 : : hd,
285 : : false,
286 : : visited,
287 : 1033 : visited_cons);
288 [ + + ]: 1033 : if (branch_pos)
289 : : {
290 : : // if at a branching position, the other constraints don't matter
291 : : // if this is satisfied
292 : 142 : Node bcons = nm->mkAnd(cconstraints);
293 : 142 : branch_constraints.push_back(bcons);
294 [ + - ]: 284 : Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
295 : 142 : << " is " << bcons << std::endl;
296 : : }
297 : : constraints.insert(
298 : 1033 : constraints.end(), cconstraints.begin(), cconstraints.end());
299 : : }
300 : 1145 : children.push_back(c);
301 [ + + ][ - + ]: 1145 : childChanged = c != n[i] || childChanged;
[ + - ][ - - ]
302 : : }
303 [ + + ]: 759 : if (childChanged)
304 : : {
305 : 34 : nn = nm->mkNode(n.getKind(), children);
306 : : }
307 [ + + ][ + + ]: 759 : if (branch_pos && !constraints.empty())
[ + + ]
308 : : {
309 : : // if we are at a branching position in the formula, we can
310 : : // minimize recursive constraints on recursively defined predicates if
311 : : // we know one child forces the overall evaluation of this formula.
312 : 36 : Node branch_cond;
313 [ - + ]: 18 : if (n.getKind() == Kind::ITE)
314 : : {
315 : : // always care about constraints on the head of the ITE, but only
316 : : // care about one of the children depending on how it evaluates
317 : 0 : branch_cond = nm->mkNode(Kind::AND,
318 : 0 : branch_constraints[0],
319 : 0 : nm->mkNode(Kind::ITE,
320 : : n[0],
321 : 0 : branch_constraints[1],
322 : 0 : branch_constraints[2]));
323 : : }
324 : : else
325 : : {
326 : : // in the default case, we care about all conditions
327 : 18 : branch_cond = nm->mkAnd(constraints);
328 [ + + ]: 79 : for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
329 : : {
330 : : // if this child holds with forcing polarity (true child of OR or
331 : : // false child of AND), then we only care about its associated
332 : : // recursive conditions
333 : : branch_cond =
334 : 305 : nm->mkNode(Kind::ITE,
335 [ + + ][ + + ]: 122 : (n.getKind() == Kind::OR ? n[i] : n[i].negate()),
[ - - ]
336 : 61 : branch_constraints[i],
337 : 61 : branch_cond);
338 : : }
339 : : }
340 [ + - ]: 36 : Trace("fmf-fun-def-debug2")
341 : 18 : << "Made branching condition " << branch_cond << std::endl;
342 : 18 : constraints.clear();
343 : 18 : constraints.push_back(branch_cond);
344 : : }
345 : : }
346 : : else
347 : : {
348 : : // simplify term
349 : 678 : std::map<Node, Node> visitedT;
350 : 678 : getConstraints(n, constraints, visitedT);
351 : : }
352 [ + + ][ + + ]: 1437 : if (!constraints.empty() && isBool && hasPol)
[ + + ][ + + ]
353 : : {
354 : : // conjoin with current
355 : 402 : Node cons = nm->mkAnd(constraints);
356 [ + + ]: 201 : if (pol)
357 : : {
358 : 168 : ret = nm->mkNode(Kind::AND, nn, cons);
359 : : }
360 : : else
361 : : {
362 : 33 : ret = nm->mkNode(Kind::OR, nn, cons.negate());
363 : : }
364 [ + - ]: 402 : Trace("fmf-fun-def-debug2")
365 : 201 : << "Add constraint to obtain " << ret << std::endl;
366 : 201 : constraints.clear();
367 : : }
368 : : else
369 : : {
370 : 1236 : ret = nn;
371 : : }
372 : : }
373 [ + + ]: 1586 : if (!constraints.empty())
374 : : {
375 : 218 : Node cons;
376 : : // flatten to AND node for the purposes of caching
377 [ - + ]: 218 : if (constraints.size() > 1)
378 : : {
379 : 0 : cons = nm->mkNode(Kind::AND, constraints);
380 : 0 : cons = rewrite(cons);
381 : 0 : constraints.clear();
382 : 0 : constraints.push_back(cons);
383 : : }
384 : : else
385 : : {
386 : 218 : cons = constraints[0];
387 : : }
388 : 218 : visited_cons[index][n] = cons;
389 [ + - ][ + - ]: 436 : Assert(constraints.size() == 1 && constraints[0] == cons);
[ - + ][ - - ]
390 : : }
391 : 1586 : visited[index][n] = ret;
392 : 1586 : return ret;
393 : : }
394 : :
395 : 3291 : void FunDefFmf::getConstraints(Node n,
396 : : std::vector<Node>& constraints,
397 : : std::map<Node, Node>& visited)
398 : : {
399 : 3291 : std::map<Node, Node>::iterator itv = visited.find(n);
400 [ + + ]: 3291 : if (itv != visited.end())
401 : : {
402 : : // already visited
403 [ + + ]: 758 : if (!itv->second.isNull())
404 : : {
405 : : // add the cached constraint if it does not already occur
406 : 384 : if (std::find(constraints.begin(), constraints.end(), itv->second)
407 [ + - ]: 768 : == constraints.end())
408 : : {
409 : 384 : constraints.push_back(itv->second);
410 : : }
411 : : }
412 : 758 : return;
413 : : }
414 : 2533 : visited[n] = Node::null();
415 : 5066 : std::vector<Node> currConstraints;
416 : 2533 : NodeManager* nm = nodeManager();
417 [ + + ]: 2533 : if (n.getKind() == Kind::ITE)
418 : : {
419 : : // collect constraints for the condition
420 : 116 : getConstraints(n[0], currConstraints, visited);
421 : : // collect constraints for each branch
422 [ + + ][ + + ]: 696 : Node cs[2];
[ - - ]
423 [ + + ]: 348 : for (unsigned i = 0; i < 2; i++)
424 : : {
425 : 232 : std::vector<Node> ccons;
426 : 232 : getConstraints(n[i + 1], ccons, visited);
427 : 232 : cs[i] = nm->mkAnd(ccons);
428 : : }
429 [ + + ][ + + ]: 116 : if (!cs[0].isConst() || !cs[1].isConst())
[ + + ]
430 : : {
431 : 124 : Node itec = nm->mkNode(Kind::ITE, n[0], cs[0], cs[1]);
432 : 62 : currConstraints.push_back(itec);
433 [ + - ]: 124 : Trace("fmf-fun-def-debug")
434 : 62 : << "---> add constraint " << itec << " for " << n << std::endl;
435 : : }
436 : : }
437 : : else
438 : : {
439 [ + + ]: 2417 : if (n.getKind() == Kind::APPLY_UF)
440 : : {
441 : : // check if f is defined, if so, we must enforce domain constraints for
442 : : // this f-application
443 : 824 : Node f = n.getOperator();
444 : 412 : std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
445 [ + + ]: 412 : if (it != d_sorts.end())
446 : : {
447 : : // create existential
448 : 630 : Node z = NodeManager::mkBoundVar("?z", it->second);
449 : 420 : Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, z);
450 : 420 : std::vector<Node> children;
451 [ + + ]: 484 : for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
452 : : {
453 : 548 : Node uz = nm->mkNode(Kind::APPLY_UF, d_input_arg_inj[f][j], z);
454 : 274 : children.push_back(uz.eqNode(n[j]));
455 : : }
456 : 420 : Node bd = nm->mkAnd(children);
457 : 210 : bd = bd.negate();
458 : 420 : Node ex = nm->mkNode(Kind::FORALL, bvl, bd);
459 : 210 : ex = ex.negate();
460 : 210 : currConstraints.push_back(ex);
461 [ + - ]: 420 : Trace("fmf-fun-def-debug")
462 : 210 : << "---> add constraint " << ex << " for " << n << std::endl;
463 : : }
464 : : }
465 [ + + ]: 4298 : for (const Node& cn : n)
466 : : {
467 : 1881 : getConstraints(cn, currConstraints, visited);
468 : : }
469 : : }
470 : : // set the visited cache
471 [ + + ]: 2533 : if (!currConstraints.empty())
472 : : {
473 : 384 : Node finalc = nm->mkAnd(currConstraints);
474 : 384 : visited[n] = finalc;
475 : : // add to constraints
476 : 384 : getConstraints(n, constraints, visited);
477 : : }
478 : : }
479 : :
480 : : } // namespace passes
481 : : } // namespace preprocessing
482 : : } // namespace cvc5::internal
|