Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * A class for augmenting model-based instantiations via fast sygus enumeration.
11 : : */
12 : :
13 : : #include "theory/quantifiers/mbqi_enum.h"
14 : :
15 : : #include "expr/node_algorithm.h"
16 : : #include "expr/skolem_manager.h"
17 : : #include "printer/smt2/smt2_printer.h"
18 : : #include "theory/datatypes/sygus_datatype_utils.h"
19 : : #include "theory/quantifiers/inst_strategy_mbqi.h"
20 : : #include "theory/quantifiers/sygus/sygus_enumerator.h"
21 : : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
22 : : #include "theory/smt_engine_subsolver.h"
23 : : #include "util/random.h"
24 : : #include "smt/set_defaults.h"
25 : :
26 : : namespace cvc5::internal {
27 : : namespace theory {
28 : : namespace quantifiers {
29 : :
30 : 49 : void MVarInfo::initialize(Env& env,
31 : : const Node& q,
32 : : const Node& v,
33 : : const std::vector<Node>& etrules)
34 : : {
35 : 49 : NodeManager* nm = env.getNodeManager();
36 : 49 : TypeNode tn = v.getType();
37 [ - + ][ - + ]: 49 : Assert(MQuantInfo::shouldEnumerate(tn));
[ - - ]
38 : 49 : TypeNode retType = tn;
39 : 49 : std::vector<Node> trules;
40 [ + + ]: 49 : if (tn.isFunction())
41 : : {
42 : 42 : std::vector<TypeNode> argTypes = tn.getArgTypes();
43 : 42 : retType = tn.getRangeType();
44 : 42 : std::vector<Node> vs;
45 [ + + ]: 89 : for (const TypeNode& tnc : argTypes)
46 : : {
47 : 47 : Node vc = NodeManager::mkBoundVar(tnc);
48 : 47 : vs.push_back(vc);
49 : 47 : }
50 : 42 : d_lamVars = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
51 : 42 : trules.insert(trules.end(), vs.begin(), vs.end());
52 : 42 : }
53 : : // include free symbols from body of quantified formula if applicable
54 [ + - ]: 49 : if (env.getOptions().quantifiers.mbqiEnumFreeSymsGrammar)
55 : : {
56 : 49 : std::unordered_set<Node> syms;
57 : 49 : expr::getSymbols(q[1], syms);
58 : 49 : trules.insert(trules.end(), syms.begin(), syms.end());
59 : 49 : }
60 : : // include the external terminal rules
61 [ + + ]: 63 : for (const Node& symbol : etrules)
62 : : {
63 [ + - ]: 14 : if (std::find(trules.begin(), trules.end(), symbol) == trules.end())
64 : : {
65 : 14 : trules.push_back(symbol);
66 : : }
67 : : }
68 [ + - ]: 49 : Trace("mbqi-fast-enum") << "Symbols: " << trules << std::endl;
69 : : SygusGrammarCons sgc;
70 : 49 : Node bvl;
71 : 49 : TypeNode tng = sgc.mkDefaultSygusType(env, retType, bvl, trules);
72 [ - + ]: 49 : if (TraceIsOn("mbqi-fast-enum"))
73 : : {
74 [ - - ]: 0 : Trace("mbqi-fast-enum") << "Enumerate terms for " << retType;
75 [ - - ]: 0 : if (!d_lamVars.isNull())
76 : : {
77 [ - - ]: 0 : Trace("mbqi-fast-enum") << ", variable list " << d_lamVars;
78 : : }
79 [ - - ]: 0 : Trace("mbqi-fast-enum") << std::endl;
80 [ - - ]: 0 : Trace("mbqi-fast-enum") << "Based on grammar:" << std::endl;
81 [ - - ]: 0 : Trace("mbqi-fast-enum")
82 : 0 : << printer::smt2::Smt2Printer::sygusGrammarString(tng) << std::endl;
83 : : }
84 : 49 : d_senum.reset(new SygusTermEnumerator(env, tng));
85 : 49 : }
86 : :
87 : 4995 : Node MVarInfo::getEnumeratedTerm(NodeManager* nm, size_t i)
88 : : {
89 : 4995 : size_t nullCount = 0;
90 [ + + ]: 11841 : while (i >= d_enum.size())
91 : : {
92 : 6874 : Node curr = d_senum->getCurrent();
93 [ + - ]: 6874 : Trace("mbqi-sygus-enum") << "Enumerate: " << curr << std::endl;
94 [ + + ]: 6874 : if (!curr.isNull())
95 : : {
96 [ + + ]: 2049 : if (!d_lamVars.isNull())
97 : : {
98 : 236 : curr = nm->mkNode(Kind::LAMBDA, d_lamVars, curr);
99 : : }
100 : 2049 : d_enum.push_back(curr);
101 : 2049 : nullCount = 0;
102 : : }
103 : : else
104 : : {
105 : 4825 : nullCount++;
106 [ + + ]: 4825 : if (nullCount > 100)
107 : : {
108 : : // break if we aren't making progress
109 : 28 : break;
110 : : }
111 : : }
112 [ - + ]: 6846 : if (!d_senum->incrementPartial())
113 : : {
114 : : // enumeration is finished
115 : 0 : break;
116 : : }
117 [ + + ]: 6874 : }
118 [ + + ]: 4995 : if (i >= d_enum.size())
119 : : {
120 : 28 : return Node::null();
121 : : }
122 : 4967 : return d_enum[i];
123 : : }
124 : :
125 : 63 : void MQuantInfo::initialize(Env& env, InstStrategyMbqi& parent, const Node& q)
126 : : {
127 : : // The externally provided terminal rules. This set is shared between
128 : : // all variables we instantiate.
129 : 63 : std::vector<Node> etrules;
130 [ + + ]: 154 : for (const Node& v : q[0])
131 : : {
132 : 91 : size_t index = d_vinfo.size();
133 : 91 : d_vinfo.emplace_back();
134 : 91 : TypeNode vtn = v.getType();
135 : : // if enumerated, add to list
136 [ + + ]: 91 : if (shouldEnumerate(vtn))
137 : : {
138 : 49 : d_indices.push_back(index);
139 : : }
140 : : else
141 : : {
142 : 42 : d_nindices.push_back(index);
143 : : // include variables defined in terms of others if applicable
144 [ + - ]: 42 : if (env.getOptions().quantifiers.mbqiEnumExtVarsGrammar)
145 : : {
146 : 42 : etrules.push_back(v);
147 : : }
148 : : }
149 : 154 : }
150 : : // include the global symbols if applicable
151 [ - + ]: 63 : if (env.getOptions().quantifiers.mbqiEnumGlobalSymGrammar)
152 : : {
153 : 0 : const context::CDHashSet<Node>& gsyms = parent.getGlobalSyms();
154 [ - - ]: 0 : for (const Node& v : gsyms)
155 : : {
156 : 0 : etrules.push_back(v);
157 : 0 : }
158 : : }
159 : : // initialize the variables we are instantiating
160 [ + + ]: 112 : for (size_t index : d_indices)
161 : : {
162 : 49 : d_vinfo[index].initialize(env, q, q[0][index], etrules);
163 : : }
164 : 63 : }
165 : :
166 : 183 : MVarInfo& MQuantInfo::getVarInfo(size_t index)
167 : : {
168 [ - + ][ - + ]: 183 : Assert(index < d_vinfo.size());
[ - - ]
169 : 183 : return d_vinfo[index];
170 : : }
171 : :
172 : 204 : std::vector<size_t> MQuantInfo::getInstIndices() { return d_indices; }
173 : 204 : std::vector<size_t> MQuantInfo::getNoInstIndices() { return d_nindices; }
174 : :
175 : 140 : bool MQuantInfo::shouldEnumerate(const TypeNode& tn)
176 : : {
177 [ + + ]: 140 : if (tn.isUninterpretedSort())
178 : : {
179 : 42 : return false;
180 : : }
181 : 98 : return true;
182 : : }
183 : :
184 : 353 : MbqiEnum::MbqiEnum(Env& env, InstStrategyMbqi& parent)
185 : 353 : : EnvObj(env), d_parent(parent)
186 : : {
187 : 353 : d_subOptions.copyValues(options());
188 : 353 : smt::SetDefaults::disableChecking(d_subOptions);
189 : 353 : }
190 : :
191 : 204 : MQuantInfo& MbqiEnum::getOrMkQuantInfo(const Node& q)
192 : : {
193 : 204 : auto [it, inserted] = d_qinfo.try_emplace(q);
194 [ + + ]: 204 : if (inserted)
195 : : {
196 : 63 : it->second.initialize(d_env, d_parent, q);
197 : : }
198 : 204 : return it->second;
199 : : }
200 : :
201 : 204 : bool MbqiEnum::constructInstantiation(const Node& q,
202 : : const Node& query,
203 : : const std::vector<Node>& vars,
204 : : std::vector<Node>& mvs,
205 : : const std::map<Node, Node>& mvFreshVar)
206 : : {
207 [ - + ][ - + ]: 204 : Assert(q[0].getNumChildren() == vars.size());
[ - - ]
208 [ - + ][ - + ]: 204 : Assert(vars.size() == mvs.size());
[ - - ]
209 [ - + ]: 204 : if (TraceIsOn("mbqi-model-enum"))
210 : : {
211 [ - - ]: 0 : Trace("mbqi-model-enum") << "Instantiate " << q << std::endl;
212 [ - - ]: 0 : for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
213 : : {
214 [ - - ]: 0 : Trace("mbqi-model-enum")
215 : 0 : << " " << q[0][i] << " -> " << mvs[i] << std::endl;
216 : : }
217 : : }
218 : 204 : SubsolverSetupInfo ssi(d_env, d_subOptions);
219 : 204 : MQuantInfo& qi = getOrMkQuantInfo(q);
220 : 204 : std::vector<size_t> indices = qi.getInstIndices();
221 : 204 : std::vector<size_t> nindices = qi.getNoInstIndices();
222 : 204 : Subs inst;
223 : 204 : Subs vinst;
224 : 204 : std::unordered_map<Node, Node> tmpCMap;
225 [ + + ]: 260 : for (size_t i : nindices)
226 : : {
227 : 56 : Node v = mvs[i];
228 : 56 : v = d_parent.convertFromModel(v, tmpCMap, mvFreshVar);
229 [ - + ]: 56 : if (v.isNull())
230 : : {
231 : 0 : return false;
232 : : }
233 [ + - ]: 112 : Trace("mbqi-model-enum")
234 : 56 : << "* Assume: " << q[0][i] << " -> " << v << std::endl;
235 : : // if we don't enumerate it, we are already considering this instantiation
236 : 56 : inst.add(vars[i], v);
237 : 56 : vinst.add(q[0][i], v);
238 [ + - ]: 56 : }
239 : 204 : Node queryCurr = query;
240 [ + - ]: 204 : Trace("mbqi-model-enum") << "...query is " << queryCurr << std::endl;
241 : 204 : queryCurr = rewrite(inst.apply(queryCurr));
242 [ + - ]: 204 : Trace("mbqi-model-enum") << "...processed is " << queryCurr << std::endl;
243 : : // consider variables in random order, for diversity of instantiations
244 : 204 : std::shuffle(indices.begin(), indices.end(), Random::getRandom());
245 [ + + ]: 387 : for (size_t i = 0, isize = indices.size(); i < isize; i++)
246 : : {
247 : 183 : size_t ii = indices[i];
248 : 183 : TNode v = vars[ii];
249 : 183 : MVarInfo& vi = qi.getVarInfo(ii);
250 : 183 : size_t cindex = 0;
251 : 183 : bool success = false;
252 : : bool successEnum;
253 : : do
254 : : {
255 : 4995 : Node ret = vi.getEnumeratedTerm(nodeManager(), cindex);
256 : 4995 : cindex++;
257 : 4995 : Node retc;
258 [ + + ]: 4995 : if (!ret.isNull())
259 : : {
260 [ + - ]: 4967 : Trace("mbqi-model-enum") << "- Try candidate: " << ret << std::endl;
261 : : // apply current substitution (to account for cases where ret has
262 : : // other variables in its grammar).
263 : 4967 : ret = vinst.apply(ret);
264 : 4967 : retc = ret;
265 : 4967 : successEnum = true;
266 : : // now convert the value
267 : 4967 : std::unordered_map<Node, Node> tmpConvertMap;
268 : 4967 : std::map<TypeNode, std::unordered_set<Node> > freshVarType;
269 : 4967 : retc = d_parent.convertToQuery(retc, tmpConvertMap, freshVarType);
270 : 4967 : }
271 : : else
272 : : {
273 [ + - ]: 56 : Trace("mbqi-model-enum")
274 : 28 : << "- Failed to enumerate candidate" << std::endl;
275 : : // if we failed to enumerate, just try the original
276 : 28 : Node mc = d_parent.convertFromModel(mvs[ii], tmpCMap, mvFreshVar);
277 [ - + ]: 28 : if (mc.isNull())
278 : : {
279 : : // if failed to convert, we fail
280 : 0 : return false;
281 : : }
282 : 28 : ret = mc;
283 : 28 : retc = mc;
284 : 28 : successEnum = false;
285 [ + - ]: 28 : }
286 [ + - ]: 9990 : Trace("mbqi-model-enum")
287 : 4995 : << "- Converted candidate: " << v << " -> " << retc << std::endl;
288 : : // see if it is still satisfiable, if still SAT, we replace
289 : 9990 : Node queryCheck = queryCurr.substitute(v, TNode(retc));
290 : 4995 : queryCheck = rewrite(queryCheck);
291 [ + - ]: 4995 : Trace("mbqi-model-enum") << "...check " << queryCheck << std::endl;
292 : 4995 : Result r = checkWithSubsolver(queryCheck, ssi);
293 [ + + ]: 4995 : if (r == Result::SAT)
294 : : {
295 : : // remember the updated query
296 : 183 : queryCurr = queryCheck;
297 [ + - ]: 183 : Trace("mbqi-model-enum") << "...success" << std::endl;
298 [ + - ]: 366 : Trace("mbqi-model-enum")
299 : 183 : << "* Enumerated " << q[0][ii] << " -> " << ret << std::endl;
300 : 183 : mvs[ii] = ret;
301 : 183 : vinst.add(q[0][ii], ret);
302 : 183 : success = true;
303 : : }
304 [ - + ]: 4812 : else if (!successEnum)
305 : : {
306 : : // we did not enumerate a candidate, and tried the original, which
307 : : // failed.
308 : 0 : return false;
309 : : }
310 [ + - ][ + - ]: 9990 : } while (!success);
[ + - ][ + - ]
[ + + ]
311 [ + - ]: 183 : }
312 : : // try the instantiation
313 : 204 : return d_parent.tryInstantiation(
314 : 204 : q, mvs, InferenceId::QUANTIFIERS_INST_MBQI_ENUM, mvFreshVar);
315 : 204 : }
316 : : } // namespace quantifiers
317 : : } // namespace theory
318 : : } // namespace cvc5::internal
|