LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/bags - bag_reduction.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 125 127 98.4 %
Date: 2026-02-27 11:41:18 Functions: 4 6 66.7 %
Branches: 8 24 33.3 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * This file is part of the cvc5 project.
       3                 :            :  *
       4                 :            :  * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
       5                 :            :  * in the top-level source directory and their institutional affiliations.
       6                 :            :  * All rights reserved.  See the file COPYING in the top-level source
       7                 :            :  * directory for licensing information.
       8                 :            :  * ****************************************************************************
       9                 :            :  *
      10                 :            :  * bag reduction.
      11                 :            :  */
      12                 :            : 
      13                 :            : #include "theory/bags/bag_reduction.h"
      14                 :            : 
      15                 :            : #include "expr/bound_var_manager.h"
      16                 :            : #include "expr/emptybag.h"
      17                 :            : #include "expr/skolem_manager.h"
      18                 :            : #include "theory/datatypes/project_op.h"
      19                 :            : #include "theory/datatypes/tuple_utils.h"
      20                 :            : #include "theory/quantifiers/fmf/bounded_integers.h"
      21                 :            : #include "util/rational.h"
      22                 :            : 
      23                 :            : using namespace cvc5::internal;
      24                 :            : using namespace cvc5::internal::kind;
      25                 :            : 
      26                 :            : namespace cvc5::internal {
      27                 :            : namespace theory {
      28                 :            : namespace bags {
      29                 :            : 
      30                 :          0 : BagReduction::BagReduction() {}
      31                 :            : 
      32                 :          0 : BagReduction::~BagReduction() {}
      33                 :            : 
      34                 :          8 : Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
      35                 :            : {
      36 [ -  + ][ -  + ]:          8 :   Assert(node.getKind() == Kind::BAG_FOLD);
                 [ -  - ]
      37                 :          8 :   NodeManager* nm = node.getNodeManager();
      38                 :          8 :   SkolemManager* sm = nm->getSkolemManager();
      39                 :          8 :   Node f = node[0];
      40                 :          8 :   Node t = node[1];
      41                 :          8 :   Node A = node[2];
      42                 :          8 :   Node zero = nm->mkConstInt(Rational(0));
      43                 :          8 :   Node one = nm->mkConstInt(Rational(1));
      44                 :            :   // skolem functions
      45                 :          8 :   Node n = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_CARD, A);
      46                 :          8 :   Node elements = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_ELEMENTS, A);
      47                 :            :   Node unionDisjoint =
      48                 :          8 :       sm->mkSkolemFunction(SkolemId::BAGS_FOLD_UNION_DISJOINT, A);
      49                 :         40 :   Node combine = sm->mkSkolemFunction(SkolemId::BAGS_FOLD_COMBINE, {f, t, A});
      50                 :            : 
      51                 :          8 :   BoundVarManager* bvm = nm->getBoundVarManager();
      52                 :            :   Node i = bvm->mkBoundVar(
      53                 :         16 :       BoundVarId::BAGS_FIRST_INDEX, node, "i", nm->integerType());
      54                 :          8 :   Node iList = nm->mkNode(Kind::BOUND_VAR_LIST, i);
      55                 :         16 :   Node iMinusOne = nm->mkNode(Kind::SUB, i, one);
      56                 :         16 :   Node elements_i = nm->mkNode(Kind::APPLY_UF, elements, i);
      57                 :         16 :   Node combine_0 = nm->mkNode(Kind::APPLY_UF, combine, zero);
      58                 :         16 :   Node combine_iMinusOne = nm->mkNode(Kind::APPLY_UF, combine, iMinusOne);
      59                 :         16 :   Node combine_i = nm->mkNode(Kind::APPLY_UF, combine, i);
      60                 :         16 :   Node combine_n = nm->mkNode(Kind::APPLY_UF, combine, n);
      61                 :         16 :   Node unionDisjoint_0 = nm->mkNode(Kind::APPLY_UF, unionDisjoint, zero);
      62                 :            :   Node unionDisjoint_iMinusOne =
      63                 :         16 :       nm->mkNode(Kind::APPLY_UF, unionDisjoint, iMinusOne);
      64                 :         16 :   Node unionDisjoint_i = nm->mkNode(Kind::APPLY_UF, unionDisjoint, i);
      65                 :         16 :   Node unionDisjoint_n = nm->mkNode(Kind::APPLY_UF, unionDisjoint, n);
      66                 :          8 :   Node combine_0_equal = combine_0.eqNode(t);
      67                 :            :   Node combine_i_equal = combine_i.eqNode(
      68                 :         16 :       nm->mkNode(Kind::APPLY_UF, f, elements_i, combine_iMinusOne));
      69                 :            :   Node unionDisjoint_0_equal =
      70                 :         16 :       unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(A.getType())));
      71                 :         16 :   Node singleton = nm->mkNode(Kind::BAG_MAKE, elements_i, one);
      72                 :            : 
      73                 :            :   Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
      74                 :         16 :       nm->mkNode(Kind::BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
      75                 :            :   Node interval_i = nm->mkNode(
      76                 :         16 :       Kind::AND, nm->mkNode(Kind::GEQ, i, one), nm->mkNode(Kind::LEQ, i, n));
      77                 :            : 
      78                 :            :   Node body_i =
      79                 :            :       nm->mkNode(Kind::IMPLIES,
      80                 :            :                  interval_i,
      81                 :         16 :                  nm->mkNode(Kind::AND, combine_i_equal, unionDisjoint_i_equal));
      82                 :            :   Node forAll_i =
      83                 :         16 :       quantifiers::BoundedIntegers::mkBoundedForall(nm, iList, body_i);
      84                 :         16 :   Node nonNegative = nm->mkNode(Kind::GEQ, n, zero);
      85                 :          8 :   Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
      86                 :          8 :   asserts.push_back(forAll_i);
      87                 :          8 :   asserts.push_back(combine_0_equal);
      88                 :          8 :   asserts.push_back(unionDisjoint_0_equal);
      89                 :          8 :   asserts.push_back(unionDisjoint_n_equal);
      90                 :          8 :   asserts.push_back(nonNegative);
      91                 :         16 :   return combine_n;
      92                 :          8 : }
      93                 :            : 
      94                 :         33 : Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
      95                 :            : {
      96 [ -  + ][ -  + ]:         33 :   Assert(node.getKind() == Kind::BAG_CARD);
                 [ -  - ]
      97                 :         33 :   NodeManager* nm = node.getNodeManager();
      98                 :         33 :   SkolemManager* sm = nm->getSkolemManager();
      99                 :         33 :   Node A = node[0];
     100                 :         33 :   Node zero = nm->mkConstInt(Rational(0));
     101                 :         33 :   Node one = nm->mkConstInt(Rational(1));
     102                 :            :   // types
     103                 :         33 :   TypeNode bagType = A.getType();
     104                 :            :   // skolem functions
     105                 :         33 :   Node n = sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS_SIZE, A);
     106                 :         33 :   Node elements = sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS, A);
     107                 :            :   Node unionDisjoint =
     108                 :         33 :       sm->mkSkolemFunction(SkolemId::BAGS_DISTINCT_ELEMENTS_UNION_DISJOINT, A);
     109                 :         33 :   Node combine = sm->mkSkolemFunction(SkolemId::BAGS_CARD_COMBINE, A);
     110                 :            : 
     111                 :         33 :   BoundVarManager* bvm = nm->getBoundVarManager();
     112                 :            :   Node i = bvm->mkBoundVar(
     113                 :         66 :       BoundVarId::BAGS_FIRST_INDEX, node, "i", nm->integerType());
     114                 :            :   Node j = bvm->mkBoundVar(
     115                 :         66 :       BoundVarId::BAGS_SECOND_INDEX, node, "j", nm->integerType());
     116                 :         33 :   Node iList = nm->mkNode(Kind::BOUND_VAR_LIST, i);
     117                 :         33 :   Node jList = nm->mkNode(Kind::BOUND_VAR_LIST, j);
     118                 :         66 :   Node iMinusOne = nm->mkNode(Kind::SUB, i, one);
     119                 :         66 :   Node elements_i = nm->mkNode(Kind::APPLY_UF, elements, i);
     120                 :         66 :   Node elements_j = nm->mkNode(Kind::APPLY_UF, elements, j);
     121                 :         66 :   Node combine_0 = nm->mkNode(Kind::APPLY_UF, combine, zero);
     122                 :         66 :   Node combine_iMinusOne = nm->mkNode(Kind::APPLY_UF, combine, iMinusOne);
     123                 :         66 :   Node combine_i = nm->mkNode(Kind::APPLY_UF, combine, i);
     124                 :         66 :   Node combine_n = nm->mkNode(Kind::APPLY_UF, combine, n);
     125                 :         66 :   Node unionDisjoint_0 = nm->mkNode(Kind::APPLY_UF, unionDisjoint, zero);
     126                 :            :   Node unionDisjoint_iMinusOne =
     127                 :         66 :       nm->mkNode(Kind::APPLY_UF, unionDisjoint, iMinusOne);
     128                 :         66 :   Node unionDisjoint_i = nm->mkNode(Kind::APPLY_UF, unionDisjoint, i);
     129                 :         66 :   Node unionDisjoint_n = nm->mkNode(Kind::APPLY_UF, unionDisjoint, n);
     130                 :         33 :   Node combine_0_equal = combine_0.eqNode(zero);
     131                 :         66 :   Node elements_i_multiplicity = nm->mkNode(Kind::BAG_COUNT, elements_i, A);
     132                 :            :   Node combine_i_equal = combine_i.eqNode(
     133                 :         66 :       nm->mkNode(Kind::ADD, elements_i_multiplicity, combine_iMinusOne));
     134                 :            :   Node unionDisjoint_0_equal =
     135                 :         66 :       unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
     136                 :         66 :   Node bag = nm->mkNode(Kind::BAG_MAKE, elements_i, elements_i_multiplicity);
     137                 :            : 
     138                 :            :   Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
     139                 :         66 :       nm->mkNode(Kind::BAG_UNION_DISJOINT, bag, unionDisjoint_iMinusOne));
     140                 :            :   // 1 <= i <= n
     141                 :            :   Node interval_i = nm->mkNode(
     142                 :         66 :       Kind::AND, nm->mkNode(Kind::GEQ, i, one), nm->mkNode(Kind::LEQ, i, n));
     143                 :            : 
     144                 :            :   // i < j <= n
     145                 :            :   Node interval_j = nm->mkNode(
     146                 :         66 :       Kind::AND, nm->mkNode(Kind::LT, i, j), nm->mkNode(Kind::LEQ, j, n));
     147                 :            :   // elements(i) != elements(j)
     148                 :            :   Node elements_i_equals_elements_j =
     149                 :         66 :       nm->mkNode(Kind::EQUAL, elements_i, elements_j);
     150                 :         66 :   Node notEqual = nm->mkNode(Kind::EQUAL, elements_i, elements_j).negate();
     151                 :         66 :   Node body_j = nm->mkNode(Kind::OR, interval_j.negate(), notEqual);
     152                 :            :   Node forAll_j =
     153                 :         66 :       quantifiers::BoundedIntegers::mkBoundedForall(nm, jList, body_j);
     154                 :            :   Node body_i = nm->mkNode(
     155                 :            :       Kind::IMPLIES,
     156                 :            :       interval_i,
     157                 :         66 :       nm->mkNode(Kind::AND, combine_i_equal, unionDisjoint_i_equal, forAll_j));
     158                 :            :   Node forAll_i =
     159                 :         66 :       quantifiers::BoundedIntegers::mkBoundedForall(nm, iList, body_i);
     160                 :         66 :   Node nonNegative = nm->mkNode(Kind::GEQ, n, zero);
     161                 :         33 :   Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
     162                 :         33 :   asserts.push_back(forAll_i);
     163                 :         33 :   asserts.push_back(combine_0_equal);
     164                 :         33 :   asserts.push_back(unionDisjoint_0_equal);
     165                 :         33 :   asserts.push_back(unionDisjoint_n_equal);
     166                 :         33 :   asserts.push_back(nonNegative);
     167                 :         66 :   return combine_n;
     168                 :         33 : }
     169                 :            : 
     170                 :          2 : Node BagReduction::reduceAggregateOperator(Node node)
     171                 :            : {
     172 [ -  + ][ -  + ]:          2 :   Assert(node.getKind() == Kind::TABLE_AGGREGATE);
                 [ -  - ]
     173                 :          2 :   NodeManager* nm = node.getNodeManager();
     174                 :          2 :   BoundVarManager* bvm = nm->getBoundVarManager();
     175                 :          2 :   Node function = node[0];
     176                 :          4 :   TypeNode elementType = function.getType().getArgTypes()[0];
     177                 :          2 :   Node initialValue = node[1];
     178                 :          2 :   Node A = node[2];
     179                 :          2 :   ProjectOp op = node.getOperator().getConst<ProjectOp>();
     180                 :            : 
     181                 :          2 :   Node groupOp = nm->mkConst(Kind::TABLE_GROUP_OP, op);
     182                 :          6 :   Node group = nm->mkNode(Kind::TABLE_GROUP, {groupOp, A});
     183                 :            : 
     184                 :            :   Node bag = bvm->mkBoundVar(
     185                 :          4 :       BoundVarId::BAGS_FIRST_INDEX, group, "bag", nm->mkBagType(elementType));
     186                 :          2 :   Node foldList = nm->mkNode(Kind::BOUND_VAR_LIST, bag);
     187                 :          4 :   Node foldBody = nm->mkNode(Kind::BAG_FOLD, function, initialValue, bag);
     188                 :            : 
     189                 :          4 :   Node fold = nm->mkNode(Kind::LAMBDA, foldList, foldBody);
     190                 :          4 :   Node map = nm->mkNode(Kind::BAG_MAP, fold, group);
     191                 :          4 :   return map;
     192                 :          2 : }
     193                 :            : 
     194                 :         25 : Node BagReduction::reduceProjectOperator(Node n)
     195                 :            : {
     196 [ -  + ][ -  + ]:         25 :   Assert(n.getKind() == Kind::TABLE_PROJECT);
                 [ -  - ]
     197                 :         25 :   NodeManager* nm = n.getNodeManager();
     198                 :         25 :   Node A = n[0];
     199                 :         25 :   TypeNode elementType = A.getType().getBagElementType();
     200                 :         25 :   ProjectOp projectOp = n.getOperator().getConst<ProjectOp>();
     201                 :         25 :   Node op = nm->mkConst(Kind::TUPLE_PROJECT_OP, projectOp);
     202                 :         50 :   Node t = NodeManager::mkBoundVar("t", elementType);
     203                 :         50 :   Node projection = nm->mkNode(Kind::TUPLE_PROJECT, op, t);
     204                 :            :   Node lambda =
     205                 :         50 :       nm->mkNode(Kind::LAMBDA, nm->mkNode(Kind::BOUND_VAR_LIST, t), projection);
     206                 :         50 :   Node setMap = nm->mkNode(Kind::BAG_MAP, lambda, A);
     207                 :         50 :   return setMap;
     208                 :         25 : }
     209                 :            : 
     210                 :            : }  // namespace bags
     211                 :            : }  // namespace theory
     212                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14