LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/preprocessing/passes - bv_gauss.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 306 341 89.7 %
Date: 2024-11-03 11:40:22 Functions: 8 8 100.0 %
Branches: 272 436 62.4 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Aina Niemetz, Mathias Preiner, Andrew Reynolds
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Gaussian Elimination preprocessing pass.
      14                 :            :  *
      15                 :            :  * Simplify a given equation system modulo a (prime) number via Gaussian
      16                 :            :  * Elimination if possible.
      17                 :            :  */
      18                 :            : 
      19                 :            : #include "preprocessing/passes/bv_gauss.h"
      20                 :            : 
      21                 :            : #include <unordered_map>
      22                 :            : #include <vector>
      23                 :            : 
      24                 :            : #include "expr/node.h"
      25                 :            : #include "preprocessing/assertion_pipeline.h"
      26                 :            : #include "preprocessing/preprocessing_pass_context.h"
      27                 :            : #include "theory/bv/theory_bv_rewrite_rules_normalization.h"
      28                 :            : #include "theory/bv/theory_bv_utils.h"
      29                 :            : #include "theory/rewriter.h"
      30                 :            : #include "util/bitvector.h"
      31                 :            : 
      32                 :            : using namespace cvc5::internal;
      33                 :            : using namespace cvc5::internal::theory;
      34                 :            : using namespace cvc5::internal::theory::bv;
      35                 :            : 
      36                 :            : namespace cvc5::internal {
      37                 :            : namespace preprocessing {
      38                 :            : namespace passes {
      39                 :            : 
      40                 :       2901 : bool BVGauss::is_bv_const(Node n)
      41                 :            : {
      42         [ +  + ]:       2901 :   if (n.isConst()) { return true; }
      43                 :       1996 :   return rewrite(n).getKind() == Kind::CONST_BITVECTOR;
      44                 :            : }
      45                 :            : 
      46                 :        451 : Node BVGauss::get_bv_const(Node n)
      47                 :            : {
      48 [ -  + ][ -  + ]:        451 :   Assert(is_bv_const(n));
                 [ -  - ]
      49                 :        451 :   return rewrite(n);
      50                 :            : }
      51                 :            : 
      52                 :        241 : Integer BVGauss::get_bv_const_value(Node n)
      53                 :            : {
      54 [ -  + ][ -  + ]:        241 :   Assert(is_bv_const(n));
                 [ -  - ]
      55                 :        482 :   return get_bv_const(n).getConst<BitVector>().getValue();
      56                 :            : }
      57                 :            : 
      58                 :            : /**
      59                 :            :  * Determines if an overflow may occur in given 'expr'.
      60                 :            :  *
      61                 :            :  * Returns 0 if an overflow may occur, and the minimum required
      62                 :            :  * bit-width such that no overflow occurs, otherwise.
      63                 :            :  *
      64                 :            :  * Note that it would suffice for this function to be Boolean.
      65                 :            :  * However, it is handy to determine the minimum required bit-width for
      66                 :            :  * debugging purposes.
      67                 :            :  *
      68                 :            :  * Note: getMinBwExpr assumes that 'expr' is rewritten.
      69                 :            :  *
      70                 :            :  * If not, all operators that are removed via rewriting (e.g., ror, rol, ...)
      71                 :            :  * will be handled via the default case, which is not incorrect but also not
      72                 :            :  * necessarily the minimum.
      73                 :            :  */
      74                 :        101 : uint32_t BVGauss::getMinBwExpr(Node expr)
      75                 :            : {
      76                 :        202 :   std::vector<Node> visit;
      77                 :            :   /* Maps visited nodes to the determined minimum bit-width required. */
      78                 :        202 :   std::unordered_map<Node, unsigned> visited;
      79                 :        101 :   std::unordered_map<Node, unsigned>::iterator it;
      80                 :            : 
      81                 :        101 :   visit.push_back(expr);
      82         [ +  + ]:       1459 :   while (!visit.empty())
      83                 :            :   {
      84                 :       1361 :     Node n = visit.back();
      85                 :       1361 :     visit.pop_back();
      86                 :       1361 :     it = visited.find(n);
      87         [ +  + ]:       1361 :     if (it == visited.end())
      88                 :            :     {
      89         [ +  + ]:        745 :       if (is_bv_const(n))
      90                 :            :       {
      91                 :            :         /* Rewrite const expr, overflows in consts are irrelevant. */
      92                 :        210 :         visited[n] = get_bv_const(n).getConst<BitVector>().getValue().length();
      93                 :            :       }
      94                 :            :       else
      95                 :            :       {
      96                 :        535 :         visited[n] = 0;
      97                 :        535 :         visit.push_back(n);
      98         [ +  + ]:       1263 :         for (const Node &nn : n) { visit.push_back(nn); }
      99                 :            :       }
     100                 :            :     }
     101         [ +  + ]:        616 :     else if (it->second == 0)
     102                 :            :     {
     103                 :        533 :       Kind k = n.getKind();
     104 [ -  + ][ -  + ]:        533 :       Assert(k != Kind::CONST_BITVECTOR);
                 [ -  - ]
     105 [ -  + ][ -  + ]:        533 :       Assert(!is_bv_const(n));
                 [ -  - ]
     106 [ +  + ][ +  + ]:        533 :       switch (k)
         [ +  - ][ +  + ]
     107                 :            :       {
     108                 :         28 :         case Kind::BITVECTOR_EXTRACT:
     109                 :            :         {
     110                 :         28 :           const unsigned size = bv::utils::getSize(n);
     111                 :         28 :           const unsigned low = bv::utils::getExtractLow(n);
     112                 :         28 :           const unsigned child_min_width = visited[n[0]];
     113                 :         28 :           visited[n] = std::min(
     114         [ +  - ]:         28 :               size, child_min_width >= low ? child_min_width - low : 0u);
     115 [ -  + ][ -  + ]:         28 :           Assert(visited[n] <= visited[n[0]]);
                 [ -  - ]
     116                 :         28 :           break;
     117                 :            :         }
     118                 :            : 
     119                 :         36 :         case Kind::BITVECTOR_ZERO_EXTEND:
     120                 :            :         {
     121                 :         36 :           visited[n] = visited[n[0]];
     122                 :         36 :           break;
     123                 :            :         }
     124                 :            : 
     125                 :         69 :         case Kind::BITVECTOR_MULT:
     126                 :            :         {
     127                 :         69 :           Integer maxval = Integer(1);
     128         [ +  + ]:        213 :           for (const Node& nn : n)
     129                 :            :           {
     130         [ +  + ]:        144 :             if (is_bv_const(nn))
     131                 :            :             {
     132                 :         57 :               maxval *= get_bv_const_value(nn);
     133                 :            :             }
     134                 :            :             else
     135                 :            :             {
     136                 :         87 :               maxval *= BitVector::mkOnes(visited[nn]).getValue();
     137                 :            :             }
     138                 :            :           }
     139                 :         69 :           unsigned w = maxval.length();
     140         [ +  + ]:         69 :           if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
     141                 :         67 :           visited[n] = w;
     142                 :         67 :           break;
     143                 :            :         }
     144                 :            : 
     145                 :        164 :         case Kind::BITVECTOR_CONCAT:
     146                 :            :         {
     147                 :            :           unsigned i, wnz, nc;
     148         [ +  + ]:        325 :           for (i = 0, wnz = 0, nc = n.getNumChildren() - 1; i < nc; ++i)
     149                 :            :           {
     150                 :        189 :             unsigned wni = bv::utils::getSize(n[i]);
     151         [ +  + ]:        189 :             if (n[i] != bv::utils::mkZero(wni)) { break; }
     152                 :            :             /* sum of all bit-widths of leading zero concats */
     153                 :        161 :             wnz += wni;
     154                 :            :           }
     155                 :            :           /* Do not consider leading zero concats, i.e.,
     156                 :            :            * min bw of current concat is determined as
     157                 :            :            *   min bw of first non-zero term
     158                 :            :            *   plus actual bw of all subsequent terms */
     159                 :        328 :           visited[n] = bv::utils::getSize(n) + visited[n[i]]
     160                 :        164 :                        - bv::utils::getSize(n[i]) - wnz;
     161                 :        164 :           break;
     162                 :            :         }
     163                 :            : 
     164                 :          3 :         case Kind::BITVECTOR_UREM:
     165                 :            :         case Kind::BITVECTOR_LSHR:
     166                 :            :         case Kind::BITVECTOR_ASHR:
     167                 :            :         {
     168                 :          3 :           visited[n] = visited[n[0]];
     169                 :          3 :           break;
     170                 :            :         }
     171                 :            : 
     172                 :          0 :         case Kind::BITVECTOR_OR:
     173                 :            :         case Kind::BITVECTOR_NOR:
     174                 :            :         case Kind::BITVECTOR_XOR:
     175                 :            :         case Kind::BITVECTOR_XNOR:
     176                 :            :         case Kind::BITVECTOR_AND:
     177                 :            :         case Kind::BITVECTOR_NAND:
     178                 :            :         {
     179                 :          0 :           unsigned wmax = 0;
     180         [ -  - ]:          0 :           for (const Node &nn : n)
     181                 :            :           {
     182         [ -  - ]:          0 :             if (visited[nn] > wmax)
     183                 :            :             {
     184                 :          0 :               wmax = visited[nn];
     185                 :            :             }
     186                 :            :           }
     187                 :          0 :           visited[n] = wmax;
     188                 :          0 :           break;
     189                 :            :         }
     190                 :            : 
     191                 :         52 :         case Kind::BITVECTOR_ADD:
     192                 :            :         {
     193                 :         52 :           Integer maxval = Integer(0);
     194         [ +  + ]:        182 :           for (const Node& nn : n)
     195                 :            :           {
     196         [ -  + ]:        130 :             if (is_bv_const(nn))
     197                 :            :             {
     198                 :          0 :               maxval += get_bv_const_value(nn);
     199                 :            :             }
     200                 :            :             else
     201                 :            :             {
     202                 :        130 :               maxval += BitVector::mkOnes(visited[nn]).getValue();
     203                 :            :             }
     204                 :            :           }
     205                 :         52 :           unsigned w = maxval.length();
     206         [ +  + ]:         52 :           if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
     207                 :         51 :           visited[n] = w;
     208                 :         51 :           break;
     209                 :            :         }
     210                 :            : 
     211                 :        181 :         default:
     212                 :            :         {
     213                 :            :           /* BITVECTOR_UDIV (since x / 0 = -1)
     214                 :            :            * BITVECTOR_NOT
     215                 :            :            * BITVECTOR_NEG
     216                 :            :            * BITVECTOR_SHL */
     217                 :        181 :           visited[n] = bv::utils::getSize(n);
     218                 :            :         }
     219                 :            :       }
     220                 :            :     }
     221                 :            :   }
     222 [ -  + ][ -  + ]:         98 :   Assert(visited.find(expr) != visited.end());
                 [ -  - ]
     223                 :         98 :   return visited[expr];
     224                 :            : }
     225                 :            : 
     226                 :            : /**
     227                 :            :  * Apply Gaussian Elimination modulo a (prime) number.
     228                 :            :  * The given equation system is represented as a matrix of Integers.
     229                 :            :  *
     230                 :            :  * Note that given 'prime' does not have to be prime but can be any
     231                 :            :  * arbitrary number. However, if 'prime' is indeed prime, GE is guaranteed
     232                 :            :  * to succeed, which is not the case, otherwise.
     233                 :            :  *
     234                 :            :  * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
     235                 :            :  * successful, and NONE, otherwise.
     236                 :            :  *
     237                 :            :  * Vectors 'rhs' and 'lhs' represent the right hand side and left hand side
     238                 :            :  * of the given matrix, respectively. The resulting matrix (in row echelon
     239                 :            :  * form) is stored in 'rhs' and 'lhs', i.e., the given matrix is overwritten
     240                 :            :  * with the resulting matrix.
     241                 :            :  */
     242                 :         70 : BVGauss::Result BVGauss::gaussElim(Integer prime,
     243                 :            :                                    std::vector<Integer>& rhs,
     244                 :            :                                    std::vector<std::vector<Integer>>& lhs)
     245                 :            : {
     246 [ -  + ][ -  + ]:         70 :   Assert(prime > 0);
                 [ -  - ]
     247 [ -  + ][ -  + ]:         70 :   Assert(lhs.size());
                 [ -  - ]
     248 [ -  + ][ -  + ]:         70 :   Assert(lhs.size() == rhs.size());
                 [ -  - ]
     249 [ -  + ][ -  + ]:         70 :   Assert(lhs.size() <= lhs[0].size());
                 [ -  - ]
     250                 :            : 
     251                 :            :   /* special case: zero ring */
     252         [ +  + ]:         70 :   if (prime == 1)
     253                 :            :   {
     254                 :          1 :     rhs = std::vector<Integer>(rhs.size(), Integer(0));
     255                 :          3 :     lhs = std::vector<std::vector<Integer>>(
     256                 :          3 :         lhs.size(), std::vector<Integer>(lhs[0].size(), Integer(0)));
     257                 :          1 :     return BVGauss::Result::UNIQUE;
     258                 :            :   }
     259                 :            : 
     260                 :         69 :   size_t nrows = lhs.size();
     261                 :         69 :   size_t ncols = lhs[0].size();
     262                 :            : 
     263                 :            : #ifdef CVC5_ASSERTIONS
     264 [ +  + ][ -  + ]:        195 :   for (size_t i = 1; i < nrows; ++i) Assert(lhs[i].size() == ncols);
         [ -  + ][ -  - ]
     265                 :            : #endif
     266                 :            :   /* (1) if element in pivot column is non-zero and != 1, divide row elements
     267                 :            :    *     by element in pivot column modulo prime, i.e., multiply row with
     268                 :            :    *     multiplicative inverse of element in pivot column modulo prime
     269                 :            :    *
     270                 :            :    * (2) subtract pivot row from all rows below pivot row
     271                 :            :    *
     272                 :            :    * (3) subtract (multiple of) current row from all rows above s.t. all
     273                 :            :    *     elements in current pivot column above current row become equal to one
     274                 :            :    *
     275                 :            :    * Note: we do not normalize the given matrix to values modulo prime
     276                 :            :    *       beforehand but on-the-fly. */
     277                 :            : 
     278                 :            :   /* pivot = lhs[pcol][pcol] */
     279 [ +  + ][ +  + ]:        234 :   for (size_t pcol = 0, prow = 0; pcol < ncols && prow < nrows; ++pcol, ++prow)
     280                 :            :   {
     281                 :            :     /* lhs[j][pcol]: element in pivot column */
     282         [ +  + ]:        510 :     for (size_t j = prow; j < nrows; ++j)
     283                 :            :     {
     284                 :            : #ifdef CVC5_ASSERTIONS
     285         [ +  + ]:        572 :       for (size_t k = 0; k < pcol; ++k)
     286                 :            :       {
     287 [ -  + ][ -  + ]:        227 :         Assert(lhs[j][k] == 0);
                 [ -  - ]
     288                 :            :       }
     289                 :            : #endif
     290                 :            :       /* normalize element in pivot column to modulo prime */
     291                 :        345 :       lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
     292                 :            :       /* exchange rows if pivot elem is 0 */
     293         [ +  + ]:        345 :       if (j == prow)
     294                 :            :       {
     295         [ +  + ]:        209 :         while (lhs[j][pcol] == 0)
     296                 :            :         {
     297         [ +  + ]:         85 :           for (size_t k = prow + 1; k < nrows; ++k)
     298                 :            :           {
     299                 :         53 :             lhs[k][pcol] = lhs[k][pcol].euclidianDivideRemainder(prime);
     300         [ +  + ]:         53 :             if (lhs[k][pcol] != 0)
     301                 :            :             {
     302                 :         27 :               std::swap(rhs[j], rhs[k]);
     303                 :         27 :               std::swap(lhs[j], lhs[k]);
     304                 :         27 :               break;
     305                 :            :             }
     306                 :            :           }
     307         [ +  + ]:         59 :           if (pcol >= ncols - 1) break;
     308         [ +  + ]:         38 :           if (lhs[j][pcol] == 0)
     309                 :            :           {
     310                 :         16 :             pcol += 1;
     311         [ +  + ]:         16 :             if (lhs[j][pcol] != 0)
     312                 :         10 :               lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
     313                 :            :           }
     314                 :            :         }
     315                 :            :       }
     316                 :            : 
     317         [ +  + ]:        345 :       if (lhs[j][pcol] != 0)
     318                 :            :       {
     319                 :            :         /* (1) */
     320         [ +  + ]:        264 :         if (lhs[j][pcol] != 1)
     321                 :            :         {
     322                 :        196 :           Integer inv = lhs[j][pcol].modInverse(prime);
     323         [ +  + ]:        196 :           if (inv == -1)
     324                 :            :           {
     325                 :          6 :             return BVGauss::Result::INVALID; /* not coprime */
     326                 :            :           }
     327         [ +  + ]:        626 :           for (size_t k = pcol; k < ncols; ++k)
     328                 :            :           {
     329                 :        436 :             lhs[j][k] = lhs[j][k].modMultiply(inv, prime);
     330         [ +  + ]:        436 :             if (j <= prow) continue; /* pivot */
     331                 :        246 :             lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
     332                 :            :           }
     333                 :        190 :           rhs[j] = rhs[j].modMultiply(inv, prime);
     334         [ +  + ]:        190 :           if (j > prow) { rhs[j] = rhs[j].modAdd(-rhs[prow], prime); }
     335                 :            :         }
     336                 :            :         /* (2) */
     337         [ +  + ]:         68 :         else if (j != prow)
     338                 :            :         {
     339         [ +  + ]:         46 :           for (size_t k = pcol; k < ncols; ++k)
     340                 :            :           {
     341                 :         34 :             lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
     342                 :            :           }
     343                 :         12 :           rhs[j] = rhs[j].modAdd(-rhs[prow], prime);
     344                 :            :         }
     345                 :            :       }
     346                 :            :     }
     347                 :            :     /* (3) */
     348         [ +  + ]:        305 :     for (size_t j = 0; j < prow; ++j)
     349                 :            :     {
     350                 :        280 :       Integer mul = lhs[j][pcol];
     351         [ +  + ]:        140 :       if (mul != 0)
     352                 :            :       {
     353         [ +  + ]:        276 :         for (size_t k = pcol; k < ncols; ++k)
     354                 :            :         {
     355                 :        162 :           lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k] * mul, prime);
     356                 :            :         }
     357                 :        114 :         rhs[j] = rhs[j].modAdd(-rhs[prow] * mul, prime);
     358                 :            :       }
     359                 :            :     }
     360                 :            :   }
     361                 :            : 
     362                 :         63 :   bool ispart = false;
     363         [ +  + ]:        236 :   for (size_t i = 0; i < nrows; ++i)
     364                 :            :   {
     365                 :        177 :     size_t pcol = i;
     366 [ +  + ][ +  + ]:        229 :     while (pcol < ncols && lhs[i][pcol] == 0) ++pcol;
         [ +  + ][ +  + ]
                 [ -  - ]
     367         [ +  + ]:        177 :     if (pcol >= ncols)
     368                 :            :     {
     369                 :         29 :       rhs[i] = rhs[i].euclidianDivideRemainder(prime);
     370         [ +  + ]:         29 :       if (rhs[i] != 0)
     371                 :            :       {
     372                 :            :         /* no solution */
     373                 :          4 :         return BVGauss::Result::NONE;
     374                 :            :       }
     375                 :         25 :       continue;
     376                 :            :     }
     377         [ +  + ]:        488 :     for (size_t j = i; j < ncols; ++j)
     378                 :            :     {
     379 [ +  + ][ -  + ]:        340 :       if (lhs[i][j] >= prime || lhs[i][j] <= -prime)
         [ +  + ][ +  + ]
                 [ -  - ]
     380                 :            :       {
     381                 :          1 :         lhs[i][j] = lhs[i][j].euclidianDivideRemainder(prime);
     382                 :            :       }
     383 [ +  + ][ +  + ]:        340 :       if (j > pcol && lhs[i][j] != 0)
         [ +  + ][ +  + ]
                 [ -  - ]
     384                 :            :       {
     385                 :         36 :         ispart = true;
     386                 :            :       }
     387                 :            :     }
     388                 :            :   }
     389                 :            : 
     390         [ +  + ]:         59 :   if (ispart)
     391                 :            :   {
     392                 :         21 :     return BVGauss::Result::PARTIAL;
     393                 :            :   }
     394                 :            : 
     395                 :         38 :   return BVGauss::Result::UNIQUE;
     396                 :            : }
     397                 :            : 
     398                 :            : /**
     399                 :            :  * Apply Gaussian Elimination on a set of equations modulo some (prime)
     400                 :            :  * number given as bit-vector equations.
     401                 :            :  *
     402                 :            :  * IMPORTANT: Applying GE modulo some number (rather than modulo 2^bw)
     403                 :            :  * on a set of bit-vector equations is only sound if this set of equations
     404                 :            :  * has a solution that does not produce overflows. Consequently, we only
     405                 :            :  * apply GE if the given bit-width guarantees that no overflows can occur
     406                 :            :  * in the given set of equations.
     407                 :            :  *
     408                 :            :  * Note that the given set of equations does not have to be modulo a prime
     409                 :            :  * but can be modulo any arbitrary number. However, if it is indeed modulo
     410                 :            :  * prime, GE is guaranteed to succeed, which is not the case, otherwise.
     411                 :            :  *
     412                 :            :  * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
     413                 :            :  * successful, and NONE, otherwise.
     414                 :            :  *
     415                 :            :  * The resulting constraints are stored in 'res' as a mapping of unknown
     416                 :            :  * to result (modulo prime). These mapped results are added as constraints
     417                 :            :  * of the form 'unknown = mapped result' in applyInternal.
     418                 :            :  */
     419                 :         18 : BVGauss::Result BVGauss::gaussElimRewriteForUrem(
     420                 :            :     const std::vector<Node>& equations, std::unordered_map<Node, Node>& res)
     421                 :            : {
     422 [ -  + ][ -  + ]:         18 :   Assert(res.empty());
                 [ -  - ]
     423                 :            : 
     424                 :         36 :   Node prime;
     425                 :         36 :   Integer iprime;
     426                 :         36 :   std::unordered_map<Node, std::vector<Integer>> vars;
     427                 :         18 :   size_t neqs = equations.size();
     428                 :         36 :   std::vector<Integer> rhs;
     429                 :            :   std::vector<std::vector<Integer>> lhs =
     430                 :         54 :       std::vector<std::vector<Integer>>(neqs, std::vector<Integer>());
     431                 :            : 
     432                 :         18 :   res = std::unordered_map<Node, Node>();
     433                 :            : 
     434         [ +  + ]:         64 :   for (size_t i = 0; i < neqs; ++i)
     435                 :            :   {
     436                 :         46 :     Node eq = equations[i];
     437 [ -  + ][ -  + ]:         46 :     Assert(eq.getKind() == Kind::EQUAL);
                 [ -  - ]
     438                 :         46 :     Node urem, eqrhs;
     439                 :            : 
     440         [ +  - ]:         46 :     if (eq[0].getKind() == Kind::BITVECTOR_UREM)
     441                 :            :     {
     442                 :         46 :       urem = eq[0];
     443 [ -  + ][ -  + ]:         46 :       Assert(is_bv_const(eq[1]));
                 [ -  - ]
     444                 :         46 :       eqrhs = eq[1];
     445                 :            :     }
     446                 :            :     else
     447                 :            :     {
     448                 :          0 :       Assert(eq[1].getKind() == Kind::BITVECTOR_UREM);
     449                 :          0 :       urem = eq[1];
     450                 :          0 :       Assert(is_bv_const(eq[0]));
     451                 :          0 :       eqrhs = eq[0];
     452                 :            :     }
     453         [ -  + ]:         46 :     if (getMinBwExpr(rewrite(urem[0])) == 0)
     454                 :            :     {
     455         [ -  - ]:          0 :       Trace("bv-gauss-elim")
     456                 :            :           << "Minimum required bit-width exceeds given bit-width, "
     457                 :          0 :              "will not apply Gaussian Elimination."
     458                 :          0 :           << std::endl;
     459                 :          0 :       return BVGauss::Result::INVALID;
     460                 :            :     }
     461                 :         46 :     rhs.push_back(get_bv_const_value(eqrhs));
     462                 :            : 
     463 [ -  + ][ -  + ]:         46 :     Assert(is_bv_const(urem[1]));
                 [ -  - ]
     464                 :         46 :     Assert(i == 0 || get_bv_const_value(urem[1]) == iprime);
     465         [ +  + ]:         46 :     if (i == 0)
     466                 :            :     {
     467                 :         18 :       prime = urem[1];
     468                 :         18 :       iprime = get_bv_const_value(prime);
     469                 :            :     }
     470                 :            : 
     471                 :         92 :     std::unordered_map<Node, Integer> tmp;
     472                 :         92 :     std::vector<Node> stack;
     473                 :         46 :     stack.push_back(urem[0]);
     474         [ +  + ]:        198 :     while (!stack.empty())
     475                 :            :     {
     476                 :        152 :       Node n = stack.back();
     477                 :        152 :       stack.pop_back();
     478                 :            : 
     479                 :            :       /* Subtract from rhs if const */
     480         [ -  + ]:        152 :       if (is_bv_const(n))
     481                 :            :       {
     482                 :          0 :         Integer val = get_bv_const_value(n);
     483         [ -  - ]:          0 :         if (val > 0) rhs.back() -= val;
     484                 :          0 :         continue;
     485                 :            :       }
     486                 :            : 
     487                 :            :       /* Split into matrix columns */
     488                 :        152 :       Kind k = n.getKind();
     489         [ +  + ]:        152 :       if (k == Kind::BITVECTOR_ADD)
     490                 :            :       {
     491         [ +  + ]:        159 :         for (const Node& nn : n) { stack.push_back(nn); }
     492                 :            :       }
     493         [ +  + ]:         99 :       else if (k == Kind::BITVECTOR_MULT)
     494                 :            :       {
     495                 :        184 :         Node n0, n1;
     496                 :            :         /* Flatten mult expression. */
     497                 :         92 :         n = RewriteRule<FlattenAssocCommut>::run<true>(n);
     498                 :            :         /* Split operands into consts and non-consts */
     499                 :        184 :         NodeBuilder nb_consts(nodeManager(), k);
     500                 :         92 :         NodeBuilder nb_nonconsts(nodeManager(), k);
     501         [ +  + ]:        284 :         for (const Node& nn : n)
     502                 :            :         {
     503                 :        384 :           Node nnrw = rewrite(nn);
     504         [ +  + ]:        192 :           if (is_bv_const(nnrw))
     505                 :            :           {
     506                 :         88 :             nb_consts << nnrw;
     507                 :            :           }
     508                 :            :           else
     509                 :            :           {
     510                 :        104 :             nb_nonconsts << nnrw;
     511                 :            :           }
     512                 :            :         }
     513 [ -  + ][ -  + ]:         92 :         Assert(nb_nonconsts.getNumChildren() > 0);
                 [ -  - ]
     514                 :            :         /* n0 is const */
     515                 :         92 :         unsigned nc = nb_consts.getNumChildren();
     516         [ -  + ]:         92 :         if (nc > 1)
     517                 :            :         {
     518                 :          0 :           n0 = rewrite(nb_consts.constructNode());
     519                 :            :         }
     520         [ +  + ]:         92 :         else if (nc == 1)
     521                 :            :         {
     522                 :         88 :           n0 = nb_consts[0];
     523                 :            :         }
     524                 :            :         else
     525                 :            :         {
     526                 :          4 :           n0 = bv::utils::mkOne(bv::utils::getSize(n));
     527                 :            :         }
     528                 :            :         /* n1 is a mult with non-const operands */
     529         [ +  + ]:         92 :         if (nb_nonconsts.getNumChildren() > 1)
     530                 :            :         {
     531                 :         10 :           n1 = rewrite(nb_nonconsts.constructNode());
     532                 :            :         }
     533                 :            :         else
     534                 :            :         {
     535                 :         82 :           n1 = nb_nonconsts[0];
     536                 :            :         }
     537 [ -  + ][ -  + ]:         92 :         Assert(is_bv_const(n0));
                 [ -  - ]
     538 [ -  + ][ -  + ]:         92 :         Assert(!is_bv_const(n1));
                 [ -  - ]
     539                 :         92 :         tmp[n1] += get_bv_const_value(n0);
     540                 :            :       }
     541                 :            :       else
     542                 :            :       {
     543                 :          7 :         tmp[n] += Integer(1);
     544                 :            :       }
     545                 :            :     }
     546                 :            : 
     547                 :            :     /* Note: "var" is not necessarily a VARIABLE but can be an arbitrary expr */
     548                 :            : 
     549         [ +  + ]:        145 :     for (const auto& p : tmp)
     550                 :            :     {
     551                 :        198 :       Node var = p.first;
     552                 :        198 :       Integer val = p.second;
     553 [ +  + ][ +  + ]:         99 :       if (i > 0 && vars.find(var) == vars.end())
                 [ +  + ]
     554                 :            :       {
     555                 :            :         /* Add column and fill column elements of rows above with 0. */
     556                 :         14 :         vars[var].insert(vars[var].end(), i, Integer(0));
     557                 :            :       }
     558                 :         99 :       vars[var].push_back(val);
     559                 :            :     }
     560                 :            : 
     561         [ +  + ]:        166 :     for (const auto& p : vars)
     562                 :            :     {
     563         [ +  + ]:        120 :       if (tmp.find(p.first) == tmp.end())
     564                 :            :       {
     565                 :         21 :         vars[p.first].push_back(Integer(0));
     566                 :            :       }
     567                 :            :     }
     568                 :            :   }
     569                 :            : 
     570                 :         18 :   size_t nvars = vars.size();
     571         [ -  + ]:         18 :   if (nvars == 0)
     572                 :            :   {
     573                 :          0 :     return BVGauss::Result::INVALID;
     574                 :            :   }
     575                 :         18 :   size_t nrows = vars.begin()->second.size();
     576                 :            : #ifdef CVC5_ASSERTIONS
     577         [ +  + ]:         71 :   for (const auto& p : vars)
     578                 :            :   {
     579 [ -  + ][ -  + ]:         53 :     Assert(p.second.size() == nrows);
                 [ -  - ]
     580                 :            :   }
     581                 :            : #endif
     582                 :            : 
     583         [ -  + ]:         18 :   if (nrows < 1)
     584                 :            :   {
     585                 :          0 :     return BVGauss::Result::INVALID;
     586                 :            :   }
     587                 :            : 
     588         [ +  + ]:         64 :   for (size_t i = 0; i < nrows; ++i)
     589                 :            :   {
     590         [ +  + ]:        182 :     for (const auto& p : vars)
     591                 :            :     {
     592                 :        136 :       lhs[i].push_back(p.second[i]);
     593                 :            :     }
     594                 :            :   }
     595                 :            : 
     596                 :            : #ifdef CVC5_ASSERTIONS
     597         [ +  + ]:         64 :   for (const auto& row : lhs)
     598                 :            :   {
     599 [ -  + ][ -  + ]:         46 :     Assert(row.size() == nvars);
                 [ -  - ]
     600                 :            :   }
     601 [ -  + ][ -  + ]:         18 :   Assert(lhs.size() == rhs.size());
                 [ -  - ]
     602                 :            : #endif
     603                 :            : 
     604         [ +  + ]:         18 :   if (lhs.size() > lhs[0].size())
     605                 :            :   {
     606                 :          1 :     return BVGauss::Result::INVALID;
     607                 :            :   }
     608                 :            : 
     609         [ +  - ]:         17 :   Trace("bv-gauss-elim") << "Applying Gaussian Elimination..." << std::endl;
     610                 :         17 :   BVGauss::Result ret = gaussElim(iprime, rhs, lhs);
     611                 :            : 
     612 [ +  - ][ +  - ]:         17 :   if (ret != BVGauss::Result::NONE && ret != BVGauss::Result::INVALID)
     613                 :            :   {
     614                 :         34 :     std::vector<Node> vvars;
     615         [ +  + ]:         68 :     for (const auto& p : vars) { vvars.push_back(p.first); }
     616 [ -  + ][ -  + ]:         17 :     Assert(nvars == vvars.size());
                 [ -  - ]
     617 [ -  + ][ -  + ]:         17 :     Assert(nrows == lhs.size());
                 [ -  - ]
     618 [ -  + ][ -  + ]:         17 :     Assert(nrows == rhs.size());
                 [ -  - ]
     619                 :         17 :     NodeManager* nm = nodeManager();
     620         [ +  + ]:         17 :     if (ret == BVGauss::Result::UNIQUE)
     621                 :            :     {
     622         [ +  + ]:         27 :       for (size_t i = 0; i < nvars; ++i)
     623                 :            :       {
     624                 :         40 :         res[vvars[i]] = nm->mkConst<BitVector>(
     625                 :         60 :             BitVector(bv::utils::getSize(vvars[i]), rhs[i]));
     626                 :            :       }
     627                 :            :     }
     628                 :            :     else
     629                 :            :     {
     630 [ -  + ][ -  + ]:         10 :       Assert(ret == BVGauss::Result::PARTIAL);
                 [ -  - ]
     631                 :            : 
     632 [ +  - ][ +  + ]:         30 :       for (size_t pcol = 0, prow = 0; pcol < nvars && prow < nrows;
     633                 :            :            ++pcol, ++prow)
     634                 :            :       {
     635                 :         22 :         Assert(lhs[prow][pcol] == 0 || lhs[prow][pcol] == 1);
     636 [ +  + ][ +  + ]:         25 :         while (pcol < nvars && lhs[prow][pcol] == 0) pcol += 1;
         [ +  + ][ +  + ]
                 [ -  - ]
     637         [ +  + ]:         22 :         if (pcol >= nvars)
     638                 :            :         {
     639 [ -  + ][ -  + ]:          2 :           Assert(rhs[prow] == 0);
                 [ -  - ]
     640                 :          2 :           break;
     641                 :            :         }
     642         [ -  + ]:         20 :         if (lhs[prow][pcol] == 0)
     643                 :            :         {
     644                 :          0 :           Assert(rhs[prow] == 0);
     645                 :          0 :           continue;
     646                 :            :         }
     647 [ -  + ][ -  + ]:         20 :         Assert(lhs[prow][pcol] == 1);
                 [ -  - ]
     648                 :         40 :         std::vector<Node> stack;
     649         [ +  + ]:         51 :         for (size_t i = pcol + 1; i < nvars; ++i)
     650                 :            :         {
     651         [ +  + ]:         31 :           if (lhs[prow][i] == 0) continue;
     652                 :            :           /* Normalize (no negative numbers, hence no subtraction)
     653                 :            :            * e.g., x = 4 - 2y  --> x = 4 + 9y (modulo 11) */
     654                 :         34 :           Integer m = iprime - lhs[prow][i];
     655                 :         34 :           Node bv = bv::utils::mkConst(bv::utils::getSize(vvars[i]), m);
     656                 :         51 :           Node mult = nm->mkNode(Kind::BITVECTOR_MULT, vvars[i], bv);
     657                 :         17 :           stack.push_back(mult);
     658                 :            :         }
     659                 :            : 
     660         [ +  + ]:         20 :         if (stack.empty())
     661                 :            :         {
     662                 :          6 :           res[vvars[pcol]] = nm->mkConst<BitVector>(
     663                 :          9 :               BitVector(bv::utils::getSize(vvars[pcol]), rhs[prow]));
     664                 :            :         }
     665                 :            :         else
     666                 :            :         {
     667                 :         34 :           Node tmp = stack.size() == 1 ? stack[0]
     668         [ +  - ]:         34 :                                        : nm->mkNode(Kind::BITVECTOR_ADD, stack);
     669                 :            : 
     670         [ +  + ]:         17 :           if (rhs[prow] != 0)
     671                 :            :           {
     672                 :         64 :             tmp = nm->mkNode(
     673                 :            :                 Kind::BITVECTOR_ADD,
     674                 :         32 :                 bv::utils::mkConst(bv::utils::getSize(vvars[pcol]), rhs[prow]),
     675                 :         16 :                 tmp);
     676                 :            :           }
     677 [ -  + ][ -  + ]:         17 :           Assert(!is_bv_const(tmp));
                 [ -  - ]
     678                 :         17 :           res[vvars[pcol]] = nm->mkNode(Kind::BITVECTOR_UREM, tmp, prime);
     679                 :            :         }
     680                 :            :       }
     681                 :            :     }
     682                 :            :   }
     683                 :         17 :   return ret;
     684                 :            : }
     685                 :            : 
     686                 :      50409 : BVGauss::BVGauss(PreprocessingPassContext* preprocContext,
     687                 :      50409 :                  const std::string& name)
     688                 :      50409 :     : PreprocessingPass(preprocContext, name)
     689                 :            : {
     690                 :      50409 : }
     691                 :            : 
     692                 :          3 : PreprocessingPassResult BVGauss::applyInternal(
     693                 :            :     AssertionPipeline* assertionsToPreprocess)
     694                 :            : {
     695                 :          6 :   std::vector<Node> assertions(assertionsToPreprocess->ref());
     696                 :          6 :   std::unordered_map<Node, std::vector<Node>> equations;
     697                 :            : 
     698         [ +  + ]:         13 :   while (!assertions.empty())
     699                 :            :   {
     700                 :         10 :     Node a = assertions.back();
     701                 :         10 :     assertions.pop_back();
     702                 :         10 :     cvc5::internal::Kind k = a.getKind();
     703                 :            : 
     704         [ -  + ]:         10 :     if (k == Kind::AND)
     705                 :            :     {
     706         [ -  - ]:          0 :       for (const Node& aa : a)
     707                 :            :       {
     708                 :          0 :         assertions.push_back(aa);
     709                 :            :       }
     710                 :            :     }
     711         [ +  - ]:         10 :     else if (k == Kind::EQUAL)
     712                 :            :     {
     713                 :         10 :       Node urem;
     714                 :            : 
     715                 :         10 :       if (is_bv_const(a[1]) && a[0].getKind() == Kind::BITVECTOR_UREM)
     716                 :            :       {
     717                 :         10 :         urem = a[0];
     718                 :            :       }
     719                 :          0 :       else if (is_bv_const(a[0]) && a[1].getKind() == Kind::BITVECTOR_UREM)
     720                 :            :       {
     721                 :          0 :         urem = a[1];
     722                 :            :       }
     723                 :            :       else
     724                 :            :       {
     725                 :          0 :         continue;
     726                 :            :       }
     727                 :            : 
     728                 :         10 :       if (urem[0].getKind() == Kind::BITVECTOR_ADD && is_bv_const(urem[1]))
     729                 :            :       {
     730                 :         10 :         equations[urem[1]].push_back(a);
     731                 :            :       }
     732                 :            :     }
     733                 :            :   }
     734                 :            : 
     735                 :          6 :   std::unordered_map<Node, Node> subst;
     736                 :            : 
     737                 :          3 :   NodeManager* nm = nodeManager();
     738         [ +  + ]:          7 :   for (const auto& eq : equations)
     739                 :            :   {
     740         [ -  + ]:          4 :     if (eq.second.size() <= 1) { continue; }
     741                 :            : 
     742                 :          4 :     std::unordered_map<Node, Node> res;
     743                 :          4 :     BVGauss::Result ret = gaussElimRewriteForUrem(eq.second, res);
     744         [ +  - ]:          8 :     Trace("bv-gauss-elim") << "result: "
     745                 :            :                            << (ret == BVGauss::Result::INVALID
     746         [ -  - ]:          4 :                                    ? "INVALID"
     747                 :            :                                    : (ret == BVGauss::Result::UNIQUE
     748         [ -  - ]:          0 :                                           ? "UNIQUE"
     749                 :            :                                           : (ret == BVGauss::Result::PARTIAL
     750         [ -  - ]:          0 :                                                  ? "PARTIAL"
     751                 :          0 :                                                  : "NONE")))
     752                 :          4 :                            << std::endl;
     753         [ +  - ]:          4 :     if (ret != BVGauss::Result::INVALID)
     754                 :            :     {
     755         [ -  + ]:          4 :       if (ret == BVGauss::Result::NONE)
     756                 :            :       {
     757                 :          0 :         Node n = nm->mkConst<bool>(false);
     758                 :          0 :         assertionsToPreprocess->push_back(n);
     759                 :          0 :         return PreprocessingPassResult::CONFLICT;
     760                 :            :       }
     761                 :            :       else
     762                 :            :       {
     763         [ +  + ]:         14 :         for (const Node& e : eq.second)
     764                 :            :         {
     765                 :         10 :           subst[e] = nm->mkConst<bool>(true);
     766                 :            :         }
     767                 :            :         /* add resulting constraints */
     768         [ +  + ]:         14 :         for (const auto& p : res)
     769                 :            :         {
     770                 :         20 :           Node a = nm->mkNode(Kind::EQUAL, p.first, p.second);
     771         [ +  - ]:         10 :           Trace("bv-gauss-elim") << "added assertion: " << a << std::endl;
     772                 :            :           // add new assertion
     773                 :         10 :           assertionsToPreprocess->push_back(a);
     774                 :            :         }
     775                 :            :       }
     776                 :            :     }
     777                 :            :   }
     778                 :            : 
     779         [ +  - ]:          3 :   if (!subst.empty())
     780                 :            :   {
     781                 :            :     /* delete (= substitute with true) obsolete assertions */
     782                 :          3 :     const std::vector<Node>& aref = assertionsToPreprocess->ref();
     783         [ +  + ]:         23 :     for (size_t i = 0, asize = aref.size(); i < asize; ++i)
     784                 :            :     {
     785                 :         40 :       Node a = aref[i];
     786                 :         20 :       Node as = a.substitute(subst.begin(), subst.end());
     787                 :            :       // replace the assertion
     788                 :         20 :       assertionsToPreprocess->replace(i, as);
     789                 :            :     }
     790                 :            :   }
     791                 :          3 :   return PreprocessingPassResult::NO_CONFLICT;
     792                 :            : }
     793                 :            : 
     794                 :            : 
     795                 :            : }  // namespace passes
     796                 :            : }  // namespace preprocessing
     797                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14