Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Abdalrhman Mohamed, Andrew Reynolds, Aina Niemetz
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Implementation of class for constructing SyGuS Grammars.
14 : : */
15 : :
16 : : #include "expr/sygus_grammar.h"
17 : :
18 : : #include <sstream>
19 : :
20 : : #include "expr/dtype.h"
21 : : #include "expr/dtype_cons.h"
22 : : #include "printer/printer.h"
23 : : #include "printer/smt2/smt2_printer.h"
24 : : #include "theory/datatypes/sygus_datatype_utils.h"
25 : : #include "expr/skolem_manager.h"
26 : : #include "util/hash.h"
27 : :
28 : : namespace cvc5::internal {
29 : :
30 : 3775 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
31 : 3775 : const std::vector<Node>& ntSyms)
32 : 3775 : : d_sygusVars(sygusVars), d_ntSyms(ntSyms)
33 : : {
34 [ + + ]: 10350 : for (const Node& ntSym : ntSyms)
35 : : {
36 : 6575 : d_rules.emplace(ntSym, std::vector<Node>{});
37 : : }
38 : 3775 : }
39 : :
40 : 0 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
41 : 0 : const TypeNode& sdt)
42 : 0 : : d_sygusVars(sygusVars)
43 : : {
44 : 0 : Assert(sdt.isSygusDatatype());
45 : 0 : std::vector<TypeNode> tnlist;
46 : : // ensure that sdt is first
47 : 0 : tnlist.push_back(sdt);
48 : 0 : std::map<TypeNode, Node> ntsyms;
49 [ - - ]: 0 : for (size_t i = 0; i < tnlist.size(); i++)
50 : : {
51 : 0 : TypeNode tn = tnlist[i];
52 : 0 : Assert(tn.isSygusDatatype());
53 : 0 : const DType& dt = tn.getDType();
54 : 0 : std::stringstream ss;
55 : 0 : ss << dt.getName();
56 : 0 : Node v = NodeManager::mkBoundVar(ss.str(), dt.getSygusType());
57 : 0 : ntsyms[tn] = v;
58 : 0 : d_ntSyms.push_back(v);
59 : 0 : d_rules.emplace(v, std::vector<Node>{});
60 : : // process the subfield types
61 : 0 : std::unordered_set<TypeNode> tns = dt.getSubfieldTypes();
62 [ - - ]: 0 : for (const TypeNode& tnsc : tns)
63 : : {
64 : 0 : if (tnsc.isSygusDatatype()
65 : 0 : && std::find(tnlist.begin(), tnlist.end(), tnsc) == tnlist.end())
66 : : {
67 : 0 : tnlist.push_back(tnsc);
68 : : }
69 : : }
70 : : }
71 : 0 : std::map<TypeNode, Node>::iterator itn;
72 [ - - ]: 0 : for (const TypeNode& tn : tnlist)
73 : : {
74 [ - - ]: 0 : if (!tn.isSygusDatatype())
75 : : {
76 : 0 : continue;
77 : : }
78 : 0 : Node nts = ntsyms[tn];
79 : 0 : const DType& dt = tn.getDType();
80 [ - - ]: 0 : for (size_t i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
81 : : {
82 : 0 : const DTypeConstructor& cons = dt[i];
83 [ - - ]: 0 : if (cons.isSygusAnyConstant())
84 : : {
85 : 0 : addAnyConstant(nts, cons[0].getRangeType());
86 : 0 : continue;
87 : : }
88 : 0 : Node op = cons.getSygusOp();
89 : 0 : std::vector<Node> args;
90 [ - - ]: 0 : for (size_t j = 0, nargs = cons.getNumArgs(); j < nargs; j++)
91 : : {
92 : 0 : TypeNode argType = cons[j].getRangeType();
93 : 0 : itn = ntsyms.find(argType);
94 : 0 : Assert(itn != ntsyms.end()) << "Missing " << argType << " in " << op;
95 : 0 : args.push_back(itn->second);
96 : : }
97 : 0 : Node rule = theory::datatypes::utils::mkSygusTerm(op, args, true);
98 : 0 : addRule(nts, rule);
99 : : }
100 : : }
101 : 0 : }
102 : :
103 : 41310 : void SygusGrammar::addRule(const Node& ntSym, const Node& rule)
104 : : {
105 [ - + ][ - + ]: 41310 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
106 [ - + ][ - + ]: 41310 : Assert(rule.getType().isInstanceOf(ntSym.getType()));
[ - - ]
107 : : // avoid duplication
108 : 41310 : std::vector<Node>& rs = d_rules[ntSym];
109 [ + + ]: 41310 : if (std::find(rs.begin(), rs.end(), rule) == rs.end())
110 : : {
111 : 32399 : rs.push_back(rule);
112 : : }
113 : 41310 : }
114 : :
115 : 56 : void SygusGrammar::addRules(const Node& ntSym, const std::vector<Node>& rules)
116 : : {
117 [ + + ]: 122 : for (const Node& rule : rules)
118 : : {
119 : 66 : addRule(ntSym, rule);
120 : : }
121 : 56 : }
122 : :
123 : 174 : void SygusGrammar::addAnyConstant(const Node& ntSym, const TypeNode& tn)
124 : : {
125 [ - + ][ - + ]: 174 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
126 [ - + ][ - + ]: 174 : Assert(tn.isInstanceOf(ntSym.getType()));
[ - - ]
127 : 174 : SkolemManager* sm = NodeManager::currentNM()->getSkolemManager();
128 : : Node anyConst =
129 : 522 : sm->mkInternalSkolemFunction(InternalSkolemId::SYGUS_ANY_CONSTANT, tn);
130 : 174 : addRule(ntSym, anyConst);
131 : 174 : }
132 : :
133 : 83 : void SygusGrammar::addAnyVariable(const Node& ntSym)
134 : : {
135 [ - + ][ - + ]: 83 : Assert(d_rules.find(ntSym) != d_rules.cend());
[ - - ]
136 : : // each variable of appropriate type becomes a rule.
137 [ + + ]: 253 : for (const Node& v : d_sygusVars)
138 : : {
139 [ + + ]: 170 : if (v.getType().isInstanceOf(ntSym.getType()))
140 : : {
141 : 125 : addRule(ntSym, v);
142 : : }
143 : : }
144 : 83 : }
145 : :
146 : 605 : void SygusGrammar::removeRule(const Node& ntSym, const Node& rule)
147 : : {
148 : : std::unordered_map<Node, std::vector<Node>>::iterator itr =
149 : 605 : d_rules.find(ntSym);
150 [ - + ][ - + ]: 605 : Assert(itr != d_rules.end());
[ - - ]
151 : : std::vector<Node>::iterator it =
152 : 605 : std::find(itr->second.begin(), itr->second.end(), rule);
153 [ - + ][ - + ]: 605 : Assert(it != itr->second.end());
[ - - ]
154 : 605 : itr->second.erase(it);
155 : 605 : }
156 : :
157 : : /**
158 : : * Purify SyGuS grammar node.
159 : : *
160 : : * This returns a node where all occurrences of non-terminal symbols (those in
161 : : * the domain of \p ntsToUnres) are replaced by fresh variables. For each
162 : : * variable replaced in this way, we add the fresh variable it is replaced with
163 : : * to \p args, and the unresolved types corresponding to the non-terminal symbol
164 : : * to \p cargs (constructor args). In other words, \p args contains the free
165 : : * variables in the node returned by this method (which should be bound by a
166 : : * lambda), and \p cargs contains the types of the arguments of the sygus
167 : : * constructor.
168 : : *
169 : : * @param n The node to purify.
170 : : * @param args The free variables in the node returned by this method.
171 : : * @param ntSymMap Map from each variable in args to the non-terminal they were
172 : : * introduced for.
173 : : * @param nts The list of non-terminal symbols
174 : : * @return The purfied node.
175 : : */
176 : 64335 : Node purifySygusGNode(const Node& n,
177 : : std::vector<Node>& args,
178 : : std::map<Node, Node>& ntSymMap,
179 : : const std::vector<Node>& nts)
180 : : {
181 : 64335 : NodeManager* nm = NodeManager::currentNM();
182 : : // if n is non-terminal
183 [ + + ]: 64335 : if (std::find(nts.begin(), nts.end(), n) != nts.end())
184 : : {
185 : 64544 : Node ret = NodeManager::mkBoundVar(n.getType());
186 : 32272 : ntSymMap[ret] = n;
187 : 32272 : args.push_back(ret);
188 : 32272 : return ret;
189 : : }
190 : 64126 : std::vector<Node> pchildren;
191 : 32063 : bool childChanged = false;
192 [ + + ]: 64984 : for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
193 : : {
194 : 32921 : Node ptermc = purifySygusGNode(n[i], args, ntSymMap, nts);
195 : 32921 : pchildren.push_back(ptermc);
196 [ + + ][ + + ]: 32921 : childChanged = childChanged || ptermc != n[i];
[ + + ][ - - ]
197 : : }
198 [ + + ]: 32063 : if (!childChanged)
199 : : {
200 : 15473 : return n;
201 : : }
202 : 33180 : internal::Node nret;
203 [ + + ]: 16590 : if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
204 : : {
205 : : // it's an indexed operator so we should provide the op
206 : 766 : internal::NodeBuilder nb(NodeManager::currentNM(), n.getKind());
207 : 766 : nb << n.getOperator();
208 : 766 : nb.append(pchildren);
209 : 766 : nret = nb.constructNode();
210 : : }
211 : : else
212 : : {
213 : 15824 : nret = nm->mkNode(n.getKind(), pchildren);
214 : : }
215 : 16590 : return nret;
216 : : }
217 : :
218 : 31414 : bool isId(const Node& n)
219 : : {
220 [ + + ][ - - ]: 48060 : return n.getKind() == Kind::LAMBDA && n[0].getNumChildren() == 1
221 : 79474 : && n[0][0] == n[1];
222 : : }
223 : :
224 : : /**
225 : : * Add \p rule to the set of constructors of \p dt.
226 : : *
227 : : * @param dt The datatype to which the rule is added.
228 : : * @param rule The rule to add.
229 : : * @param ntsToUnres Mapping from non-terminals to their unresolved types.
230 : : */
231 : 31571 : void addSygusConstructor(DType& dt,
232 : : const Node& rule,
233 : : const std::vector<Node>& nts,
234 : : const std::unordered_map<Node, TypeNode>& ntsToUnres)
235 : : {
236 : 31571 : NodeManager* nm = NodeManager::currentNM();
237 : 63142 : std::stringstream ss;
238 : 31571 : if (rule.getKind() == Kind::SKOLEM
239 [ + + ][ + + ]: 31571 : && rule.getInternalSkolemId() == InternalSkolemId::SYGUS_ANY_CONSTANT)
[ + + ]
240 : : {
241 : 157 : ss << dt.getName() << "_any_constant";
242 : 314 : dt.addSygusConstructor(rule, ss.str(), {rule.getType()}, 0);
243 : : }
244 : : else
245 : : {
246 : 62828 : std::vector<Node> args;
247 : 62828 : std::map<Node, Node> ntSymMap;
248 : 62828 : Node op = purifySygusGNode(rule, args, ntSymMap, nts);
249 : 31414 : std::vector<TypeNode> cargs;
250 : 31414 : std::unordered_map<Node, TypeNode>::const_iterator it;
251 [ + + ]: 63686 : for (const Node& a : args)
252 : : {
253 [ - + ][ - + ]: 32272 : Assert(ntSymMap.find(a) != ntSymMap.end());
[ - - ]
254 : 64544 : Node na = ntSymMap[a];
255 : 32272 : it = ntsToUnres.find(na);
256 [ - + ][ - + ]: 32272 : Assert(it != ntsToUnres.end());
[ - - ]
257 : 32272 : cargs.push_back(it->second);
258 : : }
259 : 31414 : ss << op.getKind();
260 [ + + ]: 31414 : if (!args.empty())
261 : : {
262 : 16646 : Node lbvl = nm->mkNode(Kind::BOUND_VAR_LIST, args);
263 : 16646 : op = nm->mkNode(Kind::LAMBDA, lbvl, op);
264 : : }
265 : : // assign identity rules a weight of 0.
266 [ + + ]: 31414 : dt.addSygusConstructor(op, ss.str(), cargs, isId(op) ? 0 : -1);
267 : : }
268 : 31571 : }
269 : :
270 : 0 : Node SygusGrammar::getLambdaForRule(const Node& r,
271 : : std::map<Node, Node>& ntSymMap) const
272 : : {
273 : 0 : std::vector<Node> args;
274 : 0 : Node rp = purifySygusGNode(r, args, ntSymMap, d_ntSyms);
275 [ - - ]: 0 : if (!args.empty())
276 : : {
277 : 0 : NodeManager* nm = NodeManager::currentNM();
278 : 0 : return nm->mkNode(Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, args), rp);
279 : : }
280 : 0 : return r;
281 : : }
282 : :
283 : 58 : bool SygusGrammar::hasRules() const
284 : : {
285 [ + + ]: 87 : for (const auto& r : d_rules)
286 : : {
287 [ + + ]: 58 : if (r.second.size() > 0)
288 : : {
289 : 29 : return true;
290 : : }
291 : : }
292 : 29 : return false;
293 : : }
294 : :
295 : 3727 : TypeNode SygusGrammar::resolve(bool allowAny)
296 : : {
297 [ + + ]: 3727 : if (!isResolved())
298 : : {
299 : 3122 : NodeManager* nm = NodeManager::currentNM();
300 : 6244 : Node bvl;
301 [ + + ]: 3122 : if (!d_sygusVars.empty())
302 : : {
303 : 1611 : bvl = nm->mkNode(Kind::BOUND_VAR_LIST, d_sygusVars);
304 : : }
305 : 6244 : std::unordered_map<Node, TypeNode> ntsToUnres;
306 [ + + ]: 9027 : for (const Node& ntSym : d_ntSyms)
307 : : {
308 : : // make the unresolved type, used for referencing the final version of
309 : : // the ntSym's datatype
310 : 5905 : ntsToUnres.emplace(ntSym, nm->mkUnresolvedDatatypeSort(ntSym.getName()));
311 : : }
312 : : // Set of non-terminals that can be arbitrary constants.
313 : 6244 : std::unordered_set<Node> allowConsts;
314 : : // push the rules into the sygus datatypes
315 : 3122 : std::vector<DType> dts;
316 [ + + ]: 9027 : for (const Node& ntSym : d_ntSyms)
317 : : {
318 : : // make the datatype, which encodes terms generated by this non-terminal
319 : 11810 : DType dt(ntSym.getName());
320 : :
321 [ + + ]: 37476 : for (const Node& rule : d_rules[ntSym])
322 : : {
323 : 31571 : if (rule.getKind() == Kind::SKOLEM
324 [ + + ][ + + ]: 31571 : && rule.getInternalSkolemId() == InternalSkolemId::SYGUS_ANY_CONSTANT)
[ + + ]
325 : : {
326 : 157 : allowConsts.insert(ntSym);
327 : : }
328 : 31571 : addSygusConstructor(dt, rule, d_ntSyms, ntsToUnres);
329 : : }
330 : 5905 : bool allowConst = allowConsts.find(ntSym) != allowConsts.end();
331 [ + + ][ + + ]: 5905 : dt.setSygus(ntSym.getType(), bvl, allowConst || allowAny, allowAny);
332 : : // We can be in a case where the only rule specified was (Variable T)
333 : : // and there are no variables of type T, in which case this is a bogus
334 : : // grammar. This results in the error below.
335 [ - + ][ - + ]: 5905 : Assert(dt.getNumConstructors() != 0) << "Grouped rule listing for " << dt
[ - - ]
336 : 0 : << " produced an empty rule list";
337 : 5905 : dts.push_back(dt);
338 : : }
339 : 3122 : d_datatype = nm->mkMutualDatatypeTypes(dts)[0];
340 : : }
341 : : // return the first datatype
342 : 3727 : return d_datatype;
343 : : }
344 : :
345 : 12143 : bool SygusGrammar::isResolved() { return !d_datatype.isNull(); }
346 : :
347 : 0 : const std::vector<Node>& SygusGrammar::getSygusVars() const
348 : : {
349 : 0 : return d_sygusVars;
350 : : }
351 : :
352 : 13677 : const std::vector<Node>& SygusGrammar::getNtSyms() const { return d_ntSyms; }
353 : :
354 : 13683 : const std::vector<Node>& SygusGrammar::getRulesFor(const Node& ntSym) const
355 : : {
356 : : std::unordered_map<Node, std::vector<Node>>::const_iterator itr =
357 : 13683 : d_rules.find(ntSym);
358 [ - + ][ - + ]: 13683 : Assert(itr != d_rules.end());
[ - - ]
359 : 13683 : return itr->second;
360 : : }
361 : :
362 : 29 : std::string SygusGrammar::toString() const
363 : : {
364 : 29 : std::stringstream ss;
365 : : // clone this grammar before printing it to avoid freezing it.
366 : : return printer::smt2::Smt2Printer::sygusGrammarString(
367 : 87 : SygusGrammar(*this).resolve());
368 : : }
369 : :
370 : : } // namespace cvc5::internal
371 : :
372 : : namespace std {
373 : 2376 : size_t hash<cvc5::internal::SygusGrammar>::operator()(
374 : : const cvc5::internal::SygusGrammar& grammar) const
375 : : {
376 : 2376 : uint64_t ret = cvc5::internal::fnv1a::offsetBasis;
377 [ + + ]: 3385 : for (const auto& v : grammar.d_sygusVars)
378 : : {
379 : 1009 : ret = cvc5::internal::fnv1a::fnv1a_64(ret,
380 : 1009 : std::hash<cvc5::internal::Node>{}(v));
381 : : }
382 [ + + ]: 4752 : for (const auto& nts : grammar.d_ntSyms)
383 : : {
384 : 2376 : ret = cvc5::internal::fnv1a::fnv1a_64(
385 : 2376 : ret, std::hash<cvc5::internal::Node>{}(nts));
386 : : }
387 [ + + ]: 4752 : for (const auto& r : grammar.d_rules)
388 : : {
389 : 2376 : uint64_t rhash = cvc5::internal::fnv1a::offsetBasis;
390 [ + + ]: 2406 : for (const auto& n : r.second)
391 : : {
392 : 30 : rhash = cvc5::internal::fnv1a::fnv1a_64(
393 : 30 : rhash, std::hash<cvc5::internal::Node>{}(n));
394 : : }
395 : 2376 : rhash = cvc5::internal::fnv1a::fnv1a_64(
396 : 2376 : rhash, std::hash<cvc5::internal::Node>{}(r.first));
397 : 2376 : ret = cvc5::internal::fnv1a::fnv1a_64(ret, rhash);
398 : : }
399 : 2376 : return cvc5::internal::fnv1a::fnv1a_64(
400 : 4752 : ret, std::hash<cvc5::internal::TypeNode>{}(grammar.d_datatype));
401 : : }
402 : : } // namespace std
|