LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/quantifiers/sygus - embedding_converter.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 218 219 99.5 %
Date: 2024-11-22 12:41:49 Functions: 6 6 100.0 %
Branches: 131 206 63.6 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Andrew Reynolds, Aina Niemetz, Mathias Preiner
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Class for applying the deep embedding for SyGuS
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/quantifiers/sygus/embedding_converter.h"
      17                 :            : 
      18                 :            : #include "options/base_options.h"
      19                 :            : #include "options/quantifiers_options.h"
      20                 :            : #include "printer/smt2/smt2_printer.h"
      21                 :            : #include "theory/datatypes/sygus_datatype_utils.h"
      22                 :            : #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
      23                 :            : #include "theory/quantifiers/sygus/sygus_grammar_norm.h"
      24                 :            : #include "theory/quantifiers/sygus/sygus_utils.h"
      25                 :            : #include "theory/quantifiers/sygus/synth_conjecture.h"
      26                 :            : #include "theory/quantifiers/sygus/term_database_sygus.h"
      27                 :            : #include "util/rational.h"
      28                 :            : 
      29                 :            : using namespace cvc5::internal::kind;
      30                 :            : 
      31                 :            : namespace cvc5::internal {
      32                 :            : namespace theory {
      33                 :            : namespace quantifiers {
      34                 :            : 
      35                 :       7264 : EmbeddingConverter::EmbeddingConverter(Env& env,
      36                 :            :                                        TermDbSygus* tds,
      37                 :       7264 :                                        SynthConjecture* p)
      38                 :       7264 :     : EnvObj(env), d_tds(tds), d_parent(p), d_is_syntax_restricted(false)
      39                 :            : {
      40                 :       7264 : }
      41                 :            : 
      42                 :        296 : bool EmbeddingConverter::hasSyntaxRestrictions(Node q)
      43                 :            : {
      44 [ -  + ][ -  + ]:        296 :   Assert(q.getKind() == Kind::FORALL);
                 [ -  - ]
      45         [ +  + ]:        442 :   for (const Node& f : q[0])
      46                 :            :   {
      47                 :        296 :     TypeNode tn = SygusUtils::getSygusType(f);
      48 [ +  + ][ +  - ]:        296 :     if (tn.isDatatype() && tn.getDType().isSygus())
                 [ +  + ]
      49                 :            :     {
      50                 :        150 :       return true;
      51                 :            :     }
      52                 :            :   }
      53                 :        146 :   return false;
      54                 :            : }
      55                 :            : 
      56                 :       1746 : void EmbeddingConverter::collectTerms(
      57                 :            :     Node n, std::map<TypeNode, std::unordered_set<Node>>& consts)
      58                 :            : {
      59                 :       1746 :   NodeManager* nm = nodeManager();
      60                 :       3492 :   std::unordered_map<TNode, bool> visited;
      61                 :       1746 :   std::unordered_map<TNode, bool>::iterator it;
      62                 :       3492 :   std::vector<TNode> visit;
      63                 :       3492 :   TNode cur;
      64                 :       1746 :   visit.push_back(n);
      65                 :      39150 :   do
      66                 :            :   {
      67                 :      40896 :     cur = visit.back();
      68                 :      40896 :     visit.pop_back();
      69                 :      40896 :     it = visited.find(cur);
      70         [ +  + ]:      40896 :     if (it == visited.end())
      71                 :            :     {
      72                 :      25724 :       visited[cur] = true;
      73                 :            :       // is this a constant?
      74         [ +  + ]:      25724 :       if (cur.isConst())
      75                 :            :       {
      76                 :       7584 :         TypeNode tn = cur.getType();
      77                 :       7584 :         Node c = cur;
      78         [ +  + ]:       3792 :         if (tn.isRealOrInt())
      79                 :            :         {
      80                 :       1910 :           c = nm->mkConstRealOrInt(tn, c.getConst<Rational>().abs());
      81                 :            :         }
      82                 :       3792 :         consts[tn].insert(c);
      83         [ +  + ]:       3792 :         if (tn.isInteger())
      84                 :            :         {
      85                 :       1880 :           c = nm->mkConstReal(c.getConst<Rational>().abs());
      86                 :       3760 :           TypeNode rtype = nm->realType();
      87                 :       1880 :           consts[rtype].insert(c);
      88                 :            :         }
      89                 :            :       }
      90                 :            :       // recurse
      91                 :      25724 :       visit.insert(visit.end(), cur.begin(), cur.end());
      92                 :            :     }
      93         [ +  + ]:      40896 :   } while (!visit.empty());
      94                 :       1746 : }
      95                 :            : 
      96                 :       1761 : Node EmbeddingConverter::process(Node q,
      97                 :            :                                  const std::map<Node, Node>& templates,
      98                 :            :                                  const std::map<Node, Node>& templates_arg)
      99                 :            : {
     100                 :            :   // convert to deep embedding and finalize single invocation here
     101                 :            :   // now, construct the grammar
     102         [ +  - ]:       3522 :   Trace("cegqi") << "SynthConjecture : convert to deep embedding..."
     103                 :       1761 :                  << std::endl;
     104                 :       3522 :   std::map<TypeNode, std::unordered_set<Node>> extra_cons;
     105                 :       1761 :   if (options().quantifiers.sygusAddConstGrammar
     106 [ +  + ][ +  + ]:       1761 :       && options().quantifiers.sygusGrammarConsMode
                 [ +  + ]
     107                 :            :              == options::SygusGrammarConsMode::SIMPLE)
     108                 :            :   {
     109         [ +  - ]:       1746 :     Trace("cegqi") << "SynthConjecture : collect constants..." << std::endl;
     110                 :       1746 :     collectTerms(q[1], extra_cons);
     111                 :            :   }
     112                 :       3522 :   std::map<TypeNode, std::unordered_set<Node>> exc_cons;
     113                 :       3522 :   std::map<TypeNode, std::unordered_set<Node>> inc_cons;
     114                 :            : 
     115                 :       1761 :   NodeManager* nm = nodeManager();
     116                 :            : 
     117                 :       1761 :   std::vector<Node> ebvl;
     118         [ +  + ]:       3856 :   for (unsigned i = 0; i < q[0].getNumChildren(); i++)
     119                 :            :   {
     120                 :       4190 :     Node sf = q[0][i];
     121                 :            :     // if non-null, v encodes the syntactic restrictions (via an inductive
     122                 :            :     // datatype) on sf from the input.
     123                 :       4190 :     TypeNode preGrammarType = SygusUtils::getSygusType(sf);
     124         [ +  + ]:       2095 :     if (preGrammarType.isNull())
     125                 :            :     {
     126                 :            :       // otherwise, the grammar is the default for the range of the function
     127                 :       1107 :       preGrammarType = sf.getType();
     128         [ +  + ]:       1107 :       if (preGrammarType.isFunction())
     129                 :            :       {
     130                 :        829 :         preGrammarType = preGrammarType.getRangeType();
     131                 :            :       }
     132                 :            :     }
     133                 :            : 
     134                 :            :     // the actual sygus datatype we will use (normalized below)
     135                 :       4190 :     TypeNode tn;
     136                 :       4190 :     Node sfvl;
     137 [ +  + ][ +  + ]:       2095 :     if (preGrammarType.isDatatype() && preGrammarType.getDType().isSygus())
                 [ +  + ]
     138                 :            :     {
     139                 :        988 :       sfvl = preGrammarType.getDType().getSygusVarList();
     140                 :        988 :       tn = preGrammarType;
     141                 :            :       // normalize type, if user-provided
     142                 :        988 :       SygusGrammarNorm sygus_norm(d_env, d_tds);
     143                 :        988 :       tn = sygus_norm.normalizeSygusType(tn, sfvl);
     144                 :            :     }
     145                 :            :     else
     146                 :            :     {
     147                 :       1107 :       sfvl = SygusUtils::getOrMkSygusArgumentList(sf);
     148                 :            :       // check which arguments are irrelevant
     149                 :       2214 :       std::unordered_set<unsigned> arg_irrelevant;
     150                 :       1107 :       d_parent->getProcess()->getIrrelevantArgs(sf, arg_irrelevant);
     151                 :       1107 :       std::vector<Node> trules;
     152                 :            :       // add the variables from the free variable list that we did not
     153                 :            :       // infer were irrelevant.
     154         [ +  + ]:       3028 :       for (size_t j = 0, nargs = sfvl.getNumChildren(); j < nargs; j++)
     155                 :            :       {
     156         [ +  + ]:       1921 :         if (arg_irrelevant.find(j) == arg_irrelevant.end())
     157                 :            :         {
     158                 :       1853 :           trules.push_back(sfvl[j]);
     159                 :            :         }
     160                 :            :       }
     161                 :            :       // add the constants computed avove
     162                 :       1862 :       for (const std::pair<const TypeNode, std::unordered_set<Node>>& c :
     163         [ +  + ]:       2969 :            extra_cons)
     164                 :            :       {
     165                 :       1862 :         trules.insert(trules.end(), c.second.begin(), c.second.end());
     166                 :            :       }
     167                 :       1107 :       tn = SygusGrammarCons::mkDefaultSygusType(
     168                 :       1107 :           d_env, preGrammarType, sfvl, trules);
     169                 :            :     }
     170                 :            :     // Ensure the expanded definition forms are set. This is done after
     171                 :            :     // normalization above.
     172                 :       2095 :     datatypes::utils::computeExpandedDefinitionForms(d_env, tn);
     173                 :            :     // print the grammar
     174         [ +  + ]:       2095 :     if (isOutputOn(OutputTag::SYGUS_GRAMMAR))
     175                 :            :     {
     176                 :          2 :       output(OutputTag::SYGUS_GRAMMAR)
     177                 :          4 :           << "(sygus-grammar " << sf << " "
     178                 :          4 :           << printer::smt2::Smt2Printer::sygusGrammarString(tn) << ")"
     179                 :          2 :           << std::endl;
     180                 :            :     }
     181                 :            :     // sfvl may be null for constant synthesis functions
     182         [ +  - ]:       4190 :     Trace("cegqi-debug") << "...sygus var list associated with " << sf << " is "
     183                 :       2095 :                          << sfvl << std::endl;
     184                 :            : 
     185                 :       2095 :     std::map<Node, Node>::const_iterator itt = templates.find(sf);
     186         [ +  + ]:       2095 :     if (itt != templates.end())
     187                 :            :     {
     188                 :         88 :       Node templ = itt->second;
     189                 :         44 :       std::map<Node, Node>::const_iterator itta = templates_arg.find(sf);
     190 [ -  + ][ -  + ]:         44 :       Assert(itta != templates_arg.end());
                 [ -  - ]
     191                 :         44 :       TNode templ_arg = itta->second;
     192 [ -  + ][ -  + ]:         44 :       Assert(!templ_arg.isNull());
                 [ -  - ]
     193                 :            :     }
     194                 :            : 
     195                 :            :     // ev is the first-order variable corresponding to this synth fun
     196                 :       4190 :     Node ev = nm->mkBoundVar("f" + sf.getName(), tn);
     197                 :       2095 :     ebvl.push_back(ev);
     198         [ +  - ]:       4190 :     Trace("cegqi") << "...embedding synth fun : " << sf << " -> " << ev
     199                 :       2095 :                    << std::endl;
     200                 :            :   }
     201                 :       3522 :   return process(q, templates, templates_arg, ebvl);
     202                 :            : }
     203                 :            : 
     204                 :       1761 : Node EmbeddingConverter::process(Node q,
     205                 :            :                                  const std::map<Node, Node>& templates,
     206                 :            :                                  const std::map<Node, Node>& templates_arg,
     207                 :            :                                  const std::vector<Node>& ebvl)
     208                 :            : {
     209 [ -  + ][ -  + ]:       1761 :   Assert(q[0].getNumChildren() == ebvl.size());
                 [ -  - ]
     210 [ -  + ][ -  + ]:       1761 :   Assert(d_synth_fun_vars.empty());
                 [ -  - ]
     211                 :            : 
     212                 :       1761 :   NodeManager* nm = nodeManager();
     213                 :            : 
     214                 :       3522 :   std::vector<Node> qchildren;
     215                 :       3522 :   Node qbody_subs = q[1];
     216         [ +  + ]:       3856 :   for (unsigned i = 0, size = q[0].getNumChildren(); i < size; i++)
     217                 :            :   {
     218                 :       4190 :     Node sf = q[0][i];
     219                 :       2095 :     d_synth_fun_vars[sf] = ebvl[i];
     220                 :       4190 :     Node sfvl = SygusUtils::getOrMkSygusArgumentList(sf);
     221                 :       4190 :     TypeNode tn = ebvl[i].getType();
     222                 :            :     // check if there is a template
     223                 :       2095 :     std::map<Node, Node>::const_iterator itt = templates.find(sf);
     224         [ +  + ]:       2095 :     if (itt != templates.end())
     225                 :            :     {
     226                 :         88 :       Node templ = itt->second;
     227                 :         44 :       std::map<Node, Node>::const_iterator itta = templates_arg.find(sf);
     228 [ -  + ][ -  + ]:         44 :       Assert(itta != templates_arg.end());
                 [ -  - ]
     229                 :         88 :       TNode templ_arg = itta->second;
     230 [ -  + ][ -  + ]:         44 :       Assert(!templ_arg.isNull());
                 [ -  - ]
     231                 :            :       // if there is a template for this argument, make a sygus type on top of
     232                 :            :       // it
     233                 :            :       // otherwise, apply it as a preprocessing pass
     234         [ +  - ]:         88 :       Trace("cegqi-debug") << "Template for " << sf << " is : " << templ
     235                 :         44 :                            << " with arg " << templ_arg << std::endl;
     236         [ +  - ]:         88 :       Trace("cegqi-debug")
     237                 :          0 :           << "  apply this template as a substitution during preprocess..."
     238                 :         44 :           << std::endl;
     239                 :         88 :       std::vector<Node> schildren;
     240                 :         88 :       std::vector<Node> largs;
     241         [ +  + ]:        261 :       for (unsigned j = 0; j < sfvl.getNumChildren(); j++)
     242                 :            :       {
     243                 :        217 :         schildren.push_back(sfvl[j]);
     244                 :        217 :         largs.push_back(nm->mkBoundVar(sfvl[j].getType()));
     245                 :            :       }
     246                 :         88 :       std::vector<Node> subsfn_children;
     247                 :         44 :       subsfn_children.push_back(sf);
     248                 :            :       subsfn_children.insert(
     249                 :         44 :           subsfn_children.end(), schildren.begin(), schildren.end());
     250                 :         88 :       Node subsfn = nm->mkNode(Kind::APPLY_UF, subsfn_children);
     251                 :         88 :       TNode subsf = subsfn;
     252         [ +  - ]:         88 :       Trace("cegqi-debug") << "  substitute arg : " << templ_arg << " -> "
     253                 :         44 :                            << subsf << std::endl;
     254                 :         44 :       templ = templ.substitute(templ_arg, subsf);
     255                 :            :       // substitute lambda arguments
     256                 :         88 :       templ = templ.substitute(
     257                 :         44 :           schildren.begin(), schildren.end(), largs.begin(), largs.end());
     258                 :            :       Node subsn = nm->mkNode(
     259                 :        132 :           Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, largs), templ);
     260                 :         88 :       TNode var = sf;
     261                 :         44 :       TNode subs = subsn;
     262         [ +  - ]:         88 :       Trace("cegqi-debug") << "  substitute : " << var << " -> " << subs
     263                 :         44 :                            << std::endl;
     264                 :         44 :       qbody_subs = qbody_subs.substitute(var, subs);
     265         [ +  - ]:         44 :       Trace("cegqi-debug") << "  body is now : " << qbody_subs << std::endl;
     266                 :            :     }
     267                 :       2095 :     d_tds->registerSygusType(tn);
     268 [ -  + ][ -  + ]:       2095 :     Assert(tn.isDatatype());
                 [ -  - ]
     269                 :       2095 :     const DType& dt = tn.getDType();
     270 [ -  + ][ -  + ]:       2095 :     Assert(dt.isSygus());
                 [ -  - ]
     271         [ +  + ]:       2095 :     if (!dt.getSygusAllowAll())
     272                 :            :     {
     273                 :        986 :       d_is_syntax_restricted = true;
     274                 :            :     }
     275                 :            :   }
     276                 :       1761 :   qchildren.push_back(nm->mkNode(Kind::BOUND_VAR_LIST, ebvl));
     277         [ +  + ]:       1761 :   if (qbody_subs != q[1])
     278                 :            :   {
     279         [ +  - ]:         44 :     Trace("cegqi") << "...rewriting : " << qbody_subs << std::endl;
     280                 :         44 :     qbody_subs = rewrite(qbody_subs);
     281         [ +  - ]:         44 :     Trace("cegqi") << "...got : " << qbody_subs << std::endl;
     282                 :            :   }
     283                 :       1761 :   qchildren.push_back(convertToEmbedding(qbody_subs));
     284         [ +  - ]:       1761 :   if (q.getNumChildren() == 3)
     285                 :            :   {
     286                 :       1761 :     qchildren.push_back(q[2]);
     287                 :            :   }
     288                 :       3522 :   return nm->mkNode(Kind::FORALL, qchildren);
     289                 :            : }
     290                 :            : 
     291                 :       2354 : Node EmbeddingConverter::convertToEmbedding(Node n)
     292                 :            : {
     293                 :       2354 :   NodeManager* nm = nodeManager();
     294                 :       4708 :   std::unordered_map<TNode, Node> visited;
     295                 :       2354 :   std::unordered_map<TNode, Node>::iterator it;
     296                 :       4708 :   std::vector<TNode> visit;
     297                 :       2354 :   TNode cur;
     298                 :       2354 :   visit.push_back(n);
     299                 :      78942 :   do
     300                 :            :   {
     301                 :      81296 :     cur = visit.back();
     302                 :      81296 :     visit.pop_back();
     303                 :      81296 :     it = visited.find(cur);
     304         [ +  + ]:      81296 :     if (it == visited.end())
     305                 :            :     {
     306                 :      31656 :       visited[cur] = Node::null();
     307                 :      31656 :       visit.push_back(cur);
     308                 :      31656 :       visit.insert(visit.end(), cur.begin(), cur.end());
     309                 :            :     }
     310         [ +  + ]:      49640 :     else if (it->second.isNull())
     311                 :            :     {
     312                 :      63312 :       Node ret = cur;
     313                 :      31656 :       Kind ret_k = cur.getKind();
     314                 :      63312 :       Node op;
     315                 :      31656 :       bool childChanged = false;
     316                 :      63312 :       std::vector<Node> children;
     317                 :            :       // get the potential operator
     318         [ +  + ]:      31656 :       if (cur.getNumChildren() > 0)
     319                 :            :       {
     320         [ +  + ]:      22846 :         if (cur.getKind() == Kind::APPLY_UF)
     321                 :            :         {
     322                 :       2718 :           op = cur.getOperator();
     323                 :            :         }
     324                 :            :       }
     325                 :            :       else
     326                 :            :       {
     327                 :       8810 :         op = cur;
     328                 :            :       }
     329                 :            :       // is the operator a synth function?
     330                 :      31656 :       bool makeEvalFun = false;
     331         [ +  + ]:      31656 :       if (!op.isNull())
     332                 :            :       {
     333                 :      11528 :         std::map<Node, Node>::iterator its = d_synth_fun_vars.find(op);
     334         [ +  + ]:      11528 :         if (its != d_synth_fun_vars.end())
     335                 :            :         {
     336                 :       3180 :           children.push_back(its->second);
     337                 :       3180 :           makeEvalFun = true;
     338                 :            :         }
     339                 :            :       }
     340         [ +  + ]:      31656 :       if (!makeEvalFun)
     341                 :            :       {
     342                 :            :         // otherwise, we apply the previous operator
     343         [ +  + ]:      28476 :         if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
     344                 :            :         {
     345                 :        439 :           children.push_back(cur.getOperator());
     346                 :            :         }
     347                 :            :       }
     348         [ +  + ]:      78942 :       for (unsigned i = 0; i < cur.getNumChildren(); i++)
     349                 :            :       {
     350                 :      47286 :         it = visited.find(cur[i]);
     351 [ -  + ][ -  + ]:      47286 :         Assert(it != visited.end());
                 [ -  - ]
     352 [ -  + ][ -  + ]:      47286 :         Assert(!it->second.isNull());
                 [ -  - ]
     353 [ +  + ][ +  + ]:      47286 :         childChanged = childChanged || cur[i] != it->second;
         [ +  + ][ -  - ]
     354                 :      47286 :         children.push_back(it->second);
     355                 :            :       }
     356         [ +  + ]:      31656 :       if (makeEvalFun)
     357                 :            :       {
     358         [ +  + ]:       3180 :         if (!cur.getType().isFunction())
     359                 :            :         {
     360                 :            :           // will make into an application of an evaluation function
     361                 :       3170 :           ret = nm->mkNode(Kind::DT_SYGUS_EVAL, children);
     362                 :            :         }
     363                 :            :         else
     364                 :            :         {
     365 [ -  + ][ -  + ]:         10 :           Assert(children.size() == 1);
                 [ -  - ]
     366                 :         20 :           Node ef = children[0];
     367                 :            :           // Otherwise, we are using the function-to-synthesize itself in a
     368                 :            :           // higher-order setting. We must return the lambda term:
     369                 :            :           //   lambda x1...xn. (DT_SYGUS_EVAL ef x1 ... xn)
     370                 :            :           // where ef is the first order variable for the
     371                 :            :           // function-to-synthesize.
     372                 :         10 :           SygusTypeInfo& ti = d_tds->getTypeInfo(ef.getType());
     373                 :         10 :           const std::vector<Node>& vars = ti.getVarList();
     374 [ -  + ][ -  + ]:         10 :           Assert(!vars.empty());
                 [ -  - ]
     375                 :         20 :           std::vector<Node> vs;
     376         [ +  + ]:         24 :           for (const Node& v : vars)
     377                 :            :           {
     378                 :         14 :             vs.push_back(nm->mkBoundVar(v.getType()));
     379                 :            :           }
     380                 :         20 :           Node lvl = nm->mkNode(Kind::BOUND_VAR_LIST, vs);
     381                 :         10 :           std::vector<Node> eargs;
     382                 :         10 :           eargs.push_back(ef);
     383                 :         10 :           eargs.insert(eargs.end(), vs.begin(), vs.end());
     384                 :         20 :           ret = nm->mkNode(
     385                 :         30 :               Kind::LAMBDA, lvl, nm->mkNode(Kind::DT_SYGUS_EVAL, eargs));
     386                 :            :         }
     387                 :            :       }
     388         [ +  + ]:      28476 :       else if (childChanged)
     389                 :            :       {
     390                 :      10445 :         ret = nm->mkNode(ret_k, children);
     391                 :            :       }
     392                 :      31656 :       visited[cur] = ret;
     393                 :            :     }
     394         [ +  + ]:      81296 :   } while (!visit.empty());
     395 [ -  + ][ -  + ]:       2354 :   Assert(visited.find(n) != visited.end());
                 [ -  - ]
     396 [ -  + ][ -  + ]:       2354 :   Assert(!visited.find(n)->second.isNull());
                 [ -  - ]
     397                 :       4708 :   return visited[n];
     398                 :            : }
     399                 :            : 
     400                 :            : }  // namespace quantifiers
     401                 :            : }  // namespace theory
     402                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14