LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/preprocessing/passes - pseudo_boolean_processor.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 147 194 75.8 %
Date: 2025-01-26 14:42:37 Functions: 17 17 100.0 %
Branches: 88 197 44.7 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Tim King, Andres Noetzli, Aina Niemetz
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * [[ Add one-line brief description here ]]
      14                 :            :  *
      15                 :            :  * [[ Add lengthier description here ]]
      16                 :            :  * \todo document this file
      17                 :            :  */
      18                 :            : 
      19                 :            : #include "preprocessing/passes/pseudo_boolean_processor.h"
      20                 :            : 
      21                 :            : #include "base/output.h"
      22                 :            : #include "preprocessing/assertion_pipeline.h"
      23                 :            : #include "preprocessing/preprocessing_pass_context.h"
      24                 :            : #include "theory/arith/arith_utilities.h"
      25                 :            : #include "theory/arith/linear/normal_form.h"
      26                 :            : #include "theory/rewriter.h"
      27                 :            : 
      28                 :            : namespace cvc5::internal {
      29                 :            : namespace preprocessing {
      30                 :            : namespace passes {
      31                 :            : 
      32                 :            : using namespace cvc5::internal::theory;
      33                 :            : using namespace cvc5::internal::theory::arith;
      34                 :            : 
      35                 :      51296 : PseudoBooleanProcessor::PseudoBooleanProcessor(
      36                 :      51296 :     PreprocessingPassContext* preprocContext)
      37                 :            :     : PreprocessingPass(preprocContext, "pseudo-boolean-processor"),
      38                 :     102592 :       d_pbBounds(userContext()),
      39                 :     102592 :       d_subCache(userContext()),
      40                 :      51296 :       d_pbs(userContext(), 0)
      41                 :            : {
      42                 :      51296 : }
      43                 :            : 
      44                 :          2 : PreprocessingPassResult PseudoBooleanProcessor::applyInternal(
      45                 :            :     AssertionPipeline* assertionsToPreprocess)
      46                 :            : {
      47                 :          2 :   learn(assertionsToPreprocess->ref());
      48         [ +  - ]:          2 :   if (likelyToHelp())
      49                 :            :   {
      50                 :          2 :     applyReplacements(assertionsToPreprocess);
      51                 :            :   }
      52                 :            : 
      53                 :          2 :   return PreprocessingPassResult::NO_CONFLICT;
      54                 :            : }
      55                 :            : 
      56                 :        202 : bool PseudoBooleanProcessor::decomposeAssertion(Node assertion, bool negated)
      57                 :            : {
      58         [ -  + ]:        202 :   if (assertion.getKind() != Kind::GEQ)
      59                 :            :   {
      60                 :          0 :     return false;
      61                 :            :   }
      62 [ -  + ][ -  + ]:        202 :   Assert(assertion.getKind() == Kind::GEQ);
                 [ -  - ]
      63                 :            : 
      64         [ +  - ]:        202 :   Trace("pbs::rewrites") << "decomposeAssertion" << assertion << std::endl;
      65                 :            : 
      66                 :        404 :   Node l = assertion[0];
      67                 :        404 :   Node r = assertion[1];
      68                 :            : 
      69         [ -  + ]:        202 :   if (!r.isConst())
      70                 :            :   {
      71         [ -  - ]:          0 :     Trace("pbs::rewrites") << "not rhs constant" << assertion << std::endl;
      72                 :          0 :     return false;
      73                 :            :   }
      74                 :            :   // don't bother matching on anything other than + on the left hand side
      75         [ +  + ]:        202 :   if (l.getKind() != Kind::ADD)
      76                 :            :   {
      77         [ +  - ]:        200 :     Trace("pbs::rewrites") << "not plus" << assertion << std::endl;
      78                 :        200 :     return false;
      79                 :            :   }
      80                 :            : 
      81         [ -  + ]:          2 :   if (!linear::Polynomial::isMember(l))
      82                 :            :   {
      83         [ -  - ]:          0 :     Trace("pbs::rewrites") << "not polynomial" << assertion << std::endl;
      84                 :          0 :     return false;
      85                 :            :   }
      86                 :            : 
      87                 :          4 :   linear::Polynomial p = linear::Polynomial::parsePolynomial(l);
      88                 :          2 :   clear();
      89         [ -  + ]:          2 :   if (negated)
      90                 :            :   {
      91                 :            :     // (not (>= p r))
      92                 :            :     // (< p r)
      93                 :            :     // (> (-p) (-r))
      94                 :            :     // (>= (-p) (-r +1))
      95                 :          0 :     d_off = (-r.getConst<Rational>());
      96                 :            : 
      97         [ -  - ]:          0 :     if (d_off.value().isIntegral())
      98                 :            :     {
      99                 :          0 :       d_off = d_off.value() + Rational(1);
     100                 :            :     }
     101                 :            :     else
     102                 :            :     {
     103                 :          0 :       d_off = Rational(d_off.value().ceiling());
     104                 :            :     }
     105                 :            :   }
     106                 :            :   else
     107                 :            :   {
     108                 :            :     // (>= p r)
     109                 :          2 :     d_off = r.getConst<Rational>();
     110                 :          2 :     d_off = Rational(d_off.value().ceiling());
     111                 :            :   }
     112 [ -  + ][ -  + ]:          2 :   Assert(d_off.value().isIntegral());
                 [ -  - ]
     113                 :            : 
     114         [ -  + ]:          2 :   int adj = negated ? -1 : 1;
     115         [ +  + ]:          6 :   for (linear::Polynomial::iterator i = p.begin(), end = p.end(); i != end; ++i)
     116                 :            :   {
     117                 :          4 :     linear::Monomial m = *i;
     118                 :          4 :     const Rational& coeff = m.getConstant().getValue();
     119 [ +  + ][ -  + ]:          4 :     if (!(coeff.isOne() || coeff.isNegativeOne()))
                 [ -  + ]
     120                 :            :     {
     121                 :          0 :       return false;
     122                 :            :     }
     123 [ -  + ][ -  + ]:          4 :     Assert(coeff.sgn() != 0);
                 [ -  - ]
     124                 :            : 
     125                 :          4 :     const linear::VarList& vl = m.getVarList();
     126                 :          4 :     Node v = vl.getNode();
     127                 :            : 
     128         [ -  + ]:          4 :     if (!isPseudoBoolean(v))
     129                 :            :     {
     130                 :          0 :       return false;
     131                 :            :     }
     132                 :          4 :     int sgn = adj * coeff.sgn();
     133         [ +  + ]:          4 :     if (sgn > 0)
     134                 :            :     {
     135                 :          2 :       d_pos.push_back(v);
     136                 :            :     }
     137                 :            :     else
     138                 :            :     {
     139                 :          2 :       d_neg.push_back(v);
     140                 :            :     }
     141                 :            :   }
     142                 :            :   // all of the variables are pseudoboolean
     143                 :            :   // with coefficients +/- and the offsetoff
     144                 :          2 :   return true;
     145                 :            : }
     146                 :            : 
     147                 :        204 : bool PseudoBooleanProcessor::isPseudoBoolean(Node v) const
     148                 :            : {
     149                 :        204 :   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
     150         [ +  - ]:        204 :   if (ci != d_pbBounds.end())
     151                 :            :   {
     152                 :        204 :     const std::pair<Node, Node>& p = (*ci).second;
     153 [ +  - ][ +  - ]:        204 :     return !(p.first).isNull() && !(p.second).isNull();
     154                 :            :   }
     155                 :          0 :   return false;
     156                 :            : }
     157                 :            : 
     158                 :        200 : void PseudoBooleanProcessor::addGeqZero(Node v, Node exp)
     159                 :            : {
     160 [ -  + ][ -  + ]:        200 :   Assert(isIntVar(v));
                 [ -  - ]
     161 [ -  + ][ -  + ]:        200 :   Assert(!exp.isNull());
                 [ -  - ]
     162                 :        200 :   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
     163                 :            : 
     164         [ +  - ]:        200 :   Trace("pbs::rewrites") << "addGeqZero " << v << std::endl;
     165                 :            : 
     166         [ +  - ]:        200 :   if (ci == d_pbBounds.end())
     167                 :            :   {
     168                 :        200 :     d_pbBounds.insert(v, std::make_pair(exp, Node::null()));
     169                 :            :   }
     170                 :            :   else
     171                 :            :   {
     172                 :          0 :     const std::pair<Node, Node>& p = (*ci).second;
     173         [ -  - ]:          0 :     if (p.first.isNull())
     174                 :            :     {
     175                 :          0 :       Assert(!p.second.isNull());
     176                 :          0 :       d_pbBounds.insert(v, std::make_pair(exp, p.second));
     177         [ -  - ]:          0 :       Trace("pbs::rewrites") << "add pbs " << v << std::endl;
     178                 :          0 :       Assert(isPseudoBoolean(v));
     179                 :          0 :       d_pbs = d_pbs + 1;
     180                 :            :     }
     181                 :            :   }
     182                 :        200 : }
     183                 :            : 
     184                 :        200 : void PseudoBooleanProcessor::addLeqOne(Node v, Node exp)
     185                 :            : {
     186 [ -  + ][ -  + ]:        200 :   Assert(isIntVar(v));
                 [ -  - ]
     187 [ -  + ][ -  + ]:        200 :   Assert(!exp.isNull());
                 [ -  - ]
     188         [ +  - ]:        200 :   Trace("pbs::rewrites") << "addLeqOne " << v << std::endl;
     189                 :        200 :   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
     190         [ -  + ]:        200 :   if (ci == d_pbBounds.end())
     191                 :            :   {
     192                 :          0 :     d_pbBounds.insert(v, std::make_pair(Node::null(), exp));
     193                 :            :   }
     194                 :            :   else
     195                 :            :   {
     196                 :        200 :     const std::pair<Node, Node>& p = (*ci).second;
     197         [ +  - ]:        200 :     if (p.second.isNull())
     198                 :            :     {
     199 [ -  + ][ -  + ]:        200 :       Assert(!p.first.isNull());
                 [ -  - ]
     200                 :        200 :       d_pbBounds.insert(v, std::make_pair(p.first, exp));
     201         [ +  - ]:        200 :       Trace("pbs::rewrites") << "add pbs " << v << std::endl;
     202 [ -  + ][ -  + ]:        200 :       Assert(isPseudoBoolean(v));
                 [ -  - ]
     203                 :        200 :       d_pbs = d_pbs + 1;
     204                 :            :     }
     205                 :            :   }
     206                 :        200 : }
     207                 :            : 
     208                 :        402 : void PseudoBooleanProcessor::learnRewrittenGeq(Node assertion,
     209                 :            :                                                bool negated,
     210                 :            :                                                Node orig)
     211                 :            : {
     212 [ -  + ][ -  + ]:        402 :   Assert(assertion.getKind() == Kind::GEQ);
                 [ -  - ]
     213 [ -  + ][ -  + ]:        402 :   Assert(assertion == rewrite(assertion));
                 [ -  - ]
     214                 :            : 
     215                 :            :   // assume assertion is rewritten
     216                 :        804 :   Node l = assertion[0];
     217                 :        804 :   Node r = assertion[1];
     218                 :            : 
     219         [ +  - ]:        402 :   if (r.isConst())
     220                 :            :   {
     221                 :        402 :     const Rational& rc = r.getConst<Rational>();
     222         [ +  + ]:        402 :     if (isIntVar(l))
     223                 :            :     {
     224 [ +  + ][ +  - ]:        400 :       if (!negated && rc.isZero())
                 [ +  + ]
     225                 :            :       {  // (>= x 0)
     226                 :        200 :         addGeqZero(l, orig);
     227                 :            :       }
     228 [ +  - ][ +  - ]:        200 :       else if (negated && rc == Rational(2))
         [ +  - ][ +  - ]
                 [ -  - ]
     229                 :            :       {
     230                 :        200 :         addLeqOne(l, orig);
     231                 :            :       }
     232                 :            :     }
     233 [ -  + ][ -  - ]:          2 :     else if (l.getKind() == Kind::MULT && l.getNumChildren() == 2)
                 [ -  + ]
     234                 :            :     {
     235                 :          0 :       Node c = l[0], v = l[1];
     236                 :          0 :       if (c.isConst() && c.getConst<Rational>().isNegativeOne())
     237                 :            :       {
     238         [ -  - ]:          0 :         if (isIntVar(v))
     239                 :            :         {
     240                 :          0 :           if (!negated && rc.isNegativeOne())
     241                 :            :           {  // (>= (* -1 x) -1)
     242                 :          0 :             addLeqOne(v, orig);
     243                 :            :           }
     244                 :            :         }
     245                 :            :       }
     246                 :            :     }
     247                 :            :   }
     248                 :            : 
     249         [ +  + ]:        402 :   if (!negated)
     250                 :            :   {
     251                 :        202 :     learnGeqSub(assertion);
     252                 :            :   }
     253                 :        402 : }
     254                 :            : 
     255                 :        604 : void PseudoBooleanProcessor::learnInternal(Node assertion,
     256                 :            :                                            bool negated,
     257                 :            :                                            Node orig)
     258                 :            : {
     259    [ +  + ][ + ]:        604 :   switch (assertion.getKind())
     260                 :            :   {
     261                 :        402 :     case Kind::GEQ:
     262                 :            :     case Kind::GT:
     263                 :            :     case Kind::LEQ:
     264                 :            :     case Kind::LT:
     265                 :            :     {
     266                 :        804 :       Node rw = rewrite(assertion);
     267         [ +  - ]:        402 :       if (assertion == rw)
     268                 :            :       {
     269         [ +  - ]:        402 :         if (assertion.getKind() == Kind::GEQ)
     270                 :            :         {
     271                 :        402 :           learnRewrittenGeq(assertion, negated, orig);
     272                 :            :         }
     273                 :            :       }
     274                 :            :       else
     275                 :            :       {
     276                 :          0 :         learnInternal(rw, negated, orig);
     277                 :            :       }
     278                 :            :     }
     279                 :        402 :     break;
     280                 :        200 :     case Kind::NOT: learnInternal(assertion[0], !negated, orig); break;
     281                 :          2 :     default: break;  // do nothing
     282                 :            :   }
     283                 :        604 : }
     284                 :            : 
     285                 :        404 : void PseudoBooleanProcessor::learn(Node assertion)
     286                 :            : {
     287         [ -  + ]:        404 :   if (assertion.getKind() == Kind::AND)
     288                 :            :   {
     289                 :          0 :     Node::iterator ci = assertion.begin(), cend = assertion.end();
     290         [ -  - ]:          0 :     for (; ci != cend; ++ci)
     291                 :            :     {
     292                 :          0 :       learn(*ci);
     293                 :            :     }
     294                 :            :   }
     295                 :            :   else
     296                 :            :   {
     297                 :        404 :     learnInternal(assertion, false, assertion);
     298                 :            :   }
     299                 :        404 : }
     300                 :            : 
     301                 :          4 : Node PseudoBooleanProcessor::mkGeqOne(NodeManager* nm, Node v)
     302                 :            : {
     303                 :            :   return nm->mkNode(
     304                 :          4 :       Kind::GEQ, v, nm->mkConstRealOrInt(v.getType(), Rational(1)));
     305                 :            : }
     306                 :            : 
     307                 :          2 : void PseudoBooleanProcessor::learn(const std::vector<Node>& assertions)
     308                 :            : {
     309                 :          2 :   std::vector<Node>::const_iterator ci, cend;
     310                 :          2 :   ci = assertions.begin();
     311                 :          2 :   cend = assertions.end();
     312         [ +  + ]:        406 :   for (; ci != cend; ++ci)
     313                 :            :   {
     314                 :        404 :     learn(*ci);
     315                 :            :   }
     316                 :          2 : }
     317                 :            : 
     318                 :          2 : void PseudoBooleanProcessor::addSub(Node from, Node to)
     319                 :            : {
     320         [ +  - ]:          2 :   if (!d_subCache.hasSubstitution(from))
     321                 :            :   {
     322                 :          2 :     Node rw_to = rewrite(to);
     323                 :          2 :     d_subCache.addSubstitution(from, rw_to);
     324                 :            :   }
     325                 :          2 : }
     326                 :            : 
     327                 :        202 : void PseudoBooleanProcessor::learnGeqSub(Node geq)
     328                 :            : {
     329 [ -  + ][ -  + ]:        202 :   Assert(geq.getKind() == Kind::GEQ);
                 [ -  - ]
     330                 :        202 :   const bool negated = false;
     331                 :        202 :   bool success = decomposeAssertion(geq, negated);
     332         [ +  + ]:        202 :   if (!success)
     333                 :            :   {
     334         [ +  - ]:        200 :     Trace("pbs::rewrites") << "failed " << std::endl;
     335                 :        200 :     return;
     336                 :            :   }
     337 [ -  + ][ -  + ]:          2 :   Assert(d_off.value().isIntegral());
                 [ -  - ]
     338                 :          4 :   Integer off = d_off.value().ceiling();
     339                 :            : 
     340                 :            :   // \sum pos >= \sum neg + off
     341                 :            : 
     342                 :          2 :   NodeManager* nm = nodeManager();
     343                 :            : 
     344                 :            :   // for now special case everything we want
     345                 :            :   // target easy clauses
     346 [ +  - ][ +  - ]:          2 :   if (d_pos.size() == 1 && d_neg.size() == 1 && off.isZero())
         [ +  - ][ +  - ]
     347                 :            :   {
     348                 :            :     // x >= y
     349                 :            :     // |- (y >= 1) => (x >= 1)
     350                 :          4 :     Node x = d_pos.front();
     351                 :          4 :     Node y = d_neg.front();
     352                 :            : 
     353                 :          4 :     Node xGeq1 = mkGeqOne(nm, x);
     354                 :          4 :     Node yGeq1 = mkGeqOne(nm, y);
     355                 :          2 :     Node imp = yGeq1.impNode(xGeq1);
     356                 :          2 :     addSub(geq, imp);
     357                 :            :   }
     358                 :          0 :   else if (d_pos.size() == 0 && d_neg.size() == 2 && off.isNegativeOne())
     359                 :            :   {
     360                 :            :     // 0 >= (x + y -1)
     361                 :            :     // |- 1 >= x + y
     362                 :            :     // |- (or (not (x >= 1)) (not (y >= 1)))
     363                 :          0 :     Node x = d_neg[0];
     364                 :          0 :     Node y = d_neg[1];
     365                 :            : 
     366                 :          0 :     Node xGeq1 = mkGeqOne(nm, x);
     367                 :          0 :     Node yGeq1 = mkGeqOne(nm, y);
     368                 :          0 :     Node cases = (xGeq1.notNode()).orNode(yGeq1.notNode());
     369                 :          0 :     addSub(geq, cases);
     370                 :            :   }
     371                 :          0 :   else if (d_pos.size() == 2 && d_neg.size() == 1 && off.isZero())
     372                 :            :   {
     373                 :            :     // (x + y) >= z
     374                 :            :     // |- (z >= 1) => (or (x >= 1) (y >=1 ))
     375                 :          0 :     Node x = d_pos[0];
     376                 :          0 :     Node y = d_pos[1];
     377                 :          0 :     Node z = d_neg[0];
     378                 :            : 
     379                 :          0 :     Node xGeq1 = mkGeqOne(nm, x);
     380                 :          0 :     Node yGeq1 = mkGeqOne(nm, y);
     381                 :          0 :     Node zGeq1 = mkGeqOne(nm, z);
     382                 :          0 :     Node dis = nm->mkNode(Kind::OR, zGeq1.notNode(), xGeq1, yGeq1);
     383                 :          0 :     addSub(geq, dis);
     384                 :            :   }
     385                 :            : }
     386                 :            : 
     387                 :        404 : Node PseudoBooleanProcessor::applyReplacements(Node pre)
     388                 :            : {
     389                 :        808 :   Node assertion = rewrite(pre);
     390                 :            : 
     391                 :        404 :   Node result = d_subCache.apply(assertion);
     392                 :        404 :   if (TraceIsOn("pbs::rewrites") && result != assertion)
     393                 :            :   {
     394         [ -  - ]:          0 :     Trace("pbs::rewrites") << "applyReplacements" << assertion << "-> "
     395                 :          0 :                            << result << std::endl;
     396                 :            :   }
     397                 :        808 :   return result;
     398                 :            : }
     399                 :            : 
     400                 :          2 : bool PseudoBooleanProcessor::likelyToHelp() const { return d_pbs >= 100; }
     401                 :            : 
     402                 :          2 : void PseudoBooleanProcessor::applyReplacements(
     403                 :            :     AssertionPipeline* assertionsToPreprocess)
     404                 :            : {
     405         [ +  + ]:        406 :   for (size_t i = 0, N = assertionsToPreprocess->size(); i < N; ++i)
     406                 :            :   {
     407                 :        404 :     assertionsToPreprocess->replace(
     408                 :        808 :         i, applyReplacements((*assertionsToPreprocess)[i]));
     409                 :            :   }
     410                 :          2 : }
     411                 :            : 
     412                 :          2 : void PseudoBooleanProcessor::clear()
     413                 :            : {
     414                 :          2 :   d_off.reset();
     415                 :          2 :   d_pos.clear();
     416                 :          2 :   d_neg.clear();
     417                 :          2 : }
     418                 :            : 
     419                 :            : 
     420                 :            : }  // namespace passes
     421                 :            : }  // namespace preprocessing
     422                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14