Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Sygus inference module.
11 : : */
12 : :
13 : : #include "preprocessing/passes/sygus_inference.h"
14 : :
15 : : #include "options/quantifiers_options.h"
16 : : #include "preprocessing/assertion_pipeline.h"
17 : : #include "preprocessing/preprocessing_pass_context.h"
18 : : #include "smt/logic_exception.h"
19 : : #include "smt/solver_engine.h"
20 : : #include "theory/quantifiers/quantifiers_attributes.h"
21 : : #include "theory/quantifiers/quantifiers_preprocess.h"
22 : : #include "theory/quantifiers/sygus/sygus_utils.h"
23 : : #include "theory/rewriter.h"
24 : : #include "theory/smt_engine_subsolver.h"
25 : :
26 : : using namespace std;
27 : : using namespace cvc5::internal::kind;
28 : : using namespace cvc5::internal::theory;
29 : :
30 : : namespace cvc5::internal {
31 : : namespace preprocessing {
32 : : namespace passes {
33 : :
34 : 51756 : SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
35 : 51756 : : PreprocessingPass(preprocContext, "sygus-infer") {};
36 : :
37 : 61 : PreprocessingPassResult SygusInference::applyInternal(
38 : : AssertionPipeline* assertionsToPreprocess)
39 : : {
40 [ + - ]: 61 : Trace("sygus-infer") << "Run sygus inference..." << std::endl;
41 : 61 : std::vector<Node> funs;
42 : 61 : std::vector<Node> sols;
43 : : // see if we can successfully solve the input as a sygus problem
44 [ + + ]: 61 : if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
45 : : {
46 [ + - ]: 53 : Trace("sygus-infer") << "...Solved:" << std::endl;
47 [ - + ][ - + ]: 53 : Assert(funs.size() == sols.size());
[ - - ]
48 : : // if so, sygus gives us function definitions, which we add as substitutions
49 [ + + ]: 164 : for (unsigned i = 0, size = funs.size(); i < size; i++)
50 : : {
51 [ + - ]: 111 : Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
52 : 111 : d_preprocContext->addSubstitution(funs[i], sols[i]);
53 : : }
54 : :
55 : : // apply substitution to everything, should result in SAT
56 [ + + ]: 205 : for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
57 : : i++)
58 : : {
59 : 152 : Node prev = (*assertionsToPreprocess)[i];
60 : : Node curr =
61 : 152 : prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
62 [ + + ]: 152 : if (curr != prev)
63 : : {
64 : 99 : curr = rewrite(curr);
65 [ + - ]: 198 : Trace("sygus-infer-debug")
66 : 99 : << "...rewrote " << prev << " to " << curr << std::endl;
67 : 99 : assertionsToPreprocess->replace(i, curr);
68 : : }
69 : 152 : }
70 : : }
71 : 8 : else if (options().quantifiers.sygusInference
72 [ - + ]: 8 : == options::SygusInferenceMode::ON)
73 : : {
74 : 0 : std::stringstream ss;
75 : 0 : ss << "Cannot translate input to sygus for --sygus-inference";
76 : 0 : throw LogicException(ss.str());
77 : 0 : }
78 : 61 : return PreprocessingPassResult::NO_CONFLICT;
79 : 61 : }
80 : :
81 : 61 : bool SygusInference::solveSygus(const std::vector<Node>& assertions,
82 : : std::vector<Node>& funs,
83 : : std::vector<Node>& sols)
84 : : {
85 [ - + ]: 61 : if (assertions.empty())
86 : : {
87 [ - - ]: 0 : Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
88 [ - - ]: 0 : Warning() << "Cannot convert to sygus since there are no assertions."
89 : 0 : << std::endl;
90 : 0 : return false;
91 : : }
92 : :
93 : 61 : NodeManager* nm = nodeManager();
94 : :
95 : : // collect free variables in all assertions
96 : 61 : std::vector<Node> qvars;
97 : 61 : std::map<TypeNode, std::vector<Node> > qtvars;
98 : 61 : std::vector<Node> free_functions;
99 : :
100 : 61 : std::vector<TNode> visit;
101 : 61 : std::unordered_set<TNode> visited;
102 : :
103 : : // add top-level conjuncts to eassertions
104 : 61 : std::vector<Node> assertions_proc = assertions;
105 : 61 : std::vector<Node> eassertions;
106 : 61 : unsigned index = 0;
107 [ + + ]: 241 : while (index < assertions_proc.size())
108 : : {
109 : 180 : Node ca = assertions_proc[index];
110 [ + + ]: 180 : if (ca.getKind() == Kind::AND)
111 : : {
112 [ + + ]: 6 : for (const Node& ai : ca)
113 : : {
114 : 4 : assertions_proc.push_back(ai);
115 : 4 : }
116 : : }
117 : : else
118 : : {
119 : 178 : eassertions.push_back(ca);
120 : : }
121 : 180 : index++;
122 : 180 : }
123 : :
124 : : // process eassertions
125 : 61 : std::vector<Node> processed_assertions;
126 : 61 : quantifiers::QuantifiersPreprocess qp(d_env);
127 [ + + ]: 237 : for (const Node& as : eassertions)
128 : : {
129 : : // substitution for this assertion
130 : 177 : std::vector<Node> vars;
131 : 177 : std::vector<Node> subs;
132 : 177 : std::map<TypeNode, unsigned> type_count;
133 : 177 : Node pas = as;
134 : : // rewrite
135 : 177 : pas = rewrite(pas);
136 [ + - ]: 177 : Trace("sygus-infer") << "assertion : " << pas << std::endl;
137 [ + + ]: 177 : if (pas.getKind() == Kind::FORALL)
138 : : {
139 : : // preprocess the quantified formula
140 : 15 : TrustNode trn = qp.preprocess(pas);
141 [ - + ]: 15 : if (!trn.isNull())
142 : : {
143 : 0 : pas = trn.getNode();
144 : : }
145 [ + - ]: 15 : Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl;
146 : 15 : }
147 [ + + ]: 177 : if (pas.getKind() == Kind::FORALL)
148 : : {
149 : : // it must be a standard quantifier
150 : 15 : theory::quantifiers::QAttributes qa;
151 : 15 : theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
152 [ - + ]: 15 : if (!qa.isStandard())
153 : : {
154 [ - - ]: 0 : Trace("sygus-infer")
155 : 0 : << "...fail: non-standard top-level quantifier." << std::endl;
156 [ - - ]: 0 : Warning() << "Cannot convert to sygus since there is a non-standard "
157 : 0 : "top-level quantified formula: "
158 : 0 : << pas << std::endl;
159 : 0 : return false;
160 : : }
161 : : // infer prefix
162 [ + + ]: 37 : for (const Node& v : pas[0])
163 : : {
164 : 22 : TypeNode tnv = v.getType();
165 : 22 : unsigned vnum = type_count[tnv];
166 : 22 : type_count[tnv]++;
167 : 22 : vars.push_back(v);
168 [ + + ]: 22 : if (vnum < qtvars[tnv].size())
169 : : {
170 : 3 : subs.push_back(qtvars[tnv][vnum]);
171 : : }
172 : : else
173 : : {
174 [ - + ][ - + ]: 19 : Assert(vnum == qtvars[tnv].size());
[ - - ]
175 : 19 : Node bv = NodeManager::mkBoundVar(tnv);
176 : 19 : qtvars[tnv].push_back(bv);
177 : 19 : qvars.push_back(bv);
178 : 19 : subs.push_back(bv);
179 : 19 : }
180 : 37 : }
181 : 15 : pas = pas[1];
182 [ + - ]: 15 : if (!vars.empty())
183 : : {
184 : : pas =
185 : 15 : pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
186 : : }
187 [ + - ]: 15 : }
188 [ + - ]: 177 : Trace("sygus-infer-debug") << " ...substituted to " << pas << std::endl;
189 : :
190 : : // collect free functions, ensure no quantified formulas
191 : 177 : TNode cur = pas;
192 : : // compute free variables
193 : 177 : visit.push_back(cur);
194 : : do
195 : : {
196 : 1150 : cur = visit.back();
197 : 1150 : visit.pop_back();
198 [ + + ]: 1150 : if (visited.find(cur) == visited.end())
199 : : {
200 : 873 : visited.insert(cur);
201 [ + + ]: 873 : if (cur.getKind() == Kind::APPLY_UF)
202 : : {
203 : 45 : Node op = cur.getOperator();
204 : : // visit the operator, which might not be a variable
205 : 45 : visit.push_back(op);
206 : 45 : }
207 [ + + ][ + + ]: 828 : else if (cur.isVar() && cur.getKind() != Kind::BOUND_VARIABLE)
[ + + ]
208 : : {
209 : : // We are either in the case of a free first-order constant or a
210 : : // function in a higher-order context. We add to free_functions
211 : : // in either case. Note that a free constant that is not in a
212 : : // higher-order context is a 0-argument function-to-synthesize.
213 : : // We should not have traversed here before due to our visited cache.
214 [ - + ][ - + ]: 130 : Assert(std::find(free_functions.begin(), free_functions.end(), cur)
[ - - ]
215 : : == free_functions.end());
216 : 130 : free_functions.push_back(cur);
217 : : }
218 [ + + ]: 698 : else if (cur.isClosure())
219 : : {
220 [ + - ]: 2 : Trace("sygus-infer")
221 : 1 : << "...fail: non-top-level quantifier." << std::endl;
222 [ + - ]: 1 : Warning() << "Cannot convert to sygus since there is a non-top-level "
223 : 0 : "quantified formula: "
224 : 1 : << cur << std::endl;
225 : 1 : return false;
226 : : }
227 [ + + ]: 1801 : for (const TNode& cn : cur)
228 : : {
229 : 929 : visit.push_back(cn);
230 : 929 : }
231 : : }
232 [ + + ]: 1149 : } while (!visit.empty());
233 : 176 : processed_assertions.push_back(pas);
234 [ + + ][ + + ]: 181 : }
[ + + ][ + + ]
[ + + ]
235 : :
236 : : // no functions to synthesize
237 [ - + ]: 60 : if (free_functions.empty())
238 : : {
239 [ - - ]: 0 : Warning()
240 : 0 : << "Cannot convert to sygus since there are no free function symbols."
241 : 0 : << std::endl;
242 [ - - ]: 0 : Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
243 : 0 : return false;
244 : : }
245 : :
246 : : // Note that we do not restrict based on the types of free functions here,
247 : : // i.e. we assume that all types are handled in sygus grammar construction.
248 : :
249 [ - + ][ - + ]: 60 : Assert(!processed_assertions.empty());
[ - - ]
250 : : // conjunction of the assertions
251 [ + - ]: 60 : Trace("sygus-infer") << "Construct body..." << std::endl;
252 : 60 : Node body;
253 [ - + ]: 60 : if (processed_assertions.size() == 1)
254 : : {
255 : 0 : body = processed_assertions[0];
256 : : }
257 : : else
258 : : {
259 : 60 : body = nm->mkNode(Kind::AND, processed_assertions);
260 : : }
261 : :
262 : : // for each free function symbol, make a bound variable of the same type
263 [ + - ]: 60 : Trace("sygus-infer") << "Do free function substitution..." << std::endl;
264 : 60 : std::vector<Node> ff_vars;
265 : 60 : std::map<Node, Node> ff_var_to_ff;
266 [ + + ]: 190 : for (const Node& ff : free_functions)
267 : : {
268 : 130 : Node ffv = NodeManager::mkBoundVar(ff.getType());
269 : 130 : ff_vars.push_back(ffv);
270 [ + - ]: 130 : Trace("sygus-infer") << " synth-fun: " << ff << " as " << ffv << std::endl;
271 : 130 : ff_var_to_ff[ffv] = ff;
272 : 130 : }
273 : : // substitute free functions -> variables
274 : 120 : body = body.substitute(free_functions.begin(),
275 : : free_functions.end(),
276 : : ff_vars.begin(),
277 : 60 : ff_vars.end());
278 [ + - ]: 60 : Trace("sygus-infer-debug") << "...got : " << body << std::endl;
279 : :
280 : : // quantify the body
281 [ + - ]: 60 : Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
282 : 60 : body = body.negate();
283 [ + + ]: 60 : if (!qvars.empty())
284 : : {
285 : 12 : Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, qvars);
286 : 12 : body = nm->mkNode(Kind::EXISTS, bvl, body);
287 : 12 : }
288 : :
289 : : // sygus attribute to mark the conjecture as a sygus conjecture
290 [ + - ]: 60 : Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
291 : :
292 : : body =
293 : 60 : quantifiers::SygusUtils::mkSygusConjecture(nodeManager(), ff_vars, body);
294 : :
295 [ + - ]: 60 : Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
296 : :
297 : : // make a separate smt call
298 : 60 : std::unique_ptr<SolverEngine> rrSygus;
299 : 60 : theory::initializeSubsolver(rrSygus, d_env);
300 : 60 : rrSygus->assertFormula(body);
301 [ + - ]: 60 : Trace("sygus-infer") << "*** Check sat..." << std::endl;
302 : 60 : Result r = rrSygus->checkSat();
303 [ + - ]: 60 : Trace("sygus-infer") << "...result : " << r << std::endl;
304 : : // get the synthesis solutions
305 : 60 : std::map<Node, Node> synth_sols;
306 [ + + ]: 60 : if (!rrSygus->getSubsolverSynthSolutions(synth_sols))
307 : : {
308 : : // failed, conjecture was infeasible
309 [ - + ]: 7 : if (options().quantifiers.sygusInference == options::SygusInferenceMode::ON)
310 : : {
311 : 0 : std::stringstream ss;
312 : : ss << "Translated to sygus, but failed to show problem to be satisfiable "
313 : 0 : "with --sygus-inference.";
314 : 0 : throw LogicException(ss.str());
315 : 0 : }
316 : 7 : return false;
317 : : }
318 : :
319 : 53 : std::vector<Node> final_ff;
320 : 53 : std::vector<Node> final_ff_sol;
321 : 53 : for (std::map<Node, Node>::iterator it = synth_sols.begin();
322 [ + + ]: 164 : it != synth_sols.end();
323 : 111 : ++it)
324 : : {
325 [ + - ]: 222 : Trace("sygus-infer") << " synth sol : " << it->first << " -> "
326 : 111 : << it->second << std::endl;
327 : 111 : Node ffv = it->first;
328 : 111 : std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
329 : : // all synthesis solutions should correspond to a variable we introduced
330 [ - + ][ - + ]: 111 : Assert(itffv != ff_var_to_ff.end());
[ - - ]
331 [ + - ]: 111 : if (itffv != ff_var_to_ff.end())
332 : : {
333 : 111 : Node ff = itffv->second;
334 : 111 : Node body2 = it->second;
335 [ + - ]: 111 : Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
336 : 111 : funs.push_back(ff);
337 : 111 : sols.push_back(body2);
338 : 111 : }
339 : 111 : }
340 : 53 : return true;
341 : 61 : }
342 : :
343 : : } // namespace passes
344 : : } // namespace preprocessing
345 : : } // namespace cvc5::internal
|