LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/quantifiers - mbqi_fast_sygus.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 149 163 91.4 %
Date: 2024-11-10 12:40:22 Functions: 10 10 100.0 %
Branches: 67 118 56.8 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Andrew Reynolds
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * A class for augmenting model-based instantiations via fast sygus enumeration.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/quantifiers/mbqi_fast_sygus.h"
      17                 :            : 
      18                 :            : #include "expr/node_algorithm.h"
      19                 :            : #include "expr/skolem_manager.h"
      20                 :            : #include "printer/smt2/smt2_printer.h"
      21                 :            : #include "theory/datatypes/sygus_datatype_utils.h"
      22                 :            : #include "theory/quantifiers/inst_strategy_mbqi.h"
      23                 :            : #include "theory/quantifiers/sygus/sygus_enumerator.h"
      24                 :            : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
      25                 :            : #include "theory/smt_engine_subsolver.h"
      26                 :            : #include "util/random.h"
      27                 :            : #include "smt/set_defaults.h"
      28                 :            : 
      29                 :            : namespace cvc5::internal {
      30                 :            : namespace theory {
      31                 :            : namespace quantifiers {
      32                 :            : 
      33                 :         56 : void MVarInfo::initialize(Env& env,
      34                 :            :                           const Node& q,
      35                 :            :                           const Node& v,
      36                 :            :                           const std::vector<Node>& etrules)
      37                 :            : {
      38                 :         56 :   NodeManager* nm = NodeManager::currentNM();
      39                 :        112 :   TypeNode tn = v.getType();
      40 [ -  + ][ -  + ]:         56 :   Assert(MQuantInfo::shouldEnumerate(tn));
                 [ -  - ]
      41                 :        112 :   TypeNode retType = tn;
      42                 :        112 :   std::vector<Node> trules;
      43         [ +  + ]:         56 :   if (tn.isFunction())
      44                 :            :   {
      45                 :         96 :     std::vector<TypeNode> argTypes = tn.getArgTypes();
      46                 :         48 :     retType = tn.getRangeType();
      47                 :         48 :     std::vector<Node> vs;
      48         [ +  + ]:        102 :     for (const TypeNode& tnc : argTypes)
      49                 :            :     {
      50                 :        108 :       Node vc = nm->mkBoundVar(tnc);
      51                 :         54 :       vs.push_back(vc);
      52                 :            :     }
      53                 :         48 :     d_lamVars = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
      54                 :         48 :     trules.insert(trules.end(), vs.begin(), vs.end());
      55                 :            :   }
      56                 :            :   // include free symbols from body of quantified formula if applicable
      57                 :        112 :   std::unordered_set<Node> syms;
      58                 :         56 :   expr::getSymbols(q[1], syms);
      59                 :         56 :   trules.insert(trules.end(), syms.begin(), syms.end());
      60                 :            :   // include the external terminal rules
      61         [ +  + ]:        190 :   for (const Node& symbol : etrules)
      62                 :            :   {
      63         [ +  + ]:        134 :     if (std::find(trules.begin(), trules.end(), symbol) == trules.end())
      64                 :            :     {
      65                 :         16 :       trules.push_back(symbol);
      66                 :            :     }
      67                 :            :   }
      68                 :            :   SygusGrammarCons sgc;
      69                 :        112 :   Node bvl;
      70                 :        112 :   TypeNode tng = sgc.mkDefaultSygusType(env, retType, bvl, trules);
      71         [ -  + ]:         56 :   if (TraceIsOn("mbqi-model-enum"))
      72                 :            :   {
      73         [ -  - ]:          0 :     Trace("mbqi-model-enum") << "Enumerate terms for " << retType;
      74         [ -  - ]:          0 :     if (!d_lamVars.isNull())
      75                 :            :     {
      76         [ -  - ]:          0 :       Trace("mbqi-model-enum") << ", variable list " << d_lamVars;
      77                 :            :     }
      78         [ -  - ]:          0 :     Trace("mbqi-model-enum") << std::endl;
      79         [ -  - ]:          0 :     Trace("mbqi-model-enum") << "Based on grammar:" << std::endl;
      80         [ -  - ]:          0 :     Trace("mbqi-model-enum")
      81                 :          0 :         << printer::smt2::Smt2Printer::sygusGrammarString(tng) << std::endl;
      82                 :            :   }
      83                 :         56 :   d_senum.reset(new SygusTermEnumerator(env, tng));
      84                 :         56 : }
      85                 :            : 
      86                 :       5673 : Node MVarInfo::getEnumeratedTerm(size_t i)
      87                 :            : {
      88                 :       5673 :   NodeManager* nm = NodeManager::currentNM();
      89                 :       5673 :   size_t nullCount = 0;
      90         [ +  + ]:      11921 :   while (i >= d_enum.size())
      91                 :            :   {
      92                 :       6264 :     Node curr = d_senum->getCurrent();
      93         [ +  - ]:       6264 :     Trace("mbqi-sygus-enum") << "Enumerate: " << curr << std::endl;
      94         [ +  + ]:       6264 :     if (!curr.isNull())
      95                 :            :     {
      96         [ +  + ]:       2351 :       if (!d_lamVars.isNull())
      97                 :            :       {
      98                 :        279 :         curr = nm->mkNode(Kind::LAMBDA, d_lamVars, curr);
      99                 :            :       }
     100                 :       2351 :       d_enum.push_back(curr);
     101                 :       2351 :       nullCount = 0;
     102                 :            :     }
     103                 :            :     else
     104                 :            :     {
     105                 :       3913 :       nullCount++;
     106         [ +  + ]:       3913 :       if (nullCount > 100)
     107                 :            :       {
     108                 :            :         // break if we aren't making progress
     109                 :         16 :         break;
     110                 :            :       }
     111                 :            :     }
     112         [ -  + ]:       6248 :     if (!d_senum->incrementPartial())
     113                 :            :     {
     114                 :            :       // enumeration is finished
     115                 :          0 :       break;
     116                 :            :     }
     117                 :            :   }
     118         [ +  + ]:       5673 :   if (i >= d_enum.size())
     119                 :            :   {
     120                 :         16 :     return Node::null();
     121                 :            :   }
     122                 :       5657 :   return d_enum[i];
     123                 :            : }
     124                 :            : 
     125                 :         56 : void MQuantInfo::initialize(Env& env, InstStrategyMbqi& parent, const Node& q)
     126                 :            : {
     127                 :            :   // The externally provided terminal rules. This set is shared between
     128                 :            :   // all variables we instantiate.
     129                 :        112 :   std::vector<Node> etrules;
     130         [ +  + ]:        128 :   for (const Node& v : q[0])
     131                 :            :   {
     132                 :         72 :     size_t index = d_vinfo.size();
     133                 :         72 :     d_vinfo.emplace_back();
     134                 :        144 :     TypeNode vtn = v.getType();
     135                 :            :     // if enumerated, add to list
     136         [ +  + ]:         72 :     if (shouldEnumerate(vtn))
     137                 :            :     {
     138                 :         56 :       d_indices.push_back(index);
     139                 :            :     }
     140                 :            :     else
     141                 :            :     {
     142                 :         16 :       d_nindices.push_back(index);
     143                 :            :       // Include variables defined in terms of others we are not enumerating.
     144                 :         16 :       etrules.push_back(v);
     145                 :            :     }
     146                 :            :   }
     147                 :            :   // Get free symbols from body of quantified formula here
     148                 :        112 :   std::unordered_set<Node> syms;
     149                 :         56 :   expr::getSymbols(q[1], syms);
     150         [ +  + ]:        174 :   for (const Node& symbol : syms)
     151                 :            :   {
     152         [ +  - ]:        118 :     if (std::find(etrules.begin(), etrules.end(), symbol) == etrules.end())
     153                 :            :     {
     154                 :        118 :       etrules.push_back(symbol);
     155                 :            :     }
     156                 :            :   }
     157         [ +  - ]:         56 :   Trace("mbqi-model-enum") << "Terminals: " << etrules << std::endl;
     158                 :            :   // initialize the variables we are instantiating
     159         [ +  + ]:        112 :   for (size_t index : d_indices)
     160                 :            :   {
     161                 :         56 :     d_vinfo[index].initialize(env, q, q[0][index], etrules);
     162                 :            :   }
     163                 :         56 : }
     164                 :            : 
     165                 :        195 : MVarInfo& MQuantInfo::getVarInfo(size_t index)
     166                 :            : {
     167 [ -  + ][ -  + ]:        195 :   Assert(index < d_vinfo.size());
                 [ -  - ]
     168                 :        195 :   return d_vinfo[index];
     169                 :            : }
     170                 :            : 
     171                 :        195 : std::vector<size_t> MQuantInfo::getInstIndices() { return d_indices; }
     172                 :        195 : std::vector<size_t> MQuantInfo::getNoInstIndices() { return d_nindices; }
     173                 :            : 
     174                 :        128 : bool MQuantInfo::shouldEnumerate(const TypeNode& tn)
     175                 :            : {
     176         [ +  + ]:        128 :   if (tn.isUninterpretedSort())
     177                 :            :   {
     178                 :         16 :     return false;
     179                 :            :   }
     180                 :        112 :   return true;
     181                 :            : }
     182                 :            : 
     183                 :        333 : MbqiFastSygus::MbqiFastSygus(Env& env, InstStrategyMbqi& parent)
     184                 :        333 :     : EnvObj(env), d_parent(parent)
     185                 :            : {
     186                 :        333 :   d_subOptions.copyValues(options());
     187                 :        333 :   smt::SetDefaults::disableChecking(d_subOptions);
     188                 :        333 : }
     189                 :            : 
     190                 :        195 : MQuantInfo& MbqiFastSygus::getOrMkQuantInfo(const Node& q)
     191                 :            : {
     192                 :        195 :   auto [it, inserted] = d_qinfo.try_emplace(q);
     193         [ +  + ]:        195 :   if (inserted)
     194                 :            :   {
     195                 :         56 :     it->second.initialize(d_env, d_parent, q);
     196                 :            :   }
     197                 :        195 :   return it->second;
     198                 :            : }
     199                 :            : 
     200                 :        195 : bool MbqiFastSygus::constructInstantiation(
     201                 :            :     const Node& q,
     202                 :            :     const Node& query,
     203                 :            :     const std::vector<Node>& vars,
     204                 :            :     std::vector<Node>& mvs,
     205                 :            :     const std::map<Node, Node>& mvFreshVar)
     206                 :            : {
     207 [ -  + ][ -  + ]:        195 :   Assert(q[0].getNumChildren() == vars.size());
                 [ -  - ]
     208 [ -  + ][ -  + ]:        195 :   Assert(vars.size() == mvs.size());
                 [ -  - ]
     209         [ -  + ]:        195 :   if (TraceIsOn("mbqi-model-enum"))
     210                 :            :   {
     211         [ -  - ]:          0 :     Trace("mbqi-model-enum") << "Instantiate " << q << std::endl;
     212         [ -  - ]:          0 :     for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
     213                 :            :     {
     214         [ -  - ]:          0 :       Trace("mbqi-model-enum")
     215                 :          0 :           << "  " << q[0][i] << " -> " << mvs[i] << std::endl;
     216                 :            :     }
     217                 :            :   }
     218                 :        390 :   SubsolverSetupInfo ssi(d_env, d_subOptions);
     219                 :        195 :   MQuantInfo& qi = getOrMkQuantInfo(q);
     220                 :        390 :   std::vector<size_t> indices = qi.getInstIndices();
     221                 :        390 :   std::vector<size_t> nindices = qi.getNoInstIndices();
     222                 :        390 :   Subs inst;
     223                 :        390 :   Subs vinst;
     224                 :        390 :   std::unordered_map<Node, Node> tmpCMap;
     225         [ +  + ]:        219 :   for (size_t i : nindices)
     226                 :            :   {
     227                 :         24 :     Node v = mvs[i];
     228                 :         24 :     v = d_parent.convertFromModel(v, tmpCMap, mvFreshVar);
     229         [ -  + ]:         24 :     if (v.isNull())
     230                 :            :     {
     231                 :          0 :       return false;
     232                 :            :     }
     233         [ +  - ]:         48 :     Trace("mbqi-model-enum")
     234                 :         24 :         << "* Assume: " << q[0][i] << " -> " << v << std::endl;
     235                 :            :     // if we don't enumerate it, we are already considering this instantiation
     236                 :         24 :     inst.add(vars[i], v);
     237                 :         24 :     vinst.add(q[0][i], v);
     238                 :            :   }
     239                 :        390 :   Node queryCurr = query;
     240         [ +  - ]:        195 :   Trace("mbqi-model-enum") << "...query is " << queryCurr << std::endl;
     241                 :        195 :   queryCurr = rewrite(inst.apply(queryCurr));
     242         [ +  - ]:        195 :   Trace("mbqi-model-enum") << "...processed is " << queryCurr << std::endl;
     243                 :            :   // consider variables in random order, for diversity of instantiations
     244                 :        195 :   std::shuffle(indices.begin(), indices.end(), Random::getRandom());
     245         [ +  + ]:        382 :   for (size_t i = 0, isize = indices.size(); i < isize; i++)
     246                 :            :   {
     247                 :        195 :     size_t ii = indices[i];
     248                 :        195 :     TNode v = vars[ii];
     249                 :        195 :     MVarInfo& vi = qi.getVarInfo(ii);
     250                 :        195 :     size_t cindex = 0;
     251                 :        195 :     bool success = false;
     252                 :            :     bool successEnum;
     253                 :       5478 :     do
     254                 :            :     {
     255                 :       5673 :       Node ret = vi.getEnumeratedTerm(cindex);
     256                 :       5673 :       cindex++;
     257                 :       5673 :       Node retc;
     258         [ +  + ]:       5673 :       if (!ret.isNull())
     259                 :            :       {
     260         [ +  - ]:       5657 :         Trace("mbqi-model-enum") << "- Try candidate: " << ret << std::endl;
     261                 :            :         // apply current substitution (to account for cases where ret has
     262                 :            :         // other variables in its grammar).
     263                 :       5657 :         ret = vinst.apply(ret);
     264                 :       5657 :         retc = ret;
     265                 :       5657 :         successEnum = true;
     266                 :            :         // now convert the value
     267                 :      11314 :         std::unordered_map<Node, Node> tmpConvertMap;
     268                 :       5657 :         std::map<TypeNode, std::unordered_set<Node> > freshVarType;
     269                 :       5657 :         retc = d_parent.convertToQuery(retc, tmpConvertMap, freshVarType);
     270                 :            :       }
     271                 :            :       else
     272                 :            :       {
     273         [ +  - ]:         32 :         Trace("mbqi-model-enum")
     274                 :         16 :             << "- Failed to enumerate candidate" << std::endl;
     275                 :            :         // if we failed to enumerate, just try the original
     276                 :         16 :         Node mc = d_parent.convertFromModel(mvs[ii], tmpCMap, mvFreshVar);
     277         [ +  + ]:         16 :         if (mc.isNull())
     278                 :            :         {
     279                 :            :           // if failed to convert, we fail
     280                 :          8 :           return false;
     281                 :            :         }
     282                 :          8 :         ret = mc;
     283                 :          8 :         retc = mc;
     284                 :          8 :         successEnum = false;
     285                 :            :       }
     286         [ +  - ]:      11330 :       Trace("mbqi-model-enum")
     287                 :       5665 :           << "- Converted candidate: " << v << " -> " << retc << std::endl;
     288                 :            :       // see if it is still satisfiable, if still SAT, we replace
     289                 :      11330 :       Node queryCheck = queryCurr.substitute(v, TNode(retc));
     290                 :       5665 :       queryCheck = rewrite(queryCheck);
     291         [ +  - ]:       5665 :       Trace("mbqi-model-enum") << "...check " << queryCheck << std::endl;
     292                 :       5665 :       Result r = checkWithSubsolver(queryCheck, ssi);
     293         [ +  + ]:       5665 :       if (r == Result::SAT)
     294                 :            :       {
     295                 :            :         // remember the updated query
     296                 :        187 :         queryCurr = queryCheck;
     297         [ +  - ]:        187 :         Trace("mbqi-model-enum") << "...success" << std::endl;
     298         [ +  - ]:        374 :         Trace("mbqi-model-enum")
     299                 :        187 :             << "* Enumerated " << q[0][ii] << " -> " << ret << std::endl;
     300                 :        187 :         mvs[ii] = ret;
     301                 :        187 :         vinst.add(q[0][ii], ret);
     302                 :        187 :         success = true;
     303                 :            :       }
     304         [ -  + ]:       5478 :       else if (!successEnum)
     305                 :            :       {
     306                 :            :         // we did not enumerate a candidate, and tried the original, which
     307                 :            :         // failed.
     308                 :          0 :         return false;
     309                 :            :       }
     310         [ +  + ]:       5665 :     } while (!success);
     311                 :            :   }
     312                 :        187 :   return true;
     313                 :            : }
     314                 :            : }  // namespace quantifiers
     315                 :            : }  // namespace theory
     316                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14