Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Gereon Kremer, Andrew Reynolds, Andres Noetzli
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * Extract bounds on variables from theory atoms.
14 : : */
15 : :
16 : : #include "theory/arith/bound_inference.h"
17 : :
18 : : #include "smt/env.h"
19 : : #include "theory/arith/arith_utilities.h"
20 : : #include "theory/arith/linear/normal_form.h"
21 : : #include "theory/rewriter.h"
22 : :
23 : : using namespace cvc5::internal::kind;
24 : :
25 : : namespace cvc5::internal {
26 : : namespace theory {
27 : : namespace arith {
28 : :
29 : 0 : std::ostream& operator<<(std::ostream& os, const Bounds& b) {
30 [ - - ]: 0 : return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
31 [ - - ]: 0 : << b.upper_value << (b.upper_strict ? ')' : ']');
32 : : }
33 : :
34 : 45231 : BoundInference::BoundInference(Env& env) : EnvObj(env) {}
35 : :
36 : 10 : void BoundInference::reset() { d_bounds.clear(); }
37 : :
38 : 428187 : Bounds& BoundInference::get_or_add(const Node& lhs)
39 : : {
40 : 428187 : auto it = d_bounds.find(lhs);
41 [ + + ]: 428187 : if (it == d_bounds.end())
42 : : {
43 : 193522 : it = d_bounds.emplace(lhs, Bounds()).first;
44 : : }
45 : 428187 : return it->second;
46 : : }
47 : 40 : Bounds BoundInference::get(const Node& lhs) const
48 : : {
49 : 40 : auto it = d_bounds.find(lhs);
50 [ + + ]: 40 : if (it == d_bounds.end())
51 : : {
52 : 25 : return Bounds();
53 : : }
54 : 15 : return it->second;
55 : : }
56 : :
57 : 12256 : const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
58 : 498984 : bool BoundInference::add(const Node& n, bool onlyVariables)
59 : : {
60 : 997968 : Node tmp = rewrite(n);
61 [ - + ]: 498984 : if (tmp.getKind() == Kind::CONST_BOOLEAN)
62 : : {
63 : 0 : return false;
64 : : }
65 : : // Parse the node as a comparison
66 : 997968 : auto comp = linear::Comparison::parseNormalForm(tmp);
67 : 997968 : auto dec = comp.decompose(true);
68 [ + + ][ + + ]: 498984 : if (onlyVariables && !std::get<0>(dec).isVariable())
[ + + ]
69 : : {
70 : 34 : return false;
71 : : }
72 : :
73 : 997900 : Node lhs = std::get<0>(dec).getNode();
74 : 498950 : Kind relation = std::get<1>(dec);
75 [ + + ]: 498950 : if (relation == Kind::DISTINCT) return false;
76 : 389469 : Node bound = std::get<2>(dec).getNode();
77 : : // has the form lhs ~relation~ bound
78 : :
79 [ + + ]: 389469 : if (lhs.getType().isInteger())
80 : : {
81 : 238245 : Rational br = bound.getConst<Rational>();
82 : 238245 : auto* nm = nodeManager();
83 [ - + ][ - + ]: 238245 : switch (relation)
[ + ]
84 : : {
85 : 0 : case Kind::LEQ: bound = nm->mkConstInt(br.floor()); break;
86 : 79001 : case Kind::LT:
87 : 79001 : bound = nm->mkConstInt((br - 1).ceiling());
88 : 79001 : relation = Kind::LEQ;
89 : 79001 : break;
90 : 0 : case Kind::GT:
91 : 0 : bound = nm->mkConstInt((br + 1).floor());
92 : 0 : relation = Kind::GEQ;
93 : 0 : break;
94 : 133326 : case Kind::GEQ: bound = nm->mkConstInt(br.ceiling()); break;
95 : 25918 : default:
96 : : // always ensure integer
97 : 25918 : bound = nm->mkConstInt(br);
98 : 25918 : break;
99 : : }
100 [ + - ]: 476490 : Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
101 : 238245 : << relation << " " << bound << std::endl;
102 : : }
103 : :
104 [ + + ][ + + ]: 389469 : switch (relation)
[ + - ]
105 : : {
106 : 101877 : case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
107 : 42599 : case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
108 : 38718 : case Kind::EQUAL:
109 : 38718 : update_lower_bound(n, lhs, bound, false);
110 : 38718 : update_upper_bound(n, lhs, bound, false);
111 : 38718 : break;
112 : 30112 : case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
113 : 176163 : case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
114 : 0 : default: Assert(false);
115 : : }
116 : 389469 : return true;
117 : : }
118 : :
119 : 0 : void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
120 : : {
121 : 0 : std::vector<Node> toAdd;
122 [ - - ]: 0 : for (auto& n : nodes)
123 : : {
124 [ - - ]: 0 : for (const auto& b : d_bounds)
125 : : {
126 : 0 : if (n == b.second.lower_bound && n == b.second.upper_bound)
127 : : {
128 : 0 : if (n != b.second.lower_origin && n != b.second.upper_origin)
129 : : {
130 [ - - ]: 0 : Trace("bound-inf")
131 : 0 : << "Replace " << n << " by origins " << b.second.lower_origin
132 : 0 : << " and " << b.second.upper_origin << std::endl;
133 : 0 : n = b.second.lower_origin;
134 : 0 : toAdd.emplace_back(b.second.upper_origin);
135 : : }
136 : : }
137 [ - - ]: 0 : else if (n == b.second.lower_bound)
138 : : {
139 [ - - ]: 0 : if (n != b.second.lower_origin)
140 : : {
141 [ - - ]: 0 : Trace("bound-inf") << "Replace " << n << " by origin "
142 : 0 : << b.second.lower_origin << std::endl;
143 : 0 : n = b.second.lower_origin;
144 : : }
145 : : }
146 [ - - ]: 0 : else if (n == b.second.upper_bound)
147 : : {
148 [ - - ]: 0 : if (n != b.second.upper_origin)
149 : : {
150 [ - - ]: 0 : Trace("bound-inf") << "Replace " << n << " by origin "
151 : 0 : << b.second.upper_origin << std::endl;
152 : 0 : n = b.second.upper_origin;
153 : : }
154 : : }
155 : : }
156 : : }
157 : 0 : nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
158 : 0 : }
159 : :
160 : 244993 : void BoundInference::update_lower_bound(const Node& origin,
161 : : const Node& lhs,
162 : : const Node& value,
163 : : bool strict)
164 : : {
165 [ - + ][ - + ]: 244993 : Assert(value.isConst());
[ - - ]
166 : : // lhs > or >= value because of origin
167 [ + - ][ - - ]: 489986 : Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
168 : 244993 : << " due to " << origin << std::endl;
169 : 244993 : Bounds& b = get_or_add(lhs);
170 : 244993 : if (b.lower_value.isNull()
171 [ + + ][ + + ]: 244993 : || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
[ + + ]
172 : : {
173 : 155300 : auto* nm = nodeManager();
174 : 155300 : b.lower_value = value;
175 : 155300 : b.lower_strict = strict;
176 : :
177 : 155300 : b.lower_origin = origin;
178 : :
179 [ + + ][ + + ]: 155300 : if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
[ + + ][ + + ]
180 : : {
181 : 2768 : Node eq = mkEquality(lhs, value);
182 : 2768 : b.lower_bound = b.upper_bound = rewrite(eq);
183 : : }
184 : : else
185 : : {
186 : : b.lower_bound =
187 [ + + ]: 152532 : rewrite(nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
188 : : }
189 : : }
190 [ + + ][ + + ]: 89693 : else if (strict && b.lower_value == value)
[ + + ]
191 : : {
192 : 4318 : auto* nm = nodeManager();
193 : 4318 : b.lower_strict = strict;
194 : 4318 : b.lower_bound = rewrite(nm->mkNode(Kind::GT, lhs, value));
195 : 4318 : b.lower_origin = origin;
196 : : }
197 : 244993 : }
198 : 183194 : void BoundInference::update_upper_bound(const Node& origin,
199 : : const Node& lhs,
200 : : const Node& value,
201 : : bool strict)
202 : : {
203 : : // lhs < or <= value because of origin
204 [ + - ][ - - ]: 366388 : Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
205 : 183194 : << " due to " << origin << std::endl;
206 : 183194 : Bounds& b = get_or_add(lhs);
207 : 183194 : if (b.upper_value.isNull()
208 [ + + ][ + + ]: 183194 : || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
[ + + ]
209 : : {
210 : 126022 : auto* nm = nodeManager();
211 : 126022 : b.upper_value = value;
212 : 126022 : b.upper_strict = strict;
213 : 126022 : b.upper_origin = origin;
214 [ + + ][ + + ]: 126022 : if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
[ + + ][ + + ]
215 : : {
216 : 37231 : Node eq = mkEquality(lhs, value);
217 : 37231 : b.lower_bound = b.upper_bound = rewrite(eq);
218 : : }
219 : : else
220 : : {
221 : : b.upper_bound =
222 [ + + ]: 88791 : rewrite(nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
223 : : }
224 : : }
225 [ + + ][ + + ]: 57172 : else if (strict && b.upper_value == value)
[ + + ]
226 : : {
227 : 2907 : auto* nm = nodeManager();
228 : 2907 : b.upper_strict = strict;
229 : 2907 : b.upper_bound = rewrite(nm->mkNode(Kind::LT, lhs, value));
230 : 2907 : b.upper_origin = origin;
231 : : }
232 : 183194 : }
233 : :
234 : 0 : std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
235 : : {
236 : 0 : os << "Bounds:" << std::endl;
237 [ - - ]: 0 : for (const auto& b : bi.get())
238 : : {
239 : 0 : os << "\t" << b.first << " -> " << b.second.lower_value << ".."
240 : 0 : << b.second.upper_value << std::endl;
241 : : }
242 : 0 : return os;
243 : : }
244 : :
245 : : } // namespace arith
246 : : } // namespace theory
247 : : } // namespace cvc5::internal
|