Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds, Aina Niemetz, Mathias Preiner
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Class for applying the deep embedding for SyGuS
14 : : */
15 : :
16 : : #include "theory/quantifiers/sygus/embedding_converter.h"
17 : :
18 : : #include "options/base_options.h"
19 : : #include "options/quantifiers_options.h"
20 : : #include "printer/smt2/smt2_printer.h"
21 : : #include "theory/datatypes/sygus_datatype_utils.h"
22 : : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
23 : : #include "theory/quantifiers/sygus/sygus_grammar_norm.h"
24 : : #include "theory/quantifiers/sygus/sygus_utils.h"
25 : : #include "theory/quantifiers/sygus/synth_conjecture.h"
26 : : #include "theory/quantifiers/sygus/term_database_sygus.h"
27 : : #include "util/rational.h"
28 : :
29 : : using namespace cvc5::internal::kind;
30 : :
31 : : namespace cvc5::internal {
32 : : namespace theory {
33 : : namespace quantifiers {
34 : :
35 : 7264 : EmbeddingConverter::EmbeddingConverter(Env& env,
36 : : TermDbSygus* tds,
37 : 7264 : SynthConjecture* p)
38 : 7264 : : EnvObj(env), d_tds(tds), d_parent(p), d_is_syntax_restricted(false)
39 : : {
40 : 7264 : }
41 : :
42 : 296 : bool EmbeddingConverter::hasSyntaxRestrictions(Node q)
43 : : {
44 [ - + ][ - + ]: 296 : Assert(q.getKind() == Kind::FORALL);
[ - - ]
45 [ + + ]: 442 : for (const Node& f : q[0])
46 : : {
47 : 296 : TypeNode tn = SygusUtils::getSygusType(f);
48 [ + + ][ + - ]: 296 : if (tn.isDatatype() && tn.getDType().isSygus())
[ + + ]
49 : : {
50 : 150 : return true;
51 : : }
52 : : }
53 : 146 : return false;
54 : : }
55 : :
56 : 1746 : void EmbeddingConverter::collectTerms(
57 : : Node n, std::map<TypeNode, std::unordered_set<Node>>& consts)
58 : : {
59 : 1746 : NodeManager* nm = nodeManager();
60 : 3492 : std::unordered_map<TNode, bool> visited;
61 : 1746 : std::unordered_map<TNode, bool>::iterator it;
62 : 3492 : std::vector<TNode> visit;
63 : 3492 : TNode cur;
64 : 1746 : visit.push_back(n);
65 : 39150 : do
66 : : {
67 : 40896 : cur = visit.back();
68 : 40896 : visit.pop_back();
69 : 40896 : it = visited.find(cur);
70 [ + + ]: 40896 : if (it == visited.end())
71 : : {
72 : 25724 : visited[cur] = true;
73 : : // is this a constant?
74 [ + + ]: 25724 : if (cur.isConst())
75 : : {
76 : 7584 : TypeNode tn = cur.getType();
77 : 7584 : Node c = cur;
78 [ + + ]: 3792 : if (tn.isRealOrInt())
79 : : {
80 : 1910 : c = nm->mkConstRealOrInt(tn, c.getConst<Rational>().abs());
81 : : }
82 : 3792 : consts[tn].insert(c);
83 [ + + ]: 3792 : if (tn.isInteger())
84 : : {
85 : 1880 : c = nm->mkConstReal(c.getConst<Rational>().abs());
86 : 3760 : TypeNode rtype = nm->realType();
87 : 1880 : consts[rtype].insert(c);
88 : : }
89 : : }
90 : : // recurse
91 : 25724 : visit.insert(visit.end(), cur.begin(), cur.end());
92 : : }
93 [ + + ]: 40896 : } while (!visit.empty());
94 : 1746 : }
95 : :
96 : 1761 : Node EmbeddingConverter::process(Node q,
97 : : const std::map<Node, Node>& templates,
98 : : const std::map<Node, Node>& templates_arg)
99 : : {
100 : : // convert to deep embedding and finalize single invocation here
101 : : // now, construct the grammar
102 [ + - ]: 3522 : Trace("cegqi") << "SynthConjecture : convert to deep embedding..."
103 : 1761 : << std::endl;
104 : 3522 : std::map<TypeNode, std::unordered_set<Node>> extra_cons;
105 : 1761 : if (options().quantifiers.sygusAddConstGrammar
106 [ + + ][ + + ]: 1761 : && options().quantifiers.sygusGrammarConsMode
[ + + ]
107 : : == options::SygusGrammarConsMode::SIMPLE)
108 : : {
109 [ + - ]: 1746 : Trace("cegqi") << "SynthConjecture : collect constants..." << std::endl;
110 : 1746 : collectTerms(q[1], extra_cons);
111 : : }
112 : 3522 : std::map<TypeNode, std::unordered_set<Node>> exc_cons;
113 : 3522 : std::map<TypeNode, std::unordered_set<Node>> inc_cons;
114 : :
115 : 1761 : NodeManager* nm = nodeManager();
116 : :
117 : 1761 : std::vector<Node> ebvl;
118 [ + + ]: 3856 : for (unsigned i = 0; i < q[0].getNumChildren(); i++)
119 : : {
120 : 4190 : Node sf = q[0][i];
121 : : // if non-null, v encodes the syntactic restrictions (via an inductive
122 : : // datatype) on sf from the input.
123 : 4190 : TypeNode preGrammarType = SygusUtils::getSygusType(sf);
124 [ + + ]: 2095 : if (preGrammarType.isNull())
125 : : {
126 : : // otherwise, the grammar is the default for the range of the function
127 : 1107 : preGrammarType = sf.getType();
128 [ + + ]: 1107 : if (preGrammarType.isFunction())
129 : : {
130 : 829 : preGrammarType = preGrammarType.getRangeType();
131 : : }
132 : : }
133 : :
134 : : // the actual sygus datatype we will use (normalized below)
135 : 4190 : TypeNode tn;
136 : 4190 : Node sfvl;
137 [ + + ][ + + ]: 2095 : if (preGrammarType.isDatatype() && preGrammarType.getDType().isSygus())
[ + + ]
138 : : {
139 : 988 : sfvl = preGrammarType.getDType().getSygusVarList();
140 : 988 : tn = preGrammarType;
141 : : // normalize type, if user-provided
142 : 988 : SygusGrammarNorm sygus_norm(d_env, d_tds);
143 : 988 : tn = sygus_norm.normalizeSygusType(tn, sfvl);
144 : : }
145 : : else
146 : : {
147 : 1107 : sfvl = SygusUtils::getOrMkSygusArgumentList(sf);
148 : : // check which arguments are irrelevant
149 : 2214 : std::unordered_set<unsigned> arg_irrelevant;
150 : 1107 : d_parent->getProcess()->getIrrelevantArgs(sf, arg_irrelevant);
151 : 1107 : std::vector<Node> trules;
152 : : // add the variables from the free variable list that we did not
153 : : // infer were irrelevant.
154 [ + + ]: 3028 : for (size_t j = 0, nargs = sfvl.getNumChildren(); j < nargs; j++)
155 : : {
156 [ + + ]: 1921 : if (arg_irrelevant.find(j) == arg_irrelevant.end())
157 : : {
158 : 1853 : trules.push_back(sfvl[j]);
159 : : }
160 : : }
161 : : // add the constants computed avove
162 : 1862 : for (const std::pair<const TypeNode, std::unordered_set<Node>>& c :
163 [ + + ]: 2969 : extra_cons)
164 : : {
165 : 1862 : trules.insert(trules.end(), c.second.begin(), c.second.end());
166 : : }
167 : 1107 : tn = SygusGrammarCons::mkDefaultSygusType(
168 : 1107 : d_env, preGrammarType, sfvl, trules);
169 : : }
170 : : // Ensure the expanded definition forms are set. This is done after
171 : : // normalization above.
172 : 2095 : datatypes::utils::computeExpandedDefinitionForms(d_env, tn);
173 : : // print the grammar
174 [ + + ]: 2095 : if (isOutputOn(OutputTag::SYGUS_GRAMMAR))
175 : : {
176 : 2 : output(OutputTag::SYGUS_GRAMMAR)
177 : 4 : << "(sygus-grammar " << sf << " "
178 : 4 : << printer::smt2::Smt2Printer::sygusGrammarString(tn) << ")"
179 : 2 : << std::endl;
180 : : }
181 : : // sfvl may be null for constant synthesis functions
182 [ + - ]: 4190 : Trace("cegqi-debug") << "...sygus var list associated with " << sf << " is "
183 : 2095 : << sfvl << std::endl;
184 : :
185 : 2095 : std::map<Node, Node>::const_iterator itt = templates.find(sf);
186 [ + + ]: 2095 : if (itt != templates.end())
187 : : {
188 : 88 : Node templ = itt->second;
189 : 44 : std::map<Node, Node>::const_iterator itta = templates_arg.find(sf);
190 [ - + ][ - + ]: 44 : Assert(itta != templates_arg.end());
[ - - ]
191 : 44 : TNode templ_arg = itta->second;
192 [ - + ][ - + ]: 44 : Assert(!templ_arg.isNull());
[ - - ]
193 : : }
194 : :
195 : : // ev is the first-order variable corresponding to this synth fun
196 : 4190 : Node ev = nm->mkBoundVar("f" + sf.getName(), tn);
197 : 2095 : ebvl.push_back(ev);
198 [ + - ]: 4190 : Trace("cegqi") << "...embedding synth fun : " << sf << " -> " << ev
199 : 2095 : << std::endl;
200 : : }
201 : 3522 : return process(q, templates, templates_arg, ebvl);
202 : : }
203 : :
204 : 1761 : Node EmbeddingConverter::process(Node q,
205 : : const std::map<Node, Node>& templates,
206 : : const std::map<Node, Node>& templates_arg,
207 : : const std::vector<Node>& ebvl)
208 : : {
209 [ - + ][ - + ]: 1761 : Assert(q[0].getNumChildren() == ebvl.size());
[ - - ]
210 [ - + ][ - + ]: 1761 : Assert(d_synth_fun_vars.empty());
[ - - ]
211 : :
212 : 1761 : NodeManager* nm = nodeManager();
213 : :
214 : 3522 : std::vector<Node> qchildren;
215 : 3522 : Node qbody_subs = q[1];
216 [ + + ]: 3856 : for (unsigned i = 0, size = q[0].getNumChildren(); i < size; i++)
217 : : {
218 : 4190 : Node sf = q[0][i];
219 : 2095 : d_synth_fun_vars[sf] = ebvl[i];
220 : 4190 : Node sfvl = SygusUtils::getOrMkSygusArgumentList(sf);
221 : 4190 : TypeNode tn = ebvl[i].getType();
222 : : // check if there is a template
223 : 2095 : std::map<Node, Node>::const_iterator itt = templates.find(sf);
224 [ + + ]: 2095 : if (itt != templates.end())
225 : : {
226 : 88 : Node templ = itt->second;
227 : 44 : std::map<Node, Node>::const_iterator itta = templates_arg.find(sf);
228 [ - + ][ - + ]: 44 : Assert(itta != templates_arg.end());
[ - - ]
229 : 88 : TNode templ_arg = itta->second;
230 [ - + ][ - + ]: 44 : Assert(!templ_arg.isNull());
[ - - ]
231 : : // if there is a template for this argument, make a sygus type on top of
232 : : // it
233 : : // otherwise, apply it as a preprocessing pass
234 [ + - ]: 88 : Trace("cegqi-debug") << "Template for " << sf << " is : " << templ
235 : 44 : << " with arg " << templ_arg << std::endl;
236 [ + - ]: 88 : Trace("cegqi-debug")
237 : 0 : << " apply this template as a substitution during preprocess..."
238 : 44 : << std::endl;
239 : 88 : std::vector<Node> schildren;
240 : 88 : std::vector<Node> largs;
241 [ + + ]: 261 : for (unsigned j = 0; j < sfvl.getNumChildren(); j++)
242 : : {
243 : 217 : schildren.push_back(sfvl[j]);
244 : 217 : largs.push_back(nm->mkBoundVar(sfvl[j].getType()));
245 : : }
246 : 88 : std::vector<Node> subsfn_children;
247 : 44 : subsfn_children.push_back(sf);
248 : : subsfn_children.insert(
249 : 44 : subsfn_children.end(), schildren.begin(), schildren.end());
250 : 88 : Node subsfn = nm->mkNode(Kind::APPLY_UF, subsfn_children);
251 : 88 : TNode subsf = subsfn;
252 [ + - ]: 88 : Trace("cegqi-debug") << " substitute arg : " << templ_arg << " -> "
253 : 44 : << subsf << std::endl;
254 : 44 : templ = templ.substitute(templ_arg, subsf);
255 : : // substitute lambda arguments
256 : 88 : templ = templ.substitute(
257 : 44 : schildren.begin(), schildren.end(), largs.begin(), largs.end());
258 : : Node subsn = nm->mkNode(
259 : 132 : Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, largs), templ);
260 : 88 : TNode var = sf;
261 : 44 : TNode subs = subsn;
262 [ + - ]: 88 : Trace("cegqi-debug") << " substitute : " << var << " -> " << subs
263 : 44 : << std::endl;
264 : 44 : qbody_subs = qbody_subs.substitute(var, subs);
265 [ + - ]: 44 : Trace("cegqi-debug") << " body is now : " << qbody_subs << std::endl;
266 : : }
267 : 2095 : d_tds->registerSygusType(tn);
268 [ - + ][ - + ]: 2095 : Assert(tn.isDatatype());
[ - - ]
269 : 2095 : const DType& dt = tn.getDType();
270 [ - + ][ - + ]: 2095 : Assert(dt.isSygus());
[ - - ]
271 [ + + ]: 2095 : if (!dt.getSygusAllowAll())
272 : : {
273 : 986 : d_is_syntax_restricted = true;
274 : : }
275 : : }
276 : 1761 : qchildren.push_back(nm->mkNode(Kind::BOUND_VAR_LIST, ebvl));
277 [ + + ]: 1761 : if (qbody_subs != q[1])
278 : : {
279 [ + - ]: 44 : Trace("cegqi") << "...rewriting : " << qbody_subs << std::endl;
280 : 44 : qbody_subs = rewrite(qbody_subs);
281 [ + - ]: 44 : Trace("cegqi") << "...got : " << qbody_subs << std::endl;
282 : : }
283 : 1761 : qchildren.push_back(convertToEmbedding(qbody_subs));
284 [ + - ]: 1761 : if (q.getNumChildren() == 3)
285 : : {
286 : 1761 : qchildren.push_back(q[2]);
287 : : }
288 : 3522 : return nm->mkNode(Kind::FORALL, qchildren);
289 : : }
290 : :
291 : 2354 : Node EmbeddingConverter::convertToEmbedding(Node n)
292 : : {
293 : 2354 : NodeManager* nm = nodeManager();
294 : 4708 : std::unordered_map<TNode, Node> visited;
295 : 2354 : std::unordered_map<TNode, Node>::iterator it;
296 : 4708 : std::vector<TNode> visit;
297 : 2354 : TNode cur;
298 : 2354 : visit.push_back(n);
299 : 78942 : do
300 : : {
301 : 81296 : cur = visit.back();
302 : 81296 : visit.pop_back();
303 : 81296 : it = visited.find(cur);
304 [ + + ]: 81296 : if (it == visited.end())
305 : : {
306 : 31656 : visited[cur] = Node::null();
307 : 31656 : visit.push_back(cur);
308 : 31656 : visit.insert(visit.end(), cur.begin(), cur.end());
309 : : }
310 [ + + ]: 49640 : else if (it->second.isNull())
311 : : {
312 : 63312 : Node ret = cur;
313 : 31656 : Kind ret_k = cur.getKind();
314 : 63312 : Node op;
315 : 31656 : bool childChanged = false;
316 : 63312 : std::vector<Node> children;
317 : : // get the potential operator
318 [ + + ]: 31656 : if (cur.getNumChildren() > 0)
319 : : {
320 [ + + ]: 22846 : if (cur.getKind() == Kind::APPLY_UF)
321 : : {
322 : 2718 : op = cur.getOperator();
323 : : }
324 : : }
325 : : else
326 : : {
327 : 8810 : op = cur;
328 : : }
329 : : // is the operator a synth function?
330 : 31656 : bool makeEvalFun = false;
331 [ + + ]: 31656 : if (!op.isNull())
332 : : {
333 : 11528 : std::map<Node, Node>::iterator its = d_synth_fun_vars.find(op);
334 [ + + ]: 11528 : if (its != d_synth_fun_vars.end())
335 : : {
336 : 3180 : children.push_back(its->second);
337 : 3180 : makeEvalFun = true;
338 : : }
339 : : }
340 [ + + ]: 31656 : if (!makeEvalFun)
341 : : {
342 : : // otherwise, we apply the previous operator
343 [ + + ]: 28476 : if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
344 : : {
345 : 439 : children.push_back(cur.getOperator());
346 : : }
347 : : }
348 [ + + ]: 78942 : for (unsigned i = 0; i < cur.getNumChildren(); i++)
349 : : {
350 : 47286 : it = visited.find(cur[i]);
351 [ - + ][ - + ]: 47286 : Assert(it != visited.end());
[ - - ]
352 [ - + ][ - + ]: 47286 : Assert(!it->second.isNull());
[ - - ]
353 [ + + ][ + + ]: 47286 : childChanged = childChanged || cur[i] != it->second;
[ + + ][ - - ]
354 : 47286 : children.push_back(it->second);
355 : : }
356 [ + + ]: 31656 : if (makeEvalFun)
357 : : {
358 [ + + ]: 3180 : if (!cur.getType().isFunction())
359 : : {
360 : : // will make into an application of an evaluation function
361 : 3170 : ret = nm->mkNode(Kind::DT_SYGUS_EVAL, children);
362 : : }
363 : : else
364 : : {
365 [ - + ][ - + ]: 10 : Assert(children.size() == 1);
[ - - ]
366 : 20 : Node ef = children[0];
367 : : // Otherwise, we are using the function-to-synthesize itself in a
368 : : // higher-order setting. We must return the lambda term:
369 : : // lambda x1...xn. (DT_SYGUS_EVAL ef x1 ... xn)
370 : : // where ef is the first order variable for the
371 : : // function-to-synthesize.
372 : 10 : SygusTypeInfo& ti = d_tds->getTypeInfo(ef.getType());
373 : 10 : const std::vector<Node>& vars = ti.getVarList();
374 [ - + ][ - + ]: 10 : Assert(!vars.empty());
[ - - ]
375 : 20 : std::vector<Node> vs;
376 [ + + ]: 24 : for (const Node& v : vars)
377 : : {
378 : 14 : vs.push_back(nm->mkBoundVar(v.getType()));
379 : : }
380 : 20 : Node lvl = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
381 : 10 : std::vector<Node> eargs;
382 : 10 : eargs.push_back(ef);
383 : 10 : eargs.insert(eargs.end(), vs.begin(), vs.end());
384 : 20 : ret = nm->mkNode(
385 : 30 : Kind::LAMBDA, lvl, nm->mkNode(Kind::DT_SYGUS_EVAL, eargs));
386 : : }
387 : : }
388 [ + + ]: 28476 : else if (childChanged)
389 : : {
390 : 10445 : ret = nm->mkNode(ret_k, children);
391 : : }
392 : 31656 : visited[cur] = ret;
393 : : }
394 [ + + ]: 81296 : } while (!visit.empty());
395 [ - + ][ - + ]: 2354 : Assert(visited.find(n) != visited.end());
[ - - ]
396 [ - + ][ - + ]: 2354 : Assert(!visited.find(n)->second.isNull());
[ - - ]
397 : 4708 : return visited[n];
398 : : }
399 : :
400 : : } // namespace quantifiers
401 : : } // namespace theory
402 : : } // namespace cvc5::internal
|