Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * A class for augmenting model-based instantiations via fast sygus enumeration.
14 : : */
15 : :
16 : : #include "theory/quantifiers/mbqi_fast_sygus.h"
17 : :
18 : : #include "expr/node_algorithm.h"
19 : : #include "expr/skolem_manager.h"
20 : : #include "printer/smt2/smt2_printer.h"
21 : : #include "theory/datatypes/sygus_datatype_utils.h"
22 : : #include "theory/quantifiers/inst_strategy_mbqi.h"
23 : : #include "theory/quantifiers/sygus/sygus_enumerator.h"
24 : : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 : : #include "theory/smt_engine_subsolver.h"
26 : : #include "util/random.h"
27 : : #include "smt/set_defaults.h"
28 : :
29 : : namespace cvc5::internal {
30 : : namespace theory {
31 : : namespace quantifiers {
32 : :
33 : 56 : void MVarInfo::initialize(Env& env,
34 : : const Node& q,
35 : : const Node& v,
36 : : const std::vector<Node>& etrules)
37 : : {
38 : 56 : NodeManager* nm = NodeManager::currentNM();
39 : 112 : TypeNode tn = v.getType();
40 [ - + ][ - + ]: 56 : Assert(MQuantInfo::shouldEnumerate(tn));
[ - - ]
41 : 112 : TypeNode retType = tn;
42 : 112 : std::vector<Node> trules;
43 [ + + ]: 56 : if (tn.isFunction())
44 : : {
45 : 96 : std::vector<TypeNode> argTypes = tn.getArgTypes();
46 : 48 : retType = tn.getRangeType();
47 : 48 : std::vector<Node> vs;
48 [ + + ]: 102 : for (const TypeNode& tnc : argTypes)
49 : : {
50 : 108 : Node vc = nm->mkBoundVar(tnc);
51 : 54 : vs.push_back(vc);
52 : : }
53 : 48 : d_lamVars = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
54 : 48 : trules.insert(trules.end(), vs.begin(), vs.end());
55 : : }
56 : : // include free symbols from body of quantified formula if applicable
57 : 112 : std::unordered_set<Node> syms;
58 : 56 : expr::getSymbols(q[1], syms);
59 : 56 : trules.insert(trules.end(), syms.begin(), syms.end());
60 : : // include the external terminal rules
61 [ + + ]: 190 : for (const Node& symbol : etrules)
62 : : {
63 [ + + ]: 134 : if (std::find(trules.begin(), trules.end(), symbol) == trules.end())
64 : : {
65 : 16 : trules.push_back(symbol);
66 : : }
67 : : }
68 : : SygusGrammarCons sgc;
69 : 112 : Node bvl;
70 : 112 : TypeNode tng = sgc.mkDefaultSygusType(env, retType, bvl, trules);
71 [ - + ]: 56 : if (TraceIsOn("mbqi-model-enum"))
72 : : {
73 [ - - ]: 0 : Trace("mbqi-model-enum") << "Enumerate terms for " << retType;
74 [ - - ]: 0 : if (!d_lamVars.isNull())
75 : : {
76 [ - - ]: 0 : Trace("mbqi-model-enum") << ", variable list " << d_lamVars;
77 : : }
78 [ - - ]: 0 : Trace("mbqi-model-enum") << std::endl;
79 [ - - ]: 0 : Trace("mbqi-model-enum") << "Based on grammar:" << std::endl;
80 [ - - ]: 0 : Trace("mbqi-model-enum")
81 : 0 : << printer::smt2::Smt2Printer::sygusGrammarString(tng) << std::endl;
82 : : }
83 : 56 : d_senum.reset(new SygusTermEnumerator(env, tng));
84 : 56 : }
85 : :
86 : 5673 : Node MVarInfo::getEnumeratedTerm(size_t i)
87 : : {
88 : 5673 : NodeManager* nm = NodeManager::currentNM();
89 : 5673 : size_t nullCount = 0;
90 [ + + ]: 11921 : while (i >= d_enum.size())
91 : : {
92 : 6264 : Node curr = d_senum->getCurrent();
93 [ + - ]: 6264 : Trace("mbqi-sygus-enum") << "Enumerate: " << curr << std::endl;
94 [ + + ]: 6264 : if (!curr.isNull())
95 : : {
96 [ + + ]: 2351 : if (!d_lamVars.isNull())
97 : : {
98 : 279 : curr = nm->mkNode(Kind::LAMBDA, d_lamVars, curr);
99 : : }
100 : 2351 : d_enum.push_back(curr);
101 : 2351 : nullCount = 0;
102 : : }
103 : : else
104 : : {
105 : 3913 : nullCount++;
106 [ + + ]: 3913 : if (nullCount > 100)
107 : : {
108 : : // break if we aren't making progress
109 : 16 : break;
110 : : }
111 : : }
112 [ - + ]: 6248 : if (!d_senum->incrementPartial())
113 : : {
114 : : // enumeration is finished
115 : 0 : break;
116 : : }
117 : : }
118 [ + + ]: 5673 : if (i >= d_enum.size())
119 : : {
120 : 16 : return Node::null();
121 : : }
122 : 5657 : return d_enum[i];
123 : : }
124 : :
125 : 56 : void MQuantInfo::initialize(Env& env, InstStrategyMbqi& parent, const Node& q)
126 : : {
127 : : // The externally provided terminal rules. This set is shared between
128 : : // all variables we instantiate.
129 : 112 : std::vector<Node> etrules;
130 [ + + ]: 128 : for (const Node& v : q[0])
131 : : {
132 : 72 : size_t index = d_vinfo.size();
133 : 72 : d_vinfo.emplace_back();
134 : 144 : TypeNode vtn = v.getType();
135 : : // if enumerated, add to list
136 [ + + ]: 72 : if (shouldEnumerate(vtn))
137 : : {
138 : 56 : d_indices.push_back(index);
139 : : }
140 : : else
141 : : {
142 : 16 : d_nindices.push_back(index);
143 : : // Include variables defined in terms of others we are not enumerating.
144 : 16 : etrules.push_back(v);
145 : : }
146 : : }
147 : : // Get free symbols from body of quantified formula here
148 : 112 : std::unordered_set<Node> syms;
149 : 56 : expr::getSymbols(q[1], syms);
150 [ + + ]: 174 : for (const Node& symbol : syms)
151 : : {
152 [ + - ]: 118 : if (std::find(etrules.begin(), etrules.end(), symbol) == etrules.end())
153 : : {
154 : 118 : etrules.push_back(symbol);
155 : : }
156 : : }
157 [ + - ]: 56 : Trace("mbqi-model-enum") << "Terminals: " << etrules << std::endl;
158 : : // initialize the variables we are instantiating
159 [ + + ]: 112 : for (size_t index : d_indices)
160 : : {
161 : 56 : d_vinfo[index].initialize(env, q, q[0][index], etrules);
162 : : }
163 : 56 : }
164 : :
165 : 195 : MVarInfo& MQuantInfo::getVarInfo(size_t index)
166 : : {
167 [ - + ][ - + ]: 195 : Assert(index < d_vinfo.size());
[ - - ]
168 : 195 : return d_vinfo[index];
169 : : }
170 : :
171 : 195 : std::vector<size_t> MQuantInfo::getInstIndices() { return d_indices; }
172 : 195 : std::vector<size_t> MQuantInfo::getNoInstIndices() { return d_nindices; }
173 : :
174 : 128 : bool MQuantInfo::shouldEnumerate(const TypeNode& tn)
175 : : {
176 [ + + ]: 128 : if (tn.isUninterpretedSort())
177 : : {
178 : 16 : return false;
179 : : }
180 : 112 : return true;
181 : : }
182 : :
183 : 333 : MbqiFastSygus::MbqiFastSygus(Env& env, InstStrategyMbqi& parent)
184 : 333 : : EnvObj(env), d_parent(parent)
185 : : {
186 : 333 : d_subOptions.copyValues(options());
187 : 333 : smt::SetDefaults::disableChecking(d_subOptions);
188 : 333 : }
189 : :
190 : 195 : MQuantInfo& MbqiFastSygus::getOrMkQuantInfo(const Node& q)
191 : : {
192 : 195 : auto [it, inserted] = d_qinfo.try_emplace(q);
193 [ + + ]: 195 : if (inserted)
194 : : {
195 : 56 : it->second.initialize(d_env, d_parent, q);
196 : : }
197 : 195 : return it->second;
198 : : }
199 : :
200 : 195 : bool MbqiFastSygus::constructInstantiation(
201 : : const Node& q,
202 : : const Node& query,
203 : : const std::vector<Node>& vars,
204 : : std::vector<Node>& mvs,
205 : : const std::map<Node, Node>& mvFreshVar)
206 : : {
207 [ - + ][ - + ]: 195 : Assert(q[0].getNumChildren() == vars.size());
[ - - ]
208 [ - + ][ - + ]: 195 : Assert(vars.size() == mvs.size());
[ - - ]
209 [ - + ]: 195 : if (TraceIsOn("mbqi-model-enum"))
210 : : {
211 [ - - ]: 0 : Trace("mbqi-model-enum") << "Instantiate " << q << std::endl;
212 [ - - ]: 0 : for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
213 : : {
214 [ - - ]: 0 : Trace("mbqi-model-enum")
215 : 0 : << " " << q[0][i] << " -> " << mvs[i] << std::endl;
216 : : }
217 : : }
218 : 390 : SubsolverSetupInfo ssi(d_env, d_subOptions);
219 : 195 : MQuantInfo& qi = getOrMkQuantInfo(q);
220 : 390 : std::vector<size_t> indices = qi.getInstIndices();
221 : 390 : std::vector<size_t> nindices = qi.getNoInstIndices();
222 : 390 : Subs inst;
223 : 390 : Subs vinst;
224 : 390 : std::unordered_map<Node, Node> tmpCMap;
225 [ + + ]: 219 : for (size_t i : nindices)
226 : : {
227 : 24 : Node v = mvs[i];
228 : 24 : v = d_parent.convertFromModel(v, tmpCMap, mvFreshVar);
229 [ - + ]: 24 : if (v.isNull())
230 : : {
231 : 0 : return false;
232 : : }
233 [ + - ]: 48 : Trace("mbqi-model-enum")
234 : 24 : << "* Assume: " << q[0][i] << " -> " << v << std::endl;
235 : : // if we don't enumerate it, we are already considering this instantiation
236 : 24 : inst.add(vars[i], v);
237 : 24 : vinst.add(q[0][i], v);
238 : : }
239 : 390 : Node queryCurr = query;
240 [ + - ]: 195 : Trace("mbqi-model-enum") << "...query is " << queryCurr << std::endl;
241 : 195 : queryCurr = rewrite(inst.apply(queryCurr));
242 [ + - ]: 195 : Trace("mbqi-model-enum") << "...processed is " << queryCurr << std::endl;
243 : : // consider variables in random order, for diversity of instantiations
244 : 195 : std::shuffle(indices.begin(), indices.end(), Random::getRandom());
245 [ + + ]: 382 : for (size_t i = 0, isize = indices.size(); i < isize; i++)
246 : : {
247 : 195 : size_t ii = indices[i];
248 : 195 : TNode v = vars[ii];
249 : 195 : MVarInfo& vi = qi.getVarInfo(ii);
250 : 195 : size_t cindex = 0;
251 : 195 : bool success = false;
252 : : bool successEnum;
253 : 5478 : do
254 : : {
255 : 5673 : Node ret = vi.getEnumeratedTerm(cindex);
256 : 5673 : cindex++;
257 : 5673 : Node retc;
258 [ + + ]: 5673 : if (!ret.isNull())
259 : : {
260 [ + - ]: 5657 : Trace("mbqi-model-enum") << "- Try candidate: " << ret << std::endl;
261 : : // apply current substitution (to account for cases where ret has
262 : : // other variables in its grammar).
263 : 5657 : ret = vinst.apply(ret);
264 : 5657 : retc = ret;
265 : 5657 : successEnum = true;
266 : : // now convert the value
267 : 11314 : std::unordered_map<Node, Node> tmpConvertMap;
268 : 5657 : std::map<TypeNode, std::unordered_set<Node> > freshVarType;
269 : 5657 : retc = d_parent.convertToQuery(retc, tmpConvertMap, freshVarType);
270 : : }
271 : : else
272 : : {
273 [ + - ]: 32 : Trace("mbqi-model-enum")
274 : 16 : << "- Failed to enumerate candidate" << std::endl;
275 : : // if we failed to enumerate, just try the original
276 : 16 : Node mc = d_parent.convertFromModel(mvs[ii], tmpCMap, mvFreshVar);
277 [ + + ]: 16 : if (mc.isNull())
278 : : {
279 : : // if failed to convert, we fail
280 : 8 : return false;
281 : : }
282 : 8 : ret = mc;
283 : 8 : retc = mc;
284 : 8 : successEnum = false;
285 : : }
286 [ + - ]: 11330 : Trace("mbqi-model-enum")
287 : 5665 : << "- Converted candidate: " << v << " -> " << retc << std::endl;
288 : : // see if it is still satisfiable, if still SAT, we replace
289 : 11330 : Node queryCheck = queryCurr.substitute(v, TNode(retc));
290 : 5665 : queryCheck = rewrite(queryCheck);
291 [ + - ]: 5665 : Trace("mbqi-model-enum") << "...check " << queryCheck << std::endl;
292 : 5665 : Result r = checkWithSubsolver(queryCheck, ssi);
293 [ + + ]: 5665 : if (r == Result::SAT)
294 : : {
295 : : // remember the updated query
296 : 187 : queryCurr = queryCheck;
297 [ + - ]: 187 : Trace("mbqi-model-enum") << "...success" << std::endl;
298 [ + - ]: 374 : Trace("mbqi-model-enum")
299 : 187 : << "* Enumerated " << q[0][ii] << " -> " << ret << std::endl;
300 : 187 : mvs[ii] = ret;
301 : 187 : vinst.add(q[0][ii], ret);
302 : 187 : success = true;
303 : : }
304 [ - + ]: 5478 : else if (!successEnum)
305 : : {
306 : : // we did not enumerate a candidate, and tried the original, which
307 : : // failed.
308 : 0 : return false;
309 : : }
310 [ + + ]: 5665 : } while (!success);
311 : : }
312 : 187 : return true;
313 : : }
314 : : } // namespace quantifiers
315 : : } // namespace theory
316 : : } // namespace cvc5::internal
|