Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Abdalrhman Mohamed, Andrew Reynolds, Aina Niemetz
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Implementation of class for constructing SyGuS Grammars.
14 : : */
15 : :
16 : : #include "expr/sygus_grammar.h"
17 : :
18 : : #include <sstream>
19 : :
20 : : #include "expr/dtype.h"
21 : : #include "expr/dtype_cons.h"
22 : : #include "expr/skolem_manager.h"
23 : : #include "printer/printer.h"
24 : : #include "printer/smt2/smt2_printer.h"
25 : : #include "theory/datatypes/sygus_datatype_utils.h"
26 : : #include "util/hash.h"
27 : :
28 : : namespace cvc5::internal {
29 : :
30 : 3775 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
31 : 3775 : const std::vector<Node>& ntSyms)
32 : 3775 : : d_sygusVars(sygusVars), d_ntSyms(ntSyms)
33 : : {
34 [ + + ]: 10351 : for (const Node& ntSym : ntSyms)
35 : : {
36 : 6576 : d_rules.emplace(ntSym, std::vector<Node>{});
37 : : }
38 : 3775 : }
39 : :
40 : 0 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
41 : 0 : const TypeNode& sdt)
42 : 0 : : d_sygusVars(sygusVars)
43 : : {
44 : 0 : Assert(sdt.isSygusDatatype());
45 : 0 : std::vector<TypeNode> tnlist;
46 : : // ensure that sdt is first
47 : 0 : tnlist.push_back(sdt);
48 : 0 : std::map<TypeNode, Node> ntsyms;
49 : 0 : NodeManager* nm = NodeManager::currentNM();
50 [ - - ]: 0 : for (size_t i = 0; i < tnlist.size(); i++)
51 : : {
52 : 0 : TypeNode tn = tnlist[i];
53 : 0 : Assert(tn.isSygusDatatype());
54 : 0 : const DType& dt = tn.getDType();
55 : 0 : std::stringstream ss;
56 : 0 : ss << dt.getName();
57 : 0 : Node v = nm->mkBoundVar(ss.str(), dt.getSygusType());
58 : 0 : ntsyms[tn] = v;
59 : 0 : d_ntSyms.push_back(v);
60 : 0 : d_rules.emplace(v, std::vector<Node>{});
61 : : // process the subfield types
62 : 0 : std::unordered_set<TypeNode> tns = dt.getSubfieldTypes();
63 [ - - ]: 0 : for (const TypeNode& tnsc : tns)
64 : : {
65 : 0 : if (tnsc.isSygusDatatype()
66 : 0 : && std::find(tnlist.begin(), tnlist.end(), tnsc) == tnlist.end())
67 : : {
68 : 0 : tnlist.push_back(tnsc);
69 : : }
70 : : }
71 : : }
72 : 0 : std::map<TypeNode, Node>::iterator itn;
73 [ - - ]: 0 : for (const TypeNode& tn : tnlist)
74 : : {
75 [ - - ]: 0 : if (!tn.isSygusDatatype())
76 : : {
77 : 0 : continue;
78 : : }
79 : 0 : Node nts = ntsyms[tn];
80 : 0 : const DType& dt = tn.getDType();
81 [ - - ]: 0 : for (size_t i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
82 : : {
83 : 0 : const DTypeConstructor& cons = dt[i];
84 [ - - ]: 0 : if (cons.isSygusAnyConstant())
85 : : {
86 : 0 : addAnyConstant(nts, cons[0].getRangeType());
87 : 0 : continue;
88 : : }
89 : 0 : Node op = cons.getSygusOp();
90 : 0 : std::vector<Node> args;
91 [ - - ]: 0 : for (size_t j = 0, nargs = cons.getNumArgs(); j < nargs; j++)
92 : : {
93 : 0 : TypeNode argType = cons[j].getRangeType();
94 : 0 : itn = ntsyms.find(argType);
95 : 0 : Assert(itn != ntsyms.end()) << "Missing " << argType << " in " << op;
96 : 0 : args.push_back(itn->second);
97 : : }
98 : 0 : Node rule = theory::datatypes::utils::mkSygusTerm(op, args, true);
99 : 0 : addRule(nts, rule);
100 : : }
101 : : }
102 : 0 : }
103 : :
104 : 41304 : void SygusGrammar::addRule(const Node& ntSym, const Node& rule)
105 : : {
106 [ - + ][ - + ]: 41304 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
107 [ - + ][ - + ]: 41304 : Assert(rule.getType().isInstanceOf(ntSym.getType()));
[ - - ]
108 : : // avoid duplication
109 : 41304 : std::vector<Node>& rs = d_rules[ntSym];
110 [ + + ]: 41304 : if (std::find(rs.begin(), rs.end(), rule) == rs.end())
111 : : {
112 : 32393 : rs.push_back(rule);
113 : : }
114 : 41304 : }
115 : :
116 : 56 : void SygusGrammar::addRules(const Node& ntSym, const std::vector<Node>& rules)
117 : : {
118 [ + + ]: 122 : for (const Node& rule : rules)
119 : : {
120 : 66 : addRule(ntSym, rule);
121 : : }
122 : 56 : }
123 : :
124 : 174 : void SygusGrammar::addAnyConstant(const Node& ntSym, const TypeNode& tn)
125 : : {
126 [ - + ][ - + ]: 174 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
127 [ - + ][ - + ]: 174 : Assert(tn.isInstanceOf(ntSym.getType()));
[ - - ]
128 : 174 : SkolemManager* sm = NodeManager::currentNM()->getSkolemManager();
129 : : Node anyConst =
130 : 522 : sm->mkInternalSkolemFunction(InternalSkolemId::SYGUS_ANY_CONSTANT, tn);
131 : 174 : addRule(ntSym, anyConst);
132 : 174 : }
133 : :
134 : 83 : void SygusGrammar::addAnyVariable(const Node& ntSym)
135 : : {
136 [ - + ][ - + ]: 83 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
137 : : // each variable of appropriate type becomes a rule.
138 [ + + ]: 253 : for (const Node& v : d_sygusVars)
139 : : {
140 [ + + ]: 170 : if (v.getType().isInstanceOf(ntSym.getType()))
141 : : {
142 : 125 : addRule(ntSym, v);
143 : : }
144 : : }
145 : 83 : }
146 : :
147 : 605 : void SygusGrammar::removeRule(const Node& ntSym, const Node& rule)
148 : : {
149 : : std::unordered_map<Node, std::vector<Node>>::iterator itr =
150 : 605 : d_rules.find(ntSym);
151 [ - + ][ - + ]: 605 : Assert(itr != d_rules.end());
[ - - ]
152 : : std::vector<Node>::iterator it =
153 : 605 : std::find(itr->second.begin(), itr->second.end(), rule);
154 [ - + ][ - + ]: 605 : Assert(it != itr->second.end());
[ - - ]
155 : 605 : itr->second.erase(it);
156 : 605 : }
157 : :
158 : : /**
159 : : * Purify SyGuS grammar node.
160 : : *
161 : : * This returns a node where all occurrences of non-terminal symbols (those in
162 : : * the domain of \p ntsToUnres) are replaced by fresh variables. For each
163 : : * variable replaced in this way, we add the fresh variable it is replaced with
164 : : * to \p args, and the unresolved types corresponding to the non-terminal symbol
165 : : * to \p cargs (constructor args). In other words, \p args contains the free
166 : : * variables in the node returned by this method (which should be bound by a
167 : : * lambda), and \p cargs contains the types of the arguments of the sygus
168 : : * constructor.
169 : : *
170 : : * @param n The node to purify.
171 : : * @param args The free variables in the node returned by this method.
172 : : * @param ntSymMap Map from each variable in args to the non-terminal they were
173 : : * introduced for.
174 : : * @param nts The list of non-terminal symbols
175 : : * @return The purfied node.
176 : : */
177 : 64300 : Node purifySygusGNode(const Node& n,
178 : : std::vector<Node>& args,
179 : : std::map<Node, Node>& ntSymMap,
180 : : const std::vector<Node>& nts)
181 : : {
182 : 64300 : NodeManager* nm = NodeManager::currentNM();
183 : : // if n is non-terminal
184 [ + + ]: 64300 : if (std::find(nts.begin(), nts.end(), n) != nts.end())
185 : : {
186 : 64486 : Node ret = nm->mkBoundVar(n.getType());
187 : 32243 : ntSymMap[ret] = n;
188 : 32243 : args.push_back(ret);
189 : 32243 : return ret;
190 : : }
191 : 64114 : std::vector<Node> pchildren;
192 : 32057 : bool childChanged = false;
193 [ + + ]: 64949 : for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
194 : : {
195 : 32892 : Node ptermc = purifySygusGNode(n[i], args, ntSymMap, nts);
196 : 32892 : pchildren.push_back(ptermc);
197 [ + + ][ + + ]: 32892 : childChanged = childChanged || ptermc != n[i];
[ + + ][ - - ]
198 : : }
199 [ + + ]: 32057 : if (!childChanged)
200 : : {
201 : 15482 : return n;
202 : : }
203 : 33150 : internal::Node nret;
204 [ + + ]: 16575 : if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
205 : : {
206 : : // it's an indexed operator so we should provide the op
207 : 767 : internal::NodeBuilder nb(n.getKind());
208 : 767 : nb << n.getOperator();
209 : 767 : nb.append(pchildren);
210 : 767 : nret = nb.constructNode();
211 : : }
212 : : else
213 : : {
214 : 15808 : nret = nm->mkNode(n.getKind(), pchildren);
215 : : }
216 : 16575 : return nret;
217 : : }
218 : :
219 : 31408 : bool isId(const Node& n)
220 : : {
221 [ + + ][ - - ]: 48039 : return n.getKind() == Kind::LAMBDA && n[0].getNumChildren() == 1
222 : 79447 : && n[0][0] == n[1];
223 : : }
224 : :
225 : : /**
226 : : * Add \p rule to the set of constructors of \p dt.
227 : : *
228 : : * @param dt The datatype to which the rule is added.
229 : : * @param rule The rule to add.
230 : : * @param ntsToUnres Mapping from non-terminals to their unresolved types.
231 : : */
232 : 31565 : void addSygusConstructor(DType& dt,
233 : : const Node& rule,
234 : : const std::vector<Node>& nts,
235 : : const std::unordered_map<Node, TypeNode>& ntsToUnres)
236 : : {
237 : 31565 : NodeManager* nm = NodeManager::currentNM();
238 : 31565 : SkolemManager* sm = nm->getSkolemManager();
239 : 63130 : std::stringstream ss;
240 : 94695 : if (rule.getKind() == Kind::SKOLEM
241 [ + + ][ + + ]: 31565 : && sm->getInternalId(rule) == InternalSkolemId::SYGUS_ANY_CONSTANT)
[ + + ][ + + ]
[ - - ]
242 : : {
243 : 157 : ss << dt.getName() << "_any_constant";
244 : 314 : dt.addSygusConstructor(rule, ss.str(), {rule.getType()}, 0);
245 : : }
246 : : else
247 : : {
248 : 62816 : std::vector<Node> args;
249 : 62816 : std::map<Node, Node> ntSymMap;
250 : 62816 : Node op = purifySygusGNode(rule, args, ntSymMap, nts);
251 : 31408 : std::vector<TypeNode> cargs;
252 : 31408 : std::unordered_map<Node, TypeNode>::const_iterator it;
253 [ + + ]: 63651 : for (const Node& a : args)
254 : : {
255 [ - + ][ - + ]: 32243 : Assert(ntSymMap.find(a) != ntSymMap.end());
[ - - ]
256 : 64486 : Node na = ntSymMap[a];
257 : 32243 : it = ntsToUnres.find(na);
258 [ - + ][ - + ]: 32243 : Assert(it != ntsToUnres.end());
[ - - ]
259 : 32243 : cargs.push_back(it->second);
260 : : }
261 : 31408 : ss << op.getKind();
262 [ + + ]: 31408 : if (!args.empty())
263 : : {
264 : 16631 : Node lbvl = nm->mkNode(Kind::BOUND_VAR_LIST, args);
265 : 16631 : op = nm->mkNode(Kind::LAMBDA, lbvl, op);
266 : : }
267 : : // assign identity rules a weight of 0.
268 [ + + ]: 31408 : dt.addSygusConstructor(op, ss.str(), cargs, isId(op) ? 0 : -1);
269 : : }
270 : 31565 : }
271 : :
272 : 0 : Node SygusGrammar::getLambdaForRule(const Node& r,
273 : : std::map<Node, Node>& ntSymMap) const
274 : : {
275 : 0 : std::vector<Node> args;
276 : 0 : Node rp = purifySygusGNode(r, args, ntSymMap, d_ntSyms);
277 [ - - ]: 0 : if (!args.empty())
278 : : {
279 : 0 : NodeManager* nm = NodeManager::currentNM();
280 : 0 : return nm->mkNode(Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, args), rp);
281 : : }
282 : 0 : return r;
283 : : }
284 : :
285 : 58 : bool SygusGrammar::hasRules() const
286 : : {
287 [ + + ]: 87 : for (const auto& r : d_rules)
288 : : {
289 [ + + ]: 58 : if (r.second.size() > 0)
290 : : {
291 : 29 : return true;
292 : : }
293 : : }
294 : 29 : return false;
295 : : }
296 : :
297 : 3727 : TypeNode SygusGrammar::resolve(bool allowAny)
298 : : {
299 [ + + ]: 3727 : if (!isResolved())
300 : : {
301 : 3122 : NodeManager* nm = NodeManager::currentNM();
302 : 3122 : SkolemManager* sm = nm->getSkolemManager();
303 : 6244 : Node bvl;
304 [ + + ]: 3122 : if (!d_sygusVars.empty())
305 : : {
306 : 1612 : bvl = nm->mkNode(Kind::BOUND_VAR_LIST, d_sygusVars);
307 : : }
308 : 6244 : std::unordered_map<Node, TypeNode> ntsToUnres;
309 [ + + ]: 9028 : for (const Node& ntSym : d_ntSyms)
310 : : {
311 : : // make the unresolved type, used for referencing the final version of
312 : : // the ntSym's datatype
313 : 5906 : ntsToUnres.emplace(ntSym, nm->mkUnresolvedDatatypeSort(ntSym.getName()));
314 : : }
315 : : // Set of non-terminals that can be arbitrary constants.
316 : 6244 : std::unordered_set<Node> allowConsts;
317 : : // push the rules into the sygus datatypes
318 : 3122 : std::vector<DType> dts;
319 [ + + ]: 9028 : for (const Node& ntSym : d_ntSyms)
320 : : {
321 : : // make the datatype, which encodes terms generated by this non-terminal
322 : 11812 : DType dt(ntSym.getName());
323 : :
324 [ + + ]: 37471 : for (const Node& rule : d_rules[ntSym])
325 : : {
326 : 94695 : if (rule.getKind() == Kind::SKOLEM
327 [ + + ][ + + ]: 31565 : && sm->getInternalId(rule) == InternalSkolemId::SYGUS_ANY_CONSTANT)
[ + + ][ + + ]
[ - - ]
328 : : {
329 : 157 : allowConsts.insert(ntSym);
330 : : }
331 : 31565 : addSygusConstructor(dt, rule, d_ntSyms, ntsToUnres);
332 : : }
333 : 5906 : bool allowConst = allowConsts.find(ntSym) != allowConsts.end();
334 [ + + ][ + + ]: 5906 : dt.setSygus(ntSym.getType(), bvl, allowConst || allowAny, allowAny);
335 : : // We can be in a case where the only rule specified was (Variable T)
336 : : // and there are no variables of type T, in which case this is a bogus
337 : : // grammar. This results in the error below.
338 [ - + ][ - + ]: 5906 : Assert(dt.getNumConstructors() != 0) << "Grouped rule listing for " << dt
[ - - ]
339 : 0 : << " produced an empty rule list";
340 : 5906 : dts.push_back(dt);
341 : : }
342 : 3122 : d_datatype = nm->mkMutualDatatypeTypes(dts)[0];
343 : : }
344 : : // return the first datatype
345 : 3727 : return d_datatype;
346 : : }
347 : :
348 : 12143 : bool SygusGrammar::isResolved() { return !d_datatype.isNull(); }
349 : :
350 : 0 : const std::vector<Node>& SygusGrammar::getSygusVars() const
351 : : {
352 : 0 : return d_sygusVars;
353 : : }
354 : :
355 : 13677 : const std::vector<Node>& SygusGrammar::getNtSyms() const { return d_ntSyms; }
356 : :
357 : 13688 : const std::vector<Node>& SygusGrammar::getRulesFor(const Node& ntSym) const
358 : : {
359 : : std::unordered_map<Node, std::vector<Node>>::const_iterator itr =
360 : 13688 : d_rules.find(ntSym);
361 [ - + ][ - + ]: 13688 : Assert(itr != d_rules.end());
[ - - ]
362 : 13688 : return itr->second;
363 : : }
364 : :
365 : 29 : std::string SygusGrammar::toString() const
366 : : {
367 : 29 : std::stringstream ss;
368 : : // clone this grammar before printing it to avoid freezing it.
369 : : return printer::smt2::Smt2Printer::sygusGrammarString(
370 : 87 : SygusGrammar(*this).resolve());
371 : : }
372 : :
373 : : } // namespace cvc5::internal
374 : :
375 : : namespace std {
376 : 2376 : size_t hash<cvc5::internal::SygusGrammar>::operator()(
377 : : const cvc5::internal::SygusGrammar& grammar) const
378 : : {
379 : 2376 : uint64_t ret = cvc5::internal::fnv1a::offsetBasis;
380 [ + + ]: 3385 : for (const auto& v : grammar.d_sygusVars)
381 : : {
382 : 1009 : ret = cvc5::internal::fnv1a::fnv1a_64(ret,
383 : 1009 : std::hash<cvc5::internal::Node>{}(v));
384 : : }
385 [ + + ]: 4752 : for (const auto& nts : grammar.d_ntSyms)
386 : : {
387 : 2376 : ret = cvc5::internal::fnv1a::fnv1a_64(
388 : 2376 : ret, std::hash<cvc5::internal::Node>{}(nts));
389 : : }
390 [ + + ]: 4752 : for (const auto& r : grammar.d_rules)
391 : : {
392 : 2376 : uint64_t rhash = cvc5::internal::fnv1a::offsetBasis;
393 [ + + ]: 2406 : for (const auto& n : r.second)
394 : : {
395 : 30 : rhash = cvc5::internal::fnv1a::fnv1a_64(
396 : 30 : rhash, std::hash<cvc5::internal::Node>{}(n));
397 : : }
398 : 2376 : rhash = cvc5::internal::fnv1a::fnv1a_64(
399 : 2376 : rhash, std::hash<cvc5::internal::Node>{}(r.first));
400 : 2376 : ret = cvc5::internal::fnv1a::fnv1a_64(ret, rhash);
401 : : }
402 : 2376 : return cvc5::internal::fnv1a::fnv1a_64(
403 : 4752 : ret, std::hash<cvc5::internal::TypeNode>{}(grammar.d_datatype));
404 : : }
405 : : } // namespace std
|