Branch data Line data Source code
1 : : /******************************************************************************
2 : : * Top contributors (to current version):
3 : : * Tim King, Morgan Deters, Kshitij Bansal
4 : : *
5 : : * This file is part of the cvc5 project.
6 : : *
7 : : * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
8 : : * in the top-level source directory and their institutional affiliations.
9 : : * All rights reserved. See the file COPYING in the top-level source
10 : : * directory for licensing information.
11 : : * ****************************************************************************
12 : : *
13 : : * [[ Add one-line brief description here ]]
14 : : *
15 : : * [[ Add lengthier description here ]]
16 : : * \todo document this file
17 : : */
18 : :
19 : : #include "cvc5_private.h"
20 : :
21 : : #pragma once
22 : :
23 : : #include <map>
24 : : #include <memory>
25 : : #include <set>
26 : : #include <unordered_map>
27 : :
28 : : #include "expr/kind.h"
29 : : #include "theory/arith/linear/arithvar.h"
30 : : #include "theory/arith/linear/constraint_forward.h"
31 : : #include "util/dense_map.h"
32 : :
33 : : namespace cvc5::internal {
34 : : namespace theory {
35 : : namespace arith::linear {
36 : :
37 : : /** A low level vector of indexed doubles. */
38 : : struct PrimitiveVec {
39 : : int len;
40 : : int* inds;
41 : : double* coeffs;
42 : : PrimitiveVec();
43 : : ~PrimitiveVec();
44 : : bool initialized() const;
45 : : void clear();
46 : : void setup(int l);
47 : : void print(std::ostream& out) const;
48 : : };
49 : : std::ostream& operator<<(std::ostream& os, const PrimitiveVec& pv);
50 : :
51 : : struct DenseVector {
52 : : DenseMap<Rational> lhs;
53 : : Rational rhs;
54 : : void purge();
55 : : void print(std::ostream& os) const;
56 : :
57 : : static void print(std::ostream& os, const DenseMap<Rational>& lhs);
58 : : };
59 : :
60 : : /** The different kinds of cuts. */
61 : : enum CutInfoKlass{ MirCutKlass, GmiCutKlass, BranchCutKlass,
62 : : RowsDeletedKlass,
63 : : UnknownKlass};
64 : : std::ostream& operator<<(std::ostream& os, CutInfoKlass kl);
65 : :
66 : : /** A general class for describing a cut. */
67 : : class CutInfo {
68 : : protected:
69 : : CutInfoKlass d_klass;
70 : : int d_execOrd;
71 : :
72 : : int d_poolOrd; /* cut's ordinal in the current node pool */
73 : : Kind d_cutType; /* Lowerbound, upperbound or undefined. */
74 : : double d_cutRhs; /* right hand side of the cut */
75 : : PrimitiveVec d_cutVec; /* vector of the cut */
76 : :
77 : : /**
78 : : * The number of rows at the time the cut was made.
79 : : * This is required to descramble indices after the fact!
80 : : */
81 : : int d_mAtCreation;
82 : :
83 : : /** This is the number of structural variables. */
84 : : int d_N;
85 : :
86 : : /** if selected, make this non-zero */
87 : : int d_rowId;
88 : :
89 : : /* If the cut has been successfully created,
90 : : * the cut is stored in exact precision in d_exactPrecision.
91 : : * If the cut has not yet been proven, this is null.
92 : : */
93 : : std::unique_ptr<DenseVector> d_exactPrecision;
94 : :
95 : : std::unique_ptr<ConstraintCPVec> d_explanation;
96 : :
97 : : public:
98 : : CutInfo(CutInfoKlass kl, int cutid, int ordinal);
99 : :
100 : : virtual ~CutInfo();
101 : :
102 : : int getId() const;
103 : :
104 : : int getRowId() const;
105 : : void setRowId(int rid);
106 : :
107 : : void print(std::ostream& out) const;
108 : : //void init_cut(int l);
109 : : PrimitiveVec& getCutVector();
110 : : const PrimitiveVec& getCutVector() const;
111 : :
112 : : Kind getKind() const;
113 : : void setKind(Kind k);
114 : :
115 : :
116 : : void setRhs(double r);
117 : : double getRhs() const;
118 : :
119 : : CutInfoKlass getKlass() const;
120 : : int poolOrdinal() const;
121 : :
122 : : void setDimensions(int N, int M);
123 : : int getN() const;
124 : : int getMAtCreation() const;
125 : :
126 : : bool operator<(const CutInfo& o) const;
127 : :
128 : : /* Returns true if the cut was successfully made in exact precision.*/
129 : : bool reconstructed() const;
130 : :
131 : : /* Returns true if the cut has an explanation. */
132 : : bool proven() const;
133 : :
134 : : void setReconstruction(const DenseVector& ep);
135 : : void setExplanation(const ConstraintCPVec& ex);
136 : : void swapExplanation(ConstraintCPVec& ex);
137 : :
138 : : const DenseVector& getReconstruction() const;
139 : : const ConstraintCPVec& getExplanation() const;
140 : :
141 : : void clearReconstruction();
142 : : };
143 : : std::ostream& operator<<(std::ostream& os, const CutInfo& ci);
144 : :
145 : : class BranchCutInfo : public CutInfo {
146 : : public:
147 : : BranchCutInfo(int execOrd, int br, Kind dir, double val);
148 : : };
149 : :
150 : : class RowsDeleted : public CutInfo {
151 : : public:
152 : : RowsDeleted(int execOrd, int nrows, const int num[]);
153 : : };
154 : :
155 : : class TreeLog;
156 : :
157 : : class NodeLog {
158 : : private:
159 : : int d_nid;
160 : : NodeLog* d_parent; /* If null this is the root */
161 : : TreeLog* d_tl; /* TreeLog containing the node. */
162 : :
163 : : struct CmpCutPointer{
164 : 0 : int operator()(const CutInfo* a, const CutInfo* b) const{
165 : 0 : return *a < *b;
166 : : }
167 : : };
168 : : typedef std::set<CutInfo*, CmpCutPointer> CutSet;
169 : : CutSet d_cuts;
170 : : std::map<int, int> d_rowIdsSelected;
171 : :
172 : : enum Status {Open, Closed, Branched};
173 : : Status d_stat;
174 : :
175 : : int d_brVar; // branching variable
176 : : double d_brVal;
177 : : int d_downId;
178 : : int d_upId;
179 : :
180 : : public:
181 : : typedef std::unordered_map<int, ArithVar> RowIdMap;
182 : : private:
183 : : RowIdMap d_rowId2ArithVar;
184 : :
185 : : public:
186 : : NodeLog(); /* default constructor. */
187 : : NodeLog(TreeLog* tl, int node, const RowIdMap& m); /* makes a root node. */
188 : : NodeLog(TreeLog* tl, NodeLog* parent, int node);/* makes a non-root node. */
189 : :
190 : : ~NodeLog();
191 : :
192 : : int getNodeId() const;
193 : : void addSelected(int ord, int sel);
194 : : void applySelected();
195 : : void addCut(CutInfo* ci);
196 : : void print(std::ostream& o) const;
197 : :
198 : : bool isRoot() const;
199 : : const NodeLog& getParent() const;
200 : :
201 : : void copyParentRowIds();
202 : :
203 : : bool isBranch() const;
204 : : int branchVariable() const;
205 : : double branchValue() const;
206 : :
207 : : typedef CutSet::const_iterator const_iterator;
208 : : const_iterator begin() const;
209 : : const_iterator end() const;
210 : :
211 : : void setBranch(int br, double val, int dn, int up);
212 : : void closeNode();
213 : :
214 : : int getDownId() const;
215 : : int getUpId() const;
216 : :
217 : : /**
218 : : * Looks up a row id to the appropriate arith variable.
219 : : * Be careful these are deleted in context during replay!
220 : : * failure returns ARITHVAR_SENTINEL */
221 : : ArithVar lookupRowId(int rowId) const;
222 : :
223 : : /**
224 : : * Maps a row id to an arithvar.
225 : : * Be careful these are deleted in context during replay!
226 : : */
227 : : void mapRowId(int rowid, ArithVar v);
228 : : void applyRowsDeleted(const RowsDeleted& rd);
229 : :
230 : : };
231 : : std::ostream& operator<<(std::ostream& os, const NodeLog& nl);
232 : :
233 : : class TreeLog {
234 : : private:
235 : : int next_exec_ord;
236 : : typedef std::map<int, NodeLog> ToNodeMap;
237 : : ToNodeMap d_toNode;
238 : : DenseMultiset d_branches;
239 : :
240 : : uint32_t d_numCuts;
241 : :
242 : : bool d_active;
243 : :
244 : : public:
245 : : TreeLog();
246 : :
247 : : NodeLog& getNode(int nid);
248 : : void branch(int nid, int br, double val, int dn, int up);
249 : : void close(int nid);
250 : :
251 : : //void applySelected();
252 : : void print(std::ostream& o) const;
253 : :
254 : : typedef ToNodeMap::const_iterator const_iterator;
255 : : const_iterator begin() const;
256 : : const_iterator end() const;
257 : :
258 : : int getExecutionOrd();
259 : :
260 : : void reset(const NodeLog::RowIdMap& m);
261 : :
262 : : // Applies rd tp to the node with id nid
263 : : void applyRowsDeleted(int nid, const RowsDeleted& rd);
264 : :
265 : : // Synonym for getNode(nid).mapRowId(ind, v)
266 : : void mapRowId(int nid, int ind, ArithVar v);
267 : :
268 : : private:
269 : : void clear();
270 : :
271 : : public:
272 : : void makeInactive();
273 : : void makeActive();
274 : :
275 : : bool isActivelyLogging() const;
276 : :
277 : : void addCut();
278 : : uint32_t cutCount() const;
279 : :
280 : : void logBranch(uint32_t x);
281 : : uint32_t numBranches(uint32_t x);
282 : :
283 : : int getRootId() const;
284 : :
285 : : uint32_t numNodes() const{
286 : : return d_toNode.size();
287 : : }
288 : :
289 : : NodeLog& getRootNode();
290 : : void printBranchInfo(std::ostream& os) const;
291 : : };
292 : :
293 : : } // namespace arith
294 : : } // namespace theory
295 : : } // namespace cvc5::internal
|