xref: /aosp_15_r20/external/grpc-grpc/src/core/lib/avl/avl.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_AVL_AVL_H
16 #define GRPC_SRC_CORE_LIB_AVL_AVL_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <stdlib.h>
21 
22 #include <algorithm>  // IWYU pragma: keep
23 #include <iterator>
24 #include <utility>
25 
26 #include "src/core/lib/gpr/useful.h"
27 #include "src/core/lib/gprpp/ref_counted.h"
28 #include "src/core/lib/gprpp/ref_counted_ptr.h"
29 
30 namespace grpc_core {
31 
32 template <class K, class V = void>
33 class AVL {
34  public:
AVL()35   AVL() {}
36 
Add(K key,V value)37   AVL Add(K key, V value) const {
38     return AVL(AddKey(root_, std::move(key), std::move(value)));
39   }
40   template <typename SomethingLikeK>
Remove(const SomethingLikeK & key)41   AVL Remove(const SomethingLikeK& key) const {
42     return AVL(RemoveKey(root_, key));
43   }
44   template <typename SomethingLikeK>
Lookup(const SomethingLikeK & key)45   const V* Lookup(const SomethingLikeK& key) const {
46     NodePtr n = Get(root_, key);
47     return n != nullptr ? &n->kv.second : nullptr;
48   }
49 
LookupBelow(const K & key)50   const std::pair<K, V>* LookupBelow(const K& key) const {
51     NodePtr n = GetBelow(root_, *key);
52     return n != nullptr ? &n->kv : nullptr;
53   }
54 
Empty()55   bool Empty() const { return root_ == nullptr; }
56 
57   template <class F>
ForEach(F && f)58   void ForEach(F&& f) const {
59     ForEachImpl(root_.get(), std::forward<F>(f));
60   }
61 
SameIdentity(const AVL & avl)62   bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; }
63 
QsortCompare(const AVL & left,const AVL & right)64   friend int QsortCompare(const AVL& left, const AVL& right) {
65     if (left.root_.get() == right.root_.get()) return 0;
66     Iterator a(left.root_);
67     Iterator b(right.root_);
68     for (;;) {
69       Node* p = a.current();
70       Node* q = b.current();
71       if (p != q) {
72         if (p == nullptr) return -1;
73         if (q == nullptr) return 1;
74         const int kv = QsortCompare(p->kv, q->kv);
75         if (kv != 0) return kv;
76       } else if (p == nullptr) {
77         return 0;
78       }
79       a.MoveNext();
80       b.MoveNext();
81     }
82   }
83 
84   bool operator==(const AVL& other) const {
85     return QsortCompare(*this, other) == 0;
86   }
87 
88   bool operator<(const AVL& other) const {
89     return QsortCompare(*this, other) < 0;
90   }
91 
Height()92   size_t Height() const {
93     if (root_ == nullptr) return 0;
94     return root_->height;
95   }
96 
97  private:
98   struct Node;
99 
100   typedef RefCountedPtr<Node> NodePtr;
101   struct Node : public RefCounted<Node, NonPolymorphicRefCount> {
NodeNode102     Node(K k, V v, NodePtr l, NodePtr r, long h)
103         : kv(std::move(k), std::move(v)),
104           left(std::move(l)),
105           right(std::move(r)),
106           height(h) {}
107     const std::pair<K, V> kv;
108     const NodePtr left;
109     const NodePtr right;
110     const long height;
111   };
112   NodePtr root_;
113 
114   class IteratorStack {
115    public:
Push(Node * n)116     void Push(Node* n) {
117       nodes_[depth_] = n;
118       ++depth_;
119     }
120 
Pop()121     Node* Pop() {
122       --depth_;
123       return nodes_[depth_];
124     }
125 
Back()126     Node* Back() const { return nodes_[depth_ - 1]; }
127 
Empty()128     bool Empty() const { return depth_ == 0; }
129 
130    private:
131     size_t depth_{0};
132     // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes
133     // - which ought to suffice our use cases.
134     Node* nodes_[32];
135   };
136 
137   class Iterator {
138    public:
Iterator(const NodePtr & root)139     explicit Iterator(const NodePtr& root) {
140       auto* n = root.get();
141       while (n != nullptr) {
142         stack_.Push(n);
143         n = n->left.get();
144       }
145     }
current()146     Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); }
MoveNext()147     void MoveNext() {
148       auto* n = stack_.Pop();
149       if (n->right != nullptr) {
150         n = n->right.get();
151         while (n != nullptr) {
152           stack_.Push(n);
153           n = n->left.get();
154         }
155       }
156     }
157 
158    private:
159     IteratorStack stack_;
160   };
161 
AVL(NodePtr root)162   explicit AVL(NodePtr root) : root_(std::move(root)) {}
163 
164   template <class F>
ForEachImpl(const Node * n,F && f)165   static void ForEachImpl(const Node* n, F&& f) {
166     if (n == nullptr) return;
167     ForEachImpl(n->left.get(), std::forward<F>(f));
168     f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second));
169     ForEachImpl(n->right.get(), std::forward<F>(f));
170   }
171 
Height(const NodePtr & n)172   static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; }
173 
MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)174   static NodePtr MakeNode(K key, V value, const NodePtr& left,
175                           const NodePtr& right) {
176     return MakeRefCounted<Node>(std::move(key), std::move(value), left, right,
177                                 1 + std::max(Height(left), Height(right)));
178   }
179 
180   template <typename SomethingLikeK>
Get(const NodePtr & node,const SomethingLikeK & key)181   static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) {
182     if (node == nullptr) {
183       return nullptr;
184     }
185 
186     if (node->kv.first > key) {
187       return Get(node->left, key);
188     } else if (node->kv.first < key) {
189       return Get(node->right, key);
190     } else {
191       return node;
192     }
193   }
194 
GetBelow(const NodePtr & node,const K & key)195   static NodePtr GetBelow(const NodePtr& node, const K& key) {
196     if (!node) return nullptr;
197     if (node->kv.first > key) {
198       return GetBelow(node->left, key);
199     } else if (node->kv.first < key) {
200       NodePtr n = GetBelow(node->right, key);
201       if (n == nullptr) n = node;
202       return n;
203     } else {
204       return node;
205     }
206   }
207 
RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)208   static NodePtr RotateLeft(K key, V value, const NodePtr& left,
209                             const NodePtr& right) {
210     return MakeNode(
211         right->kv.first, right->kv.second,
212         MakeNode(std::move(key), std::move(value), left, right->left),
213         right->right);
214   }
215 
RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)216   static NodePtr RotateRight(K key, V value, const NodePtr& left,
217                              const NodePtr& right) {
218     return MakeNode(
219         left->kv.first, left->kv.second, left->left,
220         MakeNode(std::move(key), std::move(value), left->right, right));
221   }
222 
RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)223   static NodePtr RotateLeftRight(K key, V value, const NodePtr& left,
224                                  const NodePtr& right) {
225     // rotate_right(..., rotate_left(left), right)
226     return MakeNode(
227         left->right->kv.first, left->right->kv.second,
228         MakeNode(left->kv.first, left->kv.second, left->left,
229                  left->right->left),
230         MakeNode(std::move(key), std::move(value), left->right->right, right));
231   }
232 
RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)233   static NodePtr RotateRightLeft(K key, V value, const NodePtr& left,
234                                  const NodePtr& right) {
235     // rotate_left(..., left, rotate_right(right))
236     return MakeNode(
237         right->left->kv.first, right->left->kv.second,
238         MakeNode(std::move(key), std::move(value), left, right->left->left),
239         MakeNode(right->kv.first, right->kv.second, right->left->right,
240                  right->right));
241   }
242 
Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)243   static NodePtr Rebalance(K key, V value, const NodePtr& left,
244                            const NodePtr& right) {
245     switch (Height(left) - Height(right)) {
246       case 2:
247         if (Height(left->left) - Height(left->right) == -1) {
248           return RotateLeftRight(std::move(key), std::move(value), left, right);
249         } else {
250           return RotateRight(std::move(key), std::move(value), left, right);
251         }
252       case -2:
253         if (Height(right->left) - Height(right->right) == 1) {
254           return RotateRightLeft(std::move(key), std::move(value), left, right);
255         } else {
256           return RotateLeft(std::move(key), std::move(value), left, right);
257         }
258       default:
259         return MakeNode(key, value, left, right);
260     }
261   }
262 
AddKey(const NodePtr & node,K key,V value)263   static NodePtr AddKey(const NodePtr& node, K key, V value) {
264     if (node == nullptr) {
265       return MakeNode(std::move(key), std::move(value), nullptr, nullptr);
266     }
267     if (node->kv.first < key) {
268       return Rebalance(node->kv.first, node->kv.second, node->left,
269                        AddKey(node->right, std::move(key), std::move(value)));
270     }
271     if (key < node->kv.first) {
272       return Rebalance(node->kv.first, node->kv.second,
273                        AddKey(node->left, std::move(key), std::move(value)),
274                        node->right);
275     }
276     return MakeNode(std::move(key), std::move(value), node->left, node->right);
277   }
278 
InOrderHead(NodePtr node)279   static NodePtr InOrderHead(NodePtr node) {
280     while (node->left != nullptr) {
281       node = node->left;
282     }
283     return node;
284   }
285 
InOrderTail(NodePtr node)286   static NodePtr InOrderTail(NodePtr node) {
287     while (node->right != nullptr) {
288       node = node->right;
289     }
290     return node;
291   }
292 
293   template <typename SomethingLikeK>
RemoveKey(const NodePtr & node,const SomethingLikeK & key)294   static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) {
295     if (node == nullptr) {
296       return nullptr;
297     }
298     if (key < node->kv.first) {
299       return Rebalance(node->kv.first, node->kv.second,
300                        RemoveKey(node->left, key), node->right);
301     } else if (node->kv.first < key) {
302       return Rebalance(node->kv.first, node->kv.second, node->left,
303                        RemoveKey(node->right, key));
304     } else {
305       if (node->left == nullptr) {
306         return node->right;
307       } else if (node->right == nullptr) {
308         return node->left;
309       } else if (node->left->height < node->right->height) {
310         NodePtr h = InOrderHead(node->right);
311         return Rebalance(h->kv.first, h->kv.second, node->left,
312                          RemoveKey(node->right, h->kv.first));
313       } else {
314         NodePtr h = InOrderTail(node->left);
315         return Rebalance(h->kv.first, h->kv.second,
316                          RemoveKey(node->left, h->kv.first), node->right);
317       }
318     }
319     abort();
320   }
321 };
322 
323 }  // namespace grpc_core
324 
325 #endif  // GRPC_SRC_CORE_LIB_AVL_AVL_H
326