Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Andrew Reynolds, Aina Niemetz, Mathias Preiner
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Sygus inference module.
14 : : */
15 : :
16 : : #include "preprocessing/passes/sygus_inference.h"
17 : :
18 : : #include "options/quantifiers_options.h"
19 : : #include "preprocessing/assertion_pipeline.h"
20 : : #include "preprocessing/preprocessing_pass_context.h"
21 : : #include "smt/logic_exception.h"
22 : : #include "smt/solver_engine.h"
23 : : #include "theory/quantifiers/quantifiers_attributes.h"
24 : : #include "theory/quantifiers/quantifiers_preprocess.h"
25 : : #include "theory/quantifiers/sygus/sygus_utils.h"
26 : : #include "theory/rewriter.h"
27 : : #include "theory/smt_engine_subsolver.h"
28 : :
29 : : using namespace std;
30 : : using namespace cvc5::internal::kind;
31 : : using namespace cvc5::internal::theory;
32 : :
33 : : namespace cvc5::internal {
34 : : namespace preprocessing {
35 : : namespace passes {
36 : :
37 : 50594 : SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
38 : 50594 : : PreprocessingPass(preprocContext, "sygus-infer"){};
39 : :
40 : 61 : PreprocessingPassResult SygusInference::applyInternal(
41 : : AssertionPipeline* assertionsToPreprocess)
42 : : {
43 [ + - ]: 61 : Trace("sygus-infer") << "Run sygus inference..." << std::endl;
44 : 61 : std::vector<Node> funs;
45 : 61 : std::vector<Node> sols;
46 : : // see if we can successfully solve the input as a sygus problem
47 [ + + ]: 61 : if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
48 : : {
49 [ + - ]: 53 : Trace("sygus-infer") << "...Solved:" << std::endl;
50 [ - + ][ - + ]: 53 : Assert(funs.size() == sols.size());
[ - - ]
51 : : // if so, sygus gives us function definitions, which we add as substitutions
52 [ + + ]: 164 : for (unsigned i = 0, size = funs.size(); i < size; i++)
53 : : {
54 [ + - ]: 111 : Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
55 : 111 : d_preprocContext->addSubstitution(funs[i], sols[i]);
56 : : }
57 : :
58 : : // apply substitution to everything, should result in SAT
59 [ + + ]: 205 : for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
60 : : i++)
61 : : {
62 : 152 : Node prev = (*assertionsToPreprocess)[i];
63 : : Node curr =
64 : 152 : prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
65 [ + + ]: 152 : if (curr != prev)
66 : : {
67 : 99 : curr = rewrite(curr);
68 [ + - ]: 198 : Trace("sygus-infer-debug")
69 : 99 : << "...rewrote " << prev << " to " << curr << std::endl;
70 : 99 : assertionsToPreprocess->replace(i, curr);
71 : : }
72 : 152 : }
73 : : }
74 : 8 : else if (options().quantifiers.sygusInference
75 [ - + ]: 8 : == options::SygusInferenceMode::ON)
76 : : {
77 : 0 : std::stringstream ss;
78 : 0 : ss << "Cannot translate input to sygus for --sygus-inference";
79 : 0 : throw LogicException(ss.str());
80 : 0 : }
81 : 61 : return PreprocessingPassResult::NO_CONFLICT;
82 : 61 : }
83 : :
84 : 61 : bool SygusInference::solveSygus(const std::vector<Node>& assertions,
85 : : std::vector<Node>& funs,
86 : : std::vector<Node>& sols)
87 : : {
88 [ - + ]: 61 : if (assertions.empty())
89 : : {
90 [ - - ]: 0 : Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
91 [ - - ]: 0 : Warning() << "Cannot convert to sygus since there are no assertions."
92 : 0 : << std::endl;
93 : 0 : return false;
94 : : }
95 : :
96 : 61 : NodeManager* nm = nodeManager();
97 : :
98 : : // collect free variables in all assertions
99 : 61 : std::vector<Node> qvars;
100 : 61 : std::map<TypeNode, std::vector<Node> > qtvars;
101 : 61 : std::vector<Node> free_functions;
102 : :
103 : 61 : std::vector<TNode> visit;
104 : 61 : std::unordered_set<TNode> visited;
105 : :
106 : : // add top-level conjuncts to eassertions
107 : 61 : std::vector<Node> assertions_proc = assertions;
108 : 61 : std::vector<Node> eassertions;
109 : 61 : unsigned index = 0;
110 [ + + ]: 241 : while (index < assertions_proc.size())
111 : : {
112 : 180 : Node ca = assertions_proc[index];
113 [ + + ]: 180 : if (ca.getKind() == Kind::AND)
114 : : {
115 [ + + ]: 6 : for (const Node& ai : ca)
116 : : {
117 : 4 : assertions_proc.push_back(ai);
118 : 4 : }
119 : : }
120 : : else
121 : : {
122 : 178 : eassertions.push_back(ca);
123 : : }
124 : 180 : index++;
125 : 180 : }
126 : :
127 : : // process eassertions
128 : 61 : std::vector<Node> processed_assertions;
129 : 61 : quantifiers::QuantifiersPreprocess qp(d_env);
130 [ + + ]: 237 : for (const Node& as : eassertions)
131 : : {
132 : : // substitution for this assertion
133 : 177 : std::vector<Node> vars;
134 : 177 : std::vector<Node> subs;
135 : 177 : std::map<TypeNode, unsigned> type_count;
136 : 177 : Node pas = as;
137 : : // rewrite
138 : 177 : pas = rewrite(pas);
139 [ + - ]: 177 : Trace("sygus-infer") << "assertion : " << pas << std::endl;
140 [ + + ]: 177 : if (pas.getKind() == Kind::FORALL)
141 : : {
142 : : // preprocess the quantified formula
143 : 15 : TrustNode trn = qp.preprocess(pas);
144 [ - + ]: 15 : if (!trn.isNull())
145 : : {
146 : 0 : pas = trn.getNode();
147 : : }
148 [ + - ]: 15 : Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl;
149 : 15 : }
150 [ + + ]: 177 : if (pas.getKind() == Kind::FORALL)
151 : : {
152 : : // it must be a standard quantifier
153 : 15 : theory::quantifiers::QAttributes qa;
154 : 15 : theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
155 [ - + ]: 15 : if (!qa.isStandard())
156 : : {
157 [ - - ]: 0 : Trace("sygus-infer")
158 : 0 : << "...fail: non-standard top-level quantifier." << std::endl;
159 [ - - ]: 0 : Warning() << "Cannot convert to sygus since there is a non-standard "
160 : 0 : "top-level quantified formula: "
161 : 0 : << pas << std::endl;
162 : 0 : return false;
163 : : }
164 : : // infer prefix
165 [ + + ]: 37 : for (const Node& v : pas[0])
166 : : {
167 : 22 : TypeNode tnv = v.getType();
168 : 22 : unsigned vnum = type_count[tnv];
169 : 22 : type_count[tnv]++;
170 : 22 : vars.push_back(v);
171 [ + + ]: 22 : if (vnum < qtvars[tnv].size())
172 : : {
173 : 3 : subs.push_back(qtvars[tnv][vnum]);
174 : : }
175 : : else
176 : : {
177 [ - + ][ - + ]: 19 : Assert(vnum == qtvars[tnv].size());
[ - - ]
178 : 19 : Node bv = NodeManager::mkBoundVar(tnv);
179 : 19 : qtvars[tnv].push_back(bv);
180 : 19 : qvars.push_back(bv);
181 : 19 : subs.push_back(bv);
182 : 19 : }
183 : 37 : }
184 : 15 : pas = pas[1];
185 [ + - ]: 15 : if (!vars.empty())
186 : : {
187 : : pas =
188 : 15 : pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
189 : : }
190 [ + - ]: 15 : }
191 [ + - ]: 177 : Trace("sygus-infer-debug") << " ...substituted to " << pas << std::endl;
192 : :
193 : : // collect free functions, ensure no quantified formulas
194 : 177 : TNode cur = pas;
195 : : // compute free variables
196 : 177 : visit.push_back(cur);
197 : : do
198 : : {
199 : 1148 : cur = visit.back();
200 : 1148 : visit.pop_back();
201 [ + + ]: 1148 : if (visited.find(cur) == visited.end())
202 : : {
203 : 871 : visited.insert(cur);
204 [ + + ]: 871 : if (cur.getKind() == Kind::APPLY_UF)
205 : : {
206 : 45 : Node op = cur.getOperator();
207 : : // visit the operator, which might not be a variable
208 : 45 : visit.push_back(op);
209 : 45 : }
210 [ + + ][ + + ]: 826 : else if (cur.isVar() && cur.getKind() != Kind::BOUND_VARIABLE)
[ + + ]
211 : : {
212 : : // We are either in the case of a free first-order constant or a
213 : : // function in a higher-order context. We add to free_functions
214 : : // in either case. Note that a free constant that is not in a
215 : : // higher-order context is a 0-argument function-to-synthesize.
216 : : // We should not have traversed here before due to our visited cache.
217 [ - + ][ - + ]: 130 : Assert(std::find(free_functions.begin(), free_functions.end(), cur)
[ - - ]
218 : : == free_functions.end());
219 : 130 : free_functions.push_back(cur);
220 : : }
221 [ + + ]: 696 : else if (cur.isClosure())
222 : : {
223 [ + - ]: 2 : Trace("sygus-infer")
224 : 1 : << "...fail: non-top-level quantifier." << std::endl;
225 [ + - ]: 1 : Warning() << "Cannot convert to sygus since there is a non-top-level "
226 : 0 : "quantified formula: "
227 : 1 : << cur << std::endl;
228 : 1 : return false;
229 : : }
230 [ + + ]: 1797 : for (const TNode& cn : cur)
231 : : {
232 : 927 : visit.push_back(cn);
233 : 927 : }
234 : : }
235 [ + + ]: 1147 : } while (!visit.empty());
236 : 176 : processed_assertions.push_back(pas);
237 [ + + ][ + + ]: 181 : }
[ + + ][ + + ]
[ + + ]
238 : :
239 : : // no functions to synthesize
240 [ - + ]: 60 : if (free_functions.empty())
241 : : {
242 [ - - ]: 0 : Warning()
243 : 0 : << "Cannot convert to sygus since there are no free function symbols."
244 : 0 : << std::endl;
245 [ - - ]: 0 : Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
246 : 0 : return false;
247 : : }
248 : :
249 : : // Note that we do not restrict based on the types of free functions here,
250 : : // i.e. we assume that all types are handled in sygus grammar construction.
251 : :
252 [ - + ][ - + ]: 60 : Assert(!processed_assertions.empty());
[ - - ]
253 : : // conjunction of the assertions
254 [ + - ]: 60 : Trace("sygus-infer") << "Construct body..." << std::endl;
255 : 60 : Node body;
256 [ - + ]: 60 : if (processed_assertions.size() == 1)
257 : : {
258 : 0 : body = processed_assertions[0];
259 : : }
260 : : else
261 : : {
262 : 60 : body = nm->mkNode(Kind::AND, processed_assertions);
263 : : }
264 : :
265 : : // for each free function symbol, make a bound variable of the same type
266 [ + - ]: 60 : Trace("sygus-infer") << "Do free function substitution..." << std::endl;
267 : 60 : std::vector<Node> ff_vars;
268 : 60 : std::map<Node, Node> ff_var_to_ff;
269 [ + + ]: 190 : for (const Node& ff : free_functions)
270 : : {
271 : 130 : Node ffv = NodeManager::mkBoundVar(ff.getType());
272 : 130 : ff_vars.push_back(ffv);
273 [ + - ]: 130 : Trace("sygus-infer") << " synth-fun: " << ff << " as " << ffv << std::endl;
274 : 130 : ff_var_to_ff[ffv] = ff;
275 : 130 : }
276 : : // substitute free functions -> variables
277 : 120 : body = body.substitute(free_functions.begin(),
278 : : free_functions.end(),
279 : : ff_vars.begin(),
280 : 60 : ff_vars.end());
281 [ + - ]: 60 : Trace("sygus-infer-debug") << "...got : " << body << std::endl;
282 : :
283 : : // quantify the body
284 [ + - ]: 60 : Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
285 : 60 : body = body.negate();
286 [ + + ]: 60 : if (!qvars.empty())
287 : : {
288 : 12 : Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, qvars);
289 : 12 : body = nm->mkNode(Kind::EXISTS, bvl, body);
290 : 12 : }
291 : :
292 : : // sygus attribute to mark the conjecture as a sygus conjecture
293 [ + - ]: 60 : Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
294 : :
295 : : body =
296 : 60 : quantifiers::SygusUtils::mkSygusConjecture(nodeManager(), ff_vars, body);
297 : :
298 [ + - ]: 60 : Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
299 : :
300 : : // make a separate smt call
301 : 60 : std::unique_ptr<SolverEngine> rrSygus;
302 : 60 : theory::initializeSubsolver(rrSygus, d_env);
303 : 60 : rrSygus->assertFormula(body);
304 [ + - ]: 60 : Trace("sygus-infer") << "*** Check sat..." << std::endl;
305 : 60 : Result r = rrSygus->checkSat();
306 [ + - ]: 60 : Trace("sygus-infer") << "...result : " << r << std::endl;
307 : : // get the synthesis solutions
308 : 60 : std::map<Node, Node> synth_sols;
309 [ + + ]: 60 : if (!rrSygus->getSubsolverSynthSolutions(synth_sols))
310 : : {
311 : : // failed, conjecture was infeasible
312 [ - + ]: 7 : if (options().quantifiers.sygusInference == options::SygusInferenceMode::ON)
313 : : {
314 : 0 : std::stringstream ss;
315 : : ss << "Translated to sygus, but failed to show problem to be satisfiable "
316 : 0 : "with --sygus-inference.";
317 : 0 : throw LogicException(ss.str());
318 : 0 : }
319 : 7 : return false;
320 : : }
321 : :
322 : 53 : std::vector<Node> final_ff;
323 : 53 : std::vector<Node> final_ff_sol;
324 : 53 : for (std::map<Node, Node>::iterator it = synth_sols.begin();
325 [ + + ]: 164 : it != synth_sols.end();
326 : 111 : ++it)
327 : : {
328 [ + - ]: 222 : Trace("sygus-infer") << " synth sol : " << it->first << " -> "
329 : 111 : << it->second << std::endl;
330 : 111 : Node ffv = it->first;
331 : 111 : std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
332 : : // all synthesis solutions should correspond to a variable we introduced
333 [ - + ][ - + ]: 111 : Assert(itffv != ff_var_to_ff.end());
[ - - ]
334 [ + - ]: 111 : if (itffv != ff_var_to_ff.end())
335 : : {
336 : 111 : Node ff = itffv->second;
337 : 111 : Node body2 = it->second;
338 [ + - ]: 111 : Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
339 : 111 : funs.push_back(ff);
340 : 111 : sols.push_back(body2);
341 : 111 : }
342 : 111 : }
343 : 53 : return true;
344 : 61 : }
345 : :
346 : :
347 : : } // namespace passes
348 : : } // namespace preprocessing
349 : : } // namespace cvc5::internal
|