xref: /aosp_15_r20/external/mesa3d/src/util/rb_tree.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2017 Faith Ekstrand
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #include "rb_tree.h"
24 
25 /** \file rb_tree.c
26  *
27  * An implementation of a red-black tree
28  *
29  * This file implements the guts of a red-black tree.  The implementation
30  * is mostly based on the one in "Introduction to Algorithms", third
31  * edition, by Cormen, Leiserson, Rivest, and Stein.  The primary
32  * divergence in our algorithms from those presented in CLRS is that we use
33  * NULL for the leaves instead of a sentinel.  This means we have to do a
34  * tiny bit more tracking in our implementation of delete but it makes the
35  * algorithms far more explicit than stashing stuff in the sentinel.
36  */
37 
38 #include <stdlib.h>
39 #include <string.h>
40 #include <assert.h>
41 
42 #include "macros.h"
43 
44 static bool
rb_node_is_black(struct rb_node * n)45 rb_node_is_black(struct rb_node *n)
46 {
47     /* NULL nodes are leaves and therefore black */
48     return (n == NULL) || (n->parent & 1);
49 }
50 
51 static bool
rb_node_is_red(struct rb_node * n)52 rb_node_is_red(struct rb_node *n)
53 {
54     return !rb_node_is_black(n);
55 }
56 
57 static void
rb_node_set_black(struct rb_node * n)58 rb_node_set_black(struct rb_node *n)
59 {
60     n->parent |= 1;
61 }
62 
63 static void
rb_node_set_red(struct rb_node * n)64 rb_node_set_red(struct rb_node *n)
65 {
66     n->parent &= ~1ull;
67 }
68 
69 static void
rb_node_copy_color(struct rb_node * dst,struct rb_node * src)70 rb_node_copy_color(struct rb_node *dst, struct rb_node *src)
71 {
72     dst->parent = (dst->parent & ~1ull) | (src->parent & 1);
73 }
74 
75 static void
rb_node_set_parent(struct rb_node * n,struct rb_node * p)76 rb_node_set_parent(struct rb_node *n, struct rb_node *p)
77 {
78     n->parent = (n->parent & 1) | (uintptr_t)p;
79 }
80 
81 static struct rb_node *
rb_node_minimum(struct rb_node * node)82 rb_node_minimum(struct rb_node *node)
83 {
84     while (node->left)
85         node = node->left;
86     return node;
87 }
88 
89 static struct rb_node *
rb_node_maximum(struct rb_node * node)90 rb_node_maximum(struct rb_node *node)
91 {
92     while (node->right)
93         node = node->right;
94     return node;
95 }
96 
97 /**
98  * Replace the subtree of T rooted at u with the subtree rooted at v
99  *
100  * This is called RB-transplant in CLRS.
101  *
102  * The node to be replaced is assumed to be a non-leaf.
103  */
104 static void
rb_tree_splice(struct rb_tree * T,struct rb_node * u,struct rb_node * v)105 rb_tree_splice(struct rb_tree *T, struct rb_node *u, struct rb_node *v)
106 {
107     assert(u);
108     struct rb_node *p = rb_node_parent(u);
109     if (p == NULL) {
110         assert(T->root == u);
111         T->root = v;
112     } else if (u == p->left) {
113         p->left = v;
114     } else {
115         assert(u == p->right);
116         p->right = v;
117     }
118     if (v)
119         rb_node_set_parent(v, p);
120 }
121 
122 static void
rb_tree_rotate_left(struct rb_tree * T,struct rb_node * x,void (* update)(struct rb_node *))123 rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x,
124                     void (*update)(struct rb_node *))
125 {
126     assert(x && x->right);
127 
128     struct rb_node *y = x->right;
129     x->right = y->left;
130     if (y->left)
131         rb_node_set_parent(y->left, x);
132     rb_tree_splice(T, x, y);
133     y->left = x;
134     rb_node_set_parent(x, y);
135     if (update) {
136         update(x);
137         update(y);
138     }
139 }
140 
141 static void
rb_tree_rotate_right(struct rb_tree * T,struct rb_node * y,void (* update)(struct rb_node *))142 rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y,
143                      void (*update)(struct rb_node *))
144 {
145     assert(y && y->left);
146 
147     struct rb_node *x = y->left;
148     y->left = x->right;
149     if (x->right)
150         rb_node_set_parent(x->right, y);
151     rb_tree_splice(T, y, x);
152     x->right = y;
153     rb_node_set_parent(y, x);
154     if (update) {
155         update(y);
156         update(x);
157     }
158 }
159 
160 void
rb_augmented_tree_insert_at(struct rb_tree * T,struct rb_node * parent,struct rb_node * node,bool insert_left,void (* update)(struct rb_node * node))161 rb_augmented_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
162                             struct rb_node *node, bool insert_left,
163                             void (*update)(struct rb_node *node))
164 {
165     /* This sets null children, parent, and a color of red */
166     memset(node, 0, sizeof(*node));
167 
168     if (update)
169        update(node);
170 
171     if (parent == NULL) {
172         assert(T->root == NULL);
173         T->root = node;
174         rb_node_set_black(node);
175         return;
176     }
177 
178     if (insert_left) {
179         assert(parent->left == NULL);
180         parent->left = node;
181     } else {
182         assert(parent->right == NULL);
183         parent->right = node;
184     }
185     rb_node_set_parent(node, parent);
186 
187     if (update) {
188         struct rb_node *p = parent;
189         while (p) {
190             update(p);
191             p = rb_node_parent(p);
192         }
193     }
194 
195     /* Now we do the insertion fixup */
196     struct rb_node *z = node;
197     while (rb_node_is_red(rb_node_parent(z))) {
198         struct rb_node *z_p = rb_node_parent(z);
199         assert(z == z_p->left || z == z_p->right);
200         struct rb_node *z_p_p = rb_node_parent(z_p);
201         assert(z_p_p != NULL);
202         if (z_p == z_p_p->left) {
203             struct rb_node *y = z_p_p->right;
204             if (rb_node_is_red(y)) {
205                 rb_node_set_black(z_p);
206                 rb_node_set_black(y);
207                 rb_node_set_red(z_p_p);
208                 z = z_p_p;
209             } else {
210                 if (z == z_p->right) {
211                     z = z_p;
212                     rb_tree_rotate_left(T, z, update);
213                     /* We changed z */
214                     z_p = rb_node_parent(z);
215                     assert(z == z_p->left || z == z_p->right);
216                     z_p_p = rb_node_parent(z_p);
217                 }
218                 rb_node_set_black(z_p);
219                 rb_node_set_red(z_p_p);
220                 rb_tree_rotate_right(T, z_p_p, update);
221             }
222         } else {
223             struct rb_node *y = z_p_p->left;
224             if (rb_node_is_red(y)) {
225                 rb_node_set_black(z_p);
226                 rb_node_set_black(y);
227                 rb_node_set_red(z_p_p);
228                 z = z_p_p;
229             } else {
230                 if (z == z_p->left) {
231                     z = z_p;
232                     rb_tree_rotate_right(T, z, update);
233                     /* We changed z */
234                     z_p = rb_node_parent(z);
235                     assert(z == z_p->left || z == z_p->right);
236                     z_p_p = rb_node_parent(z_p);
237                 }
238                 rb_node_set_black(z_p);
239                 rb_node_set_red(z_p_p);
240                 rb_tree_rotate_left(T, z_p_p, update);
241             }
242         }
243     }
244     rb_node_set_black(T->root);
245 }
246 
247 void
rb_augmented_tree_remove(struct rb_tree * T,struct rb_node * z,void (* update)(struct rb_node *))248 rb_augmented_tree_remove(struct rb_tree *T, struct rb_node *z,
249                          void (*update)(struct rb_node *))
250 {
251     /* x_p is always the parent node of X.  We have to track this
252      * separately because x may be NULL.
253      */
254     struct rb_node *x, *x_p;
255     struct rb_node *y = z;
256     bool y_was_black = rb_node_is_black(y);
257     if (z->left == NULL) {
258         x = z->right;
259         x_p = rb_node_parent(z);
260         rb_tree_splice(T, z, x);
261     } else if (z->right == NULL) {
262         x = z->left;
263         x_p = rb_node_parent(z);
264         rb_tree_splice(T, z, x);
265     } else {
266         /* Find the minimum sub-node of z->right */
267         y = rb_node_minimum(z->right);
268         y_was_black = rb_node_is_black(y);
269 
270         x = y->right;
271         if (rb_node_parent(y) == z) {
272             x_p = y;
273         } else {
274             x_p = rb_node_parent(y);
275             rb_tree_splice(T, y, x);
276             y->right = z->right;
277             rb_node_set_parent(y->right, y);
278         }
279         assert(y->left == NULL);
280         rb_tree_splice(T, z, y);
281         y->left = z->left;
282         rb_node_set_parent(y->left, y);
283         rb_node_copy_color(y, z);
284     }
285 
286     assert(x_p == NULL || x == x_p->left || x == x_p->right);
287 
288     if (update) {
289         struct rb_node *p = x_p;
290         while (p) {
291             update(p);
292             p = rb_node_parent(p);
293         }
294     }
295 
296     if (!y_was_black)
297         return;
298 
299     /* Fixup RB tree after the delete */
300     while (x != T->root && rb_node_is_black(x)) {
301         if (x == x_p->left) {
302             struct rb_node *w = x_p->right;
303             if (rb_node_is_red(w)) {
304                 rb_node_set_black(w);
305                 rb_node_set_red(x_p);
306                 rb_tree_rotate_left(T, x_p, update);
307                 assert(x == x_p->left);
308                 w = x_p->right;
309             }
310             if (rb_node_is_black(w->left) && rb_node_is_black(w->right)) {
311                 rb_node_set_red(w);
312                 x = x_p;
313             } else {
314                 if (rb_node_is_black(w->right)) {
315                     rb_node_set_black(w->left);
316                     rb_node_set_red(w);
317                     rb_tree_rotate_right(T, w, update);
318                     w = x_p->right;
319                 }
320                 rb_node_copy_color(w, x_p);
321                 rb_node_set_black(x_p);
322                 rb_node_set_black(w->right);
323                 rb_tree_rotate_left(T, x_p, update);
324                 x = T->root;
325             }
326         } else {
327             struct rb_node *w = x_p->left;
328             if (rb_node_is_red(w)) {
329                 rb_node_set_black(w);
330                 rb_node_set_red(x_p);
331                 rb_tree_rotate_right(T, x_p, update);
332                 assert(x == x_p->right);
333                 w = x_p->left;
334             }
335             if (rb_node_is_black(w->right) && rb_node_is_black(w->left)) {
336                 rb_node_set_red(w);
337                 x = x_p;
338             } else {
339                 if (rb_node_is_black(w->left)) {
340                     rb_node_set_black(w->right);
341                     rb_node_set_red(w);
342                     rb_tree_rotate_left(T, w, update);
343                     w = x_p->left;
344                 }
345                 rb_node_copy_color(w, x_p);
346                 rb_node_set_black(x_p);
347                 rb_node_set_black(w->left);
348                 rb_tree_rotate_right(T, x_p, update);
349                 x = T->root;
350             }
351         }
352         x_p = rb_node_parent(x);
353     }
354     if (x)
355         rb_node_set_black(x);
356 }
357 
358 struct rb_node *
rb_tree_first(struct rb_tree * T)359 rb_tree_first(struct rb_tree *T)
360 {
361     return T->root ? rb_node_minimum(T->root) : NULL;
362 }
363 
364 struct rb_node *
rb_tree_last(struct rb_tree * T)365 rb_tree_last(struct rb_tree *T)
366 {
367     return T->root ? rb_node_maximum(T->root) : NULL;
368 }
369 
370 struct rb_node *
rb_node_next(struct rb_node * node)371 rb_node_next(struct rb_node *node)
372 {
373     if (node->right) {
374         /* If we have a right child, then the next thing (compared to this
375          * node) is the left-most child of our right child.
376          */
377         return rb_node_minimum(node->right);
378     } else {
379         /* If node doesn't have a right child, crawl back up the to the
380          * left until we hit a parent to the right.
381          */
382         struct rb_node *p = rb_node_parent(node);
383         while (p && node == p->right) {
384             node = p;
385             p = rb_node_parent(node);
386         }
387         assert(p == NULL || node == p->left);
388         return p;
389     }
390 }
391 
392 struct rb_node *
rb_node_prev(struct rb_node * node)393 rb_node_prev(struct rb_node *node)
394 {
395     if (node->left) {
396         /* If we have a left child, then the previous thing (compared to
397          * this node) is the right-most child of our left child.
398          */
399         return rb_node_maximum(node->left);
400     } else {
401         /* If node doesn't have a left child, crawl back up the to the
402          * right until we hit a parent to the left.
403          */
404         struct rb_node *p = rb_node_parent(node);
405         while (p && node == p->left) {
406             node = p;
407             p = rb_node_parent(node);
408         }
409         assert(p == NULL || node == p->right);
410         return p;
411     }
412 }
413 
414 /* Return the first node in an interval tree that intersects a given interval
415  * or point. The tests against the interval and the max field are abstracted
416  * via function pointers, so that this works for any type of interval.
417  */
418 static struct rb_node *
rb_node_min_intersecting(struct rb_node * node,void * interval,int (* cmp_interval)(const struct rb_node * node,const void * interval),bool (* cmp_max)(const struct rb_node * node,const void * interval))419 rb_node_min_intersecting(struct rb_node *node, void *interval,
420                          int (*cmp_interval)(const struct rb_node *node,
421                                              const void *interval),
422                          bool (*cmp_max)(const struct rb_node *node,
423                                          const void *interval))
424 {
425     if (!cmp_max(node, interval))
426         return NULL;
427 
428     while (node) {
429         int cmp = cmp_interval(node, interval);
430 
431         /* If the node's interval is entirely to the right of the interval
432          * we're searching for, all of its right descendants are also to the
433          * right and don't intersect so we have to search to the left.
434          */
435         if (cmp > 0) {
436             node = node->left;
437             continue;
438         }
439 
440         /* The interval overlaps or is to the left. This must also be true for
441          * its left descendants because their start points are to the left of
442          * node's. We can use the max to tell if there is an interval in its
443          * left descendants which overlaps our interval, in which case we
444          * should descend to the left.
445          */
446         if (node->left && cmp_max(node->left, interval)) {
447             node = node->left;
448             continue;
449         }
450 
451         /* Now the only possibilities are the node's interval intersects the
452          * interval or one of its right descendants does.
453          */
454         if (cmp == 0)
455             return node;
456 
457         node = node->right;
458         if (node && !cmp_max(node, interval))
459             return NULL;
460     }
461 
462     return NULL;
463 }
464 
465 /* Return the next node after "node" that intersects a given interval.
466  *
467  * Because rb_node_min_intersecting() takes O(log n) time and may be called up
468  * to O(log n) times, in addition to the O(log n) crawl up the tree, a naive
469  * runtime analysis would show that this takes O((log n)^2) time, but actually
470  * it's O(log n). Proving this is tricky:
471  *
472  * Call the rightmost node in the tree whose start is before the end of the
473  * interval we're searching for N. All nodes after N in the tree are to the
474  * right of the interval. We'll divide the search into two phases: in phase 1,
475  * "node" is *not* an ancestor of N, and in phase 2 it is. Because we always
476  * crawl up the tree, phase 2 cannot turn back into phase 1, but phase 1 may
477  * be followed by phase 2. We'll prove that the calls to
478  * rb_node_min_intersecting() take O(log n) time in both phases.
479  *
480  * Phase 1: Because "node" is to the left of N and N isn't a descendant of
481  * "node", the start of every interval in "node"'s subtree must be less than
482  * or equal to N's start which is less than or equal to the end of the search
483  * interval. Furthermore, either "node"'s max_end is less than the start of
484  * the interval, in which case rb_node_min_intersecting() immediately returns
485  * NULL, or some descendant has an end equal to "node"'s max_end which is
486  * greater than or equal to the search interval's start, and therefore it
487  * intersects the search interval and rb_node_min_intersecting() must return
488  * non-NULL which causes us to terminate. rb_node_min_intersecting() is called
489  * O(log n) times, with all but the last call taking constant time and the
490  * last call taking O(log n), so the overall runtime is O(log n).
491  *
492  * Phase 2: After the first call to rb_node_min_intersecting, we may crawl up
493  * the tree until we get to a node p where "node", and therefore N, is in p's
494  * left subtree. However this means that p is to the right of N in the tree
495  * and is therefore to the right of the search interval, and the search
496  * terminates on the first iteration of the loop so that
497  * rb_node_min_intersecting() is only called once.
498  */
499 static struct rb_node *
rb_node_next_intersecting(struct rb_node * node,void * interval,int (* cmp_interval)(const struct rb_node * node,const void * interval),bool (* cmp_max)(const struct rb_node * node,const void * interval))500 rb_node_next_intersecting(struct rb_node *node,
501                           void *interval,
502                           int (*cmp_interval)(const struct rb_node *node,
503                                               const void *interval),
504                           bool (*cmp_max)(const struct rb_node *node,
505                                           const void *interval))
506 {
507     while (true) {
508         /* The first place to search is the node's right subtree. */
509         if (node->right) {
510             struct rb_node *next =
511                 rb_node_min_intersecting(node->right, interval, cmp_interval, cmp_max);
512             if (next)
513                 return next;
514         }
515 
516         /* If we don't find a matching interval there, crawl up the tree until
517          * we find an ancestor to the right. This is the next node after the
518          * right subtree which we determined didn't match.
519          */
520         struct rb_node *p = rb_node_parent(node);
521         while (p && node == p->right) {
522             node = p;
523             p = rb_node_parent(node);
524         }
525         assert(p == NULL || node == p->left);
526 
527         /* Check if we've searched everything in the tree. */
528         if (!p)
529             return NULL;
530 
531         int cmp = cmp_interval(p, interval);
532 
533         /* If it intersects, return it. */
534         if (cmp == 0)
535             return p;
536 
537         /* If it's to the right of the interval, all following nodes will be
538          * to the right and we can bail early.
539          */
540         if (cmp > 0)
541             return NULL;
542 
543         node = p;
544     }
545 }
546 
547 static int
uinterval_cmp(struct uinterval a,struct uinterval b)548 uinterval_cmp(struct uinterval a, struct uinterval b)
549 {
550     if (a.end < b.start)
551         return -1;
552     else if (b.end < a.start)
553         return 1;
554     else
555         return 0;
556 }
557 
558 static int
uinterval_node_cmp(const struct rb_node * _a,const struct rb_node * _b)559 uinterval_node_cmp(const struct rb_node *_a, const struct rb_node *_b)
560 {
561     const struct uinterval_node *a = rb_node_data(struct uinterval_node, _a, node);
562     const struct uinterval_node *b = rb_node_data(struct uinterval_node, _b, node);
563 
564     return (int) (b->interval.start - a->interval.start);
565 }
566 
567 static int
uinterval_search_cmp(const struct rb_node * _node,const void * _interval)568 uinterval_search_cmp(const struct rb_node *_node, const void *_interval)
569 {
570     const struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node);
571     const struct uinterval *interval = _interval;
572 
573     return uinterval_cmp(node->interval, *interval);
574 }
575 
576 static bool
uinterval_max_cmp(const struct rb_node * _node,const void * data)577 uinterval_max_cmp(const struct rb_node *_node, const void *data)
578 {
579     const struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node);
580     const struct uinterval *interval = data;
581 
582     return node->max_end >= interval->start;
583 }
584 
585 static void
uinterval_update_max(struct rb_node * _node)586 uinterval_update_max(struct rb_node *_node)
587 {
588     struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node);
589     node->max_end = node->interval.end;
590     if (node->node.left) {
591         struct uinterval_node *left = rb_node_data(struct uinterval_node, node->node.left, node);
592         node->max_end = MAX2(node->max_end, left->max_end);
593     }
594     if (node->node.right) {
595         struct uinterval_node *right = rb_node_data(struct uinterval_node, node->node.right, node);
596         node->max_end = MAX2(node->max_end, right->max_end);
597     }
598 }
599 
600 void
uinterval_tree_insert(struct rb_tree * tree,struct uinterval_node * node)601 uinterval_tree_insert(struct rb_tree *tree, struct uinterval_node *node)
602 {
603     rb_augmented_tree_insert(tree, &node->node, uinterval_node_cmp,
604                              uinterval_update_max);
605 }
606 
607 void
uinterval_tree_remove(struct rb_tree * tree,struct uinterval_node * node)608 uinterval_tree_remove(struct rb_tree *tree, struct uinterval_node *node)
609 {
610     rb_augmented_tree_remove(tree, &node->node, uinterval_update_max);
611 }
612 
613 struct uinterval_node *
uinterval_tree_first(struct rb_tree * tree,struct uinterval interval)614 uinterval_tree_first(struct rb_tree *tree, struct uinterval interval)
615 {
616     if (!tree->root)
617         return NULL;
618 
619     struct rb_node *node =
620         rb_node_min_intersecting(tree->root, &interval, uinterval_search_cmp,
621                                  uinterval_max_cmp);
622 
623     return node ? rb_node_data(struct uinterval_node, node, node) : NULL;
624 }
625 
626 struct uinterval_node *
uinterval_node_next(struct uinterval_node * node,struct uinterval interval)627 uinterval_node_next(struct uinterval_node *node,
628                     struct uinterval interval)
629 {
630     struct rb_node *next =
631         rb_node_next_intersecting(&node->node, &interval, uinterval_search_cmp,
632                                   uinterval_max_cmp);
633 
634     return next ? rb_node_data(struct uinterval_node, next, node) : NULL;
635 }
636 
637 static void
validate_rb_node(struct rb_node * n,int black_depth)638 validate_rb_node(struct rb_node *n, int black_depth)
639 {
640     if (n == NULL) {
641         assert(black_depth == 0);
642         return;
643     }
644 
645     if (rb_node_is_black(n)) {
646         black_depth--;
647     } else {
648         assert(rb_node_is_black(n->left));
649         assert(rb_node_is_black(n->right));
650     }
651 
652     validate_rb_node(n->left, black_depth);
653     validate_rb_node(n->right, black_depth);
654 }
655 
656 void
rb_tree_validate(struct rb_tree * T)657 rb_tree_validate(struct rb_tree *T)
658 {
659     if (T->root == NULL)
660         return;
661 
662     assert(rb_node_is_black(T->root));
663 
664     unsigned black_depth = 0;
665     for (struct rb_node *n = T->root; n; n = n->left) {
666         if (rb_node_is_black(n))
667             black_depth++;
668     }
669 
670     validate_rb_node(T->root, black_depth);
671 }
672