Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implementation of sygus_sampler.
11 : : */
12 : :
13 : : #include "theory/quantifiers/sygus_sampler.h"
14 : :
15 : : #include <sstream>
16 : :
17 : : #include "expr/dtype.h"
18 : : #include "expr/dtype_cons.h"
19 : : #include "expr/node_algorithm.h"
20 : : #include "options/base_options.h"
21 : : #include "options/quantifiers_options.h"
22 : : #include "printer/printer.h"
23 : : #include "theory/datatypes/sygus_datatype_utils.h"
24 : : #include "theory/quantifiers/lazy_trie.h"
25 : : #include "theory/quantifiers/sygus/term_database_sygus.h"
26 : : #include "theory/rewriter.h"
27 : : #include "util/bitvector.h"
28 : : #include "util/random.h"
29 : : #include "util/rational.h"
30 : : #include "util/sampler.h"
31 : : #include "util/string.h"
32 : :
33 : : using namespace cvc5::internal::kind;
34 : :
35 : : namespace cvc5::internal {
36 : : namespace theory {
37 : : namespace quantifiers {
38 : :
39 : 18872 : SygusSampler::SygusSampler(Env& env)
40 : 18872 : : EnvObj(env), d_use_sygus_type(false), d_is_valid(false)
41 : : {
42 : 18872 : }
43 : :
44 : 141 : void SygusSampler::initialize(CVC5_UNUSED TypeNode tn,
45 : : const std::vector<Node>& vars,
46 : : unsigned nsamples,
47 : : bool unique_type_ids)
48 : : {
49 : 141 : d_use_sygus_type = false;
50 : 141 : d_is_valid = true;
51 : 141 : d_ftn = TypeNode::null();
52 : 141 : d_type_vars.clear();
53 : 141 : d_vars.clear();
54 : 141 : d_rvalue_cindices.clear();
55 : 141 : d_rvalue_null_cindices.clear();
56 : 141 : d_rstring_alphabet.clear();
57 : 141 : d_var_sygus_types.clear();
58 : 141 : d_const_sygus_types.clear();
59 : 141 : d_vars.insert(d_vars.end(), vars.begin(), vars.end());
60 : 141 : std::map<TypeNode, unsigned> type_to_type_id;
61 : 141 : unsigned type_id_counter = 0;
62 [ + + ]: 663 : for (const Node& sv : d_vars)
63 : : {
64 : 522 : TypeNode svt = sv.getType();
65 : 522 : unsigned tnid = 0;
66 [ - + ]: 522 : if (unique_type_ids)
67 : : {
68 : 0 : tnid = type_id_counter;
69 : 0 : type_id_counter++;
70 : : }
71 : : else
72 : : {
73 : 522 : std::map<TypeNode, unsigned>::iterator itt = type_to_type_id.find(svt);
74 [ + + ]: 522 : if (itt == type_to_type_id.end())
75 : : {
76 : 134 : type_to_type_id[svt] = type_id_counter;
77 : 134 : type_id_counter++;
78 : : }
79 : : else
80 : : {
81 : 388 : tnid = itt->second;
82 : : }
83 : : }
84 [ + - ]: 1044 : Trace("sygus-sample-debug")
85 : 522 : << "Type id for " << sv << " is " << tnid << std::endl;
86 : 522 : d_var_index[sv] = d_type_vars[tnid].size();
87 : 522 : d_type_vars[tnid].push_back(sv);
88 : 522 : d_type_ids[sv] = tnid;
89 : 522 : }
90 : 141 : initializeSamples(nsamples);
91 : 141 : }
92 : :
93 : 53 : void SygusSampler::initializeSygus(TypeNode ftn, unsigned nsamples)
94 : : {
95 : 53 : d_is_valid = true;
96 : 53 : d_ftn = ftn;
97 [ - + ][ - + ]: 53 : Assert(d_ftn.isDatatype());
[ - - ]
98 : 53 : const DType& dt = d_ftn.getDType();
99 [ - + ][ - + ]: 53 : Assert(dt.isSygus());
[ - - ]
100 : :
101 [ + - ]: 53 : Trace("sygus-sample") << "Register sampler for " << ftn << std::endl;
102 : :
103 : 53 : d_vars.clear();
104 : 53 : d_type_vars.clear();
105 : 53 : d_var_index.clear();
106 : 53 : d_type_vars.clear();
107 : 53 : d_rvalue_cindices.clear();
108 : 53 : d_rvalue_null_cindices.clear();
109 : 53 : d_var_sygus_types.clear();
110 : : // get the sygus variable list
111 : 53 : Node var_list = dt.getSygusVarList();
112 [ + - ]: 53 : if (!var_list.isNull())
113 : : {
114 [ + + ]: 344 : for (const Node& sv : var_list)
115 : : {
116 : 291 : d_vars.push_back(sv);
117 : 291 : }
118 : : }
119 : : // register sygus type
120 : 53 : registerSygusType(d_ftn);
121 : : // Variables are associated with type ids based on the set of sygus types they
122 : : // appear in.
123 : 53 : std::map<Node, unsigned> var_to_type_id;
124 : 53 : unsigned type_id_counter = 0;
125 [ + + ]: 344 : for (const Node& sv : d_vars)
126 : : {
127 : 291 : TypeNode svt = sv.getType();
128 : : // is it equivalent to a previous variable?
129 [ + + ]: 1884 : for (const auto& v : var_to_type_id)
130 : : {
131 : 1593 : Node svc = v.first;
132 [ + + ]: 1593 : if (svc.getType() == svt)
133 : : {
134 [ + - ]: 429 : if (d_var_sygus_types[sv].size() == d_var_sygus_types[svc].size())
135 : : {
136 : 429 : bool success = true;
137 [ + + ]: 1004 : for (unsigned t = 0, size = d_var_sygus_types[sv].size(); t < size;
138 : : t++)
139 : : {
140 [ + + ]: 599 : if (d_var_sygus_types[sv][t] != d_var_sygus_types[svc][t])
141 : : {
142 : 24 : success = false;
143 : 24 : break;
144 : : }
145 : : }
146 [ + + ]: 429 : if (success)
147 : : {
148 : 405 : var_to_type_id[sv] = var_to_type_id[svc];
149 : : }
150 : : }
151 : : }
152 : 1593 : }
153 [ + + ]: 291 : if (var_to_type_id.find(sv) == var_to_type_id.end())
154 : : {
155 : 125 : var_to_type_id[sv] = type_id_counter;
156 : 125 : type_id_counter++;
157 : : }
158 : 291 : unsigned tnid = var_to_type_id[sv];
159 [ + - ]: 582 : Trace("sygus-sample-debug")
160 : 291 : << "Type id for " << sv << " is " << tnid << std::endl;
161 : 291 : d_var_index[sv] = d_type_vars[tnid].size();
162 : 291 : d_type_vars[tnid].push_back(sv);
163 : 291 : d_type_ids[sv] = tnid;
164 : 291 : }
165 : :
166 : 53 : initializeSamples(nsamples);
167 : 53 : }
168 : :
169 : 194 : void SygusSampler::initializeSamples(unsigned nsamples)
170 : : {
171 : 194 : d_samples.clear();
172 : 194 : std::vector<TypeNode> types;
173 [ + + ]: 1007 : for (const Node& v : d_vars)
174 : : {
175 : 813 : TypeNode vt = v.getType();
176 : 813 : types.push_back(vt);
177 [ + - ]: 1626 : Trace("sygus-sample") << " var #" << types.size() << " : " << v << " : "
178 : 813 : << vt << std::endl;
179 : 813 : }
180 : 194 : std::map<unsigned, std::map<Node, std::vector<TypeNode> >::iterator> sts;
181 [ + - ]: 194 : if (options().quantifiers.sygusSampleGrammar)
182 : : {
183 [ + + ]: 1007 : for (unsigned j = 0, size = types.size(); j < size; j++)
184 : : {
185 : 813 : sts[j] = d_var_sygus_types.find(d_vars[j]);
186 : : }
187 : : }
188 : :
189 : 194 : unsigned nduplicates = 0;
190 [ + + ]: 68194 : for (unsigned i = 0; i < nsamples; i++)
191 : : {
192 : 68000 : std::vector<Node> sample_pt;
193 [ + + ]: 499000 : for (unsigned j = 0, size = types.size(); j < size; j++)
194 : : {
195 : 431000 : Node v = d_vars[j];
196 : 431000 : Node r;
197 [ + - ]: 431000 : if (options().quantifiers.sygusSampleGrammar)
198 : : {
199 : : // choose a random start sygus type, if possible
200 [ + + ]: 431000 : if (sts[j] != d_var_sygus_types.end())
201 : : {
202 : 282000 : unsigned ntypes = sts[j]->second.size();
203 [ + + ]: 282000 : if(ntypes > 0)
204 : : {
205 : 192000 : unsigned index = Random::getRandom().pick(0, ntypes - 1);
206 [ + - ]: 192000 : if (index < ntypes)
207 : : {
208 : : // currently hard coded to 0.0, 0.5
209 : 192000 : r = getSygusRandomValue(sts[j]->second[index], 0.0, 0.5);
210 : : }
211 : : }
212 : : }
213 : : }
214 [ + + ]: 431000 : if (r.isNull())
215 : : {
216 : 239000 : r = getRandomValue(types[j]);
217 [ - + ]: 239000 : if (r.isNull())
218 : : {
219 : 0 : d_is_valid = false;
220 : : }
221 : : }
222 : 431000 : sample_pt.push_back(r);
223 : 431000 : }
224 [ + - ]: 68000 : if (d_samples_trie.add(sample_pt))
225 : : {
226 [ - + ]: 68000 : if (TraceIsOn("sygus-sample"))
227 : : {
228 [ - - ]: 0 : Trace("sygus-sample") << "Sample point #" << i << " : ";
229 [ - - ]: 0 : for (const Node& r : sample_pt)
230 : : {
231 [ - - ]: 0 : Trace("sygus-sample") << r << " ";
232 : : }
233 [ - - ]: 0 : Trace("sygus-sample") << std::endl;
234 : : }
235 : 68000 : d_samples.push_back(sample_pt);
236 : : }
237 : : else
238 : : {
239 : 0 : i--;
240 : 0 : nduplicates++;
241 [ - - ]: 0 : if (nduplicates == nsamples * 10)
242 : : {
243 [ - - ]: 0 : Trace("sygus-sample")
244 : 0 : << "...WARNING: excessive duplicates, cut off sampling at " << i
245 : 0 : << "/" << nsamples << " points." << std::endl;
246 : 0 : break;
247 : : }
248 : : }
249 [ + - ]: 68000 : }
250 : :
251 : 194 : d_trie.clear();
252 : 194 : }
253 : :
254 : 68000 : bool SygusSampler::PtTrie::add(std::vector<Node>& pt)
255 : : {
256 : 68000 : PtTrie* curr = this;
257 [ + + ]: 499000 : for (unsigned i = 0, size = pt.size(); i < size; i++)
258 : : {
259 : 431000 : curr = &(curr->d_children[pt[i]]);
260 : : }
261 : 68000 : bool retVal = curr->d_children.empty();
262 : 68000 : return retVal;
263 : : }
264 : :
265 : 2580 : Node SygusSampler::registerTerm(Node n, bool forceKeep)
266 : : {
267 [ - + ]: 2580 : if (!d_is_valid)
268 : : {
269 : : // do nothing
270 : 0 : return n;
271 : : }
272 : 2580 : TypeNode tn = n.getType();
273 : : // cache based on the (original) type of n
274 : 2580 : return d_trie[tn].add(n, this, 0, d_samples.size(), forceKeep);
275 : 2580 : }
276 : :
277 : 0 : bool SygusSampler::isContiguous(Node n)
278 : : {
279 : : // compute free variables in n
280 : 0 : std::vector<Node> fvs;
281 : 0 : computeFreeVariables(n, fvs);
282 : : // compute contiguous condition
283 [ - - ]: 0 : for (const std::pair<const unsigned, std::vector<Node> >& p : d_type_vars)
284 : : {
285 : 0 : bool foundNotFv = false;
286 [ - - ]: 0 : for (const Node& v : p.second)
287 : : {
288 : 0 : bool hasFv = std::find(fvs.begin(), fvs.end(), v) != fvs.end();
289 [ - - ]: 0 : if (!hasFv)
290 : : {
291 : 0 : foundNotFv = true;
292 : : }
293 [ - - ]: 0 : else if (foundNotFv)
294 : : {
295 : 0 : return false;
296 : : }
297 : : }
298 : : }
299 : 0 : return true;
300 : 0 : }
301 : :
302 : 2 : void SygusSampler::computeFreeVariables(Node n, std::vector<Node>& fvs)
303 : : {
304 : 2 : std::unordered_set<TNode> visited;
305 : 2 : std::unordered_set<TNode>::iterator it;
306 : 2 : std::vector<TNode> visit;
307 : 2 : TNode cur;
308 : 2 : visit.push_back(n);
309 : : do
310 : : {
311 : 4 : cur = visit.back();
312 : 4 : visit.pop_back();
313 [ + - ]: 4 : if (visited.find(cur) == visited.end())
314 : : {
315 : 4 : visited.insert(cur);
316 [ + + ]: 4 : if (cur.isVar())
317 : : {
318 [ + - ]: 3 : if (d_var_index.find(cur) != d_var_index.end())
319 : : {
320 : 3 : fvs.push_back(cur);
321 : : }
322 : : }
323 [ + + ]: 6 : for (const Node& cn : cur)
324 : : {
325 : 2 : visit.push_back(cn);
326 : 2 : }
327 : : }
328 [ + + ]: 4 : } while (!visit.empty());
329 : 2 : }
330 : :
331 : 0 : bool SygusSampler::isOrdered(Node n) { return checkVariables(n, true, false); }
332 : :
333 : 0 : bool SygusSampler::isLinear(Node n) { return checkVariables(n, false, true); }
334 : :
335 : 6 : bool SygusSampler::checkVariables(Node n, bool checkOrder, bool checkLinear)
336 : : {
337 : : // compute free variables in n for each type
338 : 6 : std::map<unsigned, std::vector<Node> > fvs;
339 : :
340 : 6 : std::unordered_set<TNode> visited;
341 : 6 : std::unordered_set<TNode>::iterator it;
342 : 6 : std::vector<TNode> visit;
343 : 6 : TNode cur;
344 : 6 : visit.push_back(n);
345 : : do
346 : : {
347 : 11 : cur = visit.back();
348 : 11 : visit.pop_back();
349 [ + + ]: 11 : if (visited.find(cur) == visited.end())
350 : : {
351 : 10 : visited.insert(cur);
352 [ + + ]: 10 : if (cur.isVar())
353 : : {
354 : 6 : std::map<Node, unsigned>::iterator itv = d_var_index.find(cur);
355 [ + - ]: 6 : if (itv != d_var_index.end())
356 : : {
357 [ + - ]: 6 : if (checkOrder)
358 : : {
359 : 6 : unsigned tnid = d_type_ids[cur];
360 : : // if this variable is out of order
361 [ + + ]: 6 : if (itv->second != fvs[tnid].size())
362 : : {
363 : 2 : return false;
364 : : }
365 : 4 : fvs[tnid].push_back(cur);
366 : : }
367 [ - + ]: 4 : if (checkLinear)
368 : : {
369 [ - - ]: 0 : if (expr::hasSubtermMulti(n, cur))
370 : : {
371 : 0 : return false;
372 : : }
373 : : }
374 : : }
375 : : }
376 [ + + ]: 14 : for (unsigned j = 0, nchildren = cur.getNumChildren(); j < nchildren; j++)
377 : : {
378 : 6 : visit.push_back(cur[(nchildren - j) - 1]);
379 : : }
380 : : }
381 [ + + ]: 9 : } while (!visit.empty());
382 : 4 : return true;
383 : 6 : }
384 : :
385 : 2 : bool SygusSampler::containsFreeVariables(Node a, Node b, bool strict)
386 : : {
387 : : // compute free variables in a
388 : 2 : std::vector<Node> fvs;
389 : 2 : computeFreeVariables(a, fvs);
390 : 2 : std::vector<Node> fv_found;
391 : :
392 : 2 : std::unordered_set<TNode> visited;
393 : 2 : std::unordered_set<TNode>::iterator it;
394 : 2 : std::vector<TNode> visit;
395 : 2 : TNode cur;
396 : 2 : visit.push_back(b);
397 : : do
398 : : {
399 : 4 : cur = visit.back();
400 : 4 : visit.pop_back();
401 [ + - ]: 4 : if (visited.find(cur) == visited.end())
402 : : {
403 : 4 : visited.insert(cur);
404 [ + + ]: 4 : if (cur.isVar())
405 : : {
406 [ + + ]: 3 : if (std::find(fvs.begin(), fvs.end(), cur) == fvs.end())
407 : : {
408 : 1 : return false;
409 : : }
410 [ + - ]: 2 : else if (strict)
411 : : {
412 [ + + ]: 2 : if (fv_found.size() + 1 == fvs.size())
413 : : {
414 : 1 : return false;
415 : : }
416 : : // cur should only be visited once
417 [ - + ][ - + ]: 1 : Assert(std::find(fv_found.begin(), fv_found.end(), cur)
[ - - ]
418 : : == fv_found.end());
419 : 1 : fv_found.push_back(cur);
420 : : }
421 : : }
422 [ + + ]: 4 : for (const Node& cn : cur)
423 : : {
424 : 2 : visit.push_back(cn);
425 : 2 : }
426 : : }
427 [ + - ]: 2 : } while (!visit.empty());
428 : 0 : return true;
429 : 2 : }
430 : :
431 : 232 : void SygusSampler::getVariables(std::vector<Node>& vars) const
432 : : {
433 : 232 : vars.insert(vars.end(), d_vars.begin(), d_vars.end());
434 : 232 : }
435 : :
436 : 1014 : const std::vector<Node>& SygusSampler::getSamplePoint(size_t index) const
437 : : {
438 [ - + ][ - + ]: 1014 : Assert(index < d_samples.size());
[ - - ]
439 : 1014 : return d_samples[index];
440 : : }
441 : :
442 : 232 : void SygusSampler::addSamplePoint(const std::vector<Node>& pt)
443 : : {
444 [ - + ][ - + ]: 232 : Assert(pt.size() == d_vars.size());
[ - - ]
445 : 232 : d_samples.push_back(pt);
446 : 232 : }
447 : :
448 : 608370 : Node SygusSampler::evaluate(Node n, unsigned index)
449 : : {
450 [ - + ][ - + ]: 608370 : Assert(index < d_samples.size());
[ - - ]
451 : : // do beta-reductions in n first
452 : 608370 : n = d_env.getRewriter()->rewrite(n);
453 : : // use efficient rewrite for substitution + rewrite
454 : 608370 : Node ev = d_env.evaluate(n, d_vars, d_samples[index], true);
455 [ - + ][ - + ]: 608370 : Assert(!ev.isNull());
[ - - ]
456 [ + - ]: 608370 : Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> ";
457 [ + - ]: 608370 : Trace("sygus-sample-ev") << ev << std::endl;
458 : 608370 : return ev;
459 : 0 : }
460 : :
461 : 0 : int SygusSampler::getDiffSamplePointIndex(Node a, Node b)
462 : : {
463 [ - - ]: 0 : for (unsigned i = 0, nsamp = d_samples.size(); i < nsamp; i++)
464 : : {
465 : 0 : Node ae = evaluate(a, i);
466 : 0 : Node be = evaluate(b, i);
467 [ - - ]: 0 : if (ae != be)
468 : : {
469 : 0 : return i;
470 : : }
471 [ - - ][ - - ]: 0 : }
472 : 0 : return -1;
473 : : }
474 : :
475 : 438212 : Node SygusSampler::getRandomValue(TypeNode tn)
476 : : {
477 : 438212 : NodeManager* nm = nodeManager();
478 [ + + ]: 438212 : if (tn.isBoolean())
479 : : {
480 : 264512 : return nm->mkConst(Random::getRandom().pickWithProb(0.5));
481 : : }
482 [ + + ]: 305956 : else if (tn.isBitVector())
483 : : {
484 : 123280 : unsigned sz = tn.getBitVectorSize();
485 : 246560 : return nm->mkConst(Sampler::pickBvUniform(sz));
486 : : }
487 [ + + ]: 182676 : else if (tn.isFloatingPoint())
488 : : {
489 : 4916 : unsigned e = tn.getFloatingPointExponentSize();
490 : 4916 : unsigned s = tn.getFloatingPointSignificandSize();
491 : 9832 : return nm->mkConst(options().quantifiers.sygusSampleFpUniform
492 [ - + ]: 9832 : ? Sampler::pickFpUniform(e, s)
493 : 9832 : : Sampler::pickFpBiased(e, s));
494 : : }
495 [ + + ][ + + ]: 177760 : else if (tn.isString() || tn.isInteger())
[ + + ]
496 : : {
497 : : // if string, determine the alphabet
498 [ + + ][ + + ]: 108076 : if (tn.isString() && d_rstring_alphabet.empty())
[ + + ]
499 : : {
500 [ + - ]: 36 : Trace("sygus-sample-str-alpha")
501 : 18 : << "Setting string alphabet..." << std::endl;
502 : 18 : std::unordered_set<unsigned> alphas;
503 : 18 : for (const std::pair<const Node, std::vector<TypeNode> >& c :
504 [ + + ]: 86 : d_const_sygus_types)
505 : : {
506 [ + + ]: 50 : if (c.first.getType().isString())
507 : : {
508 [ + - ]: 44 : Trace("sygus-sample-str-alpha")
509 : 22 : << "...have constant " << c.first << std::endl;
510 [ - + ][ - + ]: 22 : Assert(c.first.isConst());
[ - - ]
511 : 22 : std::vector<unsigned> svec = c.first.getConst<String>().getVec();
512 [ + + ]: 35 : for (unsigned ch : svec)
513 : : {
514 : 13 : alphas.insert(ch);
515 : : }
516 : 22 : }
517 : : }
518 : : // can limit to 1 extra characters beyond those in the grammar (2 if
519 : : // there are none in the grammar)
520 [ + + ]: 18 : unsigned num_fresh_char = alphas.empty() ? 2 : 1;
521 : 18 : unsigned fresh_char = 0;
522 [ + + ]: 45 : for (unsigned i = 0; i < num_fresh_char; i++)
523 : : {
524 [ + + ]: 36 : while (alphas.find(fresh_char) != alphas.end())
525 : : {
526 : 9 : fresh_char++;
527 : : }
528 : 27 : alphas.insert(fresh_char);
529 : : }
530 [ + - ]: 36 : Trace("sygus-sample-str-alpha")
531 : 0 : << "Sygus sampler: limit strings alphabet to : " << std::endl
532 : 18 : << " ";
533 [ + + ]: 58 : for (unsigned ch : alphas)
534 : : {
535 : 40 : d_rstring_alphabet.push_back(ch);
536 [ + - ]: 40 : Trace("sygus-sample-str-alpha") << " \\u" << ch;
537 : : }
538 [ + - ]: 18 : Trace("sygus-sample-str-alpha") << std::endl;
539 : 18 : }
540 : :
541 : 108076 : std::vector<unsigned> vec;
542 : 108076 : double ext_freq = .5;
543 [ + + ]: 108076 : unsigned base = tn.isString() ? d_rstring_alphabet.size() : 10;
544 [ + + ]: 215520 : while (Random::getRandom().pickWithProb(ext_freq))
545 : : {
546 : : // add a digit
547 : 107444 : unsigned digit = Random::getRandom().pick(0, base - 1);
548 [ + + ]: 107444 : if (tn.isString())
549 : : {
550 : 46691 : digit = d_rstring_alphabet[digit];
551 : : }
552 : 107444 : vec.push_back(digit);
553 : : }
554 [ + + ]: 108076 : if (tn.isString())
555 : : {
556 : 93594 : return nm->mkConst(String(vec));
557 : : }
558 [ + - ]: 61279 : else if (tn.isInteger())
559 : : {
560 : 61279 : Rational baser(base);
561 : 61279 : Rational curr(1);
562 : 61279 : std::vector<Node> sum;
563 [ + + ]: 122032 : for (unsigned j = 0, size = vec.size(); j < size; j++)
564 : : {
565 : 121506 : Node digit = nm->mkConstInt(Rational(vec[j]) * curr);
566 : 60753 : sum.push_back(digit);
567 : 60753 : curr = curr * baser;
568 : 60753 : }
569 : 61279 : Node ret;
570 [ + + ]: 61279 : if (sum.empty())
571 : : {
572 : 30448 : ret = nm->mkConstInt(Rational(0));
573 : : }
574 [ + + ]: 30831 : else if (sum.size() == 1)
575 : : {
576 : 15543 : ret = sum[0];
577 : : }
578 : : else
579 : : {
580 : 15288 : ret = nm->mkNode(Kind::ADD, sum);
581 : : }
582 : :
583 [ + + ]: 61279 : if (Random::getRandom().pickWithProb(0.5))
584 : : {
585 : : // negative
586 : 30624 : ret = nm->mkNode(Kind::NEG, ret);
587 : : }
588 : 61279 : ret = d_env.getRewriter()->rewrite(ret);
589 [ - + ][ - + ]: 61279 : Assert(ret.isConst());
[ - - ]
590 [ - + ][ - + ]: 61279 : Assert(ret.getType()==tn);
[ - - ]
591 : 61279 : return ret;
592 : 61279 : }
593 [ - + ]: 108076 : }
594 [ - + ]: 69684 : else if (tn.isReal())
595 : : {
596 : 0 : Node s = getRandomValue(nm->integerType());
597 : 0 : Node r = getRandomValue(nm->integerType());
598 : 0 : if (!s.isNull() && !r.isNull())
599 : : {
600 : 0 : Rational sr = s.getConst<Rational>();
601 : 0 : Rational rr = r.getConst<Rational>();
602 [ - - ]: 0 : if (rr.sgn() == 0)
603 : : {
604 : 0 : return nm->mkConstReal(s.getConst<Rational>());
605 : : }
606 : 0 : return nm->mkConstReal(sr / rr);
607 : 0 : }
608 [ - - ][ - - ]: 0 : }
609 : : // default: use type enumerator
610 : 69684 : unsigned counter = 0;
611 [ + + ]: 138815 : while (Random::getRandom().pickWithProb(0.5))
612 : : {
613 : 69131 : counter++;
614 : : }
615 : 69684 : Node ret = d_tenum.getEnumerateTerm(tn, counter);
616 [ + + ]: 69684 : if (ret.isNull())
617 : : {
618 : : // beyond bounds, return the first
619 : 1713 : ret = d_tenum.getEnumerateTerm(tn, 0);
620 : : }
621 : 69684 : return ret;
622 : 69684 : }
623 : :
624 : 391145 : Node SygusSampler::getSygusRandomValue(TypeNode tn,
625 : : double rchance,
626 : : double rinc,
627 : : unsigned depth)
628 : : {
629 [ - + ]: 391145 : if (!tn.isDatatype())
630 : : {
631 : 0 : return getRandomValue(tn);
632 : : }
633 : 391145 : const DType& dt = tn.getDType();
634 [ - + ]: 391145 : if (!dt.isSygus())
635 : : {
636 : 0 : return getRandomValue(tn);
637 : : }
638 [ - + ][ - + ]: 391145 : Assert(d_rvalue_cindices.find(tn) != d_rvalue_cindices.end());
[ - - ]
639 [ + - ]: 782290 : Trace("sygus-sample-grammar")
640 : 0 : << "Sygus random value " << tn << ", depth = " << depth
641 : 391145 : << ", rchance = " << rchance << std::endl;
642 : : // check if we terminate on this call
643 : : // we refuse to enumerate terms of 10+ depth as a hard limit
644 [ + + ][ - + ]: 391145 : bool terminate = Random::getRandom().pickWithProb(rchance) || depth >= 10;
645 : : // if we terminate, only nullary constructors can be chosen
646 : : std::vector<unsigned>& cindices =
647 [ + + ]: 391145 : terminate ? d_rvalue_null_cindices[tn] : d_rvalue_cindices[tn];
648 : 391145 : unsigned ncons = cindices.size();
649 : : // select a random constructor, or random value when index=ncons.
650 : 391145 : unsigned index = Random::getRandom().pick(0, ncons);
651 [ + - ]: 782290 : Trace("sygus-sample-grammar")
652 : 391145 : << "Random index 0..." << ncons << " was : " << index << std::endl;
653 [ + + ]: 391145 : if (index < ncons)
654 : : {
655 [ + - ]: 386260 : Trace("sygus-sample-grammar")
656 : 193130 : << "Recurse constructor index #" << index << std::endl;
657 : 193130 : unsigned cindex = cindices[index];
658 [ - + ][ - + ]: 193130 : Assert(cindex < dt.getNumConstructors());
[ - - ]
659 : 193130 : const DTypeConstructor& dtc = dt[cindex];
660 : : // more likely to terminate in recursive calls
661 : 193130 : double rchance_new = rchance + (1.0 - rchance) * rinc;
662 : 193130 : bool success = true;
663 : : // generate random values for all arguments
664 : 193130 : std::vector<Node> children;
665 [ + + ]: 392275 : for (size_t i = 0, nargs = dtc.getNumArgs(); i < nargs; i++)
666 : : {
667 : 199145 : TypeNode tnc = dtc.getArgType(i);
668 : 199145 : Node c = getSygusRandomValue(tnc, rchance_new, rinc, depth + 1);
669 [ - + ]: 199145 : if (c.isNull())
670 : : {
671 : 0 : success = false;
672 [ - - ]: 0 : Trace("sygus-sample-grammar") << "...fail." << std::endl;
673 : 0 : break;
674 : : }
675 [ + - ]: 398290 : Trace("sygus-sample-grammar")
676 : 199145 : << " child #" << i << " : " << c << std::endl;
677 : 199145 : children.emplace_back(c);
678 [ + - ][ + - ]: 199145 : }
679 [ + - ]: 193130 : if (success)
680 : : {
681 [ + - ]: 193130 : Trace("sygus-sample-grammar") << "utils::mkSygusTerm" << std::endl;
682 : 193130 : Node ret = datatypes::utils::mkSygusTerm(dt, cindex, children);
683 [ + - ]: 193130 : Trace("sygus-sample-grammar") << "...returned " << ret << std::endl;
684 : 193130 : ret = d_env.getRewriter()->rewrite(ret);
685 [ + - ]: 193130 : Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl;
686 : : // A rare case where we generate a non-constant value from constant
687 : : // leaves is (/ n 0).
688 [ + + ]: 193130 : if(ret.isConst())
689 : : {
690 : 191933 : return ret;
691 : : }
692 [ + + ]: 193130 : }
693 [ + + ]: 193130 : }
694 [ + - ]: 199212 : Trace("sygus-sample-grammar") << "...resort to random value" << std::endl;
695 : : // if we did not generate based on the grammar, pick a random value
696 : 398424 : return getRandomValue(dt.getSygusType());
697 : : }
698 : :
699 : : // recursion depth bounded by number of types in grammar (small)
700 : 761 : void SygusSampler::registerSygusType(TypeNode tn)
701 : : {
702 [ + + ]: 761 : if (d_rvalue_cindices.find(tn) == d_rvalue_cindices.end())
703 : : {
704 : 236 : d_rvalue_cindices[tn].clear();
705 [ - + ]: 236 : if (!tn.isDatatype())
706 : : {
707 : 0 : return;
708 : : }
709 : 236 : const DType& dt = tn.getDType();
710 [ - + ]: 236 : if (!dt.isSygus())
711 : : {
712 : 0 : return;
713 : : }
714 [ + + ]: 1203 : for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
715 : : {
716 : 967 : const DTypeConstructor& dtc = dt[i];
717 : 967 : Node sop = dtc.getSygusOp();
718 : 967 : bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end();
719 [ + + ]: 967 : if (isVar)
720 : : {
721 : : // if it is a variable, add it to the list of sygus types for that var
722 : 384 : d_var_sygus_types[sop].push_back(tn);
723 : : }
724 : : else
725 : : {
726 : : // otherwise, it is a constructor for sygus random value
727 : 583 : d_rvalue_cindices[tn].push_back(i);
728 [ + + ]: 583 : if (dtc.getNumArgs() == 0)
729 : : {
730 : 134 : d_rvalue_null_cindices[tn].push_back(i);
731 [ + + ]: 134 : if (sop.isConst())
732 : : {
733 : 87 : d_const_sygus_types[sop].push_back(tn);
734 : : }
735 : : }
736 : : }
737 : : // recurse on all subfields
738 [ + + ]: 1675 : for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++)
739 : : {
740 : 708 : TypeNode tnc = dtc.getArgType(j);
741 : 708 : registerSygusType(tnc);
742 : 708 : }
743 : 967 : }
744 : : }
745 : : }
746 : :
747 : : } // namespace quantifiers
748 : : } // namespace theory
749 : : } // namespace cvc5::internal
|