Branch data Line data Source code
1 : : /****************************************************************************** 2 : : * This file is part of the cvc5 project. 3 : : * 4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS 5 : : * in the top-level source directory and their institutional affiliations. 6 : : * All rights reserved. See the file COPYING in the top-level source 7 : : * directory for licensing information. 8 : : * **************************************************************************** 9 : : * 10 : : * The solver for abduction queries. 11 : : */ 12 : : 13 : : #include "smt/abduction_solver.h" 14 : : 15 : : #include <sstream> 16 : : 17 : : #include "base/modal_exception.h" 18 : : #include "options/quantifiers_options.h" 19 : : #include "options/smt_options.h" 20 : : #include "smt/env.h" 21 : : #include "smt/set_defaults.h" 22 : : #include "smt/sygus_solver.h" 23 : : #include "theory/quantifiers/quantifiers_attributes.h" 24 : : #include "theory/quantifiers/sygus/sygus_abduct.h" 25 : : #include "theory/quantifiers/sygus/sygus_utils.h" 26 : : #include "theory/smt_engine_subsolver.h" 27 : : #include "theory/trust_substitutions.h" 28 : : 29 : : using namespace cvc5::internal::theory; 30 : : 31 : : namespace cvc5::internal { 32 : : namespace smt { 33 : : 34 : 6288 : AbductionSolver::AbductionSolver(Env& env) : EnvObj(env) {} 35 : : 36 : 12516 : AbductionSolver::~AbductionSolver() {} 37 : 416 : bool AbductionSolver::getAbduct(const std::vector<Node>& axioms, 38 : : const Node& goal, 39 : : const TypeNode& grammarType, 40 : : Node& abd) 41 : : { 42 [ - + ]: 416 : if (!options().smt.produceAbducts) 43 : : { 44 : 0 : const char* msg = "Cannot get abduct when produce-abducts options is off."; 45 : 0 : throw ModalException(msg); 46 : : } 47 [ + - ]: 416 : Trace("sygus-abduct") << "Axioms: " << axioms << std::endl; 48 [ + - ]: 832 : Trace("sygus-abduct") << "SolverEngine::getAbduct: goal " << goal 49 : 416 : << std::endl; 50 : 416 : SubstitutionMap& tls = d_env.getTopLevelSubstitutions().get(); 51 : 416 : std::vector<Node> axiomsn; 52 [ + + ]: 947 : for (const Node& ax : axioms) 53 : : { 54 : 531 : axiomsn.emplace_back(tls.apply(ax)); 55 : : } 56 : 416 : std::vector<Node> asserts(axiomsn.begin(), axiomsn.end()); 57 : : // must expand definitions 58 : 416 : Node conjn = tls.apply(goal); 59 : 416 : conjn = rewrite(conjn); 60 : : // now negate 61 : 416 : conjn = conjn.negate(); 62 : 416 : d_abdConj = conjn; 63 : 416 : asserts.push_back(conjn); 64 : 416 : std::string name("__internal_abduct"); 65 : : Node aconj = quantifiers::SygusAbduct::mkAbductionConjecture( 66 : 416 : nodeManager(), name, asserts, axiomsn, grammarType); 67 : : // should be a quantified conjecture with one function-to-synthesize 68 : 416 : Assert(aconj.getKind() == Kind::FORALL && aconj[0].getNumChildren() == 1); 69 : : // remember the abduct-to-synthesize 70 : 416 : d_sssf = aconj[0][0]; 71 [ + - ]: 832 : Trace("sygus-abduct") << "SolverEngine::getAbduct: made conjecture : " 72 : 416 : << aconj << ", solving for " << d_sssf << std::endl; 73 : : 74 : 416 : Options subOptions; 75 : 416 : subOptions.copyValues(d_env.getOptions()); 76 : 416 : subOptions.write_quantifiers().sygus = true; 77 : : // by default, we don't want disjunctive terms (ITE, OR) in abducts 78 [ + - ]: 416 : if (!d_env.getOptions().quantifiers.sygusGrammarUseDisjWasSetByUser) 79 : : { 80 : 416 : subOptions.write_quantifiers().sygusGrammarUseDisj = false; 81 : : } 82 : 416 : SetDefaults::disableChecking(subOptions); 83 : 416 : SubsolverSetupInfo ssi(d_env, subOptions); 84 : : // we generate a new smt engine to do the abduction query 85 : 416 : initializeSubsolver(nodeManager(), d_subsolver, ssi); 86 : : // get the logic 87 : 416 : LogicInfo l = d_subsolver->getLogicInfo().getUnlockedCopy(); 88 : : // enable everything needed for sygus 89 : 416 : l.enableSygus(); 90 : 416 : d_subsolver->setLogic(l); 91 : : // assert the abduction query 92 : 416 : d_subsolver->assertFormula(aconj); 93 : 416 : d_axioms = axioms; 94 : 828 : return getAbductInternal(abd); 95 : 444 : } 96 : : 97 : 92 : bool AbductionSolver::getAbductNext(Node& abd) 98 : : { 99 : : // Since we are using the subsolver's check-sat interface directly, we 100 : : // simply call getAbductInternal again here. We assert that the subsolver 101 : : // is already initialized, which must be the case or else we are not in the 102 : : // proper SMT mode to make this call. Due to the default behavior of 103 : : // subsolvers having synthesis conjectures, this is guaranteed to produce 104 : : // a new solution. 105 [ - + ][ - + ]: 92 : Assert(d_subsolver != nullptr); [ - - ] 106 : 92 : return getAbductInternal(abd); 107 : : } 108 : : 109 : 508 : bool AbductionSolver::getAbductInternal(Node& abd) 110 : : { 111 : : // should have initialized the subsolver by now 112 [ - + ][ - + ]: 508 : Assert(d_subsolver != nullptr); [ - - ] 113 [ + - ]: 1016 : Trace("sygus-abduct") << " SolverEngine::getAbduct check sat..." 114 : 508 : << std::endl; 115 : 508 : Result r = d_subsolver->checkSat(); 116 [ + - ]: 1008 : Trace("sygus-abduct") << " SolverEngine::getAbduct result: " << r 117 : 504 : << std::endl; 118 : : // get the synthesis solution 119 : 504 : std::map<Node, Node> sols; 120 : : // use the "getSubsolverSynthSolutions" interface, since we asserted the 121 : : // internal form of the SyGuS conjecture and used check-sat. 122 [ + + ]: 504 : if (d_subsolver->getSubsolverSynthSolutions(sols)) 123 : : { 124 [ - + ][ - + ]: 491 : Assert(sols.size() == 1); [ - - ] 125 : 491 : std::map<Node, Node>::iterator its = sols.find(d_sssf); 126 [ + - ]: 491 : if (its != sols.end()) 127 : : { 128 [ + - ]: 982 : Trace("sygus-abduct") << "SolverEngine::getAbduct: solution is " 129 : 491 : << its->second << std::endl; 130 : 491 : abd = its->second; 131 [ + + ]: 491 : if (abd.getKind() == Kind::LAMBDA) 132 : : { 133 : 487 : abd = abd[1]; 134 : : } 135 : : // get the grammar type for the abduct 136 : : Node agdtbv = 137 : 491 : theory::quantifiers::SygusUtils::getOrMkSygusArgumentList(d_sssf); 138 [ + + ]: 491 : if (!agdtbv.isNull()) 139 : : { 140 [ - + ][ - + ]: 487 : Assert(agdtbv.getKind() == Kind::BOUND_VAR_LIST); [ - - ] 141 : : // convert back to original 142 : : // must replace formal arguments of abd with the free variables in the 143 : : // input problem that they correspond to. 144 : 487 : std::vector<Node> vars; 145 : 487 : std::vector<Node> syms; 146 : : SygusVarToTermAttribute sta; 147 [ + + ]: 1397 : for (const Node& bv : agdtbv) 148 : : { 149 : 910 : vars.push_back(bv); 150 [ + - ]: 910 : syms.push_back(bv.hasAttribute(sta) ? bv.getAttribute(sta) : bv); 151 : 910 : } 152 : : abd = 153 : 487 : abd.substitute(vars.begin(), vars.end(), syms.begin(), syms.end()); 154 : 487 : } 155 : : 156 : : // if check abducts option is set, we check the correctness 157 [ + + ]: 491 : if (options().smt.checkAbducts) 158 : : { 159 : 35 : checkAbduct(abd); 160 : : } 161 : 491 : return true; 162 : 491 : } 163 [ - - ]: 0 : Trace("sygus-abduct") << "SolverEngine::getAbduct: could not find solution!" 164 : 0 : << std::endl; 165 : 0 : throw RecoverableModalException("Could not find solution for get-abduct."); 166 : : } 167 : 13 : return false; 168 : 504 : } 169 : : 170 : 35 : void AbductionSolver::checkAbduct(Node a) 171 : : { 172 [ - + ][ - + ]: 35 : Assert(a.getType().isBoolean()); [ - - ] 173 [ + - ]: 70 : Trace("check-abduct") << "SolverEngine::checkAbduct: get expanded assertions" 174 : 35 : << std::endl; 175 : 35 : bool canTrustResult = SygusSolver::canTrustSynthesisResult(options()); 176 [ + + ]: 35 : if (!canTrustResult) 177 : : { 178 : 3 : warning() << "Running check-abducts is not guaranteed to pass with the " 179 : 3 : "current options." 180 : 3 : << std::endl; 181 : : } 182 : 35 : std::vector<Node> asserts(d_axioms.begin(), d_axioms.end()); 183 : 35 : asserts.push_back(a); 184 : : 185 : 35 : Options subOptions; 186 : 35 : subOptions.copyValues(d_env.getOptions()); 187 : 35 : subOptions.write_smt().produceAbducts = false; 188 : 35 : SetDefaults::disableChecking(subOptions); 189 : 35 : SubsolverSetupInfo ssi(d_env, subOptions); 190 : : // two checks: first, consistent with assertions, second, implies negated goal 191 : : // is unsatisfiable. 192 [ + + ]: 105 : for (unsigned j = 0; j < 2; j++) 193 : : { 194 [ + - ]: 140 : Trace("check-abduct") << "SolverEngine::checkAbduct: phase " << j 195 : 70 : << ": make new SMT engine" << std::endl; 196 : : // Start new SMT engine to check solution 197 : 70 : std::unique_ptr<SolverEngine> abdChecker; 198 : 70 : initializeSubsolver(nodeManager(), abdChecker, ssi); 199 [ + - ]: 140 : Trace("check-abduct") << "SolverEngine::checkAbduct: phase " << j 200 : 70 : << ": asserting formulas" << std::endl; 201 [ + + ]: 383 : for (const Node& e : asserts) 202 : : { 203 : 313 : abdChecker->assertFormula(e); 204 : : } 205 [ + - ]: 140 : Trace("check-abduct") << "SolverEngine::checkAbduct: phase " << j 206 : 70 : << ": check the assertions" << std::endl; 207 : 70 : Result r = abdChecker->checkSat(); 208 [ + - ]: 140 : Trace("check-abduct") << "SolverEngine::checkAbduct: phase " << j 209 : 70 : << ": result is " << r << std::endl; 210 : 70 : std::stringstream serr; 211 : 70 : bool isError = false; 212 : 70 : bool hardFailure = canTrustResult; 213 [ + + ]: 70 : if (j == 0) 214 : : { 215 [ + + ]: 35 : if (r.getStatus() != Result::SAT) 216 : : { 217 : 1 : isError = true; 218 : : serr 219 : : << "SolverEngine::checkAbduct(): produced solution cannot be shown " 220 : 1 : "to be consistent with assertions, result was " 221 : 1 : << r; 222 [ + - ]: 1 : hardFailure = r.isUnknown() ? false : hardFailure; 223 : : } 224 [ + - ]: 70 : Trace("check-abduct") 225 : 35 : << "SolverEngine::checkAbduct: goal is " << d_abdConj << std::endl; 226 : : // add the goal to the set of assertions 227 [ - + ][ - + ]: 35 : Assert(!d_abdConj.isNull()); [ - - ] 228 : 35 : asserts.push_back(d_abdConj); 229 : : } 230 : : else 231 : : { 232 [ - + ]: 35 : if (r.getStatus() != Result::UNSAT) 233 : : { 234 : 0 : isError = true; 235 : : serr << "SolverEngine::checkAbduct(): negated goal cannot be shown " 236 : 0 : "unsatisfiable with produced solution, result was " 237 : 0 : << r; 238 [ - - ]: 0 : hardFailure = r.isUnknown() ? false : hardFailure; 239 : : } 240 : : } 241 : : // did we get an unexpected result? 242 [ + + ]: 70 : if (isError) 243 : : { 244 [ - + ]: 1 : if (hardFailure) 245 : : { 246 : 0 : InternalError() << serr.str(); 247 : : } 248 : : else 249 : : { 250 : 1 : warning() << serr.str() << std::endl; 251 : : } 252 : : } 253 : 70 : } 254 : 35 : } 255 : : 256 : : } // namespace smt 257 : : } // namespace cvc5::internal