LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/bags - bags_rewriter.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 288 308 93.5 %
Date: 2025-03-15 12:03:33 Functions: 25 26 96.2 %
Branches: 189 323 58.5 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Mudathir Mohamed, Aina Niemetz, Andrew Reynolds
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Bags theory rewriter.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/bags/bags_rewriter.h"
      17                 :            : 
      18                 :            : #include "expr/emptybag.h"
      19                 :            : #include "theory/bags/bags_utils.h"
      20                 :            : #include "theory/rewriter.h"
      21                 :            : #include "util/rational.h"
      22                 :            : #include "util/statistics_registry.h"
      23                 :            : 
      24                 :            : using namespace cvc5::internal::kind;
      25                 :            : 
      26                 :            : namespace cvc5::internal {
      27                 :            : namespace theory {
      28                 :            : namespace bags {
      29                 :            : 
      30                 :      63272 : BagsRewriteResponse::BagsRewriteResponse()
      31                 :      63272 :     : d_node(Node::null()), d_rewrite(Rewrite::NONE)
      32                 :            : {
      33                 :      63272 : }
      34                 :            : 
      35                 :      63272 : BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
      36                 :      63272 :     : d_node(n), d_rewrite(rewrite)
      37                 :            : {
      38                 :      63272 : }
      39                 :            : 
      40                 :          0 : BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
      41                 :          0 :     : d_node(r.d_node), d_rewrite(r.d_rewrite)
      42                 :            : {
      43                 :          0 : }
      44                 :            : 
      45                 :      51373 : BagsRewriter::BagsRewriter(NodeManager* nm,
      46                 :            :                            Rewriter* r,
      47                 :      51373 :                            HistogramStat<Rewrite>* statistics)
      48                 :      51373 :     : TheoryRewriter(nm), d_rewriter(r), d_statistics(statistics)
      49                 :            : {
      50                 :      51373 :   d_zero = d_nm->mkConstInt(Rational(0));
      51                 :      51373 :   d_one = d_nm->mkConstInt(Rational(1));
      52                 :      51373 : }
      53                 :            : 
      54                 :      29054 : RewriteResponse BagsRewriter::postRewrite(TNode n)
      55                 :            : {
      56                 :      58108 :   BagsRewriteResponse response;
      57         [ +  + ]:      29054 :   if (n.isConst())
      58                 :            :   {
      59                 :            :     // no need to rewrite n if it is already in a normal form
      60                 :       1730 :     response = BagsRewriteResponse(n, Rewrite::NONE);
      61                 :            :   }
      62         [ +  + ]:      27324 :   else if (n.getKind() == Kind::EQUAL)
      63                 :            :   {
      64                 :      13327 :     response = postRewriteEqual(n);
      65                 :            :   }
      66         [ +  + ]:      13997 :   else if (n.getKind() == Kind::BAG_CHOOSE)
      67                 :            :   {
      68                 :         29 :     response = rewriteChoose(n);
      69                 :            :   }
      70         [ +  + ]:      13968 :   else if (BagsUtils::areChildrenConstants(n))
      71                 :            :   {
      72                 :        706 :     Node value = BagsUtils::evaluate(d_rewriter, n);
      73                 :        706 :     response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
      74                 :            :   }
      75                 :            :   else
      76                 :            :   {
      77                 :      13262 :     Kind k = n.getKind();
      78 [ +  + ][ +  + ]:      13262 :     switch (k)
         [ +  + ][ +  + ]
         [ +  + ][ +  + ]
         [ +  + ][ +  + ]
      79                 :            :     {
      80                 :        598 :       case Kind::BAG_MAKE: response = rewriteMakeBag(n); break;
      81                 :      10607 :       case Kind::BAG_COUNT: response = rewriteBagCount(n); break;
      82                 :         33 :       case Kind::BAG_SETOF: response = rewriteSetof(n); break;
      83                 :        117 :       case Kind::BAG_UNION_MAX: response = rewriteUnionMax(n); break;
      84                 :        515 :       case Kind::BAG_UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
      85                 :         84 :       case Kind::BAG_INTER_MIN: response = rewriteIntersectionMin(n); break;
      86                 :        103 :       case Kind::BAG_DIFFERENCE_SUBTRACT:
      87                 :        103 :         response = rewriteDifferenceSubtract(n);
      88                 :        103 :         break;
      89                 :        119 :       case Kind::BAG_DIFFERENCE_REMOVE:
      90                 :        119 :         response = rewriteDifferenceRemove(n);
      91                 :        119 :         break;
      92                 :         95 :       case Kind::BAG_CARD: response = rewriteCard(n); break;
      93                 :        403 :       case Kind::BAG_MAP: response = postRewriteMap(n); break;
      94                 :        329 :       case Kind::BAG_FILTER: response = postRewriteFilter(n); break;
      95                 :         39 :       case Kind::BAG_FOLD: response = postRewriteFold(n); break;
      96                 :          8 :       case Kind::BAG_PARTITION: response = postRewritePartition(n); break;
      97                 :         42 :       case Kind::TABLE_PRODUCT: response = postRewriteProduct(n); break;
      98                 :          2 :       case Kind::TABLE_AGGREGATE: response = postRewriteAggregate(n); break;
      99                 :        168 :       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     100                 :            :     }
     101                 :            :   }
     102                 :            : 
     103         [ +  - ]:      58108 :   Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
     104                 :      29054 :                         << " by " << response.d_rewrite << "." << std::endl;
     105                 :            : 
     106         [ +  + ]:      29054 :   if (d_statistics != nullptr)
     107                 :            :   {
     108                 :      28963 :     (*d_statistics) << response.d_rewrite;
     109                 :            :   }
     110         [ +  + ]:      29054 :   if (response.d_node != n)
     111                 :            :   {
     112                 :       4250 :     return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
     113                 :            :   }
     114                 :      24804 :   return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
     115                 :            : }
     116                 :            : 
     117                 :      34218 : RewriteResponse BagsRewriter::preRewrite(TNode n)
     118                 :            : {
     119                 :      68436 :   BagsRewriteResponse response;
     120                 :      34218 :   Kind k = n.getKind();
     121 [ +  + ][ +  + ]:      34218 :   switch (k)
     122                 :            :   {
     123                 :      17099 :     case Kind::EQUAL: response = preRewriteEqual(n); break;
     124                 :         32 :     case Kind::BAG_SUBBAG: response = rewriteSubBag(n); break;
     125                 :        172 :     case Kind::BAG_MEMBER: response = rewriteMember(n); break;
     126                 :      16915 :     default: response = BagsRewriteResponse(n, Rewrite::NONE);
     127                 :            :   }
     128                 :            : 
     129         [ +  - ]:      68436 :   Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
     130                 :      34218 :                         << " by " << response.d_rewrite << "." << std::endl;
     131                 :            : 
     132         [ +  + ]:      34218 :   if (d_statistics != nullptr)
     133                 :            :   {
     134                 :      34217 :     (*d_statistics) << response.d_rewrite;
     135                 :            :   }
     136         [ +  + ]:      34218 :   if (response.d_node != n)
     137                 :            :   {
     138                 :        787 :     return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
     139                 :            :   }
     140                 :      33431 :   return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
     141                 :            : }
     142                 :            : 
     143                 :      17099 : BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
     144                 :            : {
     145 [ -  + ][ -  + ]:      17099 :   Assert(n.getKind() == Kind::EQUAL);
                 [ -  - ]
     146         [ +  + ]:      17099 :   if (n[0] == n[1])
     147                 :            :   {
     148                 :            :     // (= A A) = true where A is a bag
     149                 :       1166 :     return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
     150                 :            :   }
     151                 :      16516 :   return BagsRewriteResponse(n, Rewrite::NONE);
     152                 :            : }
     153                 :            : 
     154                 :         32 : BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
     155                 :            : {
     156 [ -  + ][ -  + ]:         32 :   Assert(n.getKind() == Kind::BAG_SUBBAG);
                 [ -  - ]
     157                 :            : 
     158                 :            :   // (bag.subbag A B) = ((bag.difference_subtract A B) == bag.empty)
     159                 :         96 :   Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
     160                 :         96 :   Node subtract = d_nm->mkNode(Kind::BAG_DIFFERENCE_SUBTRACT, n[0], n[1]);
     161                 :         32 :   Node equal = subtract.eqNode(emptybag);
     162                 :         64 :   return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
     163                 :            : }
     164                 :            : 
     165                 :        172 : BagsRewriteResponse BagsRewriter::rewriteMember(const TNode& n) const
     166                 :            : {
     167 [ -  + ][ -  + ]:        172 :   Assert(n.getKind() == Kind::BAG_MEMBER);
                 [ -  - ]
     168                 :            : 
     169                 :            :   // - (bag.member x A) = (>= (bag.count x A) 1)
     170                 :        516 :   Node count = d_nm->mkNode(Kind::BAG_COUNT, n[0], n[1]);
     171                 :        344 :   Node geq = d_nm->mkNode(Kind::GEQ, count, d_one);
     172                 :        344 :   return BagsRewriteResponse(geq, Rewrite::MEMBER);
     173                 :            : }
     174                 :            : 
     175                 :        598 : BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
     176                 :            : {
     177 [ -  + ][ -  + ]:        598 :   Assert(n.getKind() == Kind::BAG_MAKE);
                 [ -  - ]
     178                 :            :   // return bag.empty for negative or zero multiplicity
     179                 :        598 :   if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
     180                 :            :   {
     181                 :            :     // (bag x c) = bag.empty where c <= 0
     182                 :        128 :     Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
     183                 :         64 :     return BagsRewriteResponse(emptybag, Rewrite::BAG_MAKE_COUNT_NEGATIVE);
     184                 :            :   }
     185                 :        534 :   return BagsRewriteResponse(n, Rewrite::NONE);
     186                 :            : }
     187                 :            : 
     188                 :      10607 : BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
     189                 :            : {
     190 [ -  + ][ -  + ]:      10607 :   Assert(n.getKind() == Kind::BAG_COUNT);
                 [ -  - ]
     191                 :      10607 :   if (n[1].isConst() && n[1].getKind() == Kind::BAG_EMPTY)
     192                 :            :   {
     193                 :            :     // (bag.count x bag.empty) = 0
     194                 :        369 :     return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
     195                 :            :   }
     196                 :      11152 :   if (n[1].getKind() == Kind::BAG_MAKE && n[0] == n[1][0] && n[1][1].isConst()
     197                 :      11152 :       && n[1][1].getConst<Rational>() > Rational(0))
     198                 :            :   {
     199                 :            :     // (bag.count x (bag x c)) = c, c > 0 is a constant
     200                 :        124 :     Node c = n[1][1];
     201                 :         62 :     return BagsRewriteResponse(c, Rewrite::COUNT_BAG_MAKE);
     202                 :            :   }
     203                 :      10176 :   return BagsRewriteResponse(n, Rewrite::NONE);
     204                 :            : }
     205                 :            : 
     206                 :         33 : BagsRewriteResponse BagsRewriter::rewriteSetof(const TNode& n) const
     207                 :            : {
     208 [ -  + ][ -  + ]:         33 :   Assert(n.getKind() == Kind::BAG_SETOF);
                 [ -  - ]
     209                 :         34 :   if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
     210                 :         34 :       && n[0][1].getConst<Rational>().sgn() == 1)
     211                 :            :   {
     212                 :            :     // (bag.setof (bag x n)) = (bag x 1)
     213                 :            :     //  where n is a positive constant
     214                 :          2 :     Node bag = d_nm->mkNode(Kind::BAG_MAKE, n[0][0], d_one);
     215                 :          1 :     return BagsRewriteResponse(bag, Rewrite::SETOF_BAG_MAKE);
     216                 :            :   }
     217                 :         32 :   return BagsRewriteResponse(n, Rewrite::NONE);
     218                 :            : }
     219                 :            : 
     220                 :        117 : BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
     221                 :            : {
     222 [ -  + ][ -  + ]:        117 :   Assert(n.getKind() == Kind::BAG_UNION_MAX);
                 [ -  - ]
     223                 :        117 :   if (n[1].getKind() == Kind::BAG_EMPTY || n[0] == n[1])
     224                 :            :   {
     225                 :            :     // (bag.union_max A A) = A
     226                 :            :     // (bag.union_max A bag.empty) = A
     227                 :          2 :     return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
     228                 :            :   }
     229         [ +  + ]:        115 :   if (n[0].getKind() == Kind::BAG_EMPTY)
     230                 :            :   {
     231                 :            :     // (bag.union_max bag.empty A) = A
     232                 :          1 :     return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
     233                 :            :   }
     234                 :            : 
     235 [ +  + ][ -  - ]:        114 :   if ((n[1].getKind() == Kind::BAG_UNION_MAX
     236 [ +  + ][ +  - ]:        226 :        || n[1].getKind() == Kind::BAG_UNION_DISJOINT)
                 [ -  - ]
     237                 :        226 :       && (n[0] == n[1][0] || n[0] == n[1][1]))
     238                 :            :   {
     239                 :            :     // (bag.union_max A (bag.union_max A B)) = (bag.union_max A B)
     240                 :            :     // (bag.union_max A (bag.union_max B A)) = (bag.union_max B A)
     241                 :            :     // (bag.union_max A (bag.union_disjoint A B)) = (bag.union_disjoint A B)
     242                 :            :     // (bag.union_max A (bag.union_disjoint B A)) = (bag.union_disjoint B A)
     243                 :          4 :     return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
     244                 :            :   }
     245                 :            : 
     246 [ +  + ][ -  - ]:        110 :   if ((n[0].getKind() == Kind::BAG_UNION_MAX
     247 [ +  + ][ +  - ]:        218 :        || n[0].getKind() == Kind::BAG_UNION_DISJOINT)
                 [ -  - ]
     248                 :        218 :       && (n[0][0] == n[1] || n[0][1] == n[1]))
     249                 :            :   {
     250                 :            :     // (bag.union_max (bag.union_max A B) A)) = (bag.union_max A B)
     251                 :            :     // (bag.union_max (bag.union_max B A) A)) = (bag.union_max B A)
     252                 :            :     // (bag.union_max (bag.union_disjoint A B) A)) = (bag.union_disjoint A B)
     253                 :            :     // (bag.union_max (bag.union_disjoint B A) A)) = (bag.union_disjoint B A)
     254                 :          4 :     return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
     255                 :            :   }
     256                 :        106 :   return BagsRewriteResponse(n, Rewrite::NONE);
     257                 :            : }
     258                 :            : 
     259                 :        515 : BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
     260                 :            : {
     261 [ -  + ][ -  + ]:        515 :   Assert(n.getKind() == Kind::BAG_UNION_DISJOINT);
                 [ -  - ]
     262         [ +  + ]:        515 :   if (n[1].getKind() == Kind::BAG_EMPTY)
     263                 :            :   {
     264                 :            :     // (bag.union_disjoint A bag.empty) = A
     265                 :         17 :     return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
     266                 :            :   }
     267         [ +  + ]:        498 :   if (n[0].getKind() == Kind::BAG_EMPTY)
     268                 :            :   {
     269                 :            :     // (bag.union_disjoint bag.empty A) = A
     270                 :         21 :     return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
     271                 :            :   }
     272 [ +  + ][ -  - ]:        477 :   if ((n[0].getKind() == Kind::BAG_UNION_MAX
     273 [ -  + ][ +  - ]:        480 :        && n[1].getKind() == Kind::BAG_INTER_MIN)
                 [ -  - ]
     274 [ +  + ][ -  + ]:        957 :       || (n[1].getKind() == Kind::BAG_UNION_MAX
         [ +  + ][ -  - ]
     275 [ -  - ][ -  + ]:        477 :           && n[0].getKind() == Kind::BAG_INTER_MIN))
         [ +  + ][ -  - ]
     276                 :            : 
     277                 :            :   {
     278                 :            :     // (bag.union_disjoint (bag.union_max A B) (bag.inter_min A B)) =
     279                 :            :     //         (bag.union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
     280                 :            :     // check if the operands of bag.union_max and bag.inter_min are the
     281                 :            :     // same
     282                 :          6 :     std::set<Node> left(n[0].begin(), n[0].end());
     283                 :          6 :     std::set<Node> right(n[1].begin(), n[1].end());
     284         [ +  + ]:          3 :     if (left == right)
     285                 :            :     {
     286                 :          4 :       Node rewritten = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, n[0][0], n[0][1]);
     287                 :          2 :       return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
     288                 :            :     }
     289                 :            :   }
     290                 :        475 :   return BagsRewriteResponse(n, Rewrite::NONE);
     291                 :            : }
     292                 :            : 
     293                 :         84 : BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
     294                 :            : {
     295 [ -  + ][ -  + ]:         84 :   Assert(n.getKind() == Kind::BAG_INTER_MIN);
                 [ -  - ]
     296         [ +  + ]:         84 :   if (n[0].getKind() == Kind::BAG_EMPTY)
     297                 :            :   {
     298                 :            :     // (bag.inter_min bag.empty A) = bag.empty
     299                 :          1 :     return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
     300                 :            :   }
     301         [ +  + ]:         83 :   if (n[1].getKind() == Kind::BAG_EMPTY)
     302                 :            :   {
     303                 :            :     // (bag.inter_min A bag.empty) = bag.empty
     304                 :          4 :     return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
     305                 :            :   }
     306         [ +  + ]:         79 :   if (n[0] == n[1])
     307                 :            :   {
     308                 :            :     // (bag.inter_min A A) = A
     309                 :          3 :     return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
     310                 :            :   }
     311 [ +  + ][ -  - ]:         76 :   if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
     312 [ +  + ][ +  + ]:         76 :       || n[1].getKind() == Kind::BAG_UNION_MAX)
         [ +  + ][ +  - ]
                 [ -  - ]
     313                 :            :   {
     314                 :          4 :     if (n[0] == n[1][0] || n[0] == n[1][1])
     315                 :            :     {
     316                 :            :       // (bag.inter_min A (bag.union_disjoint A B)) = A
     317                 :            :       // (bag.inter_min A (bag.union_disjoint B A)) = A
     318                 :            :       // (bag.inter_min A (bag.union_max A B)) = A
     319                 :            :       // (bag.inter_min A (bag.union_max B A)) = A
     320                 :          4 :       return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
     321                 :            :     }
     322                 :            :   }
     323                 :            : 
     324 [ +  + ][ -  - ]:         72 :   if (n[0].getKind() == Kind::BAG_UNION_DISJOINT
     325 [ +  + ][ +  + ]:         72 :       || n[0].getKind() == Kind::BAG_UNION_MAX)
         [ +  + ][ +  - ]
                 [ -  - ]
     326                 :            :   {
     327                 :          4 :     if (n[1] == n[0][0] || n[1] == n[0][1])
     328                 :            :     {
     329                 :            :       // (bag.inter_min (bag.union_disjoint A B) A) = A
     330                 :            :       // (bag.inter_min (bag.union_disjoint B A) A) = A
     331                 :            :       // (bag.inter_min (bag.union_max A B) A) = A
     332                 :            :       // (bag.inter_min (bag.union_max B A) A) = A
     333                 :          4 :       return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
     334                 :            :     }
     335                 :            :   }
     336                 :            : 
     337                 :         68 :   return BagsRewriteResponse(n, Rewrite::NONE);
     338                 :            : }
     339                 :            : 
     340                 :        103 : BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
     341                 :            :     const TNode& n) const
     342                 :            : {
     343 [ -  + ][ -  + ]:        103 :   Assert(n.getKind() == Kind::BAG_DIFFERENCE_SUBTRACT);
                 [ -  - ]
     344                 :        103 :   if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
     345                 :            :   {
     346                 :            :     // (bag.difference_subtract A bag.empty) = A
     347                 :            :     // (bag.difference_subtract bag.empty A) = bag.empty
     348                 :          2 :     return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
     349                 :            :   }
     350         [ +  + ]:        101 :   if (n[0] == n[1])
     351                 :            :   {
     352                 :            :     // (bag.difference_subtract A A) = bag.empty
     353                 :          2 :     Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     354                 :          1 :     return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
     355                 :            :   }
     356                 :            : 
     357         [ +  + ]:        100 :   if (n[0].getKind() == Kind::BAG_UNION_DISJOINT)
     358                 :            :   {
     359         [ +  + ]:          2 :     if (n[1] == n[0][0])
     360                 :            :     {
     361                 :            :       // (bag.difference_subtract (bag.union_disjoint A B) A) = B
     362                 :            :       return BagsRewriteResponse(n[0][1],
     363                 :          1 :                                  Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
     364                 :            :     }
     365         [ +  - ]:          1 :     if (n[1] == n[0][1])
     366                 :            :     {
     367                 :            :       // (bag.difference_subtract (bag.union_disjoint B A) A) = B
     368                 :            :       return BagsRewriteResponse(n[0][0],
     369                 :          1 :                                  Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
     370                 :            :     }
     371                 :            :   }
     372                 :            : 
     373 [ +  + ][ -  - ]:         98 :   if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
     374 [ +  + ][ +  + ]:         98 :       || n[1].getKind() == Kind::BAG_UNION_MAX)
         [ +  + ][ +  - ]
                 [ -  - ]
     375                 :            :   {
     376                 :          4 :     if (n[0] == n[1][0] || n[0] == n[1][1])
     377                 :            :     {
     378                 :            :       // (bag.difference_subtract A (bag.union_disjoint A B)) = bag.empty
     379                 :            :       // (bag.difference_subtract A (bag.union_disjoint B A)) = bag.empty
     380                 :            :       // (bag.difference_subtract A (bag.union_max A B)) = bag.empty
     381                 :            :       // (bag.difference_subtract A (bag.union_max B A)) = bag.empty
     382                 :          8 :       Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     383                 :          4 :       return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
     384                 :            :     }
     385                 :            :   }
     386                 :            : 
     387         [ +  + ]:         94 :   if (n[0].getKind() == Kind::BAG_INTER_MIN)
     388                 :            :   {
     389                 :          2 :     if (n[1] == n[0][0] || n[1] == n[0][1])
     390                 :            :     {
     391                 :            :       // (bag.difference_subtract (bag.inter_min A B) A) = bag.empty
     392                 :            :       // (bag.difference_subtract (bag.inter_min B A) A) = bag.empty
     393                 :          4 :       Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     394                 :          2 :       return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
     395                 :            :     }
     396                 :            :   }
     397                 :            : 
     398                 :         92 :   return BagsRewriteResponse(n, Rewrite::NONE);
     399                 :            : }
     400                 :            : 
     401                 :        119 : BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
     402                 :            : {
     403 [ -  + ][ -  + ]:        119 :   Assert(n.getKind() == Kind::BAG_DIFFERENCE_REMOVE);
                 [ -  - ]
     404                 :            : 
     405                 :        119 :   if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
     406                 :            :   {
     407                 :            :     // (bag.difference_remove A bag.empty) = A
     408                 :            :     // (bag.difference_remove bag.empty B) = bag.empty
     409                 :          2 :     return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
     410                 :            :   }
     411                 :            : 
     412         [ +  + ]:        117 :   if (n[0] == n[1])
     413                 :            :   {
     414                 :            :     // (bag.difference_remove A A) = bag.empty
     415                 :          6 :     Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     416                 :          3 :     return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
     417                 :            :   }
     418                 :            : 
     419 [ +  + ][ -  - ]:        114 :   if (n[1].getKind() == Kind::BAG_UNION_DISJOINT
     420 [ +  + ][ +  + ]:        114 :       || n[1].getKind() == Kind::BAG_UNION_MAX)
         [ +  + ][ +  - ]
                 [ -  - ]
     421                 :            :   {
     422                 :          4 :     if (n[0] == n[1][0] || n[0] == n[1][1])
     423                 :            :     {
     424                 :            :       // (bag.difference_remove A (bag.union_disjoint A B)) = bag.empty
     425                 :            :       // (bag.difference_remove A (bag.union_disjoint B A)) = bag.empty
     426                 :            :       // (bag.difference_remove A (bag.union_max A B)) = bag.empty
     427                 :            :       // (bag.difference_remove A (bag.union_max B A)) = bag.empty
     428                 :          8 :       Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     429                 :          4 :       return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
     430                 :            :     }
     431                 :            :   }
     432                 :            : 
     433         [ +  + ]:        110 :   if (n[0].getKind() == Kind::BAG_INTER_MIN)
     434                 :            :   {
     435                 :          2 :     if (n[1] == n[0][0] || n[1] == n[0][1])
     436                 :            :     {
     437                 :            :       // (bag.difference_remove (bag.inter_min A B) A) = bag.empty
     438                 :            :       // (bag.difference_remove (bag.inter_min B A) A) = bag.empty
     439                 :          4 :       Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
     440                 :          2 :       return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
     441                 :            :     }
     442                 :            :   }
     443                 :            : 
     444                 :        108 :   return BagsRewriteResponse(n, Rewrite::NONE);
     445                 :            : }
     446                 :            : 
     447                 :         29 : BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
     448                 :            : {
     449 [ -  + ][ -  + ]:         29 :   Assert(n.getKind() == Kind::BAG_CHOOSE);
                 [ -  - ]
     450                 :         30 :   if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst()
     451                 :         30 :       && n[0][1].getConst<Rational>() > 0)
     452                 :            :   {
     453                 :            :     // (bag.choose (bag x c)) = x where c is a constant > 0
     454                 :          1 :     return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_BAG_MAKE);
     455                 :            :   }
     456                 :         28 :   return BagsRewriteResponse(n, Rewrite::NONE);
     457                 :            : }
     458                 :            : 
     459                 :         95 : BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
     460                 :            : {
     461 [ -  + ][ -  + ]:         95 :   Assert(n.getKind() == Kind::BAG_CARD);
                 [ -  - ]
     462                 :         95 :   if (n[0].getKind() == Kind::BAG_MAKE && n[0][1].isConst())
     463                 :            :   {
     464                 :            :     // (bag.card (bag x c)) = c where c is a constant > 0
     465                 :          1 :     return BagsRewriteResponse(n[0][1], Rewrite::CARD_BAG_MAKE);
     466                 :            :   }
     467                 :            : 
     468                 :         94 :   return BagsRewriteResponse(n, Rewrite::NONE);
     469                 :            : }
     470                 :            : 
     471                 :      13327 : BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
     472                 :            : {
     473 [ -  + ][ -  + ]:      13327 :   Assert(n.getKind() == Kind::EQUAL);
                 [ -  - ]
     474         [ +  + ]:      13327 :   if (n[0] == n[1])
     475                 :            :   {
     476                 :         33 :     Node ret = d_nm->mkConst(true);
     477                 :         33 :     return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
     478                 :            :   }
     479                 :            : 
     480                 :      13294 :   if (n[0].isConst() && n[1].isConst())
     481                 :            :   {
     482                 :         88 :     Node ret = d_nm->mkConst(false);
     483                 :         88 :     return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
     484                 :            :   }
     485                 :            : 
     486                 :            :   // standard ordering
     487         [ +  + ]:      13206 :   if (n[0] > n[1])
     488                 :            :   {
     489                 :       5570 :     Node ret = d_nm->mkNode(Kind::EQUAL, n[1], n[0]);
     490                 :       2785 :     return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
     491                 :            :   }
     492                 :      10421 :   return BagsRewriteResponse(n, Rewrite::NONE);
     493                 :            : }
     494                 :            : 
     495                 :        403 : BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
     496                 :            : {
     497 [ -  + ][ -  + ]:        403 :   Assert(n.getKind() == Kind::BAG_MAP);
                 [ -  - ]
     498         [ +  + ]:        403 :   if (n[1].isConst())
     499                 :            :   {
     500                 :            :     // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
     501                 :            :     // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
     502                 :         56 :     std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
     503                 :         56 :     std::map<Node, Rational> mappedElements;
     504                 :         28 :     std::map<Node, Rational>::iterator it = elements.begin();
     505         [ +  + ]:         37 :     while (it != elements.end())
     506                 :            :     {
     507                 :         27 :       Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], it->first);
     508                 :          9 :       mappedElements[mappedElement] = it->second;
     509                 :          9 :       ++it;
     510                 :            :     }
     511                 :         84 :     TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
     512                 :         28 :     Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
     513                 :         28 :     return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
     514                 :            :   }
     515                 :        375 :   Kind k = n[1].getKind();
     516    [ -  + ][ + ]:        375 :   switch (k)
     517                 :            :   {
     518                 :          0 :     case Kind::BAG_MAKE:
     519                 :            :     {
     520                 :            :       // (bag.map f (bag x y)) = (bag (apply f x) y)
     521                 :          0 :       Node mappedElement = d_nm->mkNode(Kind::APPLY_UF, n[0], n[1][0]);
     522                 :          0 :       Node ret = d_nm->mkNode(Kind::BAG_MAKE, mappedElement, n[1][1]);
     523                 :          0 :       return BagsRewriteResponse(ret, Rewrite::MAP_BAG_MAKE);
     524                 :            :     }
     525                 :            : 
     526                 :          1 :     case Kind::BAG_UNION_DISJOINT:
     527                 :            :     {
     528                 :            :       // (bag.map f (bag.union_disjoint A B)) =
     529                 :            :       //    (bag.union_disjoint (bag.map f A) (bag.map f B))
     530                 :          3 :       Node a = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][0]);
     531                 :          3 :       Node b = d_nm->mkNode(Kind::BAG_MAP, n[0], n[1][1]);
     532                 :          2 :       Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
     533                 :          1 :       return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
     534                 :            :     }
     535                 :            : 
     536                 :        374 :     default: return BagsRewriteResponse(n, Rewrite::NONE);
     537                 :            :   }
     538                 :            : }
     539                 :            : 
     540                 :        329 : BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const
     541                 :            : {
     542 [ -  + ][ -  + ]:        329 :   Assert(n.getKind() == Kind::BAG_FILTER);
                 [ -  - ]
     543                 :        658 :   Node P = n[0];
     544                 :        658 :   Node A = n[1];
     545                 :        658 :   TypeNode t = A.getType();
     546         [ +  + ]:        329 :   if (A.isConst())
     547                 :            :   {
     548                 :            :     // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
     549                 :            :     // (bag.filter p (bag "a" 3) ((bag "b" 2))) =
     550                 :            :     //   (bag.union_disjoint
     551                 :            :     //     (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
     552                 :            :     //     (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
     553                 :            : 
     554                 :          3 :     Node ret = BagsUtils::evaluateBagFilter(n);
     555                 :          3 :     return BagsRewriteResponse(ret, Rewrite::FILTER_CONST);
     556                 :            :   }
     557                 :        326 :   Kind k = A.getKind();
     558    [ -  - ][ + ]:        326 :   switch (k)
     559                 :            :   {
     560                 :          0 :     case Kind::BAG_MAKE:
     561                 :            :     {
     562                 :            :       // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
     563                 :          0 :       Node empty = d_nm->mkConst(EmptyBag(t));
     564                 :          0 :       Node pOfe = d_nm->mkNode(Kind::APPLY_UF, P, A[0]);
     565                 :          0 :       Node ret = d_nm->mkNode(Kind::ITE, pOfe, A, empty);
     566                 :          0 :       return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE);
     567                 :            :     }
     568                 :            : 
     569                 :          0 :     case Kind::BAG_UNION_DISJOINT:
     570                 :            :     {
     571                 :            :       // (bag.filter p (bag.union_disjoint A B)) =
     572                 :            :       //    (bag.union_disjoint (bag.filter p A) (bag.filter p B))
     573                 :          0 :       Node a = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][0]);
     574                 :          0 :       Node b = d_nm->mkNode(Kind::BAG_FILTER, n[0], n[1][1]);
     575                 :          0 :       Node ret = d_nm->mkNode(Kind::BAG_UNION_DISJOINT, a, b);
     576                 :          0 :       return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT);
     577                 :            :     }
     578                 :            : 
     579                 :        326 :     default: return BagsRewriteResponse(n, Rewrite::NONE);
     580                 :            :   }
     581                 :            : }
     582                 :            : 
     583                 :         39 : BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
     584                 :            : {
     585 [ -  + ][ -  + ]:         39 :   Assert(n.getKind() == Kind::BAG_FOLD);
                 [ -  - ]
     586                 :         78 :   Node f = n[0];
     587                 :         78 :   Node t = n[1];
     588                 :         78 :   Node bag = n[2];
     589         [ +  + ]:         39 :   if (bag.isConst())
     590                 :            :   {
     591                 :          9 :     Node value = BagsUtils::evaluateBagFold(n);
     592                 :          9 :     return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
     593                 :            :   }
     594                 :         30 :   Kind k = bag.getKind();
     595    [ +  + ][ + ]:         30 :   switch (k)
     596                 :            :   {
     597                 :          1 :     case Kind::BAG_MAKE:
     598                 :            :     {
     599                 :          1 :       if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
     600                 :            :       {
     601                 :            :         // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
     602                 :          1 :         Node value = BagsUtils::evaluateBagFold(n);
     603                 :          1 :         return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
     604                 :            :       }
     605                 :          0 :       break;
     606                 :            :     }
     607                 :          1 :     case Kind::BAG_UNION_DISJOINT:
     608                 :            :     {
     609                 :            :       // (bag.fold f t (bag.union_disjoint A B)) =
     610                 :            :       //       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
     611         [ +  - ]:          3 :       Node A = bag[0] < bag[1] ? bag[0] : bag[1];
     612         [ +  - ]:          3 :       Node B = bag[0] < bag[1] ? bag[1] : bag[0];
     613                 :          3 :       Node foldA = d_nm->mkNode(Kind::BAG_FOLD, f, t, A);
     614                 :          2 :       Node fold = d_nm->mkNode(Kind::BAG_FOLD, f, foldA, B);
     615                 :          1 :       return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT);
     616                 :            :     }
     617                 :         28 :     default: return BagsRewriteResponse(n, Rewrite::NONE);
     618                 :            :   }
     619                 :          0 :   return BagsRewriteResponse(n, Rewrite::NONE);
     620                 :            : }
     621                 :            : 
     622                 :          8 : BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
     623                 :            : {
     624 [ -  + ][ -  + ]:          8 :   Assert(n.getKind() == Kind::BAG_PARTITION);
                 [ -  - ]
     625         [ +  + ]:          8 :   if (n[1].isConst())
     626                 :            :   {
     627                 :          4 :     Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n);
     628         [ +  - ]:          4 :     if (ret != n)
     629                 :            :     {
     630                 :          4 :       return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST);
     631                 :            :     }
     632                 :            :   }
     633                 :            : 
     634                 :          4 :   return BagsRewriteResponse(n, Rewrite::NONE);
     635                 :            : }
     636                 :            : 
     637                 :          2 : BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const
     638                 :            : {
     639 [ -  + ][ -  + ]:          2 :   Assert(n.getKind() == Kind::TABLE_AGGREGATE);
                 [ -  - ]
     640                 :          2 :   if (n[1].isConst() && n[2].isConst())
     641                 :            :   {
     642                 :          2 :     Node ret = BagsUtils::evaluateTableAggregate(d_rewriter, n);
     643         [ +  - ]:          2 :     if (ret != n)
     644                 :            :     {
     645                 :          2 :       return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST);
     646                 :            :     }
     647                 :            :   }
     648                 :            : 
     649                 :          0 :   return BagsRewriteResponse(n, Rewrite::NONE);
     650                 :            : }
     651                 :            : 
     652                 :         42 : BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
     653                 :            : {
     654 [ -  + ][ -  + ]:         42 :   Assert(n.getKind() == Kind::TABLE_PRODUCT);
                 [ -  - ]
     655                 :         84 :   TypeNode tableType = n.getType();
     656                 :         84 :   Node empty = d_nm->mkConst(EmptyBag(tableType));
     657                 :         42 :   if (n[0].getKind() == Kind::BAG_EMPTY || n[1].getKind() == Kind::BAG_EMPTY)
     658                 :            :   {
     659                 :          2 :     return BagsRewriteResponse(empty, Rewrite::PRODUCT_EMPTY);
     660                 :            :   }
     661                 :            : 
     662                 :         40 :   return BagsRewriteResponse(n, Rewrite::NONE);
     663                 :            : }
     664                 :            : 
     665                 :            : }  // namespace bags
     666                 :            : }  // namespace theory
     667                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14