LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/quantifiers - oracle_engine.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 128 155 82.6 %
Date: 2025-02-18 13:42:10 Functions: 12 14 85.7 %
Branches: 71 142 50.0 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Andrew Reynolds, Aina Niemetz, Daniel Larraz
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Oracle engine
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/quantifiers/oracle_engine.h"
      17                 :            : 
      18                 :            : #include "expr/attribute.h"
      19                 :            : #include "expr/skolem_manager.h"
      20                 :            : #include "options/quantifiers_options.h"
      21                 :            : #include "theory/decision_manager.h"
      22                 :            : #include "theory/quantifiers/first_order_model.h"
      23                 :            : #include "theory/quantifiers/quantifiers_attributes.h"
      24                 :            : #include "theory/quantifiers/quantifiers_inference_manager.h"
      25                 :            : #include "theory/quantifiers/quantifiers_registry.h"
      26                 :            : #include "theory/quantifiers/term_registry.h"
      27                 :            : #include "theory/quantifiers/term_tuple_enumerator.h"
      28                 :            : #include "theory/trust_substitutions.h"
      29                 :            : 
      30                 :            : using namespace cvc5::internal::kind;
      31                 :            : using namespace cvc5::context;
      32                 :            : 
      33                 :            : namespace cvc5::internal {
      34                 :            : namespace theory {
      35                 :            : namespace quantifiers {
      36                 :            : 
      37                 :            : /** Attribute true for input variables */
      38                 :            : struct OracleInputVarAttributeId
      39                 :            : {
      40                 :            : };
      41                 :            : typedef expr::Attribute<OracleInputVarAttributeId, bool>
      42                 :            :     OracleInputVarAttribute;
      43                 :            : /** Attribute true for output variables */
      44                 :            : struct OracleOutputVarAttributeId
      45                 :            : {
      46                 :            : };
      47                 :            : typedef expr::Attribute<OracleOutputVarAttributeId, bool>
      48                 :            :     OracleOutputVarAttribute;
      49                 :            : 
      50                 :        466 : OracleEngine::OracleEngine(Env& env,
      51                 :            :                            QuantifiersState& qs,
      52                 :            :                            QuantifiersInferenceManager& qim,
      53                 :            :                            QuantifiersRegistry& qr,
      54                 :        466 :                            TermRegistry& tr)
      55                 :            :     : QuantifiersModule(env, qs, qim, qr, tr),
      56                 :        932 :       d_oracleFuns(userContext()),
      57                 :        932 :       d_ochecker(env.getOracleChecker()),
      58                 :            :       d_consistencyCheckPassed(false),
      59                 :        466 :       d_dstrat(env, "OracleArgValue", qs.getValuation())
      60                 :            : {
      61 [ -  + ][ -  + ]:        466 :   Assert(d_ochecker != nullptr);
                 [ -  - ]
      62                 :        466 : }
      63                 :            : 
      64                 :        282 : void OracleEngine::presolve() {
      65                 :            :   // Ensure all oracle functions in top-level substitutions occur in
      66                 :            :   // lemmas. Otherwise the oracles will not be invoked for those values
      67                 :            :   // and the model will be inaccurate.
      68                 :            :   std::unordered_map<Node, Node> subs =
      69                 :        564 :       d_env.getTopLevelSubstitutions().get().getSubstitutions();
      70                 :        564 :   std::unordered_set<Node> visited;
      71                 :        564 :   std::vector<TNode> visit;
      72         [ -  + ]:        282 :   for (const std::pair<const Node, Node>& s : subs)
      73                 :            :   {
      74                 :          0 :     visit.push_back(s.second);
      75                 :            :   }
      76                 :        564 :   TNode cur;
      77         [ -  + ]:        282 :   while (!visit.empty())
      78                 :            :   {
      79                 :          0 :     cur = visit.back();
      80                 :          0 :     visit.pop_back();
      81         [ -  - ]:          0 :     if (visited.find(cur) == visited.end())
      82                 :            :     {
      83                 :          0 :       visited.insert(cur);
      84         [ -  - ]:          0 :       if (OracleCaller::isOracleFunctionApp(cur))
      85                 :            :       {
      86                 :          0 :         Node k = SkolemManager::mkPurifySkolem(cur);
      87                 :          0 :         Node eq = k.eqNode(cur);
      88                 :          0 :         d_qim.lemma(eq, InferenceId::QUANTIFIERS_ORACLE_PURIFY_SUBS);
      89                 :            :       }
      90         [ -  - ]:          0 :       if (cur.getNumChildren() > 0)
      91                 :            :       {
      92                 :          0 :         visit.insert(visit.end(), cur.begin(), cur.end());
      93                 :            :       }
      94                 :            :     }
      95                 :            :   }
      96                 :            :   // register the decision strategy which will insist that arguments are
      97                 :            :   // decided to be equal to values.
      98                 :        282 :   d_qim.getDecisionManager()->registerStrategy(
      99                 :            :       DecisionManager::STRAT_ORACLE_ARG_VALUE,
     100                 :            :       &d_dstrat,
     101                 :            :       DecisionManager::STRAT_SCOPE_LOCAL_SOLVE);
     102                 :        282 : }
     103                 :            : 
     104                 :       2350 : bool OracleEngine::needsCheck(Theory::Effort e)
     105                 :            : {
     106                 :       2350 :   return e == Theory::Effort::EFFORT_LAST_CALL;
     107                 :            : }
     108                 :            : 
     109                 :            : // the model is built at this effort level
     110                 :       1034 : OracleEngine::QEffort OracleEngine::needsModel(Theory::Effort e)
     111                 :            : {
     112                 :       1034 :   return QEFFORT_MODEL;
     113                 :            : }
     114                 :            : 
     115                 :       2068 : void OracleEngine::reset_round(Theory::Effort e)
     116                 :            : {
     117                 :       2068 :   d_consistencyCheckPassed = false;
     118                 :       2068 : }
     119                 :            : 
     120                 :        282 : void OracleEngine::registerQuantifier(Node q) {}
     121                 :            : 
     122                 :       3102 : void OracleEngine::check(Theory::Effort e, QEffort quant_e)
     123                 :            : {
     124         [ +  + ]:       3102 :   if (quant_e != QEFFORT_MODEL)
     125                 :            :   {
     126                 :       2068 :     return;
     127                 :            :   }
     128                 :            : 
     129                 :       1034 :   FirstOrderModel* fm = d_treg.getModel();
     130                 :       1034 :   TermDb* termDatabase = d_treg.getTermDatabase();
     131                 :       1034 :   NodeManager* nm = nodeManager();
     132                 :       1034 :   unsigned nquant = fm->getNumAssertedQuantifiers();
     133                 :       1034 :   std::vector<Node> currInterfaces;
     134         [ +  + ]:       2068 :   for (unsigned i = 0; i < nquant; i++)
     135                 :            :   {
     136                 :       1034 :     Node q = fm->getAssertedQuantifier(i);
     137         [ -  + ]:       1034 :     if (d_qreg.getOwner(q) != this)
     138                 :            :     {
     139                 :          0 :       continue;
     140                 :            :     }
     141                 :       1034 :     currInterfaces.push_back(q);
     142                 :            :   }
     143 [ -  + ][ -  - ]:       1034 :   if (d_oracleFuns.empty() && currInterfaces.empty())
                 [ -  + ]
     144                 :            :   {
     145                 :          0 :     return;
     146                 :            :   }
     147                 :       1034 :   beginCallDebug();
     148                 :            :   // Note that we currently ignore oracle interface quantified formulas, and
     149                 :            :   // look directly at the oracle functions. Note that:
     150                 :            :   // (1) The lemmas with InferenceId QUANTIFIERS_ORACLE_INTERFACE are not
     151                 :            :   // guarded by a quantified formula. This means that we are assuming that all
     152                 :            :   // oracle interface quantified formulas are top-level assertions. This is
     153                 :            :   // correct because we do not expose a way of embedding oracle interfaces into
     154                 :            :   // formulas at the user level.
     155                 :            :   // (2) We assume that oracle functions have associated oracle interface
     156                 :            :   // quantified formulas that are in currInterfaces.
     157                 :            :   // (3) We currently ignore oracle interface quantified formulas that are
     158                 :            :   // not associated with oracle functions.
     159                 :            :   //
     160                 :            :   // The current design choices above are due to the fact that our support is
     161                 :            :   // limited to "definitional SMTO" (see Polgreen et al 2022). In particular,
     162                 :            :   // we only support oracles that define I/O equalities for oracle functions
     163                 :            :   // only. The net effect of this class hence is to check the consistency of
     164                 :            :   // oracle functions, and allow "sat" or otherwise add a lemma with id
     165                 :            :   // QUANTIFIERS_ORACLE_INTERFACE.
     166                 :       2068 :   std::vector<Node> learnedLemmas;
     167                 :       1034 :   bool allFappsConsistent = true;
     168                 :            :   // iterate over oracle functions
     169         [ +  + ]:       2068 :   for (const Node& f : d_oracleFuns)
     170                 :            :   {
     171                 :       1034 :     TNodeTrie* tat = termDatabase->getTermArgTrie(f);
     172         [ -  + ]:       1034 :     if (!tat)
     173                 :            :     {
     174                 :          0 :       continue;
     175                 :            :     }
     176                 :       3102 :     std::vector<Node> apps = tat->getLeaves(f.getType().getArgTypes().size());
     177         [ +  - ]:       2068 :     Trace("oracle-calls") << "Oracle fun " << f << " with " << apps.size()
     178                 :       1034 :                           << " applications." << std::endl;
     179         [ +  + ]:       2068 :     for (const auto& fapp : apps)
     180                 :            :     {
     181                 :       2068 :       std::vector<Node> arguments;
     182                 :       1034 :       arguments.push_back(f);
     183                 :            :       // evaluate arguments
     184         [ +  + ]:       2256 :       for (const auto& arg : fapp)
     185                 :            :       {
     186                 :       1222 :         arguments.push_back(fm->getValue(arg));
     187                 :            :       }
     188                 :            :       // call oracle
     189                 :       2068 :       Node fappWithValues = nm->mkNode(Kind::APPLY_UF, arguments);
     190                 :       2068 :       Node predictedResponse = fm->getValue(fapp);
     191                 :            :       Node result =
     192                 :       3102 :           d_ochecker->checkConsistent(fappWithValues, predictedResponse);
     193         [ +  + ]:       1034 :       if (!result.isNull())
     194                 :            :       {
     195                 :            :         // Note that we add (=> (= args values) (= (f args) result))
     196                 :            :         // instead of (= (f values) result) here. The latter may be more
     197                 :            :         // compact, but we require introducing literals for (= args values)
     198                 :            :         // so that they can be preferred by the decision strategy.
     199                 :       1692 :         std::vector<Node> disj;
     200                 :       2538 :         Node conc = nm->mkNode(Kind::EQUAL, fapp, result);
     201                 :        846 :         disj.push_back(conc);
     202         [ +  + ]:       1786 :         for (size_t i = 0, nchild = fapp.getNumChildren(); i < nchild; i++)
     203                 :            :         {
     204                 :        940 :           Node eqa = fapp[i].eqNode(arguments[i + 1]);
     205                 :        940 :           eqa = rewrite(eqa);
     206                 :            :           // Insist that the decision strategy tries to make (= args values)
     207                 :            :           // true first. This is to ensure that the value of the oracle can be
     208                 :            :           // used.
     209                 :        940 :           d_dstrat.addLiteral(eqa);
     210                 :        940 :           disj.push_back(eqa.notNode());
     211                 :            :         }
     212                 :        846 :         Node lem = nm->mkOr(disj);
     213                 :        846 :         learnedLemmas.push_back(lem);
     214                 :        846 :         allFappsConsistent = false;
     215                 :            :       }
     216                 :            :     }
     217                 :            :   }
     218                 :            :   // if all were consistent, we can terminate
     219         [ +  + ]:       1034 :   if (allFappsConsistent)
     220                 :            :   {
     221         [ +  - ]:        376 :     Trace("oracle-engine-state")
     222                 :        188 :         << "All responses consistent, no lemmas added" << std::endl;
     223                 :        188 :     d_consistencyCheckPassed = true;
     224                 :            :   }
     225                 :            :   else
     226                 :            :   {
     227         [ +  + ]:       1692 :     for (const Node& l : learnedLemmas)
     228                 :            :     {
     229         [ +  - ]:        846 :       Trace("oracle-engine-state") << "adding lemma " << l << std::endl;
     230                 :        846 :       d_qim.lemma(l, InferenceId::QUANTIFIERS_ORACLE_INTERFACE);
     231                 :            :     }
     232                 :            :   }
     233                 :            :   // general SMTO: call constraint generators and assumption generators here
     234                 :            : 
     235                 :       1034 :   endCallDebug();
     236                 :            : }
     237                 :            : 
     238                 :        188 : bool OracleEngine::checkCompleteFor(Node q)
     239                 :            : {
     240         [ -  + ]:        188 :   if (d_qreg.getOwner(q) != this)
     241                 :            :   {
     242                 :          0 :     return false;
     243                 :            :   }
     244                 :            :   // Only true if oracle consistency check was successful. Notice that
     245                 :            :   // we can say true for *all* oracle interface quantified formulas in the
     246                 :            :   // case that the consistency check passed. In particular, the invocation
     247                 :            :   // of oracle interfaces does not need to be complete.
     248                 :        188 :   return d_consistencyCheckPassed;
     249                 :            : }
     250                 :            : 
     251                 :        282 : void OracleEngine::checkOwnership(Node q)
     252                 :            : {
     253                 :            :   // take ownership of quantified formulas that are oracle interfaces
     254                 :        282 :   QuantAttributes& qa = d_qreg.getQuantAttributes();
     255         [ -  + ]:        282 :   if (!qa.isOracleInterface(q))
     256                 :            :   {
     257                 :          0 :     return;
     258                 :            :   }
     259                 :        282 :   d_qreg.setOwner(q, this);
     260                 :            :   // We expect oracle interfaces to be limited to definitional SMTO currently.
     261         [ +  - ]:        282 :   if (Configuration::isAssertionBuild())
     262                 :            :   {
     263                 :        564 :     std::vector<Node> inputs, outputs;
     264                 :        564 :     Node assume, constraint, oracle;
     265         [ -  + ]:        282 :     if (!getOracleInterface(q, inputs, outputs, assume, constraint, oracle))
     266                 :            :     {
     267                 :          0 :       Assert(false) << "Not an oracle interface " << q;
     268                 :            :     }
     269                 :            :     else
     270                 :            :     {
     271                 :        282 :       Assert(outputs.size() == 1) << "Unhandled oracle constraint " << q;
     272                 :        564 :       Assert(constraint.isConst() && constraint.getConst<bool>())
     273                 :          0 :           << "Unhandled oracle constraint " << q;
     274                 :            :     }
     275                 :        282 :     CVC5_UNUSED bool isOracleFun = false;
     276         [ +  - ]:        282 :     if (assume.getKind() == Kind::EQUAL)
     277                 :            :     {
     278         [ +  + ]:        846 :       for (size_t i = 0; i < 2; i++)
     279                 :            :       {
     280 [ +  + ][ -  - ]:        564 :         if (OracleCaller::isOracleFunctionApp(assume[i])
     281 [ +  + ][ +  - ]:        564 :             && assume[1 - i] == outputs[0])
         [ +  + ][ +  - ]
                 [ -  - ]
     282                 :            :         {
     283                 :        282 :           isOracleFun = true;
     284                 :            :         }
     285                 :            :       }
     286                 :            :     }
     287                 :        282 :     Assert(isOracleFun)
     288                 :          0 :         << "Non-definitional oracle interface quantified formula " << q;
     289                 :            :   }
     290                 :            : }
     291                 :            : 
     292                 :          0 : std::string OracleEngine::identify() const
     293                 :            : {
     294                 :          0 :   return std::string("OracleEngine");
     295                 :            : }
     296                 :            : 
     297                 :        654 : void OracleEngine::declareOracleFun(Node f) { d_oracleFuns.push_back(f); }
     298                 :            : 
     299                 :          0 : std::vector<Node> OracleEngine::getOracleFuns() const
     300                 :            : {
     301                 :          0 :   std::vector<Node> ofuns;
     302         [ -  - ]:          0 :   for (const Node& f : d_oracleFuns)
     303                 :            :   {
     304                 :          0 :     ofuns.push_back(f);
     305                 :            :   }
     306                 :          0 :   return ofuns;
     307                 :            : }
     308                 :            : 
     309                 :        654 : Node OracleEngine::mkOracleInterface(const std::vector<Node>& inputs,
     310                 :            :                                      const std::vector<Node>& outputs,
     311                 :            :                                      Node assume,
     312                 :            :                                      Node constraint,
     313                 :            :                                      Node oracleNode)
     314                 :            : {
     315 [ -  + ][ -  + ]:        654 :   Assert(!assume.isNull());
                 [ -  - ]
     316 [ -  + ][ -  + ]:        654 :   Assert(!constraint.isNull());
                 [ -  - ]
     317 [ -  + ][ -  + ]:        654 :   Assert(oracleNode.getKind() == Kind::ORACLE);
                 [ -  - ]
     318                 :        654 :   NodeManager* nm = NodeManager::currentNM();
     319                 :            :   Node ipl = nm->mkNode(Kind::INST_PATTERN_LIST,
     320                 :       1962 :                         nm->mkNode(Kind::INST_ATTRIBUTE, oracleNode));
     321                 :       1308 :   std::vector<Node> vars;
     322                 :            :   OracleInputVarAttribute oiva;
     323         [ +  + ]:       1492 :   for (Node v : inputs)
     324                 :            :   {
     325                 :        838 :     v.setAttribute(oiva, true);
     326                 :        838 :     vars.push_back(v);
     327                 :            :   }
     328                 :            :   OracleOutputVarAttribute oova;
     329         [ +  + ]:       1308 :   for (Node v : outputs)
     330                 :            :   {
     331                 :        654 :     v.setAttribute(oova, true);
     332                 :        654 :     vars.push_back(v);
     333                 :            :   }
     334                 :       1308 :   Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, vars);
     335                 :       1308 :   Node body = nm->mkNode(Kind::ORACLE_FORMULA_GEN, assume, constraint);
     336                 :       1308 :   return nm->mkNode(Kind::FORALL, bvl, body, ipl);
     337                 :            : }
     338                 :            : 
     339                 :        282 : bool OracleEngine::getOracleInterface(Node q,
     340                 :            :                                       std::vector<Node>& inputs,
     341                 :            :                                       std::vector<Node>& outputs,
     342                 :            :                                       Node& assume,
     343                 :            :                                       Node& constraint,
     344                 :            :                                       Node& oracleNode) const
     345                 :            : {
     346                 :        282 :   QuantAttributes& qa = d_qreg.getQuantAttributes();
     347         [ +  - ]:        282 :   if (qa.isOracleInterface(q))
     348                 :            :   {
     349                 :            :     // fill in data
     350                 :            :     OracleInputVarAttribute oiva;
     351         [ +  + ]:        940 :     for (const Node& v : q[0])
     352                 :            :     {
     353         [ +  + ]:        658 :       if (v.getAttribute(oiva))
     354                 :            :       {
     355                 :        376 :         inputs.push_back(v);
     356                 :            :       }
     357                 :            :       else
     358                 :            :       {
     359 [ -  + ][ -  + ]:        282 :         Assert(v.getAttribute(OracleOutputVarAttribute()));
                 [ -  - ]
     360                 :        282 :         outputs.push_back(v);
     361                 :            :       }
     362                 :            :     }
     363 [ -  + ][ -  + ]:        282 :     Assert(q[1].getKind() == Kind::ORACLE_FORMULA_GEN);
                 [ -  - ]
     364                 :        282 :     assume = q[1][0];
     365                 :        282 :     constraint = q[1][1];
     366 [ -  + ][ -  + ]:        282 :     Assert(q.getNumChildren() == 3);
                 [ -  - ]
     367 [ -  + ][ -  + ]:        282 :     Assert(q[2].getNumChildren() == 1);
                 [ -  - ]
     368 [ -  + ][ -  + ]:        282 :     Assert(q[2][0].getNumChildren() == 1);
                 [ -  - ]
     369 [ -  + ][ -  + ]:        282 :     Assert(q[2][0][0].getKind() == Kind::ORACLE);
                 [ -  - ]
     370                 :        282 :     oracleNode = q[2][0][0];
     371                 :        282 :     return true;
     372                 :            :   }
     373                 :          0 :   return false;
     374                 :            : }
     375                 :            : 
     376                 :            : }  // namespace quantifiers
     377                 :            : }  // namespace theory
     378                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14