LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/preprocessing/passes - bv_gauss.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 309 344 89.8 %
Date: 2026-01-20 13:04:20 Functions: 8 8 100.0 %
Branches: 272 436 62.4 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Aina Niemetz, Mathias Preiner, Andrew Reynolds
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Gaussian Elimination preprocessing pass.
      14                 :            :  *
      15                 :            :  * Simplify a given equation system modulo a (prime) number via Gaussian
      16                 :            :  * Elimination if possible.
      17                 :            :  */
      18                 :            : 
      19                 :            : #include "preprocessing/passes/bv_gauss.h"
      20                 :            : 
      21                 :            : #include <unordered_map>
      22                 :            : #include <vector>
      23                 :            : 
      24                 :            : #include "expr/node.h"
      25                 :            : #include "preprocessing/assertion_pipeline.h"
      26                 :            : #include "preprocessing/preprocessing_pass_context.h"
      27                 :            : #include "theory/bv/theory_bv_rewrite_rules_normalization.h"
      28                 :            : #include "theory/bv/theory_bv_utils.h"
      29                 :            : #include "theory/rewriter.h"
      30                 :            : #include "util/bitvector.h"
      31                 :            : 
      32                 :            : using namespace cvc5::internal;
      33                 :            : using namespace cvc5::internal::theory;
      34                 :            : using namespace cvc5::internal::theory::bv;
      35                 :            : 
      36                 :            : namespace cvc5::internal {
      37                 :            : namespace preprocessing {
      38                 :            : namespace passes {
      39                 :            : 
      40                 :       2901 : bool BVGauss::is_bv_const(Node n)
      41                 :            : {
      42         [ +  + ]:       2901 :   if (n.isConst()) { return true; }
      43                 :       1996 :   return rewrite(n).getKind() == Kind::CONST_BITVECTOR;
      44                 :            : }
      45                 :            : 
      46                 :        451 : Node BVGauss::get_bv_const(Node n)
      47                 :            : {
      48 [ -  + ][ -  + ]:        451 :   Assert(is_bv_const(n));
                 [ -  - ]
      49                 :        451 :   return rewrite(n);
      50                 :            : }
      51                 :            : 
      52                 :        241 : Integer BVGauss::get_bv_const_value(Node n)
      53                 :            : {
      54 [ -  + ][ -  + ]:        241 :   Assert(is_bv_const(n));
                 [ -  - ]
      55                 :        482 :   return get_bv_const(n).getConst<BitVector>().getValue();
      56                 :            : }
      57                 :            : 
      58                 :            : /**
      59                 :            :  * Determines if an overflow may occur in given 'expr'.
      60                 :            :  *
      61                 :            :  * Returns 0 if an overflow may occur, and the minimum required
      62                 :            :  * bit-width such that no overflow occurs, otherwise.
      63                 :            :  *
      64                 :            :  * Note that it would suffice for this function to be Boolean.
      65                 :            :  * However, it is handy to determine the minimum required bit-width for
      66                 :            :  * debugging purposes.
      67                 :            :  *
      68                 :            :  * Note: getMinBwExpr assumes that 'expr' is rewritten.
      69                 :            :  *
      70                 :            :  * If not, all operators that are removed via rewriting (e.g., ror, rol, ...)
      71                 :            :  * will be handled via the default case, which is not incorrect but also not
      72                 :            :  * necessarily the minimum.
      73                 :            :  */
      74                 :        101 : uint32_t BVGauss::getMinBwExpr(Node expr)
      75                 :            : {
      76                 :        202 :   std::vector<Node> visit;
      77                 :            :   /* Maps visited nodes to the determined minimum bit-width required. */
      78                 :        202 :   std::unordered_map<Node, unsigned> visited;
      79                 :        101 :   std::unordered_map<Node, unsigned>::iterator it;
      80                 :            : 
      81                 :        101 :   visit.push_back(expr);
      82                 :        101 :   NodeManager* nm = nodeManager();
      83         [ +  + ]:       1459 :   while (!visit.empty())
      84                 :            :   {
      85                 :       1361 :     Node n = visit.back();
      86                 :       1361 :     visit.pop_back();
      87                 :       1361 :     it = visited.find(n);
      88         [ +  + ]:       1361 :     if (it == visited.end())
      89                 :            :     {
      90         [ +  + ]:        745 :       if (is_bv_const(n))
      91                 :            :       {
      92                 :            :         /* Rewrite const expr, overflows in consts are irrelevant. */
      93                 :        210 :         visited[n] = get_bv_const(n).getConst<BitVector>().getValue().length();
      94                 :            :       }
      95                 :            :       else
      96                 :            :       {
      97                 :        535 :         visited[n] = 0;
      98                 :        535 :         visit.push_back(n);
      99         [ +  + ]:       1263 :         for (const Node &nn : n) { visit.push_back(nn); }
     100                 :            :       }
     101                 :            :     }
     102         [ +  + ]:        616 :     else if (it->second == 0)
     103                 :            :     {
     104                 :        533 :       Kind k = n.getKind();
     105 [ -  + ][ -  + ]:        533 :       Assert(k != Kind::CONST_BITVECTOR);
                 [ -  - ]
     106 [ -  + ][ -  + ]:        533 :       Assert(!is_bv_const(n));
                 [ -  - ]
     107 [ +  + ][ +  + ]:        533 :       switch (k)
         [ +  - ][ +  + ]
     108                 :            :       {
     109                 :         28 :         case Kind::BITVECTOR_EXTRACT:
     110                 :            :         {
     111                 :         28 :           const unsigned size = bv::utils::getSize(n);
     112                 :         28 :           const unsigned low = bv::utils::getExtractLow(n);
     113                 :         28 :           const unsigned child_min_width = visited[n[0]];
     114                 :         28 :           visited[n] = std::min(
     115         [ +  - ]:         28 :               size, child_min_width >= low ? child_min_width - low : 0u);
     116 [ -  + ][ -  + ]:         28 :           Assert(visited[n] <= visited[n[0]]);
                 [ -  - ]
     117                 :         28 :           break;
     118                 :            :         }
     119                 :            : 
     120                 :         36 :         case Kind::BITVECTOR_ZERO_EXTEND:
     121                 :            :         {
     122                 :         36 :           visited[n] = visited[n[0]];
     123                 :         36 :           break;
     124                 :            :         }
     125                 :            : 
     126                 :         69 :         case Kind::BITVECTOR_MULT:
     127                 :            :         {
     128                 :         69 :           Integer maxval = Integer(1);
     129         [ +  + ]:        213 :           for (const Node& nn : n)
     130                 :            :           {
     131         [ +  + ]:        144 :             if (is_bv_const(nn))
     132                 :            :             {
     133                 :         57 :               maxval *= get_bv_const_value(nn);
     134                 :            :             }
     135                 :            :             else
     136                 :            :             {
     137                 :         87 :               maxval *= BitVector::mkOnes(visited[nn]).getValue();
     138                 :            :             }
     139                 :            :           }
     140                 :         69 :           unsigned w = maxval.length();
     141         [ +  + ]:         69 :           if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
     142                 :         67 :           visited[n] = w;
     143                 :         67 :           break;
     144                 :            :         }
     145                 :            : 
     146                 :        164 :         case Kind::BITVECTOR_CONCAT:
     147                 :            :         {
     148                 :            :           unsigned i, wnz, nc;
     149         [ +  + ]:        325 :           for (i = 0, wnz = 0, nc = n.getNumChildren() - 1; i < nc; ++i)
     150                 :            :           {
     151                 :        189 :             unsigned wni = bv::utils::getSize(n[i]);
     152         [ +  + ]:        189 :             if (n[i] != bv::utils::mkZero(nm, wni))
     153                 :            :             {
     154                 :         28 :               break;
     155                 :            :             }
     156                 :            :             /* sum of all bit-widths of leading zero concats */
     157                 :        161 :             wnz += wni;
     158                 :            :           }
     159                 :            :           /* Do not consider leading zero concats, i.e.,
     160                 :            :            * min bw of current concat is determined as
     161                 :            :            *   min bw of first non-zero term
     162                 :            :            *   plus actual bw of all subsequent terms */
     163                 :        328 :           visited[n] = bv::utils::getSize(n) + visited[n[i]]
     164                 :        164 :                        - bv::utils::getSize(n[i]) - wnz;
     165                 :        164 :           break;
     166                 :            :         }
     167                 :            : 
     168                 :          3 :         case Kind::BITVECTOR_UREM:
     169                 :            :         case Kind::BITVECTOR_LSHR:
     170                 :            :         case Kind::BITVECTOR_ASHR:
     171                 :            :         {
     172                 :          3 :           visited[n] = visited[n[0]];
     173                 :          3 :           break;
     174                 :            :         }
     175                 :            : 
     176                 :          0 :         case Kind::BITVECTOR_OR:
     177                 :            :         case Kind::BITVECTOR_NOR:
     178                 :            :         case Kind::BITVECTOR_XOR:
     179                 :            :         case Kind::BITVECTOR_XNOR:
     180                 :            :         case Kind::BITVECTOR_AND:
     181                 :            :         case Kind::BITVECTOR_NAND:
     182                 :            :         {
     183                 :          0 :           unsigned wmax = 0;
     184         [ -  - ]:          0 :           for (const Node &nn : n)
     185                 :            :           {
     186         [ -  - ]:          0 :             if (visited[nn] > wmax)
     187                 :            :             {
     188                 :          0 :               wmax = visited[nn];
     189                 :            :             }
     190                 :            :           }
     191                 :          0 :           visited[n] = wmax;
     192                 :          0 :           break;
     193                 :            :         }
     194                 :            : 
     195                 :         52 :         case Kind::BITVECTOR_ADD:
     196                 :            :         {
     197                 :         52 :           Integer maxval = Integer(0);
     198         [ +  + ]:        182 :           for (const Node& nn : n)
     199                 :            :           {
     200         [ -  + ]:        130 :             if (is_bv_const(nn))
     201                 :            :             {
     202                 :          0 :               maxval += get_bv_const_value(nn);
     203                 :            :             }
     204                 :            :             else
     205                 :            :             {
     206                 :        130 :               maxval += BitVector::mkOnes(visited[nn]).getValue();
     207                 :            :             }
     208                 :            :           }
     209                 :         52 :           unsigned w = maxval.length();
     210         [ +  + ]:         52 :           if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
     211                 :         51 :           visited[n] = w;
     212                 :         51 :           break;
     213                 :            :         }
     214                 :            : 
     215                 :        181 :         default:
     216                 :            :         {
     217                 :            :           /* BITVECTOR_UDIV (since x / 0 = -1)
     218                 :            :            * BITVECTOR_NOT
     219                 :            :            * BITVECTOR_NEG
     220                 :            :            * BITVECTOR_SHL */
     221                 :        181 :           visited[n] = bv::utils::getSize(n);
     222                 :            :         }
     223                 :            :       }
     224                 :            :     }
     225                 :            :   }
     226 [ -  + ][ -  + ]:         98 :   Assert(visited.find(expr) != visited.end());
                 [ -  - ]
     227                 :         98 :   return visited[expr];
     228                 :            : }
     229                 :            : 
     230                 :            : /**
     231                 :            :  * Apply Gaussian Elimination modulo a (prime) number.
     232                 :            :  * The given equation system is represented as a matrix of Integers.
     233                 :            :  *
     234                 :            :  * Note that given 'prime' does not have to be prime but can be any
     235                 :            :  * arbitrary number. However, if 'prime' is indeed prime, GE is guaranteed
     236                 :            :  * to succeed, which is not the case, otherwise.
     237                 :            :  *
     238                 :            :  * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
     239                 :            :  * successful, and NONE, otherwise.
     240                 :            :  *
     241                 :            :  * Vectors 'rhs' and 'lhs' represent the right hand side and left hand side
     242                 :            :  * of the given matrix, respectively. The resulting matrix (in row echelon
     243                 :            :  * form) is stored in 'rhs' and 'lhs', i.e., the given matrix is overwritten
     244                 :            :  * with the resulting matrix.
     245                 :            :  */
     246                 :         70 : BVGauss::Result BVGauss::gaussElim(Integer prime,
     247                 :            :                                    std::vector<Integer>& rhs,
     248                 :            :                                    std::vector<std::vector<Integer>>& lhs)
     249                 :            : {
     250 [ -  + ][ -  + ]:         70 :   Assert(prime > 0);
                 [ -  - ]
     251 [ -  + ][ -  + ]:         70 :   Assert(lhs.size());
                 [ -  - ]
     252 [ -  + ][ -  + ]:         70 :   Assert(lhs.size() == rhs.size());
                 [ -  - ]
     253 [ -  + ][ -  + ]:         70 :   Assert(lhs.size() <= lhs[0].size());
                 [ -  - ]
     254                 :            : 
     255                 :            :   /* special case: zero ring */
     256         [ +  + ]:         70 :   if (prime == 1)
     257                 :            :   {
     258                 :          1 :     rhs = std::vector<Integer>(rhs.size(), Integer(0));
     259                 :          3 :     lhs = std::vector<std::vector<Integer>>(
     260                 :          3 :         lhs.size(), std::vector<Integer>(lhs[0].size(), Integer(0)));
     261                 :          1 :     return BVGauss::Result::UNIQUE;
     262                 :            :   }
     263                 :            : 
     264                 :         69 :   size_t nrows = lhs.size();
     265                 :         69 :   size_t ncols = lhs[0].size();
     266                 :            : 
     267                 :            : #ifdef CVC5_ASSERTIONS
     268 [ +  + ][ -  + ]:        195 :   for (size_t i = 1; i < nrows; ++i) Assert(lhs[i].size() == ncols);
         [ -  + ][ -  - ]
     269                 :            : #endif
     270                 :            :   /* (1) if element in pivot column is non-zero and != 1, divide row elements
     271                 :            :    *     by element in pivot column modulo prime, i.e., multiply row with
     272                 :            :    *     multiplicative inverse of element in pivot column modulo prime
     273                 :            :    *
     274                 :            :    * (2) subtract pivot row from all rows below pivot row
     275                 :            :    *
     276                 :            :    * (3) subtract (multiple of) current row from all rows above s.t. all
     277                 :            :    *     elements in current pivot column above current row become equal to one
     278                 :            :    *
     279                 :            :    * Note: we do not normalize the given matrix to values modulo prime
     280                 :            :    *       beforehand but on-the-fly. */
     281                 :            : 
     282                 :            :   /* pivot = lhs[pcol][pcol] */
     283 [ +  + ][ +  + ]:        234 :   for (size_t pcol = 0, prow = 0; pcol < ncols && prow < nrows; ++pcol, ++prow)
     284                 :            :   {
     285                 :            :     /* lhs[j][pcol]: element in pivot column */
     286         [ +  + ]:        510 :     for (size_t j = prow; j < nrows; ++j)
     287                 :            :     {
     288                 :            : #ifdef CVC5_ASSERTIONS
     289         [ +  + ]:        572 :       for (size_t k = 0; k < pcol; ++k)
     290                 :            :       {
     291 [ -  + ][ -  + ]:        227 :         Assert(lhs[j][k] == 0);
                 [ -  - ]
     292                 :            :       }
     293                 :            : #endif
     294                 :            :       /* normalize element in pivot column to modulo prime */
     295                 :        345 :       lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
     296                 :            :       /* exchange rows if pivot elem is 0 */
     297         [ +  + ]:        345 :       if (j == prow)
     298                 :            :       {
     299         [ +  + ]:        209 :         while (lhs[j][pcol] == 0)
     300                 :            :         {
     301         [ +  + ]:         85 :           for (size_t k = prow + 1; k < nrows; ++k)
     302                 :            :           {
     303                 :         53 :             lhs[k][pcol] = lhs[k][pcol].euclidianDivideRemainder(prime);
     304         [ +  + ]:         53 :             if (lhs[k][pcol] != 0)
     305                 :            :             {
     306                 :         27 :               std::swap(rhs[j], rhs[k]);
     307                 :         27 :               std::swap(lhs[j], lhs[k]);
     308                 :         27 :               break;
     309                 :            :             }
     310                 :            :           }
     311         [ +  + ]:         59 :           if (pcol >= ncols - 1) break;
     312         [ +  + ]:         38 :           if (lhs[j][pcol] == 0)
     313                 :            :           {
     314                 :         16 :             pcol += 1;
     315         [ +  + ]:         16 :             if (lhs[j][pcol] != 0)
     316                 :         10 :               lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
     317                 :            :           }
     318                 :            :         }
     319                 :            :       }
     320                 :            : 
     321         [ +  + ]:        345 :       if (lhs[j][pcol] != 0)
     322                 :            :       {
     323                 :            :         /* (1) */
     324         [ +  + ]:        264 :         if (lhs[j][pcol] != 1)
     325                 :            :         {
     326                 :        196 :           Integer inv = lhs[j][pcol].modInverse(prime);
     327         [ +  + ]:        196 :           if (inv == -1)
     328                 :            :           {
     329                 :          6 :             return BVGauss::Result::INVALID; /* not coprime */
     330                 :            :           }
     331         [ +  + ]:        626 :           for (size_t k = pcol; k < ncols; ++k)
     332                 :            :           {
     333                 :        436 :             lhs[j][k] = lhs[j][k].modMultiply(inv, prime);
     334         [ +  + ]:        436 :             if (j <= prow) continue; /* pivot */
     335                 :        246 :             lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
     336                 :            :           }
     337                 :        190 :           rhs[j] = rhs[j].modMultiply(inv, prime);
     338         [ +  + ]:        190 :           if (j > prow) { rhs[j] = rhs[j].modAdd(-rhs[prow], prime); }
     339                 :            :         }
     340                 :            :         /* (2) */
     341         [ +  + ]:         68 :         else if (j != prow)
     342                 :            :         {
     343         [ +  + ]:         46 :           for (size_t k = pcol; k < ncols; ++k)
     344                 :            :           {
     345                 :         34 :             lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
     346                 :            :           }
     347                 :         12 :           rhs[j] = rhs[j].modAdd(-rhs[prow], prime);
     348                 :            :         }
     349                 :            :       }
     350                 :            :     }
     351                 :            :     /* (3) */
     352         [ +  + ]:        305 :     for (size_t j = 0; j < prow; ++j)
     353                 :            :     {
     354                 :        280 :       Integer mul = lhs[j][pcol];
     355         [ +  + ]:        140 :       if (mul != 0)
     356                 :            :       {
     357         [ +  + ]:        276 :         for (size_t k = pcol; k < ncols; ++k)
     358                 :            :         {
     359                 :        162 :           lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k] * mul, prime);
     360                 :            :         }
     361                 :        114 :         rhs[j] = rhs[j].modAdd(-rhs[prow] * mul, prime);
     362                 :            :       }
     363                 :            :     }
     364                 :            :   }
     365                 :            : 
     366                 :         63 :   bool ispart = false;
     367         [ +  + ]:        236 :   for (size_t i = 0; i < nrows; ++i)
     368                 :            :   {
     369                 :        177 :     size_t pcol = i;
     370 [ +  + ][ +  + ]:        229 :     while (pcol < ncols && lhs[i][pcol] == 0) ++pcol;
         [ +  + ][ +  + ]
                 [ -  - ]
     371         [ +  + ]:        177 :     if (pcol >= ncols)
     372                 :            :     {
     373                 :         29 :       rhs[i] = rhs[i].euclidianDivideRemainder(prime);
     374         [ +  + ]:         29 :       if (rhs[i] != 0)
     375                 :            :       {
     376                 :            :         /* no solution */
     377                 :          4 :         return BVGauss::Result::NONE;
     378                 :            :       }
     379                 :         25 :       continue;
     380                 :            :     }
     381         [ +  + ]:        488 :     for (size_t j = i; j < ncols; ++j)
     382                 :            :     {
     383 [ +  + ][ -  + ]:        340 :       if (lhs[i][j] >= prime || lhs[i][j] <= -prime)
         [ +  + ][ +  + ]
                 [ -  - ]
     384                 :            :       {
     385                 :          1 :         lhs[i][j] = lhs[i][j].euclidianDivideRemainder(prime);
     386                 :            :       }
     387 [ +  + ][ +  + ]:        340 :       if (j > pcol && lhs[i][j] != 0)
         [ +  + ][ +  + ]
                 [ -  - ]
     388                 :            :       {
     389                 :         36 :         ispart = true;
     390                 :            :       }
     391                 :            :     }
     392                 :            :   }
     393                 :            : 
     394         [ +  + ]:         59 :   if (ispart)
     395                 :            :   {
     396                 :         21 :     return BVGauss::Result::PARTIAL;
     397                 :            :   }
     398                 :            : 
     399                 :         38 :   return BVGauss::Result::UNIQUE;
     400                 :            : }
     401                 :            : 
     402                 :            : /**
     403                 :            :  * Apply Gaussian Elimination on a set of equations modulo some (prime)
     404                 :            :  * number given as bit-vector equations.
     405                 :            :  *
     406                 :            :  * IMPORTANT: Applying GE modulo some number (rather than modulo 2^bw)
     407                 :            :  * on a set of bit-vector equations is only sound if this set of equations
     408                 :            :  * has a solution that does not produce overflows. Consequently, we only
     409                 :            :  * apply GE if the given bit-width guarantees that no overflows can occur
     410                 :            :  * in the given set of equations.
     411                 :            :  *
     412                 :            :  * Note that the given set of equations does not have to be modulo a prime
     413                 :            :  * but can be modulo any arbitrary number. However, if it is indeed modulo
     414                 :            :  * prime, GE is guaranteed to succeed, which is not the case, otherwise.
     415                 :            :  *
     416                 :            :  * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
     417                 :            :  * successful, and NONE, otherwise.
     418                 :            :  *
     419                 :            :  * The resulting constraints are stored in 'res' as a mapping of unknown
     420                 :            :  * to result (modulo prime). These mapped results are added as constraints
     421                 :            :  * of the form 'unknown = mapped result' in applyInternal.
     422                 :            :  */
     423                 :         18 : BVGauss::Result BVGauss::gaussElimRewriteForUrem(
     424                 :            :     const std::vector<Node>& equations, std::unordered_map<Node, Node>& res)
     425                 :            : {
     426 [ -  + ][ -  + ]:         18 :   Assert(res.empty());
                 [ -  - ]
     427                 :            : 
     428                 :         36 :   Node prime;
     429                 :         36 :   Integer iprime;
     430                 :         36 :   std::unordered_map<Node, std::vector<Integer>> vars;
     431                 :         18 :   size_t neqs = equations.size();
     432                 :         36 :   std::vector<Integer> rhs;
     433                 :            :   std::vector<std::vector<Integer>> lhs =
     434                 :         54 :       std::vector<std::vector<Integer>>(neqs, std::vector<Integer>());
     435                 :            : 
     436                 :         18 :   res = std::unordered_map<Node, Node>();
     437                 :            : 
     438                 :         18 :   NodeManager* nm = nodeManager();
     439         [ +  + ]:         64 :   for (size_t i = 0; i < neqs; ++i)
     440                 :            :   {
     441                 :         46 :     Node eq = equations[i];
     442 [ -  + ][ -  + ]:         46 :     Assert(eq.getKind() == Kind::EQUAL);
                 [ -  - ]
     443                 :         46 :     Node urem, eqrhs;
     444                 :            : 
     445         [ +  - ]:         46 :     if (eq[0].getKind() == Kind::BITVECTOR_UREM)
     446                 :            :     {
     447                 :         46 :       urem = eq[0];
     448 [ -  + ][ -  + ]:         46 :       Assert(is_bv_const(eq[1]));
                 [ -  - ]
     449                 :         46 :       eqrhs = eq[1];
     450                 :            :     }
     451                 :            :     else
     452                 :            :     {
     453                 :          0 :       Assert(eq[1].getKind() == Kind::BITVECTOR_UREM);
     454                 :          0 :       urem = eq[1];
     455                 :          0 :       Assert(is_bv_const(eq[0]));
     456                 :          0 :       eqrhs = eq[0];
     457                 :            :     }
     458         [ -  + ]:         46 :     if (getMinBwExpr(rewrite(urem[0])) == 0)
     459                 :            :     {
     460         [ -  - ]:          0 :       Trace("bv-gauss-elim")
     461                 :            :           << "Minimum required bit-width exceeds given bit-width, "
     462                 :          0 :              "will not apply Gaussian Elimination."
     463                 :          0 :           << std::endl;
     464                 :          0 :       return BVGauss::Result::INVALID;
     465                 :            :     }
     466                 :         46 :     rhs.push_back(get_bv_const_value(eqrhs));
     467                 :            : 
     468 [ -  + ][ -  + ]:         46 :     Assert(is_bv_const(urem[1]));
                 [ -  - ]
     469                 :         46 :     Assert(i == 0 || get_bv_const_value(urem[1]) == iprime);
     470         [ +  + ]:         46 :     if (i == 0)
     471                 :            :     {
     472                 :         18 :       prime = urem[1];
     473                 :         18 :       iprime = get_bv_const_value(prime);
     474                 :            :     }
     475                 :            : 
     476                 :         92 :     std::unordered_map<Node, Integer> tmp;
     477                 :         92 :     std::vector<Node> stack;
     478                 :         46 :     stack.push_back(urem[0]);
     479         [ +  + ]:        198 :     while (!stack.empty())
     480                 :            :     {
     481                 :        152 :       Node n = stack.back();
     482                 :        152 :       stack.pop_back();
     483                 :            : 
     484                 :            :       /* Subtract from rhs if const */
     485         [ -  + ]:        152 :       if (is_bv_const(n))
     486                 :            :       {
     487                 :          0 :         Integer val = get_bv_const_value(n);
     488         [ -  - ]:          0 :         if (val > 0) rhs.back() -= val;
     489                 :          0 :         continue;
     490                 :            :       }
     491                 :            : 
     492                 :            :       /* Split into matrix columns */
     493                 :        152 :       Kind k = n.getKind();
     494         [ +  + ]:        152 :       if (k == Kind::BITVECTOR_ADD)
     495                 :            :       {
     496         [ +  + ]:        159 :         for (const Node& nn : n) { stack.push_back(nn); }
     497                 :            :       }
     498         [ +  + ]:         99 :       else if (k == Kind::BITVECTOR_MULT)
     499                 :            :       {
     500                 :        184 :         Node n0, n1;
     501                 :            :         /* Flatten mult expression. */
     502                 :         92 :         n = RewriteRule<FlattenAssocCommut>::run<true>(n);
     503                 :            :         /* Split operands into consts and non-consts */
     504                 :        184 :         NodeBuilder nb_consts(nm, k);
     505                 :         92 :         NodeBuilder nb_nonconsts(nm, k);
     506         [ +  + ]:        284 :         for (const Node& nn : n)
     507                 :            :         {
     508                 :        384 :           Node nnrw = rewrite(nn);
     509         [ +  + ]:        192 :           if (is_bv_const(nnrw))
     510                 :            :           {
     511                 :         88 :             nb_consts << nnrw;
     512                 :            :           }
     513                 :            :           else
     514                 :            :           {
     515                 :        104 :             nb_nonconsts << nnrw;
     516                 :            :           }
     517                 :            :         }
     518 [ -  + ][ -  + ]:         92 :         Assert(nb_nonconsts.getNumChildren() > 0);
                 [ -  - ]
     519                 :            :         /* n0 is const */
     520                 :         92 :         unsigned nc = nb_consts.getNumChildren();
     521         [ -  + ]:         92 :         if (nc > 1)
     522                 :            :         {
     523                 :          0 :           n0 = rewrite(nb_consts.constructNode());
     524                 :            :         }
     525         [ +  + ]:         92 :         else if (nc == 1)
     526                 :            :         {
     527                 :         88 :           n0 = nb_consts[0];
     528                 :            :         }
     529                 :            :         else
     530                 :            :         {
     531                 :          4 :           n0 = bv::utils::mkOne(nm, bv::utils::getSize(n));
     532                 :            :         }
     533                 :            :         /* n1 is a mult with non-const operands */
     534         [ +  + ]:         92 :         if (nb_nonconsts.getNumChildren() > 1)
     535                 :            :         {
     536                 :         10 :           n1 = rewrite(nb_nonconsts.constructNode());
     537                 :            :         }
     538                 :            :         else
     539                 :            :         {
     540                 :         82 :           n1 = nb_nonconsts[0];
     541                 :            :         }
     542 [ -  + ][ -  + ]:         92 :         Assert(is_bv_const(n0));
                 [ -  - ]
     543 [ -  + ][ -  + ]:         92 :         Assert(!is_bv_const(n1));
                 [ -  - ]
     544                 :         92 :         tmp[n1] += get_bv_const_value(n0);
     545                 :            :       }
     546                 :            :       else
     547                 :            :       {
     548                 :          7 :         tmp[n] += Integer(1);
     549                 :            :       }
     550                 :            :     }
     551                 :            : 
     552                 :            :     /* Note: "var" is not necessarily a VARIABLE but can be an arbitrary expr */
     553                 :            : 
     554         [ +  + ]:        145 :     for (const auto& p : tmp)
     555                 :            :     {
     556                 :        198 :       Node var = p.first;
     557                 :        198 :       Integer val = p.second;
     558 [ +  + ][ +  + ]:         99 :       if (i > 0 && vars.find(var) == vars.end())
                 [ +  + ]
     559                 :            :       {
     560                 :            :         /* Add column and fill column elements of rows above with 0. */
     561                 :         14 :         vars[var].insert(vars[var].end(), i, Integer(0));
     562                 :            :       }
     563                 :         99 :       vars[var].push_back(val);
     564                 :            :     }
     565                 :            : 
     566         [ +  + ]:        166 :     for (const auto& p : vars)
     567                 :            :     {
     568         [ +  + ]:        120 :       if (tmp.find(p.first) == tmp.end())
     569                 :            :       {
     570                 :         21 :         vars[p.first].push_back(Integer(0));
     571                 :            :       }
     572                 :            :     }
     573                 :            :   }
     574                 :            : 
     575                 :         18 :   size_t nvars = vars.size();
     576         [ -  + ]:         18 :   if (nvars == 0)
     577                 :            :   {
     578                 :          0 :     return BVGauss::Result::INVALID;
     579                 :            :   }
     580                 :         18 :   size_t nrows = vars.begin()->second.size();
     581                 :            : #ifdef CVC5_ASSERTIONS
     582         [ +  + ]:         71 :   for (const auto& p : vars)
     583                 :            :   {
     584 [ -  + ][ -  + ]:         53 :     Assert(p.second.size() == nrows);
                 [ -  - ]
     585                 :            :   }
     586                 :            : #endif
     587                 :            : 
     588         [ -  + ]:         18 :   if (nrows < 1)
     589                 :            :   {
     590                 :          0 :     return BVGauss::Result::INVALID;
     591                 :            :   }
     592                 :            : 
     593         [ +  + ]:         64 :   for (size_t i = 0; i < nrows; ++i)
     594                 :            :   {
     595         [ +  + ]:        182 :     for (const auto& p : vars)
     596                 :            :     {
     597                 :        136 :       lhs[i].push_back(p.second[i]);
     598                 :            :     }
     599                 :            :   }
     600                 :            : 
     601                 :            : #ifdef CVC5_ASSERTIONS
     602         [ +  + ]:         64 :   for (const auto& row : lhs)
     603                 :            :   {
     604 [ -  + ][ -  + ]:         46 :     Assert(row.size() == nvars);
                 [ -  - ]
     605                 :            :   }
     606 [ -  + ][ -  + ]:         18 :   Assert(lhs.size() == rhs.size());
                 [ -  - ]
     607                 :            : #endif
     608                 :            : 
     609         [ +  + ]:         18 :   if (lhs.size() > lhs[0].size())
     610                 :            :   {
     611                 :          1 :     return BVGauss::Result::INVALID;
     612                 :            :   }
     613                 :            : 
     614         [ +  - ]:         17 :   Trace("bv-gauss-elim") << "Applying Gaussian Elimination..." << std::endl;
     615                 :         17 :   BVGauss::Result ret = gaussElim(iprime, rhs, lhs);
     616                 :            : 
     617 [ +  - ][ +  - ]:         17 :   if (ret != BVGauss::Result::NONE && ret != BVGauss::Result::INVALID)
     618                 :            :   {
     619                 :         34 :     std::vector<Node> vvars;
     620         [ +  + ]:         68 :     for (const auto& p : vars) { vvars.push_back(p.first); }
     621 [ -  + ][ -  + ]:         17 :     Assert(nvars == vvars.size());
                 [ -  - ]
     622 [ -  + ][ -  + ]:         17 :     Assert(nrows == lhs.size());
                 [ -  - ]
     623 [ -  + ][ -  + ]:         17 :     Assert(nrows == rhs.size());
                 [ -  - ]
     624         [ +  + ]:         17 :     if (ret == BVGauss::Result::UNIQUE)
     625                 :            :     {
     626         [ +  + ]:         27 :       for (size_t i = 0; i < nvars; ++i)
     627                 :            :       {
     628                 :         40 :         res[vvars[i]] = nm->mkConst<BitVector>(
     629                 :         60 :             BitVector(bv::utils::getSize(vvars[i]), rhs[i]));
     630                 :            :       }
     631                 :            :     }
     632                 :            :     else
     633                 :            :     {
     634 [ -  + ][ -  + ]:         10 :       Assert(ret == BVGauss::Result::PARTIAL);
                 [ -  - ]
     635                 :            : 
     636 [ +  - ][ +  + ]:         30 :       for (size_t pcol = 0, prow = 0; pcol < nvars && prow < nrows;
     637                 :            :            ++pcol, ++prow)
     638                 :            :       {
     639                 :         22 :         Assert(lhs[prow][pcol] == 0 || lhs[prow][pcol] == 1);
     640 [ +  + ][ +  + ]:         25 :         while (pcol < nvars && lhs[prow][pcol] == 0) pcol += 1;
         [ +  + ][ +  + ]
                 [ -  - ]
     641         [ +  + ]:         22 :         if (pcol >= nvars)
     642                 :            :         {
     643 [ -  + ][ -  + ]:          2 :           Assert(rhs[prow] == 0);
                 [ -  - ]
     644                 :          2 :           break;
     645                 :            :         }
     646         [ -  + ]:         20 :         if (lhs[prow][pcol] == 0)
     647                 :            :         {
     648                 :          0 :           Assert(rhs[prow] == 0);
     649                 :          0 :           continue;
     650                 :            :         }
     651 [ -  + ][ -  + ]:         20 :         Assert(lhs[prow][pcol] == 1);
                 [ -  - ]
     652                 :         40 :         std::vector<Node> stack;
     653         [ +  + ]:         51 :         for (size_t i = pcol + 1; i < nvars; ++i)
     654                 :            :         {
     655         [ +  + ]:         31 :           if (lhs[prow][i] == 0) continue;
     656                 :            :           /* Normalize (no negative numbers, hence no subtraction)
     657                 :            :            * e.g., x = 4 - 2y  --> x = 4 + 9y (modulo 11) */
     658                 :         34 :           Integer m = iprime - lhs[prow][i];
     659                 :         34 :           Node bv = bv::utils::mkConst(nm, bv::utils::getSize(vvars[i]), m);
     660                 :         51 :           Node mult = nm->mkNode(Kind::BITVECTOR_MULT, vvars[i], bv);
     661                 :         17 :           stack.push_back(mult);
     662                 :            :         }
     663                 :            : 
     664         [ +  + ]:         20 :         if (stack.empty())
     665                 :            :         {
     666                 :          6 :           res[vvars[pcol]] = nm->mkConst<BitVector>(
     667                 :          9 :               BitVector(bv::utils::getSize(vvars[pcol]), rhs[prow]));
     668                 :            :         }
     669                 :            :         else
     670                 :            :         {
     671                 :         34 :           Node tmp = stack.size() == 1 ? stack[0]
     672         [ +  - ]:         34 :                                        : nm->mkNode(Kind::BITVECTOR_ADD, stack);
     673                 :            : 
     674         [ +  + ]:         17 :           if (rhs[prow] != 0)
     675                 :            :           {
     676                 :            :             tmp =
     677                 :         64 :                 nm->mkNode(Kind::BITVECTOR_ADD,
     678                 :         32 :                            bv::utils::mkConst(
     679                 :         16 :                                nm, bv::utils::getSize(vvars[pcol]), rhs[prow]),
     680                 :         16 :                            tmp);
     681                 :            :           }
     682 [ -  + ][ -  + ]:         17 :           Assert(!is_bv_const(tmp));
                 [ -  - ]
     683                 :         17 :           res[vvars[pcol]] = nm->mkNode(Kind::BITVECTOR_UREM, tmp, prime);
     684                 :            :         }
     685                 :            :       }
     686                 :            :     }
     687                 :            :   }
     688                 :         17 :   return ret;
     689                 :            : }
     690                 :            : 
     691                 :      50925 : BVGauss::BVGauss(PreprocessingPassContext* preprocContext,
     692                 :      50925 :                  const std::string& name)
     693                 :      50925 :     : PreprocessingPass(preprocContext, name)
     694                 :            : {
     695                 :      50925 : }
     696                 :            : 
     697                 :          3 : PreprocessingPassResult BVGauss::applyInternal(
     698                 :            :     AssertionPipeline* assertionsToPreprocess)
     699                 :            : {
     700                 :          6 :   std::vector<Node> assertions(assertionsToPreprocess->ref());
     701                 :          6 :   std::unordered_map<Node, std::vector<Node>> equations;
     702                 :            : 
     703         [ +  + ]:         13 :   while (!assertions.empty())
     704                 :            :   {
     705                 :         10 :     Node a = assertions.back();
     706                 :         10 :     assertions.pop_back();
     707                 :         10 :     cvc5::internal::Kind k = a.getKind();
     708                 :            : 
     709         [ -  + ]:         10 :     if (k == Kind::AND)
     710                 :            :     {
     711         [ -  - ]:          0 :       for (const Node& aa : a)
     712                 :            :       {
     713                 :          0 :         assertions.push_back(aa);
     714                 :            :       }
     715                 :            :     }
     716         [ +  - ]:         10 :     else if (k == Kind::EQUAL)
     717                 :            :     {
     718                 :         10 :       Node urem;
     719                 :            : 
     720                 :         10 :       if (is_bv_const(a[1]) && a[0].getKind() == Kind::BITVECTOR_UREM)
     721                 :            :       {
     722                 :         10 :         urem = a[0];
     723                 :            :       }
     724                 :          0 :       else if (is_bv_const(a[0]) && a[1].getKind() == Kind::BITVECTOR_UREM)
     725                 :            :       {
     726                 :          0 :         urem = a[1];
     727                 :            :       }
     728                 :            :       else
     729                 :            :       {
     730                 :          0 :         continue;
     731                 :            :       }
     732                 :            : 
     733                 :         10 :       if (urem[0].getKind() == Kind::BITVECTOR_ADD && is_bv_const(urem[1]))
     734                 :            :       {
     735                 :         10 :         equations[urem[1]].push_back(a);
     736                 :            :       }
     737                 :            :     }
     738                 :            :   }
     739                 :            : 
     740                 :          6 :   std::unordered_map<Node, Node> subst;
     741                 :            : 
     742                 :          3 :   NodeManager* nm = nodeManager();
     743         [ +  + ]:          7 :   for (const auto& eq : equations)
     744                 :            :   {
     745         [ -  + ]:          4 :     if (eq.second.size() <= 1) { continue; }
     746                 :            : 
     747                 :          4 :     std::unordered_map<Node, Node> res;
     748                 :          4 :     BVGauss::Result ret = gaussElimRewriteForUrem(eq.second, res);
     749         [ +  - ]:          8 :     Trace("bv-gauss-elim") << "result: "
     750                 :            :                            << (ret == BVGauss::Result::INVALID
     751         [ -  - ]:          4 :                                    ? "INVALID"
     752                 :            :                                    : (ret == BVGauss::Result::UNIQUE
     753         [ -  - ]:          0 :                                           ? "UNIQUE"
     754                 :            :                                           : (ret == BVGauss::Result::PARTIAL
     755         [ -  - ]:          0 :                                                  ? "PARTIAL"
     756                 :          0 :                                                  : "NONE")))
     757                 :          4 :                            << std::endl;
     758         [ +  - ]:          4 :     if (ret != BVGauss::Result::INVALID)
     759                 :            :     {
     760         [ -  + ]:          4 :       if (ret == BVGauss::Result::NONE)
     761                 :            :       {
     762                 :          0 :         Node n = nm->mkConst<bool>(false);
     763                 :          0 :         assertionsToPreprocess->push_back(
     764                 :            :             n, false, nullptr, TrustId::PREPROCESS_BV_GUASS_LEMMA);
     765                 :          0 :         return PreprocessingPassResult::CONFLICT;
     766                 :            :       }
     767                 :            :       else
     768                 :            :       {
     769         [ +  + ]:         14 :         for (const Node& e : eq.second)
     770                 :            :         {
     771                 :         10 :           subst[e] = nm->mkConst<bool>(true);
     772                 :            :         }
     773                 :            :         /* add resulting constraints */
     774         [ +  + ]:         14 :         for (const auto& p : res)
     775                 :            :         {
     776                 :         20 :           Node a = nm->mkNode(Kind::EQUAL, p.first, p.second);
     777         [ +  - ]:         10 :           Trace("bv-gauss-elim") << "added assertion: " << a << std::endl;
     778                 :            :           // add new assertion
     779                 :         10 :           assertionsToPreprocess->push_back(
     780                 :            :               a, false, nullptr, TrustId::PREPROCESS_BV_GUASS_LEMMA);
     781                 :            :         }
     782                 :            :       }
     783                 :            :     }
     784                 :            :   }
     785                 :            : 
     786         [ +  - ]:          3 :   if (!subst.empty())
     787                 :            :   {
     788                 :            :     /* delete (= substitute with true) obsolete assertions */
     789                 :          3 :     const std::vector<Node>& aref = assertionsToPreprocess->ref();
     790         [ +  + ]:         23 :     for (size_t i = 0, asize = aref.size(); i < asize; ++i)
     791                 :            :     {
     792                 :         40 :       Node a = aref[i];
     793                 :         20 :       Node as = a.substitute(subst.begin(), subst.end());
     794                 :            :       // replace the assertion
     795                 :         20 :       assertionsToPreprocess->replace(
     796                 :            :           i, as, nullptr, TrustId::PREPROCESS_BV_GUASS);
     797                 :            :     }
     798                 :            :   }
     799                 :          3 :   return PreprocessingPassResult::NO_CONFLICT;
     800                 :            : }
     801                 :            : 
     802                 :            : 
     803                 :            : }  // namespace passes
     804                 :            : }  // namespace preprocessing
     805                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14