Branch data Line data Source code
1 : : /******************************************************************************
2 : : * This file is part of the cvc5 project.
3 : : *
4 : : * Copyright (c) 2009-2026 by the authors listed in the file AUTHORS
5 : : * in the top-level source directory and their institutional affiliations.
6 : : * All rights reserved. See the file COPYING in the top-level source
7 : : * directory for licensing information.
8 : : * ****************************************************************************
9 : : *
10 : : * Implements a n-ary match trie
11 : : */
12 : :
13 : : #include "expr/nary_match_trie.h"
14 : :
15 : : #include <sstream>
16 : :
17 : : #include "expr/nary_term_util.h"
18 : :
19 : : using namespace cvc5::internal::kind;
20 : :
21 : : namespace cvc5::internal {
22 : : namespace expr {
23 : :
24 : : class NaryMatchFrame
25 : : {
26 : : public:
27 : 11812953 : NaryMatchFrame(const std::vector<Node>& syms, const NaryMatchTrie* t)
28 : 11812953 : : d_syms(syms), d_trie(t), d_index(0), d_variant(0), d_boundVar(false)
29 : : {
30 : 11812953 : }
31 : : /** Symbols to match */
32 : : std::vector<Node> d_syms;
33 : : /** The match trie */
34 : : const NaryMatchTrie* d_trie;
35 : : /** The index we are considering, 0 = operator, n>0 = variable # (n-1) */
36 : : size_t d_index;
37 : : /** List length considering */
38 : : size_t d_variant;
39 : : /** Whether we just bound a variable */
40 : : bool d_boundVar;
41 : : };
42 : :
43 : 1400143 : bool NaryMatchTrie::getMatches(Node n, NotifyMatch* ntm) const
44 : : {
45 : 1400143 : NodeManager* nm = n.getNodeManager();
46 : 1400143 : std::vector<Node> vars;
47 : 1400143 : std::vector<Node> subs;
48 : 1400143 : std::map<Node, Node> smap;
49 : :
50 : 1400143 : std::map<Node, NaryMatchTrie>::const_iterator itc;
51 : :
52 : 1400143 : std::vector<NaryMatchFrame> visit;
53 : 2800286 : visit.push_back(NaryMatchFrame({n}, this));
54 : :
55 [ + + ]: 29129444 : while (!visit.empty())
56 : : {
57 : 28095894 : NaryMatchFrame& curr = visit.back();
58 : : // currently, copy the symbols from previous frame TODO: improve?
59 : 28095894 : std::vector<Node> syms = curr.d_syms;
60 : 28095894 : const NaryMatchTrie* mt = curr.d_trie;
61 [ + + ]: 28095894 : if (syms.empty())
62 : : {
63 : : // if we matched, there must be a data member at this node
64 [ - + ][ - + ]: 1603384 : Assert(!mt->d_data.isNull());
[ - - ]
65 : : // notify match?
66 [ - + ][ - + ]: 1603384 : Assert(n == expr::narySubstitute(mt->d_data, vars, subs));
[ - - ]
67 [ + - ]: 1603384 : Trace("match-debug") << "notify : " << mt->d_data << std::endl;
68 [ + + ]: 1603384 : if (!ntm->notify(n, mt->d_data, vars, subs))
69 : : {
70 : 366593 : return false;
71 : : }
72 : 1236791 : visit.pop_back();
73 : 1236791 : continue;
74 : : }
75 : :
76 : : // clean up if we previously bound a variable
77 [ + + ]: 26492510 : if (curr.d_boundVar)
78 : : {
79 [ - + ][ - + ]: 6729876 : Assert(!vars.empty());
[ - - ]
80 [ - + ][ - + ]: 6729876 : Assert(smap.find(vars.back()) != smap.end());
[ - - ]
81 : 6729876 : smap.erase(vars.back());
82 : 6729876 : vars.pop_back();
83 : 6729876 : subs.pop_back();
84 : 6729876 : curr.d_boundVar = false;
85 : : }
86 : :
87 [ + + ]: 26492510 : if (curr.d_index == 0)
88 : : {
89 : 10209569 : curr.d_index++;
90 : : // finished matching variables, try to match the operator
91 : 10209569 : Node next = syms.back();
92 : : Node op =
93 [ + + ][ + + ]: 10209569 : (!next.isNull() && next.hasOperator()) ? next.getOperator() : next;
94 : 10209569 : itc = mt->d_children.find(op);
95 [ + + ]: 10209569 : if (itc != mt->d_children.end())
96 : : {
97 : 2833376 : syms.pop_back();
98 : : // push the children + null termination marker, in reverse order
99 [ + + ]: 2833376 : if (NodeManager::isNAryKind(next.getKind()))
100 : : {
101 : 369675 : syms.push_back(Node::null());
102 : : }
103 [ + + ]: 2833376 : if (next.hasOperator())
104 : : {
105 : 2349133 : syms.insert(syms.end(), next.rbegin(), next.rend());
106 : : }
107 : : // new frame
108 : 2833376 : visit.push_back(NaryMatchFrame(syms, &itc->second));
109 : : }
110 : 10209569 : }
111 [ + + ]: 16282941 : else if (curr.d_index <= mt->d_vars.size())
112 : : {
113 : : // try to match the next (variable, length)
114 : 9402477 : Node var;
115 : 9402477 : Node next;
116 : : do
117 : : {
118 : 12889536 : var = mt->d_vars[curr.d_index - 1];
119 [ - + ][ - + ]: 12889536 : Assert(mt->d_children.find(var) != mt->d_children.end());
[ - - ]
120 : 12889536 : std::vector<Node> currChildren;
121 [ + + ]: 12889536 : if (isListVar(var))
122 : : {
123 : : // get the length of the list we want to consider
124 : 3914177 : size_t l = curr.d_variant;
125 : 3914177 : curr.d_variant++;
126 : : // match with l, or increment d_index otherwise
127 : 3914177 : bool foundChildren = true;
128 : : // We are in a state where the children of an n-ary child
129 : : // have been pused to syms. We try to extract l children here. If
130 : : // we encounter the null symbol, then we do not have sufficient
131 : : // children to match for this variant and fail.
132 [ + + ]: 75077774 : for (size_t i = 0; i < l; i++)
133 : : {
134 [ - + ][ - + ]: 71798380 : Assert(!syms.empty());
[ - - ]
135 : 71798380 : Node s = syms.back();
136 : : // we currently reject the term if it does not have the same
137 : : // type as the list variable. This rejects certain corner cases of
138 : : // arithmetic operators which are permissive for subtyping.
139 : : // For example, if x is a list variable of type Real, y is a list
140 : : // variable of type Real, then (+ x y) does *not* match
141 : : // (+ 1.0 2 1.5), despite { x -> (+ 1.0 2), y -> 1.5 } being
142 : : // a well-typed match.
143 : 71798380 : if (s.isNull() || !s.getType().isComparableTo(var.getType()))
144 : : {
145 : 634783 : foundChildren = false;
146 : 634783 : break;
147 : : }
148 : 71163597 : currChildren.push_back(s);
149 : 71163597 : syms.pop_back();
150 [ + + ]: 71798380 : }
151 [ + + ]: 3914177 : if (foundChildren)
152 : : {
153 : : // we are matching the next list
154 : 3279394 : next = nm->mkNode(Kind::SEXPR, currChildren);
155 : : }
156 : : else
157 : : {
158 : : // otherwise, we have run out of variants, go to next variable
159 : 634783 : curr.d_index++;
160 : 634783 : curr.d_variant = 0;
161 : : }
162 : : }
163 : : else
164 : : {
165 : 8975359 : next = syms.back();
166 : 8975359 : curr.d_index++;
167 : : // we could be at the end of an n-ary operator, in which case we
168 : : // do not match
169 [ + + ]: 8975359 : if (!next.isNull())
170 : : {
171 : 8747931 : currChildren.push_back(next);
172 : 8747931 : syms.pop_back();
173 [ + - ]: 17495862 : Trace("match-debug")
174 : 0 : << "Compare types " << var << " " << next << " "
175 : 8747931 : << var.getType() << " " << next.getType() << std::endl;
176 : : // check types in the (non-list) case
177 [ + + ]: 8747931 : if (!var.getType().isComparableTo(next.getType()))
178 : : {
179 [ + - ]: 3135841 : Trace("match-debug") << "...fail" << std::endl;
180 : 3135841 : next = Node::null();
181 : : }
182 : : }
183 : : }
184 [ + + ]: 12889536 : if (!next.isNull())
185 : : {
186 : : // check if it is already bound, do the binding if necessary
187 : 8891484 : std::map<Node, Node>::iterator its = smap.find(var);
188 [ + + ]: 8891484 : if (its != smap.end())
189 : : {
190 [ + + ]: 1371541 : if (its->second != next)
191 : : {
192 : : // failed to match
193 : 1312050 : next = Node::null();
194 : : }
195 : : // otherwise, successfully matched, nothing to do
196 : : }
197 : : else
198 : : {
199 : : // add to binding
200 [ + - ]: 15039886 : Trace("match-debug")
201 : 7519943 : << "Set " << var << " -> " << next << std::endl;
202 : 7519943 : vars.push_back(var);
203 : 7519943 : subs.push_back(next);
204 : 7519943 : smap[var] = next;
205 : 7519943 : curr.d_boundVar = true;
206 : : }
207 : : }
208 [ + + ]: 12889536 : if (next.isNull())
209 : : {
210 : : // if we failed, revert changes to syms
211 : 5310102 : syms.insert(syms.end(), currChildren.rbegin(), currChildren.rend());
212 : : }
213 [ + + ][ + + ]: 12889536 : } while (next.isNull() && curr.d_index <= mt->d_vars.size());
[ + + ]
214 [ + + ]: 9402477 : if (next.isNull())
215 : : {
216 : : // we are out of variables to match, finished with this frame
217 : 1823043 : visit.pop_back();
218 : 1823043 : continue;
219 : : }
220 [ + - ]: 7579434 : Trace("match-debug") << "recurse var : " << var << std::endl;
221 : 7579434 : itc = mt->d_children.find(var);
222 [ - + ][ - + ]: 7579434 : Assert(itc != mt->d_children.end());
[ - - ]
223 : 7579434 : visit.push_back(NaryMatchFrame(syms, &itc->second));
224 [ + + ][ + + ]: 11225520 : }
225 : : else
226 : : {
227 : : // no variables to match, we are done
228 : 6880464 : visit.pop_back();
229 : : }
230 [ + + ][ + ]: 28095894 : }
231 : 1033550 : return true;
232 : 1400143 : }
233 : :
234 : 4041765 : void NaryMatchTrie::addTerm(Node n)
235 : : {
236 [ - + ][ - + ]: 4041765 : Assert(!n.isNull());
[ - - ]
237 : 4041765 : std::vector<Node> visit;
238 : 4041765 : visit.push_back(n);
239 : 4041765 : NaryMatchTrie* curr = this;
240 [ + + ]: 25760700 : while (!visit.empty())
241 : : {
242 : 21718935 : Node cn = visit.back();
243 : 21718935 : visit.pop_back();
244 [ + + ]: 21718935 : if (cn.isNull())
245 : : {
246 : 1723302 : curr = &(curr->d_children[cn]);
247 : : }
248 [ + + ]: 19995633 : else if (cn.hasOperator())
249 : : {
250 : 7399539 : curr = &(curr->d_children[cn.getOperator()]);
251 : : // add null terminator if an n-ary kind
252 [ + + ]: 7399539 : if (NodeManager::isNAryKind(cn.getKind()))
253 : : {
254 : 1723302 : visit.push_back(Node::null());
255 : : }
256 : : // note children are processed left to right
257 : 7399539 : visit.insert(visit.end(), cn.rbegin(), cn.rend());
258 : : }
259 : : else
260 : : {
261 : 12596094 : if (cn.isVar()
262 [ + + ][ + + ]: 36331470 : && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
263 [ + + ]: 36331470 : == curr->d_vars.end())
264 : : {
265 : 8199009 : curr->d_vars.push_back(cn);
266 : : }
267 : 12596094 : curr = &(curr->d_children[cn]);
268 : : }
269 : 21718935 : }
270 : 4041765 : curr->d_data = n;
271 : 4041765 : }
272 : :
273 : 0 : void NaryMatchTrie::clear()
274 : : {
275 : 0 : d_children.clear();
276 : 0 : d_vars.clear();
277 : 0 : d_data = Node::null();
278 : 0 : }
279 : :
280 : 0 : std::string NaryMatchTrie::debugPrint() const
281 : : {
282 : 0 : std::stringstream ss;
283 : 0 : std::vector<std::tuple<const NaryMatchTrie*, size_t, Node>> visit;
284 : 0 : visit.emplace_back(this, 0, Node::null());
285 : : do
286 : : {
287 : 0 : std::tuple<const NaryMatchTrie*, size_t, Node> curr = visit.back();
288 : 0 : visit.pop_back();
289 : 0 : size_t indent = std::get<1>(curr);
290 [ - - ]: 0 : for (size_t i = 0; i < indent; i++)
291 : : {
292 : 0 : ss << " ";
293 : : }
294 : 0 : Node n = std::get<2>(curr);
295 [ - - ]: 0 : if (indent == 0)
296 : : {
297 : 0 : ss << ".";
298 : : }
299 : : else
300 : : {
301 : 0 : ss << n;
302 : : }
303 : 0 : ss << ((!n.isNull() && isListVar(n)) ? " [*]" : "");
304 : 0 : ss << std::endl;
305 : 0 : const NaryMatchTrie* mt = std::get<0>(curr);
306 [ - - ]: 0 : for (const std::pair<const Node, NaryMatchTrie>& c : mt->d_children)
307 : : {
308 : 0 : visit.emplace_back(&c.second, indent + 1, c.first);
309 : : }
310 [ - - ]: 0 : } while (!visit.empty());
311 : 0 : return ss.str();
312 : 0 : }
313 : :
314 : : } // namespace expr
315 : : } // namespace cvc5::internal
|