LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/preprocessing/passes - sygus_inference.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 164 190 86.3 %
Date: 2026-02-21 11:58:00 Functions: 3 3 100.0 %
Branches: 98 160 61.2 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Andrew Reynolds, Aina Niemetz, Mathias Preiner
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Sygus inference module.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "preprocessing/passes/sygus_inference.h"
      17                 :            : 
      18                 :            : #include "options/quantifiers_options.h"
      19                 :            : #include "preprocessing/assertion_pipeline.h"
      20                 :            : #include "preprocessing/preprocessing_pass_context.h"
      21                 :            : #include "smt/logic_exception.h"
      22                 :            : #include "smt/solver_engine.h"
      23                 :            : #include "theory/quantifiers/quantifiers_attributes.h"
      24                 :            : #include "theory/quantifiers/quantifiers_preprocess.h"
      25                 :            : #include "theory/quantifiers/sygus/sygus_utils.h"
      26                 :            : #include "theory/rewriter.h"
      27                 :            : #include "theory/smt_engine_subsolver.h"
      28                 :            : 
      29                 :            : using namespace std;
      30                 :            : using namespace cvc5::internal::kind;
      31                 :            : using namespace cvc5::internal::theory;
      32                 :            : 
      33                 :            : namespace cvc5::internal {
      34                 :            : namespace preprocessing {
      35                 :            : namespace passes {
      36                 :            : 
      37                 :      50594 : SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
      38                 :      50594 :     : PreprocessingPass(preprocContext, "sygus-infer"){};
      39                 :            : 
      40                 :         61 : PreprocessingPassResult SygusInference::applyInternal(
      41                 :            :     AssertionPipeline* assertionsToPreprocess)
      42                 :            : {
      43         [ +  - ]:         61 :   Trace("sygus-infer") << "Run sygus inference..." << std::endl;
      44                 :         61 :   std::vector<Node> funs;
      45                 :         61 :   std::vector<Node> sols;
      46                 :            :   // see if we can successfully solve the input as a sygus problem
      47         [ +  + ]:         61 :   if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
      48                 :            :   {
      49         [ +  - ]:         53 :     Trace("sygus-infer") << "...Solved:" << std::endl;
      50 [ -  + ][ -  + ]:         53 :     Assert(funs.size() == sols.size());
                 [ -  - ]
      51                 :            :     // if so, sygus gives us function definitions, which we add as substitutions
      52         [ +  + ]:        164 :     for (unsigned i = 0, size = funs.size(); i < size; i++)
      53                 :            :     {
      54         [ +  - ]:        111 :       Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
      55                 :        111 :       d_preprocContext->addSubstitution(funs[i], sols[i]);
      56                 :            :     }
      57                 :            : 
      58                 :            :     // apply substitution to everything, should result in SAT
      59         [ +  + ]:        205 :     for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
      60                 :            :          i++)
      61                 :            :     {
      62                 :        152 :       Node prev = (*assertionsToPreprocess)[i];
      63                 :            :       Node curr =
      64                 :        152 :           prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
      65         [ +  + ]:        152 :       if (curr != prev)
      66                 :            :       {
      67                 :         99 :         curr = rewrite(curr);
      68         [ +  - ]:        198 :         Trace("sygus-infer-debug")
      69                 :         99 :             << "...rewrote " << prev << " to " << curr << std::endl;
      70                 :         99 :         assertionsToPreprocess->replace(i, curr);
      71                 :            :       }
      72                 :        152 :     }
      73                 :            :   }
      74                 :          8 :   else if (options().quantifiers.sygusInference
      75         [ -  + ]:          8 :            == options::SygusInferenceMode::ON)
      76                 :            :   {
      77                 :          0 :     std::stringstream ss;
      78                 :          0 :     ss << "Cannot translate input to sygus for --sygus-inference";
      79                 :          0 :     throw LogicException(ss.str());
      80                 :          0 :   }
      81                 :         61 :   return PreprocessingPassResult::NO_CONFLICT;
      82                 :         61 : }
      83                 :            : 
      84                 :         61 : bool SygusInference::solveSygus(const std::vector<Node>& assertions,
      85                 :            :                                 std::vector<Node>& funs,
      86                 :            :                                 std::vector<Node>& sols)
      87                 :            : {
      88         [ -  + ]:         61 :   if (assertions.empty())
      89                 :            :   {
      90         [ -  - ]:          0 :     Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
      91         [ -  - ]:          0 :     Warning() << "Cannot convert to sygus since there are no assertions."
      92                 :          0 :               << std::endl;
      93                 :          0 :     return false;
      94                 :            :   }
      95                 :            : 
      96                 :         61 :   NodeManager* nm = nodeManager();
      97                 :            : 
      98                 :            :   // collect free variables in all assertions
      99                 :         61 :   std::vector<Node> qvars;
     100                 :         61 :   std::map<TypeNode, std::vector<Node> > qtvars;
     101                 :         61 :   std::vector<Node> free_functions;
     102                 :            : 
     103                 :         61 :   std::vector<TNode> visit;
     104                 :         61 :   std::unordered_set<TNode> visited;
     105                 :            : 
     106                 :            :   // add top-level conjuncts to eassertions
     107                 :         61 :   std::vector<Node> assertions_proc = assertions;
     108                 :         61 :   std::vector<Node> eassertions;
     109                 :         61 :   unsigned index = 0;
     110         [ +  + ]:        241 :   while (index < assertions_proc.size())
     111                 :            :   {
     112                 :        180 :     Node ca = assertions_proc[index];
     113         [ +  + ]:        180 :     if (ca.getKind() == Kind::AND)
     114                 :            :     {
     115         [ +  + ]:          6 :       for (const Node& ai : ca)
     116                 :            :       {
     117                 :          4 :         assertions_proc.push_back(ai);
     118                 :          4 :       }
     119                 :            :     }
     120                 :            :     else
     121                 :            :     {
     122                 :        178 :       eassertions.push_back(ca);
     123                 :            :     }
     124                 :        180 :     index++;
     125                 :        180 :   }
     126                 :            : 
     127                 :            :   // process eassertions
     128                 :         61 :   std::vector<Node> processed_assertions;
     129                 :         61 :   quantifiers::QuantifiersPreprocess qp(d_env);
     130         [ +  + ]:        237 :   for (const Node& as : eassertions)
     131                 :            :   {
     132                 :            :     // substitution for this assertion
     133                 :        177 :     std::vector<Node> vars;
     134                 :        177 :     std::vector<Node> subs;
     135                 :        177 :     std::map<TypeNode, unsigned> type_count;
     136                 :        177 :     Node pas = as;
     137                 :            :     // rewrite
     138                 :        177 :     pas = rewrite(pas);
     139         [ +  - ]:        177 :     Trace("sygus-infer") << "assertion : " << pas << std::endl;
     140         [ +  + ]:        177 :     if (pas.getKind() == Kind::FORALL)
     141                 :            :     {
     142                 :            :       // preprocess the quantified formula
     143                 :         15 :       TrustNode trn = qp.preprocess(pas);
     144         [ -  + ]:         15 :       if (!trn.isNull())
     145                 :            :       {
     146                 :          0 :         pas = trn.getNode();
     147                 :            :       }
     148         [ +  - ]:         15 :       Trace("sygus-infer-debug") << "  ...preprocessed to " << pas << std::endl;
     149                 :         15 :     }
     150         [ +  + ]:        177 :     if (pas.getKind() == Kind::FORALL)
     151                 :            :     {
     152                 :            :       // it must be a standard quantifier
     153                 :         15 :       theory::quantifiers::QAttributes qa;
     154                 :         15 :       theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
     155         [ -  + ]:         15 :       if (!qa.isStandard())
     156                 :            :       {
     157         [ -  - ]:          0 :         Trace("sygus-infer")
     158                 :          0 :             << "...fail: non-standard top-level quantifier." << std::endl;
     159         [ -  - ]:          0 :         Warning() << "Cannot convert to sygus since there is a non-standard "
     160                 :          0 :                      "top-level quantified formula: "
     161                 :          0 :                   << pas << std::endl;
     162                 :          0 :         return false;
     163                 :            :       }
     164                 :            :       // infer prefix
     165         [ +  + ]:         37 :       for (const Node& v : pas[0])
     166                 :            :       {
     167                 :         22 :         TypeNode tnv = v.getType();
     168                 :         22 :         unsigned vnum = type_count[tnv];
     169                 :         22 :         type_count[tnv]++;
     170                 :         22 :         vars.push_back(v);
     171         [ +  + ]:         22 :         if (vnum < qtvars[tnv].size())
     172                 :            :         {
     173                 :          3 :           subs.push_back(qtvars[tnv][vnum]);
     174                 :            :         }
     175                 :            :         else
     176                 :            :         {
     177 [ -  + ][ -  + ]:         19 :           Assert(vnum == qtvars[tnv].size());
                 [ -  - ]
     178                 :         19 :           Node bv = NodeManager::mkBoundVar(tnv);
     179                 :         19 :           qtvars[tnv].push_back(bv);
     180                 :         19 :           qvars.push_back(bv);
     181                 :         19 :           subs.push_back(bv);
     182                 :         19 :         }
     183                 :         37 :       }
     184                 :         15 :       pas = pas[1];
     185         [ +  - ]:         15 :       if (!vars.empty())
     186                 :            :       {
     187                 :            :         pas =
     188                 :         15 :             pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
     189                 :            :       }
     190         [ +  - ]:         15 :     }
     191         [ +  - ]:        177 :     Trace("sygus-infer-debug") << "  ...substituted to " << pas << std::endl;
     192                 :            : 
     193                 :            :     // collect free functions, ensure no quantified formulas
     194                 :        177 :     TNode cur = pas;
     195                 :            :     // compute free variables
     196                 :        177 :     visit.push_back(cur);
     197                 :            :     do
     198                 :            :     {
     199                 :       1148 :       cur = visit.back();
     200                 :       1148 :       visit.pop_back();
     201         [ +  + ]:       1148 :       if (visited.find(cur) == visited.end())
     202                 :            :       {
     203                 :        871 :         visited.insert(cur);
     204         [ +  + ]:        871 :         if (cur.getKind() == Kind::APPLY_UF)
     205                 :            :         {
     206                 :         45 :           Node op = cur.getOperator();
     207                 :            :           // visit the operator, which might not be a variable
     208                 :         45 :           visit.push_back(op);
     209                 :         45 :         }
     210 [ +  + ][ +  + ]:        826 :         else if (cur.isVar() && cur.getKind() != Kind::BOUND_VARIABLE)
                 [ +  + ]
     211                 :            :         {
     212                 :            :           // We are either in the case of a free first-order constant or a
     213                 :            :           // function in a higher-order context. We add to free_functions
     214                 :            :           // in either case. Note that a free constant that is not in a
     215                 :            :           // higher-order context is a 0-argument function-to-synthesize.
     216                 :            :           // We should not have traversed here before due to our visited cache.
     217 [ -  + ][ -  + ]:        130 :           Assert(std::find(free_functions.begin(), free_functions.end(), cur)
                 [ -  - ]
     218                 :            :                  == free_functions.end());
     219                 :        130 :           free_functions.push_back(cur);
     220                 :            :         }
     221         [ +  + ]:        696 :         else if (cur.isClosure())
     222                 :            :         {
     223         [ +  - ]:          2 :           Trace("sygus-infer")
     224                 :          1 :               << "...fail: non-top-level quantifier." << std::endl;
     225         [ +  - ]:          1 :           Warning() << "Cannot convert to sygus since there is a non-top-level "
     226                 :          0 :                        "quantified formula: "
     227                 :          1 :                     << cur << std::endl;
     228                 :          1 :           return false;
     229                 :            :         }
     230         [ +  + ]:       1797 :         for (const TNode& cn : cur)
     231                 :            :         {
     232                 :        927 :           visit.push_back(cn);
     233                 :        927 :         }
     234                 :            :       }
     235         [ +  + ]:       1147 :     } while (!visit.empty());
     236                 :        176 :     processed_assertions.push_back(pas);
     237 [ +  + ][ +  + ]:        181 :   }
         [ +  + ][ +  + ]
                 [ +  + ]
     238                 :            : 
     239                 :            :   // no functions to synthesize
     240         [ -  + ]:         60 :   if (free_functions.empty())
     241                 :            :   {
     242         [ -  - ]:          0 :     Warning()
     243                 :          0 :         << "Cannot convert to sygus since there are no free function symbols."
     244                 :          0 :         << std::endl;
     245         [ -  - ]:          0 :     Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
     246                 :          0 :     return false;
     247                 :            :   }
     248                 :            : 
     249                 :            :   // Note that we do not restrict based on the types of free functions here,
     250                 :            :   // i.e. we assume that all types are handled in sygus grammar construction.
     251                 :            : 
     252 [ -  + ][ -  + ]:         60 :   Assert(!processed_assertions.empty());
                 [ -  - ]
     253                 :            :   // conjunction of the assertions
     254         [ +  - ]:         60 :   Trace("sygus-infer") << "Construct body..." << std::endl;
     255                 :         60 :   Node body;
     256         [ -  + ]:         60 :   if (processed_assertions.size() == 1)
     257                 :            :   {
     258                 :          0 :     body = processed_assertions[0];
     259                 :            :   }
     260                 :            :   else
     261                 :            :   {
     262                 :         60 :     body = nm->mkNode(Kind::AND, processed_assertions);
     263                 :            :   }
     264                 :            : 
     265                 :            :   // for each free function symbol, make a bound variable of the same type
     266         [ +  - ]:         60 :   Trace("sygus-infer") << "Do free function substitution..." << std::endl;
     267                 :         60 :   std::vector<Node> ff_vars;
     268                 :         60 :   std::map<Node, Node> ff_var_to_ff;
     269         [ +  + ]:        190 :   for (const Node& ff : free_functions)
     270                 :            :   {
     271                 :        130 :     Node ffv = NodeManager::mkBoundVar(ff.getType());
     272                 :        130 :     ff_vars.push_back(ffv);
     273         [ +  - ]:        130 :     Trace("sygus-infer") << "  synth-fun: " << ff << " as " << ffv << std::endl;
     274                 :        130 :     ff_var_to_ff[ffv] = ff;
     275                 :        130 :   }
     276                 :            :   // substitute free functions -> variables
     277                 :        120 :   body = body.substitute(free_functions.begin(),
     278                 :            :                          free_functions.end(),
     279                 :            :                          ff_vars.begin(),
     280                 :         60 :                          ff_vars.end());
     281         [ +  - ]:         60 :   Trace("sygus-infer-debug") << "...got : " << body << std::endl;
     282                 :            : 
     283                 :            :   // quantify the body
     284         [ +  - ]:         60 :   Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
     285                 :         60 :   body = body.negate();
     286         [ +  + ]:         60 :   if (!qvars.empty())
     287                 :            :   {
     288                 :         12 :     Node bvl = nm->mkNode(Kind::BOUND_VAR_LIST, qvars);
     289                 :         12 :     body = nm->mkNode(Kind::EXISTS, bvl, body);
     290                 :         12 :   }
     291                 :            : 
     292                 :            :   // sygus attribute to mark the conjecture as a sygus conjecture
     293         [ +  - ]:         60 :   Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
     294                 :            : 
     295                 :            :   body =
     296                 :         60 :       quantifiers::SygusUtils::mkSygusConjecture(nodeManager(), ff_vars, body);
     297                 :            : 
     298         [ +  - ]:         60 :   Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
     299                 :            : 
     300                 :            :   // make a separate smt call
     301                 :         60 :   std::unique_ptr<SolverEngine> rrSygus;
     302                 :         60 :   theory::initializeSubsolver(rrSygus, d_env);
     303                 :         60 :   rrSygus->assertFormula(body);
     304         [ +  - ]:         60 :   Trace("sygus-infer") << "*** Check sat..." << std::endl;
     305                 :         60 :   Result r = rrSygus->checkSat();
     306         [ +  - ]:         60 :   Trace("sygus-infer") << "...result : " << r << std::endl;
     307                 :            :   // get the synthesis solutions
     308                 :         60 :   std::map<Node, Node> synth_sols;
     309         [ +  + ]:         60 :   if (!rrSygus->getSubsolverSynthSolutions(synth_sols))
     310                 :            :   {
     311                 :            :     // failed, conjecture was infeasible
     312         [ -  + ]:          7 :     if (options().quantifiers.sygusInference == options::SygusInferenceMode::ON)
     313                 :            :     {
     314                 :          0 :       std::stringstream ss;
     315                 :            :       ss << "Translated to sygus, but failed to show problem to be satisfiable "
     316                 :          0 :             "with --sygus-inference.";
     317                 :          0 :       throw LogicException(ss.str());
     318                 :          0 :     }
     319                 :          7 :     return false;
     320                 :            :   }
     321                 :            : 
     322                 :         53 :   std::vector<Node> final_ff;
     323                 :         53 :   std::vector<Node> final_ff_sol;
     324                 :         53 :   for (std::map<Node, Node>::iterator it = synth_sols.begin();
     325         [ +  + ]:        164 :        it != synth_sols.end();
     326                 :        111 :        ++it)
     327                 :            :   {
     328         [ +  - ]:        222 :     Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
     329                 :        111 :                          << it->second << std::endl;
     330                 :        111 :     Node ffv = it->first;
     331                 :        111 :     std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
     332                 :            :     // all synthesis solutions should correspond to a variable we introduced
     333 [ -  + ][ -  + ]:        111 :     Assert(itffv != ff_var_to_ff.end());
                 [ -  - ]
     334         [ +  - ]:        111 :     if (itffv != ff_var_to_ff.end())
     335                 :            :     {
     336                 :        111 :       Node ff = itffv->second;
     337                 :        111 :       Node body2 = it->second;
     338         [ +  - ]:        111 :       Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
     339                 :        111 :       funs.push_back(ff);
     340                 :        111 :       sols.push_back(body2);
     341                 :        111 :     }
     342                 :        111 :   }
     343                 :         53 :   return true;
     344                 :         61 : }
     345                 :            : 
     346                 :            : 
     347                 :            : }  // namespace passes
     348                 :            : }  // namespace preprocessing
     349                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14