Branch data Line data Source code
1 : : /****************************************************************************** 2 : : * Top contributors (to current version): 3 : : * Mudathir Mohamed, Aina Niemetz, Andrew Reynolds 4 : : * 5 : : * This file is part of the cvc5 project. 6 : : * 7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS 8 : : * in the top-level source directory and their institutional affiliations. 9 : : * All rights reserved. See the file COPYING in the top-level source 10 : : * directory for licensing information. 11 : : * **************************************************************************** 12 : : * 13 : : * Implementation of bags state object. 14 : : */ 15 : : 16 : : #include "theory/bags/solver_state.h" 17 : : 18 : : #include "expr/attribute.h" 19 : : #include "expr/bound_var_manager.h" 20 : : #include "expr/skolem_manager.h" 21 : : #include "theory/smt_engine_subsolver.h" 22 : : #include "theory/uf/equality_engine.h" 23 : : 24 : : using namespace std; 25 : : using namespace cvc5::internal::kind; 26 : : 27 : : namespace cvc5::internal { 28 : : namespace theory { 29 : : namespace bags { 30 : : 31 : 50760 : SolverState::SolverState(Env& env, Valuation val) 32 : 50760 : : TheoryState(env, val), d_partElementSkolems(env.getUserContext()) 33 : : { 34 : 50760 : d_true = nodeManager()->mkConst(true); 35 : 50760 : d_false = nodeManager()->mkConst(false); 36 : 50760 : d_nm = nodeManager(); 37 : 50760 : } 38 : : 39 : 8320 : void SolverState::registerBag(TNode n) 40 : : { 41 [ - + ][ - + ]: 8320 : Assert(n.getType().isBag()); [ - - ] 42 : 8320 : d_bags.insert(n); 43 : 8320 : } 44 : : 45 : 60217 : void SolverState::registerCountTerm(Node bag, Node element, Node skolem) 46 : : { 47 : 120434 : Assert(bag.getType().isBag() && bag == getRepresentative(bag)); 48 : 120434 : Assert(element.getType() == bag.getType().getBagElementType() 49 : : && element == getRepresentative(element)); 50 : 120434 : Assert(skolem.isVar() && skolem.getType().isInteger()); 51 : 120434 : std::pair<Node, Node> pair = std::make_pair(element, skolem); 52 : 60217 : if (std::find(d_bagElements[bag].begin(), d_bagElements[bag].end(), pair) 53 [ + + ]: 120434 : == d_bagElements[bag].end()) 54 : : { 55 : 17266 : d_bagElements[bag].push_back(pair); 56 : : } 57 : 60217 : } 58 : : 59 : 586 : void SolverState::registerGroupTerm(Node n) 60 : : { 61 : : std::shared_ptr<context::CDHashSet<Node>> set = 62 : 1172 : std::make_shared<context::CDHashSet<Node>>(d_env.getUserContext()); 63 : 586 : d_partElementSkolems[n] = set; 64 : 586 : } 65 : : 66 : 0 : void SolverState::registerCardinalityTerm(Node n, Node skolem) 67 : : { 68 : 0 : Assert(n.getKind() == Kind::BAG_CARD); 69 : 0 : Assert(skolem.isVar()); 70 : 0 : d_cardTerms[n] = skolem; 71 : 0 : } 72 : : 73 : 0 : Node SolverState::getCardinalitySkolem(Node n) 74 : : { 75 : 0 : Assert(n.getKind() == Kind::BAG_CARD); 76 : 0 : Node bag = getRepresentative(n[0]); 77 : 0 : Node cardTerm = d_nm->mkNode(Kind::BAG_CARD, bag); 78 : 0 : return d_cardTerms[cardTerm]; 79 : : } 80 : : 81 : 0 : bool SolverState::hasCardinalityTerms() const { return !d_cardTerms.empty(); } 82 : : 83 : 210364 : const std::set<Node>& SolverState::getBags() { return d_bags; } 84 : : 85 : 0 : const std::map<Node, Node>& SolverState::getCardinalityTerms() 86 : : { 87 : 0 : return d_cardTerms; 88 : : } 89 : : 90 : 14075 : std::set<Node> SolverState::getElements(Node B) 91 : : { 92 : 42225 : Node bag = getRepresentative(B); 93 : 14075 : std::set<Node> elements; 94 : 28150 : std::vector<std::pair<Node, Node>> pairs = d_bagElements[bag]; 95 [ + + ]: 44199 : for (std::pair<Node, Node> pair : pairs) 96 : : { 97 : 30124 : elements.insert(pair.first); 98 : : } 99 : 28150 : return elements; 100 : : } 101 : : 102 : 384 : const std::vector<std::pair<Node, Node>>& SolverState::getElementCountPairs( 103 : : Node n) 104 : : { 105 : 768 : Node bag = getRepresentative(n); 106 : 768 : return d_bagElements[bag]; 107 : : } 108 : : 109 : : struct BagsDeqAttributeId 110 : : { 111 : : }; 112 : : typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute; 113 : : 114 : 42937 : void SolverState::collectDisequalBagTerms() 115 : : { 116 : 42937 : eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee); 117 [ + + ]: 113083 : while (!it.isFinished()) 118 : : { 119 : 140292 : Node n = (*it); 120 : 70146 : if (n.getKind() == Kind::EQUAL && n[0].getType().isBag()) 121 : : { 122 [ + - ]: 15452 : Trace("bags-eqc") << "Disequal terms: " << n << std::endl; 123 : 46356 : Node A = getRepresentative(n[0]); 124 : 46356 : Node B = getRepresentative(n[1]); 125 [ + + ]: 30904 : Node equal = A <= B ? A.eqNode(B) : B.eqNode(A); 126 [ + + ]: 15452 : if (d_deq.find(equal) == d_deq.end()) 127 : : { 128 : 4240 : SkolemManager* sm = d_nm->getSkolemManager(); 129 : 21200 : Node skolem = sm->mkSkolemFunction(SkolemId::BAGS_DEQ_DIFF, {A, B}); 130 : 4240 : d_deq[equal] = skolem; 131 : : } 132 : : } 133 : 70146 : ++it; 134 : : } 135 : 42937 : } 136 : : 137 : 42308 : const std::map<Node, Node>& SolverState::getDisequalBagTerms() { return d_deq; } 138 : : 139 : 177 : void SolverState::registerPartElementSkolem(Node group, Node skolemElement) 140 : : { 141 [ - + ][ - + ]: 177 : Assert(group.getKind() == Kind::TABLE_GROUP); [ - - ] 142 [ - + ][ - + ]: 177 : Assert(skolemElement.getType() == group[0].getType().getBagElementType()); [ - - ] 143 : 177 : d_partElementSkolems[group].get()->insert(skolemElement); 144 : 177 : } 145 : : 146 : 330 : std::shared_ptr<context::CDHashSet<Node>> SolverState::getPartElementSkolems( 147 : : Node n) 148 : : { 149 [ - + ][ - + ]: 330 : Assert(n.getKind() == Kind::TABLE_GROUP); [ - - ] 150 : 330 : return d_partElementSkolems[n]; 151 : : } 152 : : 153 : 42937 : void SolverState::reset() 154 : : { 155 : 42937 : d_bagElements.clear(); 156 : 42937 : d_bags.clear(); 157 : 42937 : d_deq.clear(); 158 : 42937 : d_cardTerms.clear(); 159 : 42937 : } 160 : : 161 : 24 : void SolverState::checkInjectivity(Node n) 162 : : { 163 : 24 : SkolemManager* sm = d_nm->getSkolemManager(); 164 : 24 : Node f = sm->getOriginalForm(n); 165 [ + + ]: 24 : if (d_functions.find(f) != d_functions.end()) 166 : : { 167 : : // we already know f 168 : 4 : return; 169 : : } 170 : : 171 [ + + ]: 20 : if (f.isVar()) 172 : : { 173 : : // no need to solve. f can be assigned any non injective function 174 : 6 : d_functions[f] = false; 175 : 6 : return; 176 : : } 177 : : 178 : 42 : TypeNode domainType = f.getType().getArgTypes()[0]; 179 : 42 : Node x = NodeManager::mkDummySkolem("x", domainType); 180 : 42 : Node y = NodeManager::mkDummySkolem("y", domainType); 181 : 42 : Node f_x = d_nm->mkNode(Kind::APPLY_UF, f, x); 182 : 42 : Node f_y = d_nm->mkNode(Kind::APPLY_UF, f, y); 183 : 28 : Node f_x_equals_f_y = f_x.eqNode(f_y); 184 : 28 : Node not_x_equals_y = x.eqNode(y).notNode(); 185 : 28 : Node query = f_x_equals_f_y.andNode(not_x_equals_y); 186 : : 187 : 28 : Options subOptions; 188 : 14 : subOptions.copyValues(d_env.getOptions()); 189 : 28 : SubsolverSetupInfo ssi(d_env, subOptions); 190 : 28 : Result result = checkWithSubsolver(query, ssi); 191 [ + + ]: 14 : if (result.getStatus() == Result::Status::UNSAT) 192 : : { 193 : 3 : d_functions[f] = true; 194 : : } 195 : : else 196 : : { 197 : 11 : d_functions[f] = false; 198 : : } 199 : : } 200 : : 201 : 336 : bool SolverState::isInjective(Node n) const 202 : : { 203 : 1008 : Node f = d_nm->getSkolemManager()->getOriginalForm(n); 204 [ + - ]: 336 : if (d_functions.find(f) != d_functions.end()) 205 : : { 206 : 336 : return d_functions.at(f); 207 : : } 208 : 0 : return false; 209 : : } 210 : : 211 : : } // namespace bags 212 : : } // namespace theory 213 : : } // namespace cvc5::internal