LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/quantifiers - mbqi_enum.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 155 174 89.1 %
Date: 2026-03-01 11:40:25 Functions: 10 10 100.0 %
Branches: 74 138 53.6 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * This file is part of the cvc5 project.
       3                 :            :  *
       4                 :            :  * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
       5                 :            :  * in the top-level source directory and their institutional affiliations.
       6                 :            :  * All rights reserved.  See the file COPYING in the top-level source
       7                 :            :  * directory for licensing information.
       8                 :            :  * ****************************************************************************
       9                 :            :  *
      10                 :            :  * A class for augmenting model-based instantiations via fast sygus enumeration.
      11                 :            :  */
      12                 :            : 
      13                 :            : #include "theory/quantifiers/mbqi_enum.h"
      14                 :            : 
      15                 :            : #include "expr/node_algorithm.h"
      16                 :            : #include "expr/skolem_manager.h"
      17                 :            : #include "printer/smt2/smt2_printer.h"
      18                 :            : #include "theory/datatypes/sygus_datatype_utils.h"
      19                 :            : #include "theory/quantifiers/inst_strategy_mbqi.h"
      20                 :            : #include "theory/quantifiers/sygus/sygus_enumerator.h"
      21                 :            : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
      22                 :            : #include "theory/smt_engine_subsolver.h"
      23                 :            : #include "util/random.h"
      24                 :            : #include "smt/set_defaults.h"
      25                 :            : 
      26                 :            : namespace cvc5::internal {
      27                 :            : namespace theory {
      28                 :            : namespace quantifiers {
      29                 :            : 
      30                 :         49 : void MVarInfo::initialize(Env& env,
      31                 :            :                           const Node& q,
      32                 :            :                           const Node& v,
      33                 :            :                           const std::vector<Node>& etrules)
      34                 :            : {
      35                 :         49 :   NodeManager* nm = env.getNodeManager();
      36                 :         49 :   TypeNode tn = v.getType();
      37 [ -  + ][ -  + ]:         49 :   Assert(MQuantInfo::shouldEnumerate(tn));
                 [ -  - ]
      38                 :         49 :   TypeNode retType = tn;
      39                 :         49 :   std::vector<Node> trules;
      40         [ +  + ]:         49 :   if (tn.isFunction())
      41                 :            :   {
      42                 :         42 :     std::vector<TypeNode> argTypes = tn.getArgTypes();
      43                 :         42 :     retType = tn.getRangeType();
      44                 :         42 :     std::vector<Node> vs;
      45         [ +  + ]:         89 :     for (const TypeNode& tnc : argTypes)
      46                 :            :     {
      47                 :         47 :       Node vc = NodeManager::mkBoundVar(tnc);
      48                 :         47 :       vs.push_back(vc);
      49                 :         47 :     }
      50                 :         42 :     d_lamVars = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
      51                 :         42 :     trules.insert(trules.end(), vs.begin(), vs.end());
      52                 :         42 :   }
      53                 :            :   // include free symbols from body of quantified formula if applicable
      54         [ +  - ]:         49 :   if (env.getOptions().quantifiers.mbqiEnumFreeSymsGrammar)
      55                 :            :   {
      56                 :         49 :     std::unordered_set<Node> syms;
      57                 :         49 :     expr::getSymbols(q[1], syms);
      58                 :         49 :     trules.insert(trules.end(), syms.begin(), syms.end());
      59                 :         49 :   }
      60                 :            :   // include the external terminal rules
      61         [ +  + ]:         63 :   for (const Node& symbol : etrules)
      62                 :            :   {
      63         [ +  - ]:         14 :     if (std::find(trules.begin(), trules.end(), symbol) == trules.end())
      64                 :            :     {
      65                 :         14 :       trules.push_back(symbol);
      66                 :            :     }
      67                 :            :   }
      68         [ +  - ]:         49 :   Trace("mbqi-fast-enum") << "Symbols: " << trules << std::endl;
      69                 :            :   SygusGrammarCons sgc;
      70                 :         49 :   Node bvl;
      71                 :         49 :   TypeNode tng = sgc.mkDefaultSygusType(env, retType, bvl, trules);
      72         [ -  + ]:         49 :   if (TraceIsOn("mbqi-fast-enum"))
      73                 :            :   {
      74         [ -  - ]:          0 :     Trace("mbqi-fast-enum") << "Enumerate terms for " << retType;
      75         [ -  - ]:          0 :     if (!d_lamVars.isNull())
      76                 :            :     {
      77         [ -  - ]:          0 :       Trace("mbqi-fast-enum") << ", variable list " << d_lamVars;
      78                 :            :     }
      79         [ -  - ]:          0 :     Trace("mbqi-fast-enum") << std::endl;
      80         [ -  - ]:          0 :     Trace("mbqi-fast-enum") << "Based on grammar:" << std::endl;
      81         [ -  - ]:          0 :     Trace("mbqi-fast-enum")
      82                 :          0 :         << printer::smt2::Smt2Printer::sygusGrammarString(tng) << std::endl;
      83                 :            :   }
      84                 :         49 :   d_senum.reset(new SygusTermEnumerator(env, tng));
      85                 :         49 : }
      86                 :            : 
      87                 :       4995 : Node MVarInfo::getEnumeratedTerm(NodeManager* nm, size_t i)
      88                 :            : {
      89                 :       4995 :   size_t nullCount = 0;
      90         [ +  + ]:      11841 :   while (i >= d_enum.size())
      91                 :            :   {
      92                 :       6874 :     Node curr = d_senum->getCurrent();
      93         [ +  - ]:       6874 :     Trace("mbqi-sygus-enum") << "Enumerate: " << curr << std::endl;
      94         [ +  + ]:       6874 :     if (!curr.isNull())
      95                 :            :     {
      96         [ +  + ]:       2049 :       if (!d_lamVars.isNull())
      97                 :            :       {
      98                 :        236 :         curr = nm->mkNode(Kind::LAMBDA, d_lamVars, curr);
      99                 :            :       }
     100                 :       2049 :       d_enum.push_back(curr);
     101                 :       2049 :       nullCount = 0;
     102                 :            :     }
     103                 :            :     else
     104                 :            :     {
     105                 :       4825 :       nullCount++;
     106         [ +  + ]:       4825 :       if (nullCount > 100)
     107                 :            :       {
     108                 :            :         // break if we aren't making progress
     109                 :         28 :         break;
     110                 :            :       }
     111                 :            :     }
     112         [ -  + ]:       6846 :     if (!d_senum->incrementPartial())
     113                 :            :     {
     114                 :            :       // enumeration is finished
     115                 :          0 :       break;
     116                 :            :     }
     117         [ +  + ]:       6874 :   }
     118         [ +  + ]:       4995 :   if (i >= d_enum.size())
     119                 :            :   {
     120                 :         28 :     return Node::null();
     121                 :            :   }
     122                 :       4967 :   return d_enum[i];
     123                 :            : }
     124                 :            : 
     125                 :         63 : void MQuantInfo::initialize(Env& env, InstStrategyMbqi& parent, const Node& q)
     126                 :            : {
     127                 :            :   // The externally provided terminal rules. This set is shared between
     128                 :            :   // all variables we instantiate.
     129                 :         63 :   std::vector<Node> etrules;
     130         [ +  + ]:        154 :   for (const Node& v : q[0])
     131                 :            :   {
     132                 :         91 :     size_t index = d_vinfo.size();
     133                 :         91 :     d_vinfo.emplace_back();
     134                 :         91 :     TypeNode vtn = v.getType();
     135                 :            :     // if enumerated, add to list
     136         [ +  + ]:         91 :     if (shouldEnumerate(vtn))
     137                 :            :     {
     138                 :         49 :       d_indices.push_back(index);
     139                 :            :     }
     140                 :            :     else
     141                 :            :     {
     142                 :         42 :       d_nindices.push_back(index);
     143                 :            :       // include variables defined in terms of others if applicable
     144         [ +  - ]:         42 :       if (env.getOptions().quantifiers.mbqiEnumExtVarsGrammar)
     145                 :            :       {
     146                 :         42 :         etrules.push_back(v);
     147                 :            :       }
     148                 :            :     }
     149                 :        154 :   }
     150                 :            :   // include the global symbols if applicable
     151         [ -  + ]:         63 :   if (env.getOptions().quantifiers.mbqiEnumGlobalSymGrammar)
     152                 :            :   {
     153                 :          0 :     const context::CDHashSet<Node>& gsyms = parent.getGlobalSyms();
     154         [ -  - ]:          0 :     for (const Node& v : gsyms)
     155                 :            :     {
     156                 :          0 :       etrules.push_back(v);
     157                 :          0 :     }
     158                 :            :   }
     159                 :            :   // initialize the variables we are instantiating
     160         [ +  + ]:        112 :   for (size_t index : d_indices)
     161                 :            :   {
     162                 :         49 :     d_vinfo[index].initialize(env, q, q[0][index], etrules);
     163                 :            :   }
     164                 :         63 : }
     165                 :            : 
     166                 :        183 : MVarInfo& MQuantInfo::getVarInfo(size_t index)
     167                 :            : {
     168 [ -  + ][ -  + ]:        183 :   Assert(index < d_vinfo.size());
                 [ -  - ]
     169                 :        183 :   return d_vinfo[index];
     170                 :            : }
     171                 :            : 
     172                 :        204 : std::vector<size_t> MQuantInfo::getInstIndices() { return d_indices; }
     173                 :        204 : std::vector<size_t> MQuantInfo::getNoInstIndices() { return d_nindices; }
     174                 :            : 
     175                 :        140 : bool MQuantInfo::shouldEnumerate(const TypeNode& tn)
     176                 :            : {
     177         [ +  + ]:        140 :   if (tn.isUninterpretedSort())
     178                 :            :   {
     179                 :         42 :     return false;
     180                 :            :   }
     181                 :         98 :   return true;
     182                 :            : }
     183                 :            : 
     184                 :        353 : MbqiEnum::MbqiEnum(Env& env, InstStrategyMbqi& parent)
     185                 :        353 :     : EnvObj(env), d_parent(parent)
     186                 :            : {
     187                 :        353 :   d_subOptions.copyValues(options());
     188                 :        353 :   smt::SetDefaults::disableChecking(d_subOptions);
     189                 :        353 : }
     190                 :            : 
     191                 :        204 : MQuantInfo& MbqiEnum::getOrMkQuantInfo(const Node& q)
     192                 :            : {
     193                 :        204 :   auto [it, inserted] = d_qinfo.try_emplace(q);
     194         [ +  + ]:        204 :   if (inserted)
     195                 :            :   {
     196                 :         63 :     it->second.initialize(d_env, d_parent, q);
     197                 :            :   }
     198                 :        204 :   return it->second;
     199                 :            : }
     200                 :            : 
     201                 :        204 : bool MbqiEnum::constructInstantiation(const Node& q,
     202                 :            :                                       const Node& query,
     203                 :            :                                       const std::vector<Node>& vars,
     204                 :            :                                       std::vector<Node>& mvs,
     205                 :            :                                       const std::map<Node, Node>& mvFreshVar)
     206                 :            : {
     207 [ -  + ][ -  + ]:        204 :   Assert(q[0].getNumChildren() == vars.size());
                 [ -  - ]
     208 [ -  + ][ -  + ]:        204 :   Assert(vars.size() == mvs.size());
                 [ -  - ]
     209         [ -  + ]:        204 :   if (TraceIsOn("mbqi-model-enum"))
     210                 :            :   {
     211         [ -  - ]:          0 :     Trace("mbqi-model-enum") << "Instantiate " << q << std::endl;
     212         [ -  - ]:          0 :     for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
     213                 :            :     {
     214         [ -  - ]:          0 :       Trace("mbqi-model-enum")
     215                 :          0 :           << "  " << q[0][i] << " -> " << mvs[i] << std::endl;
     216                 :            :     }
     217                 :            :   }
     218                 :        204 :   SubsolverSetupInfo ssi(d_env, d_subOptions);
     219                 :        204 :   MQuantInfo& qi = getOrMkQuantInfo(q);
     220                 :        204 :   std::vector<size_t> indices = qi.getInstIndices();
     221                 :        204 :   std::vector<size_t> nindices = qi.getNoInstIndices();
     222                 :        204 :   Subs inst;
     223                 :        204 :   Subs vinst;
     224                 :        204 :   std::unordered_map<Node, Node> tmpCMap;
     225         [ +  + ]:        260 :   for (size_t i : nindices)
     226                 :            :   {
     227                 :         56 :     Node v = mvs[i];
     228                 :         56 :     v = d_parent.convertFromModel(v, tmpCMap, mvFreshVar);
     229         [ -  + ]:         56 :     if (v.isNull())
     230                 :            :     {
     231                 :          0 :       return false;
     232                 :            :     }
     233         [ +  - ]:        112 :     Trace("mbqi-model-enum")
     234                 :         56 :         << "* Assume: " << q[0][i] << " -> " << v << std::endl;
     235                 :            :     // if we don't enumerate it, we are already considering this instantiation
     236                 :         56 :     inst.add(vars[i], v);
     237                 :         56 :     vinst.add(q[0][i], v);
     238         [ +  - ]:         56 :   }
     239                 :        204 :   Node queryCurr = query;
     240         [ +  - ]:        204 :   Trace("mbqi-model-enum") << "...query is " << queryCurr << std::endl;
     241                 :        204 :   queryCurr = rewrite(inst.apply(queryCurr));
     242         [ +  - ]:        204 :   Trace("mbqi-model-enum") << "...processed is " << queryCurr << std::endl;
     243                 :            :   // consider variables in random order, for diversity of instantiations
     244                 :        204 :   std::shuffle(indices.begin(), indices.end(), Random::getRandom());
     245         [ +  + ]:        387 :   for (size_t i = 0, isize = indices.size(); i < isize; i++)
     246                 :            :   {
     247                 :        183 :     size_t ii = indices[i];
     248                 :        183 :     TNode v = vars[ii];
     249                 :        183 :     MVarInfo& vi = qi.getVarInfo(ii);
     250                 :        183 :     size_t cindex = 0;
     251                 :        183 :     bool success = false;
     252                 :            :     bool successEnum;
     253                 :            :     do
     254                 :            :     {
     255                 :       4995 :       Node ret = vi.getEnumeratedTerm(nodeManager(), cindex);
     256                 :       4995 :       cindex++;
     257                 :       4995 :       Node retc;
     258         [ +  + ]:       4995 :       if (!ret.isNull())
     259                 :            :       {
     260         [ +  - ]:       4967 :         Trace("mbqi-model-enum") << "- Try candidate: " << ret << std::endl;
     261                 :            :         // apply current substitution (to account for cases where ret has
     262                 :            :         // other variables in its grammar).
     263                 :       4967 :         ret = vinst.apply(ret);
     264                 :       4967 :         retc = ret;
     265                 :       4967 :         successEnum = true;
     266                 :            :         // now convert the value
     267                 :       4967 :         std::unordered_map<Node, Node> tmpConvertMap;
     268                 :       4967 :         std::map<TypeNode, std::unordered_set<Node> > freshVarType;
     269                 :       4967 :         retc = d_parent.convertToQuery(retc, tmpConvertMap, freshVarType);
     270                 :       4967 :       }
     271                 :            :       else
     272                 :            :       {
     273         [ +  - ]:         56 :         Trace("mbqi-model-enum")
     274                 :         28 :             << "- Failed to enumerate candidate" << std::endl;
     275                 :            :         // if we failed to enumerate, just try the original
     276                 :         28 :         Node mc = d_parent.convertFromModel(mvs[ii], tmpCMap, mvFreshVar);
     277         [ -  + ]:         28 :         if (mc.isNull())
     278                 :            :         {
     279                 :            :           // if failed to convert, we fail
     280                 :          0 :           return false;
     281                 :            :         }
     282                 :         28 :         ret = mc;
     283                 :         28 :         retc = mc;
     284                 :         28 :         successEnum = false;
     285         [ +  - ]:         28 :       }
     286         [ +  - ]:       9990 :       Trace("mbqi-model-enum")
     287                 :       4995 :           << "- Converted candidate: " << v << " -> " << retc << std::endl;
     288                 :            :       // see if it is still satisfiable, if still SAT, we replace
     289                 :       9990 :       Node queryCheck = queryCurr.substitute(v, TNode(retc));
     290                 :       4995 :       queryCheck = rewrite(queryCheck);
     291         [ +  - ]:       4995 :       Trace("mbqi-model-enum") << "...check " << queryCheck << std::endl;
     292                 :       4995 :       Result r = checkWithSubsolver(queryCheck, ssi);
     293         [ +  + ]:       4995 :       if (r == Result::SAT)
     294                 :            :       {
     295                 :            :         // remember the updated query
     296                 :        183 :         queryCurr = queryCheck;
     297         [ +  - ]:        183 :         Trace("mbqi-model-enum") << "...success" << std::endl;
     298         [ +  - ]:        366 :         Trace("mbqi-model-enum")
     299                 :        183 :             << "* Enumerated " << q[0][ii] << " -> " << ret << std::endl;
     300                 :        183 :         mvs[ii] = ret;
     301                 :        183 :         vinst.add(q[0][ii], ret);
     302                 :        183 :         success = true;
     303                 :            :       }
     304         [ -  + ]:       4812 :       else if (!successEnum)
     305                 :            :       {
     306                 :            :         // we did not enumerate a candidate, and tried the original, which
     307                 :            :         // failed.
     308                 :          0 :         return false;
     309                 :            :       }
     310 [ +  - ][ +  - ]:       9990 :     } while (!success);
         [ +  - ][ +  - ]
                 [ +  + ]
     311         [ +  - ]:        183 :   }
     312                 :            :   // try the instantiation
     313                 :        204 :   return d_parent.tryInstantiation(
     314                 :        204 :       q, mvs, InferenceId::QUANTIFIERS_INST_MBQI_ENUM, mvFreshVar);
     315                 :        204 : }
     316                 :            : }  // namespace quantifiers
     317                 :            : }  // namespace theory
     318                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14