Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Simplifications for ITE expressions.
11 : : *
12 : : * This module implements preprocessing phases designed to simplify ITE
13 : : * expressions. Based on:
14 : : * Kim, Somenzi, Jin. Efficient Term-ITE Conversion for SMT. FMCAD 2009.
15 : : * Burch, Jerry. Techniques for Verifying Superscalar Microprocessors. DAC
16 : : *'96
17 : : */
18 : :
19 : : #include "cvc5_private.h"
20 : :
21 : : #ifndef CVC5__ITE_UTILITIES_H
22 : : #define CVC5__ITE_UTILITIES_H
23 : :
24 : : #include <unordered_map>
25 : : #include <vector>
26 : :
27 : : #include "expr/node.h"
28 : : #include "smt/env_obj.h"
29 : : #include "util/hash.h"
30 : : #include "util/statistics_stats.h"
31 : :
32 : : namespace cvc5::internal {
33 : :
34 : : namespace preprocessing {
35 : :
36 : : class AssertionPipeline;
37 : :
38 : : namespace util {
39 : :
40 : : class ITECompressor;
41 : : class ITESimplifier;
42 : : class ITECareSimplifier;
43 : :
44 : : /**
45 : : * A caching visitor that computes whether a node contains a term ite.
46 : : */
47 : : class ContainsTermITEVisitor
48 : : {
49 : : public:
50 : : ContainsTermITEVisitor();
51 : : ~ContainsTermITEVisitor();
52 : :
53 : : /** returns true if a node contains a term ite. */
54 : : bool containsTermITE(TNode n);
55 : :
56 : : /** Garbage collects the cache. */
57 : : void garbageCollect();
58 : :
59 : : /** returns the size of the cache. */
60 : : size_t cache_size() const { return d_cache.size(); }
61 : :
62 : : private:
63 : : typedef std::unordered_map<Node, bool> NodeBoolMap;
64 : : NodeBoolMap d_cache;
65 : : };
66 : :
67 : : class ITEUtilities : protected EnvObj
68 : : {
69 : : public:
70 : : ITEUtilities(Env& env);
71 : : ~ITEUtilities();
72 : :
73 : : Node simpITE(TNode assertion);
74 : :
75 : : bool simpIteDidALotOfWorkHeuristic() const;
76 : :
77 : : /* returns false if an assertion is discovered to be equal to false. */
78 : : bool compress(AssertionPipeline* assertionsToPreprocess);
79 : :
80 : : Node simplifyWithCare(TNode e);
81 : :
82 : : void clear();
83 : :
84 : 4 : ContainsTermITEVisitor* getContainsVisitor()
85 : : {
86 : 4 : return d_containsVisitor.get();
87 : : }
88 : :
89 : 12 : bool containsTermITE(TNode n)
90 : : {
91 : 12 : return d_containsVisitor->containsTermITE(n);
92 : : }
93 : :
94 : : private:
95 : : std::unique_ptr<ContainsTermITEVisitor> d_containsVisitor;
96 : : ITECompressor* d_compressor;
97 : : ITESimplifier* d_simplifier;
98 : : ITECareSimplifier* d_careSimp;
99 : : };
100 : :
101 : : class IncomingArcCounter
102 : : {
103 : : public:
104 : : IncomingArcCounter(bool skipVars = false, bool skipConstants = false);
105 : : ~IncomingArcCounter();
106 : : void computeReachability(const std::vector<Node>& assertions);
107 : :
108 : 7109 : inline uint32_t lookupIncoming(Node n) const
109 : : {
110 : 7109 : NodeCountMap::const_iterator it = d_reachCount.find(n);
111 [ - + ]: 7109 : if (it == d_reachCount.end())
112 : : {
113 : 0 : return 0;
114 : : }
115 : : else
116 : : {
117 : 7109 : return (*it).second;
118 : : }
119 : : }
120 : : void clear();
121 : :
122 : : private:
123 : : typedef std::unordered_map<Node, uint32_t> NodeCountMap;
124 : : NodeCountMap d_reachCount;
125 : :
126 : : bool d_skipVariables;
127 : : bool d_skipConstants;
128 : : };
129 : :
130 : : class TermITEHeightCounter
131 : : {
132 : : public:
133 : : TermITEHeightCounter();
134 : : ~TermITEHeightCounter();
135 : :
136 : : /**
137 : : * Compute and [potentially] cache the termITEHeight() of e.
138 : : * The term ite height equals the maximum number of term ites
139 : : * encountered on any path from e to a leaf.
140 : : * Inductively:
141 : : * - termITEHeight(leaves) = 0
142 : : * - termITEHeight(e: term-ite(c, t, e) ) =
143 : : * 1 + max(termITEHeight(t), termITEHeight(e)) ; Don't include c
144 : : * - termITEHeight(e not term ite) = max_{c in children(e))
145 : : * (termITEHeight(c))
146 : : */
147 : : uint32_t termITEHeight(TNode e);
148 : :
149 : : /** Clear the cache. */
150 : : void clear();
151 : :
152 : : /** Size of the cache. */
153 : : size_t cache_size() const;
154 : :
155 : : private:
156 : : typedef std::unordered_map<Node, uint32_t> NodeCountMap;
157 : : NodeCountMap d_termITEHeight;
158 : : }; /* class TermITEHeightCounter */
159 : :
160 : : /**
161 : : * A routine designed to undo the potentially large blow up
162 : : * due to expansion caused by the ite simplifier.
163 : : */
164 : : class ITECompressor : protected EnvObj
165 : : {
166 : : public:
167 : : ITECompressor(Env& env, ContainsTermITEVisitor* contains);
168 : : ~ITECompressor();
169 : :
170 : : /* returns false if an assertion is discovered to be equal to false. */
171 : : bool compress(AssertionPipeline* assertionsToPreprocess);
172 : :
173 : : /* garbage Collects the compressor. */
174 : : void garbageCollect();
175 : :
176 : : private:
177 : : class Statistics
178 : : {
179 : : public:
180 : : IntStat d_compressCalls;
181 : : IntStat d_skolemsAdded;
182 : : Statistics(StatisticsRegistry& reg);
183 : : };
184 : :
185 : : void reset();
186 : :
187 : : Node push_back_boolean(Node original, Node compressed);
188 : : bool multipleParents(TNode c);
189 : : Node compressBooleanITEs(Node toCompress);
190 : : Node compressTerm(Node toCompress);
191 : : Node compressBoolean(Node toCompress);
192 : :
193 : : Node d_true; /* Copy of true. */
194 : : Node d_false; /* Copy of false. */
195 : :
196 : : CVC5_UNUSED_FIELD ContainsTermITEVisitor* d_contains; // Only used in DEBUG
197 : : AssertionPipeline* d_assertions;
198 : : IncomingArcCounter d_incoming;
199 : :
200 : : typedef std::unordered_map<Node, Node> NodeMap;
201 : : NodeMap d_compressed;
202 : :
203 : : Statistics d_statistics;
204 : : }; /* class ITECompressor */
205 : :
206 : : class ITESimplifier : protected EnvObj
207 : : {
208 : : public:
209 : : ITESimplifier(Env& env, ContainsTermITEVisitor* d_containsVisitor);
210 : : ~ITESimplifier();
211 : :
212 : : Node simpITE(TNode assertion);
213 : :
214 : : bool doneALotOfWorkHeuristic() const;
215 : : void clearSimpITECaches();
216 : :
217 : : private:
218 : : using NodeVec = std::vector<Node>;
219 : : using ConstantLeavesMap = std::unordered_map<Node, NodeVec*>;
220 : : using NodePair = std::pair<Node, Node>;
221 : : using NodePairHashFunction =
222 : : PairHashFunction<Node, Node, std::hash<Node>, std::hash<Node>>;
223 : : using NodePairMap = std::unordered_map<NodePair, Node, NodePairHashFunction>;
224 : :
225 : : class Statistics
226 : : {
227 : : public:
228 : : IntStat d_maxNonConstantsFolded;
229 : : IntStat d_unexpected;
230 : : IntStat d_unsimplified;
231 : : IntStat d_exactMatchFold;
232 : : IntStat d_binaryPredFold;
233 : : IntStat d_specialEqualityFolds;
234 : : IntStat d_simpITEVisits;
235 : : unsigned d_numBranches;
236 : : unsigned d_numFalseBranches;
237 : : unsigned d_itesMade;
238 : : unsigned d_instance;
239 : :
240 : : HistogramStat<uint32_t> d_inSmaller;
241 : :
242 : : Statistics(StatisticsRegistry& reg);
243 : : };
244 : :
245 : 2031 : inline bool containsTermITE(TNode n)
246 : : {
247 : 2031 : return d_containsVisitor->containsTermITE(n);
248 : : }
249 : :
250 : : inline uint32_t termITEHeight(TNode e)
251 : : {
252 : : return d_termITEHeight.termITEHeight(e);
253 : : }
254 : :
255 : : // d_constantLeaves satisfies the following invariants:
256 : : // not containsTermITE(x) then !isKey(x)
257 : : // containsTermITE(x):
258 : : // - not isKey(x) then this value is uncomputed
259 : : // - d_constantLeaves[x] == NULL, then this contains a non-constant leaf
260 : : // - d_constantLeaves[x] != NULL, then this contains a sorted list of constant
261 : : // leaf
262 : : bool isConstantIte(TNode e);
263 : :
264 : : /** If its not a constant and containsTermITE(ite),
265 : : * returns a sorted NodeVec of the leaves. */
266 : : NodeVec* computeConstantLeaves(TNode ite);
267 : :
268 : : /* transforms */
269 : : Node transformAtom(TNode atom);
270 : : Node attemptConstantRemoval(TNode atom);
271 : : Node attemptLiftEquality(TNode atom);
272 : : Node attemptEagerRemoval(TNode atom);
273 : :
274 : : // Given ConstantIte trees lcite and rcite,
275 : : // return a boolean expression equivalent to (= lcite rcite)
276 : : Node intersectConstantIte(TNode lcite, TNode rcite);
277 : :
278 : : // Given ConstantIte tree cite and a constant c,
279 : : // return a boolean expression equivalent to (= lcite c)
280 : : Node constantIteEqualsConstant(TNode cite, TNode c);
281 : :
282 : : Node replaceOver(Node n, Node replaceWith, Node simpVar);
283 : : Node replaceOverTermIte(Node term, Node simpAtom, Node simpVar);
284 : :
285 : : bool leavesAreConst(TNode e, theory::TheoryId tid);
286 : : bool leavesAreConst(TNode e);
287 : :
288 : : Node simpConstants(TNode simpContext, TNode iteNode, TNode simpVar);
289 : :
290 : : Node createSimpContext(TNode c, Node& iteNode, Node& simpVar);
291 : :
292 : : Node d_true;
293 : : Node d_false;
294 : :
295 : : ContainsTermITEVisitor* d_containsVisitor;
296 : :
297 : : TermITEHeightCounter d_termITEHeight;
298 : :
299 : : // ConstantIte is a small inductive sublanguage:
300 : : // constant
301 : : // or termITE(cnd, ConstantIte, ConstantIte)
302 : : ConstantLeavesMap d_constantLeaves;
303 : :
304 : : // Lists all of the vectors in d_constantLeaves for fast deletion.
305 : : std::vector<NodeVec*> d_allocatedConstantLeaves;
306 : :
307 : : uint32_t d_citeEqConstApplications;
308 : :
309 : : NodePairMap d_constantIteEqualsConstantCache;
310 : : NodePairMap d_replaceOverCache;
311 : : NodePairMap d_replaceOverTermIteCache;
312 : :
313 : : std::unordered_map<Node, bool> d_leavesConstCache;
314 : :
315 : : NodePairMap d_simpConstCache;
316 : : std::unordered_map<TypeNode, Node> d_simpVars;
317 : : Node getSimpVar(TypeNode t);
318 : :
319 : : typedef std::unordered_map<Node, Node> NodeMap;
320 : : NodeMap d_simpContextCache;
321 : :
322 : : NodeMap d_simpITECache;
323 : : Node simpITEAtom(TNode atom);
324 : :
325 : : Statistics d_statistics;
326 : : };
327 : :
328 : : class ITECareSimplifier
329 : : {
330 : : public:
331 : : ITECareSimplifier(NodeManager* nm);
332 : : ~ITECareSimplifier();
333 : :
334 : : Node simplifyWithCare(TNode e);
335 : :
336 : : void clear();
337 : :
338 : : private:
339 : : /**
340 : : * This should always equal the number of care sets allocated by
341 : : * this object - the number of these that have been deleted. This is
342 : : * initially 0 and should always be 0 at the *start* of
343 : : * ~ITECareSimplifier().
344 : : */
345 : : unsigned d_careSetsOutstanding;
346 : :
347 : : Node d_true;
348 : : Node d_false;
349 : :
350 : : typedef std::unordered_map<TNode, Node> TNodeMap;
351 : :
352 : : class CareSetPtr;
353 : : class CareSetPtrVal
354 : : {
355 : : public:
356 : 0 : bool safeToGarbageCollect() const { return d_refCount == 0; }
357 : :
358 : : private:
359 : : friend class ITECareSimplifier::CareSetPtr;
360 : : ITECareSimplifier& d_iteSimplifier;
361 : : unsigned d_refCount;
362 : : std::set<Node> d_careSet;
363 : 0 : CareSetPtrVal(ITECareSimplifier& simp)
364 : 0 : : d_iteSimplifier(simp), d_refCount(1)
365 : : {
366 : 0 : }
367 : : }; /* class ITECareSimplifier::CareSetPtrVal */
368 : :
369 : : std::vector<CareSetPtrVal*> d_usedSets;
370 : 0 : void careSetPtrGC(CareSetPtrVal* val) { d_usedSets.push_back(val); }
371 : :
372 : : class CareSetPtr
373 : : {
374 : : CareSetPtrVal* d_val;
375 : 0 : CareSetPtr(CareSetPtrVal* val) : d_val(val) {}
376 : :
377 : : public:
378 : 0 : CareSetPtr() : d_val(NULL) {}
379 : 0 : CareSetPtr(const CareSetPtr& cs)
380 : 0 : {
381 : 0 : d_val = cs.d_val;
382 [ - - ]: 0 : if (d_val != NULL)
383 : : {
384 : 0 : ++(d_val->d_refCount);
385 : : }
386 : 0 : }
387 : 0 : ~CareSetPtr()
388 : : {
389 [ - - ][ - - ]: 0 : if (d_val != NULL && (--(d_val->d_refCount) == 0))
[ - - ]
390 : : {
391 : 0 : d_val->d_iteSimplifier.careSetPtrGC(d_val);
392 : : }
393 : 0 : }
394 : 0 : CareSetPtr& operator=(const CareSetPtr& cs)
395 : : {
396 [ - - ]: 0 : if (d_val != cs.d_val)
397 : : {
398 [ - - ][ - - ]: 0 : if (d_val != NULL && (--(d_val->d_refCount) == 0))
[ - - ]
399 : : {
400 : 0 : d_val->d_iteSimplifier.careSetPtrGC(d_val);
401 : : }
402 : 0 : d_val = cs.d_val;
403 [ - - ]: 0 : if (d_val != NULL)
404 : : {
405 : 0 : ++(d_val->d_refCount);
406 : : }
407 : : }
408 : 0 : return *this;
409 : : }
410 : 0 : std::set<Node>& getCareSet() { return d_val->d_careSet; }
411 : :
412 : : static CareSetPtr mkNew(ITECareSimplifier& simp);
413 : 0 : static CareSetPtr recycle(CareSetPtrVal* val)
414 : : {
415 : 0 : Assert(val != NULL && val->d_refCount == 0);
416 : 0 : val->d_refCount = 1;
417 : 0 : return CareSetPtr(val);
418 : : }
419 : : }; /* class ITECareSimplifier::CareSetPtr */
420 : :
421 : : CareSetPtr getNewSet();
422 : :
423 : : typedef std::map<TNode, CareSetPtr> CareMap;
424 : : void updateQueue(CareMap& queue, TNode e, CareSetPtr& careSet);
425 : : Node substitute(TNode e, TNodeMap& substTable, TNodeMap& cache);
426 : : };
427 : :
428 : : } // namespace util
429 : : } // namespace preprocessing
430 : : } // namespace cvc5::internal
431 : :
432 : : #endif
|