LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/preprocessing/passes - sygus_inference.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 164 190 86.3 %
Date: 2026-04-30 10:45:04 Functions: 3 3 100.0 %
Branches: 98 160 61.2 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * This file is part of the cvc5 project.
       3                 :            :  *
       4                 :            :  * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
       5                 :            :  * in the top-level source directory and their institutional affiliations.
       6                 :            :  * All rights reserved.  See the file COPYING in the top-level source
       7                 :            :  * directory for licensing information.
       8                 :            :  * ****************************************************************************
       9                 :            :  *
      10                 :            :  * Sygus inference module.
      11                 :            :  */
      12                 :            : 
      13                 :            : #include "preprocessing/passes/sygus_inference.h"
      14                 :            : 
      15                 :            : #include "options/quantifiers_options.h"
      16                 :            : #include "preprocessing/assertion_pipeline.h"
      17                 :            : #include "preprocessing/preprocessing_pass_context.h"
      18                 :            : #include "smt/logic_exception.h"
      19                 :            : #include "smt/solver_engine.h"
      20                 :            : #include "theory/quantifiers/quantifiers_attributes.h"
      21                 :            : #include "theory/quantifiers/quantifiers_preprocess.h"
      22                 :            : #include "theory/quantifiers/sygus/sygus_utils.h"
      23                 :            : #include "theory/rewriter.h"
      24                 :            : #include "theory/smt_engine_subsolver.h"
      25                 :            : 
      26                 :            : using namespace std;
      27                 :            : using namespace cvc5::internal::kind;
      28                 :            : using namespace cvc5::internal::theory;
      29                 :            : 
      30                 :            : namespace cvc5::internal {
      31                 :            : namespace preprocessing {
      32                 :            : namespace passes {
      33                 :            : 
      34                 :      51756 : SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
      35                 :      51756 :     : PreprocessingPass(preprocContext, "sygus-infer") {};
      36                 :            : 
      37                 :         61 : PreprocessingPassResult SygusInference::applyInternal(
      38                 :            :     AssertionPipeline* assertionsToPreprocess)
      39                 :            : {
      40         [ +  - ]:         61 :   Trace("sygus-infer") << "Run sygus inference..." << std::endl;
      41                 :         61 :   std::vector<Node> funs;
      42                 :         61 :   std::vector<Node> sols;
      43                 :            :   // see if we can successfully solve the input as a sygus problem
      44         [ +  + ]:         61 :   if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
      45                 :            :   {
      46         [ +  - ]:         53 :     Trace("sygus-infer") << "...Solved:" << std::endl;
      47 [ -  + ][ -  + ]:         53 :     Assert(funs.size() == sols.size());
                 [ -  - ]
      48                 :            :     // if so, sygus gives us function definitions, which we add as substitutions
      49         [ +  + ]:        164 :     for (unsigned i = 0, size = funs.size(); i < size; i++)
      50                 :            :     {
      51         [ +  - ]:        111 :       Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
      52                 :        111 :       d_preprocContext->addSubstitution(funs[i], sols[i]);
      53                 :            :     }
      54                 :            : 
      55                 :            :     // apply substitution to everything, should result in SAT
      56         [ +  + ]:        205 :     for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
      57                 :            :          i++)
      58                 :            :     {
      59                 :        152 :       Node prev = (*assertionsToPreprocess)[i];
      60                 :            :       Node curr =
      61                 :        152 :           prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
      62         [ +  + ]:        152 :       if (curr != prev)
      63                 :            :       {
      64                 :         99 :         curr = rewrite(curr);
      65         [ +  - ]:        198 :         Trace("sygus-infer-debug")
      66                 :         99 :             << "...rewrote " << prev << " to " << curr << std::endl;
      67                 :         99 :         assertionsToPreprocess->replace(i, curr);
      68                 :            :       }
      69                 :        152 :     }
      70                 :            :   }
      71                 :          8 :   else if (options().quantifiers.sygusInference
      72         [ -  + ]:          8 :            == options::SygusInferenceMode::ON)
      73                 :            :   {
      74                 :          0 :     std::stringstream ss;
      75                 :          0 :     ss << "Cannot translate input to sygus for --sygus-inference";
      76                 :          0 :     throw LogicException(ss.str());
      77                 :          0 :   }
      78                 :         61 :   return PreprocessingPassResult::NO_CONFLICT;
      79                 :         61 : }
      80                 :            : 
      81                 :         61 : bool SygusInference::solveSygus(const std::vector<Node>& assertions,
      82                 :            :                                 std::vector<Node>& funs,
      83                 :            :                                 std::vector<Node>& sols)
      84                 :            : {
      85         [ -  + ]:         61 :   if (assertions.empty())
      86                 :            :   {
      87         [ -  - ]:          0 :     Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
      88         [ -  - ]:          0 :     Warning() << "Cannot convert to sygus since there are no assertions."
      89                 :          0 :               << std::endl;
      90                 :          0 :     return false;
      91                 :            :   }
      92                 :            : 
      93                 :         61 :   NodeManager* nm = nodeManager();
      94                 :            : 
      95                 :            :   // collect free variables in all assertions
      96                 :         61 :   std::vector<Node> qvars;
      97                 :         61 :   std::map<TypeNode, std::vector<Node> > qtvars;
      98                 :         61 :   std::vector<Node> free_functions;
      99                 :            : 
     100                 :         61 :   std::vector<TNode> visit;
     101                 :         61 :   std::unordered_set<TNode> visited;
     102                 :            : 
     103                 :            :   // add top-level conjuncts to eassertions
     104                 :         61 :   std::vector<Node> assertions_proc = assertions;
     105                 :         61 :   std::vector<Node> eassertions;
     106                 :         61 :   unsigned index = 0;
     107         [ +  + ]:        241 :   while (index < assertions_proc.size())
     108                 :            :   {
     109                 :        180 :     Node ca = assertions_proc[index];
     110         [ +  + ]:        180 :     if (ca.getKind() == Kind::AND)
     111                 :            :     {
     112         [ +  + ]:          6 :       for (const Node& ai : ca)
     113                 :            :       {
     114                 :          4 :         assertions_proc.push_back(ai);
     115                 :          4 :       }
     116                 :            :     }
     117                 :            :     else
     118                 :            :     {
     119                 :        178 :       eassertions.push_back(ca);
     120                 :            :     }
     121                 :        180 :     index++;
     122                 :        180 :   }
     123                 :            : 
     124                 :            :   // process eassertions
     125                 :         61 :   std::vector<Node> processed_assertions;
     126                 :         61 :   quantifiers::QuantifiersPreprocess qp(d_env);
     127         [ +  + ]:        237 :   for (const Node& as : eassertions)
     128                 :            :   {
     129                 :            :     // substitution for this assertion
     130                 :        177 :     std::vector<Node> vars;
     131                 :        177 :     std::vector<Node> subs;
     132                 :        177 :     std::map<TypeNode, unsigned> type_count;
     133                 :        177 :     Node pas = as;
     134                 :            :     // rewrite
     135                 :        177 :     pas = rewrite(pas);
     136         [ +  - ]:        177 :     Trace("sygus-infer") << "assertion : " << pas << std::endl;
     137         [ +  + ]:        177 :     if (pas.getKind() == Kind::FORALL)
     138                 :            :     {
     139                 :            :       // preprocess the quantified formula
     140                 :         15 :       TrustNode trn = qp.preprocess(pas);
     141         [ -  + ]:         15 :       if (!trn.isNull())
     142                 :            :       {
     143                 :          0 :         pas = trn.getNode();
     144                 :            :       }
     145         [ +  - ]:         15 :       Trace("sygus-infer-debug") << "  ...preprocessed to " << pas << std::endl;
     146                 :         15 :     }
     147         [ +  + ]:        177 :     if (pas.getKind() == Kind::FORALL)
     148                 :            :     {
     149                 :            :       // it must be a standard quantifier
     150                 :         15 :       theory::quantifiers::QAttributes qa;
     151                 :         15 :       theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
     152         [ -  + ]:         15 :       if (!qa.isStandard())
     153                 :            :       {
     154         [ -  - ]:          0 :         Trace("sygus-infer")
     155                 :          0 :             << "...fail: non-standard top-level quantifier." << std::endl;
     156         [ -  - ]:          0 :         Warning() << "Cannot convert to sygus since there is a non-standard "
     157                 :          0 :                      "top-level quantified formula: "
     158                 :          0 :                   << pas << std::endl;
     159                 :          0 :         return false;
     160                 :            :       }
     161                 :            :       // infer prefix
     162         [ +  + ]:         37 :       for (const Node& v : pas[0])
     163                 :            :       {
     164                 :         22 :         TypeNode tnv = v.getType();
     165                 :         22 :         unsigned vnum = type_count[tnv];
     166                 :         22 :         type_count[tnv]++;
     167                 :         22 :         vars.push_back(v);
     168         [ +  + ]:         22 :         if (vnum < qtvars[tnv].size())
     169                 :            :         {
     170                 :          3 :           subs.push_back(qtvars[tnv][vnum]);
     171                 :            :         }
     172                 :            :         else
     173                 :            :         {
     174 [ -  + ][ -  + ]:         19 :           Assert(vnum == qtvars[tnv].size());
                 [ -  - ]
     175                 :         19 :           Node bv = NodeManager::mkBoundVar(tnv);
     176                 :         19 :           qtvars[tnv].push_back(bv);
     177                 :         19 :           qvars.push_back(bv);
     178                 :         19 :           subs.push_back(bv);
     179                 :         19 :         }
     180                 :         37 :       }
     181                 :         15 :       pas = pas[1];
     182         [ +  - ]:         15 :       if (!vars.empty())
     183                 :            :       {
     184                 :            :         pas =
     185                 :         15 :             pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
     186                 :            :       }
     187         [ +  - ]:         15 :     }
     188         [ +  - ]:        177 :     Trace("sygus-infer-debug") << "  ...substituted to " << pas << std::endl;
     189                 :            : 
     190                 :            :     // collect free functions, ensure no quantified formulas
     191                 :        177 :     TNode cur = pas;
     192                 :            :     // compute free variables
     193                 :        177 :     visit.push_back(cur);
     194                 :            :     do
     195                 :            :     {
     196                 :       1150 :       cur = visit.back();
     197                 :       1150 :       visit.pop_back();
     198         [ +  + ]:       1150 :       if (visited.find(cur) == visited.end())
     199                 :            :       {
     200                 :        873 :         visited.insert(cur);
     201         [ +  + ]:        873 :         if (cur.getKind() == Kind::APPLY_UF)
     202                 :            :         {
     203                 :         45 :           Node op = cur.getOperator();
     204                 :            :           // visit the operator, which might not be a variable
     205                 :         45 :           visit.push_back(op);
     206                 :         45 :         }
     207 [ +  + ][ +  + ]:        828 :         else if (cur.isVar() && cur.getKind() != Kind::BOUND_VARIABLE)
                 [ +  + ]
     208                 :            :         {
     209                 :            :           // We are either in the case of a free first-order constant or a
     210                 :            :           // function in a higher-order context. We add to free_functions
     211                 :            :           // in either case. Note that a free constant that is not in a
     212                 :            :           // higher-order context is a 0-argument function-to-synthesize.
     213                 :            :           // We should not have traversed here before due to our visited cache.
     214 [ -  + ][ -  + ]:        130 :           Assert(std::find(free_functions.begin(), free_functions.end(), cur)
                 [ -  - ]
     215                 :            :                  == free_functions.end());
     216                 :        130 :           free_functions.push_back(cur);
     217                 :            :         }
     218         [ +  + ]:        698 :         else if (cur.isClosure())
     219                 :            :         {
     220         [ +  - ]:          2 :           Trace("sygus-infer")
     221                 :          1 :               << "...fail: non-top-level quantifier." << std::endl;
     222         [ +  - ]:          1 :           Warning() << "Cannot convert to sygus since there is a non-top-level "
     223                 :          0 :                        "quantified formula: "
     224                 :          1 :                     << cur << std::endl;
     225                 :          1 :           return false;
     226                 :            :         }
     227         [ +  + ]:       1801 :         for (const TNode& cn : cur)
     228                 :            :         {
     229                 :        929 :           visit.push_back(cn);
     230                 :        929 :         }
     231                 :            :       }
     232         [ +  + ]:       1149 :     } while (!visit.empty());
     233                 :        176 :     processed_assertions.push_back(pas);
     234 [ +  + ][ +  + ]:        181 :   }
         [ +  + ][ +  + ]
                 [ +  + ]
     235                 :            : 
     236                 :            :   // no functions to synthesize
     237         [ -  + ]:         60 :   if (free_functions.empty())
     238                 :            :   {
     239         [ -  - ]:          0 :     Warning()
     240                 :          0 :         << "Cannot convert to sygus since there are no free function symbols."
     241                 :          0 :         << std::endl;
     242         [ -  - ]:          0 :     Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
     243                 :          0 :     return false;
     244                 :            :   }
     245                 :            : 
     246                 :            :   // Note that we do not restrict based on the types of free functions here,
     247                 :            :   // i.e. we assume that all types are handled in sygus grammar construction.
     248                 :            : 
     249 [ -  + ][ -  + ]:         60 :   Assert(!processed_assertions.empty());
                 [ -  - ]
     250                 :            :   // conjunction of the assertions
     251         [ +  - ]:         60 :   Trace("sygus-infer") << "Construct body..." << std::endl;
     252                 :         60 :   Node body;
     253         [ -  + ]:         60 :   if (processed_assertions.size() == 1)
     254                 :            :   {
     255                 :          0 :     body = processed_assertions[0];
     256                 :            :   }
     257                 :            :   else
     258                 :            :   {
     259                 :         60 :     body = nm->mkNode(Kind::AND, processed_assertions);
     260                 :            :   }
     261                 :            : 
     262                 :            :   // for each free function symbol, make a bound variable of the same type
     263         [ +  - ]:         60 :   Trace("sygus-infer") << "Do free function substitution..." << std::endl;
     264                 :         60 :   std::vector<Node> ff_vars;
     265                 :         60 :   std::map<Node, Node> ff_var_to_ff;
     266         [ +  + ]:        190 :   for (const Node& ff : free_functions)
     267                 :            :   {
     268                 :        130 :     Node ffv = NodeManager::mkBoundVar(ff.getType());
     269                 :        130 :     ff_vars.push_back(ffv);
     270         [ +  - ]:        130 :     Trace("sygus-infer") << "  synth-fun: " << ff << " as " << ffv << std::endl;
     271                 :        130 :     ff_var_to_ff[ffv] = ff;
     272                 :        130 :   }
     273                 :            :   // substitute free functions -> variables
     274                 :        120 :   body = body.substitute(free_functions.begin(),
     275                 :            :                          free_functions.end(),
     276                 :            :                          ff_vars.begin(),
     277                 :         60 :                          ff_vars.end());
     278         [ +  - ]:         60 :   Trace("sygus-infer-debug") << "...got : " << body << std::endl;
     279                 :            : 
     280                 :            :   // quantify the body
     281         [ +  - ]:         60 :   Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
     282                 :         60 :   body = body.negate();
     283         [ +  + ]:         60 :   if (!qvars.empty())
     284                 :            :   {
     285                 :         12 :     Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, qvars);
     286                 :         12 :     body = nm->mkNode(Kind::EXISTS, bvl, body);
     287                 :         12 :   }
     288                 :            : 
     289                 :            :   // sygus attribute to mark the conjecture as a sygus conjecture
     290         [ +  - ]:         60 :   Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
     291                 :            : 
     292                 :            :   body =
     293                 :         60 :       quantifiers::SygusUtils::mkSygusConjecture(nodeManager(), ff_vars, body);
     294                 :            : 
     295         [ +  - ]:         60 :   Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
     296                 :            : 
     297                 :            :   // make a separate smt call
     298                 :         60 :   std::unique_ptr<SolverEngine> rrSygus;
     299                 :         60 :   theory::initializeSubsolver(rrSygus, d_env);
     300                 :         60 :   rrSygus->assertFormula(body);
     301         [ +  - ]:         60 :   Trace("sygus-infer") << "*** Check sat..." << std::endl;
     302                 :         60 :   Result r = rrSygus->checkSat();
     303         [ +  - ]:         60 :   Trace("sygus-infer") << "...result : " << r << std::endl;
     304                 :            :   // get the synthesis solutions
     305                 :         60 :   std::map<Node, Node> synth_sols;
     306         [ +  + ]:         60 :   if (!rrSygus->getSubsolverSynthSolutions(synth_sols))
     307                 :            :   {
     308                 :            :     // failed, conjecture was infeasible
     309         [ -  + ]:          7 :     if (options().quantifiers.sygusInference == options::SygusInferenceMode::ON)
     310                 :            :     {
     311                 :          0 :       std::stringstream ss;
     312                 :            :       ss << "Translated to sygus, but failed to show problem to be satisfiable "
     313                 :          0 :             "with --sygus-inference.";
     314                 :          0 :       throw LogicException(ss.str());
     315                 :          0 :     }
     316                 :          7 :     return false;
     317                 :            :   }
     318                 :            : 
     319                 :         53 :   std::vector<Node> final_ff;
     320                 :         53 :   std::vector<Node> final_ff_sol;
     321                 :         53 :   for (std::map<Node, Node>::iterator it = synth_sols.begin();
     322         [ +  + ]:        164 :        it != synth_sols.end();
     323                 :        111 :        ++it)
     324                 :            :   {
     325         [ +  - ]:        222 :     Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
     326                 :        111 :                          << it->second << std::endl;
     327                 :        111 :     Node ffv = it->first;
     328                 :        111 :     std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
     329                 :            :     // all synthesis solutions should correspond to a variable we introduced
     330 [ -  + ][ -  + ]:        111 :     Assert(itffv != ff_var_to_ff.end());
                 [ -  - ]
     331         [ +  - ]:        111 :     if (itffv != ff_var_to_ff.end())
     332                 :            :     {
     333                 :        111 :       Node ff = itffv->second;
     334                 :        111 :       Node body2 = it->second;
     335         [ +  - ]:        111 :       Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
     336                 :        111 :       funs.push_back(ff);
     337                 :        111 :       sols.push_back(body2);
     338                 :        111 :     }
     339                 :        111 :   }
     340                 :         53 :   return true;
     341                 :         61 : }
     342                 :            : 
     343                 :            : }  // namespace passes
     344                 :            : }  // namespace preprocessing
     345                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14