1 // Copyright 2021 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_AVL_AVL_H 16 #define GRPC_SRC_CORE_LIB_AVL_AVL_H 17 18 #include <grpc/support/port_platform.h> 19 20 #include <stdlib.h> 21 22 #include <algorithm> // IWYU pragma: keep 23 #include <iterator> 24 #include <utility> 25 26 #include "src/core/lib/gpr/useful.h" 27 #include "src/core/lib/gprpp/ref_counted.h" 28 #include "src/core/lib/gprpp/ref_counted_ptr.h" 29 30 namespace grpc_core { 31 32 template <class K, class V = void> 33 class AVL { 34 public: AVL()35 AVL() {} 36 Add(K key,V value)37 AVL Add(K key, V value) const { 38 return AVL(AddKey(root_, std::move(key), std::move(value))); 39 } 40 template <typename SomethingLikeK> Remove(const SomethingLikeK & key)41 AVL Remove(const SomethingLikeK& key) const { 42 return AVL(RemoveKey(root_, key)); 43 } 44 template <typename SomethingLikeK> Lookup(const SomethingLikeK & key)45 const V* Lookup(const SomethingLikeK& key) const { 46 NodePtr n = Get(root_, key); 47 return n != nullptr ? &n->kv.second : nullptr; 48 } 49 LookupBelow(const K & key)50 const std::pair<K, V>* LookupBelow(const K& key) const { 51 NodePtr n = GetBelow(root_, *key); 52 return n != nullptr ? &n->kv : nullptr; 53 } 54 Empty()55 bool Empty() const { return root_ == nullptr; } 56 57 template <class F> ForEach(F && f)58 void ForEach(F&& f) const { 59 ForEachImpl(root_.get(), std::forward<F>(f)); 60 } 61 SameIdentity(const AVL & avl)62 bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; } 63 QsortCompare(const AVL & left,const AVL & right)64 friend int QsortCompare(const AVL& left, const AVL& right) { 65 if (left.root_.get() == right.root_.get()) return 0; 66 Iterator a(left.root_); 67 Iterator b(right.root_); 68 for (;;) { 69 Node* p = a.current(); 70 Node* q = b.current(); 71 if (p != q) { 72 if (p == nullptr) return -1; 73 if (q == nullptr) return 1; 74 const int kv = QsortCompare(p->kv, q->kv); 75 if (kv != 0) return kv; 76 } else if (p == nullptr) { 77 return 0; 78 } 79 a.MoveNext(); 80 b.MoveNext(); 81 } 82 } 83 84 bool operator==(const AVL& other) const { 85 return QsortCompare(*this, other) == 0; 86 } 87 88 bool operator<(const AVL& other) const { 89 return QsortCompare(*this, other) < 0; 90 } 91 Height()92 size_t Height() const { 93 if (root_ == nullptr) return 0; 94 return root_->height; 95 } 96 97 private: 98 struct Node; 99 100 typedef RefCountedPtr<Node> NodePtr; 101 struct Node : public RefCounted<Node, NonPolymorphicRefCount> { NodeNode102 Node(K k, V v, NodePtr l, NodePtr r, long h) 103 : kv(std::move(k), std::move(v)), 104 left(std::move(l)), 105 right(std::move(r)), 106 height(h) {} 107 const std::pair<K, V> kv; 108 const NodePtr left; 109 const NodePtr right; 110 const long height; 111 }; 112 NodePtr root_; 113 114 class IteratorStack { 115 public: Push(Node * n)116 void Push(Node* n) { 117 nodes_[depth_] = n; 118 ++depth_; 119 } 120 Pop()121 Node* Pop() { 122 --depth_; 123 return nodes_[depth_]; 124 } 125 Back()126 Node* Back() const { return nodes_[depth_ - 1]; } 127 Empty()128 bool Empty() const { return depth_ == 0; } 129 130 private: 131 size_t depth_{0}; 132 // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes 133 // - which ought to suffice our use cases. 134 Node* nodes_[32]; 135 }; 136 137 class Iterator { 138 public: Iterator(const NodePtr & root)139 explicit Iterator(const NodePtr& root) { 140 auto* n = root.get(); 141 while (n != nullptr) { 142 stack_.Push(n); 143 n = n->left.get(); 144 } 145 } current()146 Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); } MoveNext()147 void MoveNext() { 148 auto* n = stack_.Pop(); 149 if (n->right != nullptr) { 150 n = n->right.get(); 151 while (n != nullptr) { 152 stack_.Push(n); 153 n = n->left.get(); 154 } 155 } 156 } 157 158 private: 159 IteratorStack stack_; 160 }; 161 AVL(NodePtr root)162 explicit AVL(NodePtr root) : root_(std::move(root)) {} 163 164 template <class F> ForEachImpl(const Node * n,F && f)165 static void ForEachImpl(const Node* n, F&& f) { 166 if (n == nullptr) return; 167 ForEachImpl(n->left.get(), std::forward<F>(f)); 168 f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second)); 169 ForEachImpl(n->right.get(), std::forward<F>(f)); 170 } 171 Height(const NodePtr & n)172 static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; } 173 MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)174 static NodePtr MakeNode(K key, V value, const NodePtr& left, 175 const NodePtr& right) { 176 return MakeRefCounted<Node>(std::move(key), std::move(value), left, right, 177 1 + std::max(Height(left), Height(right))); 178 } 179 180 template <typename SomethingLikeK> Get(const NodePtr & node,const SomethingLikeK & key)181 static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) { 182 if (node == nullptr) { 183 return nullptr; 184 } 185 186 if (node->kv.first > key) { 187 return Get(node->left, key); 188 } else if (node->kv.first < key) { 189 return Get(node->right, key); 190 } else { 191 return node; 192 } 193 } 194 GetBelow(const NodePtr & node,const K & key)195 static NodePtr GetBelow(const NodePtr& node, const K& key) { 196 if (!node) return nullptr; 197 if (node->kv.first > key) { 198 return GetBelow(node->left, key); 199 } else if (node->kv.first < key) { 200 NodePtr n = GetBelow(node->right, key); 201 if (n == nullptr) n = node; 202 return n; 203 } else { 204 return node; 205 } 206 } 207 RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)208 static NodePtr RotateLeft(K key, V value, const NodePtr& left, 209 const NodePtr& right) { 210 return MakeNode( 211 right->kv.first, right->kv.second, 212 MakeNode(std::move(key), std::move(value), left, right->left), 213 right->right); 214 } 215 RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)216 static NodePtr RotateRight(K key, V value, const NodePtr& left, 217 const NodePtr& right) { 218 return MakeNode( 219 left->kv.first, left->kv.second, left->left, 220 MakeNode(std::move(key), std::move(value), left->right, right)); 221 } 222 RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)223 static NodePtr RotateLeftRight(K key, V value, const NodePtr& left, 224 const NodePtr& right) { 225 // rotate_right(..., rotate_left(left), right) 226 return MakeNode( 227 left->right->kv.first, left->right->kv.second, 228 MakeNode(left->kv.first, left->kv.second, left->left, 229 left->right->left), 230 MakeNode(std::move(key), std::move(value), left->right->right, right)); 231 } 232 RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)233 static NodePtr RotateRightLeft(K key, V value, const NodePtr& left, 234 const NodePtr& right) { 235 // rotate_left(..., left, rotate_right(right)) 236 return MakeNode( 237 right->left->kv.first, right->left->kv.second, 238 MakeNode(std::move(key), std::move(value), left, right->left->left), 239 MakeNode(right->kv.first, right->kv.second, right->left->right, 240 right->right)); 241 } 242 Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)243 static NodePtr Rebalance(K key, V value, const NodePtr& left, 244 const NodePtr& right) { 245 switch (Height(left) - Height(right)) { 246 case 2: 247 if (Height(left->left) - Height(left->right) == -1) { 248 return RotateLeftRight(std::move(key), std::move(value), left, right); 249 } else { 250 return RotateRight(std::move(key), std::move(value), left, right); 251 } 252 case -2: 253 if (Height(right->left) - Height(right->right) == 1) { 254 return RotateRightLeft(std::move(key), std::move(value), left, right); 255 } else { 256 return RotateLeft(std::move(key), std::move(value), left, right); 257 } 258 default: 259 return MakeNode(key, value, left, right); 260 } 261 } 262 AddKey(const NodePtr & node,K key,V value)263 static NodePtr AddKey(const NodePtr& node, K key, V value) { 264 if (node == nullptr) { 265 return MakeNode(std::move(key), std::move(value), nullptr, nullptr); 266 } 267 if (node->kv.first < key) { 268 return Rebalance(node->kv.first, node->kv.second, node->left, 269 AddKey(node->right, std::move(key), std::move(value))); 270 } 271 if (key < node->kv.first) { 272 return Rebalance(node->kv.first, node->kv.second, 273 AddKey(node->left, std::move(key), std::move(value)), 274 node->right); 275 } 276 return MakeNode(std::move(key), std::move(value), node->left, node->right); 277 } 278 InOrderHead(NodePtr node)279 static NodePtr InOrderHead(NodePtr node) { 280 while (node->left != nullptr) { 281 node = node->left; 282 } 283 return node; 284 } 285 InOrderTail(NodePtr node)286 static NodePtr InOrderTail(NodePtr node) { 287 while (node->right != nullptr) { 288 node = node->right; 289 } 290 return node; 291 } 292 293 template <typename SomethingLikeK> RemoveKey(const NodePtr & node,const SomethingLikeK & key)294 static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) { 295 if (node == nullptr) { 296 return nullptr; 297 } 298 if (key < node->kv.first) { 299 return Rebalance(node->kv.first, node->kv.second, 300 RemoveKey(node->left, key), node->right); 301 } else if (node->kv.first < key) { 302 return Rebalance(node->kv.first, node->kv.second, node->left, 303 RemoveKey(node->right, key)); 304 } else { 305 if (node->left == nullptr) { 306 return node->right; 307 } else if (node->right == nullptr) { 308 return node->left; 309 } else if (node->left->height < node->right->height) { 310 NodePtr h = InOrderHead(node->right); 311 return Rebalance(h->kv.first, h->kv.second, node->left, 312 RemoveKey(node->right, h->kv.first)); 313 } else { 314 NodePtr h = InOrderTail(node->left); 315 return Rebalance(h->kv.first, h->kv.second, 316 RemoveKey(node->left, h->kv.first), node->right); 317 } 318 } 319 abort(); 320 } 321 }; 322 323 } // namespace grpc_core 324 325 #endif // GRPC_SRC_CORE_LIB_AVL_AVL_H 326