LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/expr - sygus_grammar.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 143 196 73.0 %
Date: 2025-01-02 12:37:25 Functions: 16 19 84.2 %
Branches: 92 156 59.0 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Abdalrhman Mohamed, Andrew Reynolds, Aina Niemetz
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Implementation of class for constructing SyGuS Grammars.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "expr/sygus_grammar.h"
      17                 :            : 
      18                 :            : #include <sstream>
      19                 :            : 
      20                 :            : #include "expr/dtype.h"
      21                 :            : #include "expr/dtype_cons.h"
      22                 :            : #include "printer/printer.h"
      23                 :            : #include "printer/smt2/smt2_printer.h"
      24                 :            : #include "theory/datatypes/sygus_datatype_utils.h"
      25                 :            : #include "expr/skolem_manager.h"
      26                 :            : #include "util/hash.h"
      27                 :            : 
      28                 :            : namespace cvc5::internal {
      29                 :            : 
      30                 :       3775 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
      31                 :       3775 :                            const std::vector<Node>& ntSyms)
      32                 :       3775 :     : d_sygusVars(sygusVars), d_ntSyms(ntSyms)
      33                 :            : {
      34         [ +  + ]:      10350 :   for (const Node& ntSym : ntSyms)
      35                 :            :   {
      36                 :       6575 :     d_rules.emplace(ntSym, std::vector<Node>{});
      37                 :            :   }
      38                 :       3775 : }
      39                 :            : 
      40                 :          0 : SygusGrammar::SygusGrammar(const std::vector<Node>& sygusVars,
      41                 :          0 :                            const TypeNode& sdt)
      42                 :          0 :     : d_sygusVars(sygusVars)
      43                 :            : {
      44                 :          0 :   Assert(sdt.isSygusDatatype());
      45                 :          0 :   std::vector<TypeNode> tnlist;
      46                 :            :   // ensure that sdt is first
      47                 :          0 :   tnlist.push_back(sdt);
      48                 :          0 :   std::map<TypeNode, Node> ntsyms;
      49         [ -  - ]:          0 :   for (size_t i = 0; i < tnlist.size(); i++)
      50                 :            :   {
      51                 :          0 :     TypeNode tn = tnlist[i];
      52                 :          0 :     Assert(tn.isSygusDatatype());
      53                 :          0 :     const DType& dt = tn.getDType();
      54                 :          0 :     std::stringstream ss;
      55                 :          0 :     ss << dt.getName();
      56                 :          0 :     Node v = NodeManager::mkBoundVar(ss.str(), dt.getSygusType());
      57                 :          0 :     ntsyms[tn] = v;
      58                 :          0 :     d_ntSyms.push_back(v);
      59                 :          0 :     d_rules.emplace(v, std::vector<Node>{});
      60                 :            :     // process the subfield types
      61                 :          0 :     std::unordered_set<TypeNode> tns = dt.getSubfieldTypes();
      62         [ -  - ]:          0 :     for (const TypeNode& tnsc : tns)
      63                 :            :     {
      64                 :          0 :       if (tnsc.isSygusDatatype()
      65                 :          0 :           && std::find(tnlist.begin(), tnlist.end(), tnsc) == tnlist.end())
      66                 :            :       {
      67                 :          0 :         tnlist.push_back(tnsc);
      68                 :            :       }
      69                 :            :     }
      70                 :            :   }
      71                 :          0 :   std::map<TypeNode, Node>::iterator itn;
      72         [ -  - ]:          0 :   for (const TypeNode& tn : tnlist)
      73                 :            :   {
      74         [ -  - ]:          0 :     if (!tn.isSygusDatatype())
      75                 :            :     {
      76                 :          0 :       continue;
      77                 :            :     }
      78                 :          0 :     Node nts = ntsyms[tn];
      79                 :          0 :     const DType& dt = tn.getDType();
      80         [ -  - ]:          0 :     for (size_t i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
      81                 :            :     {
      82                 :          0 :       const DTypeConstructor& cons = dt[i];
      83         [ -  - ]:          0 :       if (cons.isSygusAnyConstant())
      84                 :            :       {
      85                 :          0 :         addAnyConstant(nts, cons[0].getRangeType());
      86                 :          0 :         continue;
      87                 :            :       }
      88                 :          0 :       Node op = cons.getSygusOp();
      89                 :          0 :       std::vector<Node> args;
      90         [ -  - ]:          0 :       for (size_t j = 0, nargs = cons.getNumArgs(); j < nargs; j++)
      91                 :            :       {
      92                 :          0 :         TypeNode argType = cons[j].getRangeType();
      93                 :          0 :         itn = ntsyms.find(argType);
      94                 :          0 :         Assert(itn != ntsyms.end()) << "Missing " << argType << " in " << op;
      95                 :          0 :         args.push_back(itn->second);
      96                 :            :       }
      97                 :          0 :       Node rule = theory::datatypes::utils::mkSygusTerm(op, args, true);
      98                 :          0 :       addRule(nts, rule);
      99                 :            :     }
     100                 :            :   }
     101                 :          0 : }
     102                 :            : 
     103                 :      41310 : void SygusGrammar::addRule(const Node& ntSym, const Node& rule)
     104                 :            : {
     105 [ -  + ][ -  + ]:      41310 :   Assert(d_rules.find(ntSym) != d_rules.cend());
                 [ -  - ]
     106 [ -  + ][ -  + ]:      41310 :   Assert(rule.getType().isInstanceOf(ntSym.getType()));
                 [ -  - ]
     107                 :            :   // avoid duplication
     108                 :      41310 :   std::vector<Node>& rs = d_rules[ntSym];
     109         [ +  + ]:      41310 :   if (std::find(rs.begin(), rs.end(), rule) == rs.end())
     110                 :            :   {
     111                 :      32399 :     rs.push_back(rule);
     112                 :            :   }
     113                 :      41310 : }
     114                 :            : 
     115                 :         56 : void SygusGrammar::addRules(const Node& ntSym, const std::vector<Node>& rules)
     116                 :            : {
     117         [ +  + ]:        122 :   for (const Node& rule : rules)
     118                 :            :   {
     119                 :         66 :     addRule(ntSym, rule);
     120                 :            :   }
     121                 :         56 : }
     122                 :            : 
     123                 :        174 : void SygusGrammar::addAnyConstant(const Node& ntSym, const TypeNode& tn)
     124                 :            : {
     125 [ -  + ][ -  + ]:        174 :   Assert(d_rules.find(ntSym) != d_rules.cend());
                 [ -  - ]
     126 [ -  + ][ -  + ]:        174 :   Assert(tn.isInstanceOf(ntSym.getType()));
                 [ -  - ]
     127                 :        174 :   SkolemManager* sm = NodeManager::currentNM()->getSkolemManager();
     128                 :            :   Node anyConst =
     129                 :        522 :       sm->mkInternalSkolemFunction(InternalSkolemId::SYGUS_ANY_CONSTANT, tn);
     130                 :        174 :   addRule(ntSym, anyConst);
     131                 :        174 : }
     132                 :            : 
     133                 :         83 : void SygusGrammar::addAnyVariable(const Node& ntSym)
     134                 :            : {
     135 [ -  + ][ -  + ]:         83 :   Assert(d_rules.find(ntSym) != d_rules.cend());
                 [ -  - ]
     136                 :            :   // each variable of appropriate type becomes a rule.
     137         [ +  + ]:        253 :   for (const Node& v : d_sygusVars)
     138                 :            :   {
     139         [ +  + ]:        170 :     if (v.getType().isInstanceOf(ntSym.getType()))
     140                 :            :     {
     141                 :        125 :       addRule(ntSym, v);
     142                 :            :     }
     143                 :            :   }
     144                 :         83 : }
     145                 :            : 
     146                 :        605 : void SygusGrammar::removeRule(const Node& ntSym, const Node& rule)
     147                 :            : {
     148                 :            :   std::unordered_map<Node, std::vector<Node>>::iterator itr =
     149                 :        605 :       d_rules.find(ntSym);
     150 [ -  + ][ -  + ]:        605 :   Assert(itr != d_rules.end());
                 [ -  - ]
     151                 :            :   std::vector<Node>::iterator it =
     152                 :        605 :       std::find(itr->second.begin(), itr->second.end(), rule);
     153 [ -  + ][ -  + ]:        605 :   Assert(it != itr->second.end());
                 [ -  - ]
     154                 :        605 :   itr->second.erase(it);
     155                 :        605 : }
     156                 :            : 
     157                 :            : /**
     158                 :            :  * Purify SyGuS grammar node.
     159                 :            :  *
     160                 :            :  * This returns a node where all occurrences of non-terminal symbols (those in
     161                 :            :  * the domain of \p ntsToUnres) are replaced by fresh variables. For each
     162                 :            :  * variable replaced in this way, we add the fresh variable it is replaced with
     163                 :            :  * to \p args, and the unresolved types corresponding to the non-terminal symbol
     164                 :            :  * to \p cargs (constructor args). In other words, \p args contains the free
     165                 :            :  * variables in the node returned by this method (which should be bound by a
     166                 :            :  * lambda), and \p cargs contains the types of the arguments of the sygus
     167                 :            :  * constructor.
     168                 :            :  *
     169                 :            :  * @param n The node to purify.
     170                 :            :  * @param args The free variables in the node returned by this method.
     171                 :            :  * @param ntSymMap Map from each variable in args to the non-terminal they were
     172                 :            :  * introduced for.
     173                 :            :  * @param nts The list of non-terminal symbols
     174                 :            :  * @return The purfied node.
     175                 :            :  */
     176                 :      64335 : Node purifySygusGNode(const Node& n,
     177                 :            :                       std::vector<Node>& args,
     178                 :            :                       std::map<Node, Node>& ntSymMap,
     179                 :            :                       const std::vector<Node>& nts)
     180                 :            : {
     181                 :      64335 :   NodeManager* nm = NodeManager::currentNM();
     182                 :            :   // if n is non-terminal
     183         [ +  + ]:      64335 :   if (std::find(nts.begin(), nts.end(), n) != nts.end())
     184                 :            :   {
     185                 :      64544 :     Node ret = NodeManager::mkBoundVar(n.getType());
     186                 :      32272 :     ntSymMap[ret] = n;
     187                 :      32272 :     args.push_back(ret);
     188                 :      32272 :     return ret;
     189                 :            :   }
     190                 :      64126 :   std::vector<Node> pchildren;
     191                 :      32063 :   bool childChanged = false;
     192         [ +  + ]:      64984 :   for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
     193                 :            :   {
     194                 :      32921 :     Node ptermc = purifySygusGNode(n[i], args, ntSymMap, nts);
     195                 :      32921 :     pchildren.push_back(ptermc);
     196 [ +  + ][ +  + ]:      32921 :     childChanged = childChanged || ptermc != n[i];
         [ +  + ][ -  - ]
     197                 :            :   }
     198         [ +  + ]:      32063 :   if (!childChanged)
     199                 :            :   {
     200                 :      15473 :     return n;
     201                 :            :   }
     202                 :      33180 :   internal::Node nret;
     203         [ +  + ]:      16590 :   if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
     204                 :            :   {
     205                 :            :     // it's an indexed operator so we should provide the op
     206                 :        766 :     internal::NodeBuilder nb(NodeManager::currentNM(), n.getKind());
     207                 :        766 :     nb << n.getOperator();
     208                 :        766 :     nb.append(pchildren);
     209                 :        766 :     nret = nb.constructNode();
     210                 :            :   }
     211                 :            :   else
     212                 :            :   {
     213                 :      15824 :     nret = nm->mkNode(n.getKind(), pchildren);
     214                 :            :   }
     215                 :      16590 :   return nret;
     216                 :            : }
     217                 :            : 
     218                 :      31414 : bool isId(const Node& n)
     219                 :            : {
     220 [ +  + ][ -  - ]:      48060 :   return n.getKind() == Kind::LAMBDA && n[0].getNumChildren() == 1
     221                 :      79474 :          && n[0][0] == n[1];
     222                 :            : }
     223                 :            : 
     224                 :            : /**
     225                 :            :  * Add \p rule to the set of constructors of \p dt.
     226                 :            :  *
     227                 :            :  * @param dt The datatype to which the rule is added.
     228                 :            :  * @param rule The rule to add.
     229                 :            :  * @param ntsToUnres Mapping from non-terminals to their unresolved types.
     230                 :            :  */
     231                 :      31571 : void addSygusConstructor(DType& dt,
     232                 :            :                          const Node& rule,
     233                 :            :                          const std::vector<Node>& nts,
     234                 :            :                          const std::unordered_map<Node, TypeNode>& ntsToUnres)
     235                 :            : {
     236                 :      31571 :   NodeManager* nm = NodeManager::currentNM();
     237                 :      63142 :   std::stringstream ss;
     238                 :      31571 :   if (rule.getKind() == Kind::SKOLEM
     239 [ +  + ][ +  + ]:      31571 :       && rule.getInternalSkolemId() == InternalSkolemId::SYGUS_ANY_CONSTANT)
                 [ +  + ]
     240                 :            :   {
     241                 :        157 :     ss << dt.getName() << "_any_constant";
     242                 :        314 :     dt.addSygusConstructor(rule, ss.str(), {rule.getType()}, 0);
     243                 :            :   }
     244                 :            :   else
     245                 :            :   {
     246                 :      62828 :     std::vector<Node> args;
     247                 :      62828 :     std::map<Node, Node> ntSymMap;
     248                 :      62828 :     Node op = purifySygusGNode(rule, args, ntSymMap, nts);
     249                 :      31414 :     std::vector<TypeNode> cargs;
     250                 :      31414 :     std::unordered_map<Node, TypeNode>::const_iterator it;
     251         [ +  + ]:      63686 :     for (const Node& a : args)
     252                 :            :     {
     253 [ -  + ][ -  + ]:      32272 :       Assert(ntSymMap.find(a) != ntSymMap.end());
                 [ -  - ]
     254                 :      64544 :       Node na = ntSymMap[a];
     255                 :      32272 :       it = ntsToUnres.find(na);
     256 [ -  + ][ -  + ]:      32272 :       Assert(it != ntsToUnres.end());
                 [ -  - ]
     257                 :      32272 :       cargs.push_back(it->second);
     258                 :            :     }
     259                 :      31414 :     ss << op.getKind();
     260         [ +  + ]:      31414 :     if (!args.empty())
     261                 :            :     {
     262                 :      16646 :       Node lbvl = nm->mkNode(Kind::BOUND_VAR_LIST, args);
     263                 :      16646 :       op = nm->mkNode(Kind::LAMBDA, lbvl, op);
     264                 :            :     }
     265                 :            :     // assign identity rules a weight of 0.
     266         [ +  + ]:      31414 :     dt.addSygusConstructor(op, ss.str(), cargs, isId(op) ? 0 : -1);
     267                 :            :   }
     268                 :      31571 : }
     269                 :            : 
     270                 :          0 : Node SygusGrammar::getLambdaForRule(const Node& r,
     271                 :            :                                     std::map<Node, Node>& ntSymMap) const
     272                 :            : {
     273                 :          0 :   std::vector<Node> args;
     274                 :          0 :   Node rp = purifySygusGNode(r, args, ntSymMap, d_ntSyms);
     275         [ -  - ]:          0 :   if (!args.empty())
     276                 :            :   {
     277                 :          0 :     NodeManager* nm = NodeManager::currentNM();
     278                 :          0 :     return nm->mkNode(Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, args), rp);
     279                 :            :   }
     280                 :          0 :   return r;
     281                 :            : }
     282                 :            : 
     283                 :         58 : bool SygusGrammar::hasRules() const
     284                 :            : {
     285         [ +  + ]:         87 :   for (const auto& r : d_rules)
     286                 :            :   {
     287         [ +  + ]:         58 :     if (r.second.size() > 0)
     288                 :            :     {
     289                 :         29 :       return true;
     290                 :            :     }
     291                 :            :   }
     292                 :         29 :   return false;
     293                 :            : }
     294                 :            : 
     295                 :       3727 : TypeNode SygusGrammar::resolve(bool allowAny)
     296                 :            : {
     297         [ +  + ]:       3727 :   if (!isResolved())
     298                 :            :   {
     299                 :       3122 :     NodeManager* nm = NodeManager::currentNM();
     300                 :       6244 :     Node bvl;
     301         [ +  + ]:       3122 :     if (!d_sygusVars.empty())
     302                 :            :     {
     303                 :       1611 :       bvl = nm->mkNode(Kind::BOUND_VAR_LIST, d_sygusVars);
     304                 :            :     }
     305                 :       6244 :     std::unordered_map<Node, TypeNode> ntsToUnres;
     306         [ +  + ]:       9027 :     for (const Node& ntSym : d_ntSyms)
     307                 :            :     {
     308                 :            :       // make the unresolved type, used for referencing the final version of
     309                 :            :       // the ntSym's datatype
     310                 :       5905 :       ntsToUnres.emplace(ntSym, nm->mkUnresolvedDatatypeSort(ntSym.getName()));
     311                 :            :     }
     312                 :            :     // Set of non-terminals that can be arbitrary constants.
     313                 :       6244 :     std::unordered_set<Node> allowConsts;
     314                 :            :     // push the rules into the sygus datatypes
     315                 :       3122 :     std::vector<DType> dts;
     316         [ +  + ]:       9027 :     for (const Node& ntSym : d_ntSyms)
     317                 :            :     {
     318                 :            :       // make the datatype, which encodes terms generated by this non-terminal
     319                 :      11810 :       DType dt(ntSym.getName());
     320                 :            : 
     321         [ +  + ]:      37476 :       for (const Node& rule : d_rules[ntSym])
     322                 :            :       {
     323                 :      31571 :         if (rule.getKind() == Kind::SKOLEM
     324 [ +  + ][ +  + ]:      31571 :             && rule.getInternalSkolemId() == InternalSkolemId::SYGUS_ANY_CONSTANT)
                 [ +  + ]
     325                 :            :         {
     326                 :        157 :           allowConsts.insert(ntSym);
     327                 :            :         }
     328                 :      31571 :         addSygusConstructor(dt, rule, d_ntSyms, ntsToUnres);
     329                 :            :       }
     330                 :       5905 :       bool allowConst = allowConsts.find(ntSym) != allowConsts.end();
     331 [ +  + ][ +  + ]:       5905 :       dt.setSygus(ntSym.getType(), bvl, allowConst || allowAny, allowAny);
     332                 :            :       // We can be in a case where the only rule specified was (Variable T)
     333                 :            :       // and there are no variables of type T, in which case this is a bogus
     334                 :            :       // grammar. This results in the error below.
     335 [ -  + ][ -  + ]:       5905 :       Assert(dt.getNumConstructors() != 0) << "Grouped rule listing for " << dt
                 [ -  - ]
     336                 :          0 :                                            << " produced an empty rule list";
     337                 :       5905 :       dts.push_back(dt);
     338                 :            :     }
     339                 :       3122 :     d_datatype = nm->mkMutualDatatypeTypes(dts)[0];
     340                 :            :   }
     341                 :            :   // return the first datatype
     342                 :       3727 :   return d_datatype;
     343                 :            : }
     344                 :            : 
     345                 :      12143 : bool SygusGrammar::isResolved() { return !d_datatype.isNull(); }
     346                 :            : 
     347                 :          0 : const std::vector<Node>& SygusGrammar::getSygusVars() const
     348                 :            : {
     349                 :          0 :   return d_sygusVars;
     350                 :            : }
     351                 :            : 
     352                 :      13677 : const std::vector<Node>& SygusGrammar::getNtSyms() const { return d_ntSyms; }
     353                 :            : 
     354                 :      13683 : const std::vector<Node>& SygusGrammar::getRulesFor(const Node& ntSym) const
     355                 :            : {
     356                 :            :   std::unordered_map<Node, std::vector<Node>>::const_iterator itr =
     357                 :      13683 :       d_rules.find(ntSym);
     358 [ -  + ][ -  + ]:      13683 :   Assert(itr != d_rules.end());
                 [ -  - ]
     359                 :      13683 :   return itr->second;
     360                 :            : }
     361                 :            : 
     362                 :         29 : std::string SygusGrammar::toString() const
     363                 :            : {
     364                 :         29 :   std::stringstream ss;
     365                 :            :   // clone this grammar before printing it to avoid freezing it.
     366                 :            :   return printer::smt2::Smt2Printer::sygusGrammarString(
     367                 :         87 :       SygusGrammar(*this).resolve());
     368                 :            : }
     369                 :            : 
     370                 :            : }  // namespace cvc5::internal
     371                 :            : 
     372                 :            : namespace std {
     373                 :       2376 : size_t hash<cvc5::internal::SygusGrammar>::operator()(
     374                 :            :     const cvc5::internal::SygusGrammar& grammar) const
     375                 :            : {
     376                 :       2376 :   uint64_t ret = cvc5::internal::fnv1a::offsetBasis;
     377         [ +  + ]:       3385 :   for (const auto& v : grammar.d_sygusVars)
     378                 :            :   {
     379                 :       1009 :     ret = cvc5::internal::fnv1a::fnv1a_64(ret,
     380                 :       1009 :                                           std::hash<cvc5::internal::Node>{}(v));
     381                 :            :   }
     382         [ +  + ]:       4752 :   for (const auto& nts : grammar.d_ntSyms)
     383                 :            :   {
     384                 :       2376 :     ret = cvc5::internal::fnv1a::fnv1a_64(
     385                 :       2376 :         ret, std::hash<cvc5::internal::Node>{}(nts));
     386                 :            :   }
     387         [ +  + ]:       4752 :   for (const auto& r : grammar.d_rules)
     388                 :            :   {
     389                 :       2376 :     uint64_t rhash = cvc5::internal::fnv1a::offsetBasis;
     390         [ +  + ]:       2406 :     for (const auto& n : r.second)
     391                 :            :     {
     392                 :         30 :       rhash = cvc5::internal::fnv1a::fnv1a_64(
     393                 :         30 :           rhash, std::hash<cvc5::internal::Node>{}(n));
     394                 :            :     }
     395                 :       2376 :     rhash = cvc5::internal::fnv1a::fnv1a_64(
     396                 :       2376 :         rhash, std::hash<cvc5::internal::Node>{}(r.first));
     397                 :       2376 :     ret = cvc5::internal::fnv1a::fnv1a_64(ret, rhash);
     398                 :            :   }
     399                 :       2376 :   return cvc5::internal::fnv1a::fnv1a_64(
     400                 :       4752 :       ret, std::hash<cvc5::internal::TypeNode>{}(grammar.d_datatype));
     401                 :            : }
     402                 :            : }  // namespace std

Generated by: LCOV version 1.14