1// Copyright 2022 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package abt 6 7import ( 8 "fmt" 9 "strconv" 10 "strings" 11) 12 13const ( 14 LEAF_HEIGHT = 1 15 ZERO_HEIGHT = 0 16 NOT_KEY32 = int32(-0x80000000) 17) 18 19// T is the exported applicative balanced tree data type. 20// A T can be used as a value; updates to one copy of the value 21// do not change other copies. 22type T struct { 23 root *node32 24 size int 25} 26 27// node32 is the internal tree node data type 28type node32 struct { 29 // Standard conventions hold for left = smaller, right = larger 30 left, right *node32 31 data interface{} 32 key int32 33 height_ int8 34} 35 36func makeNode(key int32) *node32 { 37 return &node32{key: key, height_: LEAF_HEIGHT} 38} 39 40// IsEmpty returns true iff t is empty. 41func (t *T) IsEmpty() bool { 42 return t.root == nil 43} 44 45// IsSingle returns true iff t is a singleton (leaf). 46func (t *T) IsSingle() bool { 47 return t.root != nil && t.root.isLeaf() 48} 49 50// VisitInOrder applies f to the key and data pairs in t, 51// with keys ordered from smallest to largest. 52func (t *T) VisitInOrder(f func(int32, interface{})) { 53 if t.root == nil { 54 return 55 } 56 t.root.visitInOrder(f) 57} 58 59func (n *node32) nilOrData() interface{} { 60 if n == nil { 61 return nil 62 } 63 return n.data 64} 65 66func (n *node32) nilOrKeyAndData() (k int32, d interface{}) { 67 if n == nil { 68 k = NOT_KEY32 69 d = nil 70 } else { 71 k = n.key 72 d = n.data 73 } 74 return 75} 76 77func (n *node32) height() int8 { 78 if n == nil { 79 return 0 80 } 81 return n.height_ 82} 83 84// Find returns the data associated with x in the tree, or 85// nil if x is not in the tree. 86func (t *T) Find(x int32) interface{} { 87 return t.root.find(x).nilOrData() 88} 89 90// Insert either adds x to the tree if x was not previously 91// a key in the tree, or updates the data for x in the tree if 92// x was already a key in the tree. The previous data associated 93// with x is returned, and is nil if x was not previously a 94// key in the tree. 95func (t *T) Insert(x int32, data interface{}) interface{} { 96 if x == NOT_KEY32 { 97 panic("Cannot use sentinel value -0x80000000 as key") 98 } 99 n := t.root 100 var newroot *node32 101 var o *node32 102 if n == nil { 103 n = makeNode(x) 104 newroot = n 105 } else { 106 newroot, n, o = n.aInsert(x) 107 } 108 var r interface{} 109 if o != nil { 110 r = o.data 111 } else { 112 t.size++ 113 } 114 n.data = data 115 t.root = newroot 116 return r 117} 118 119func (t *T) Copy() *T { 120 u := *t 121 return &u 122} 123 124func (t *T) Delete(x int32) interface{} { 125 n := t.root 126 if n == nil { 127 return nil 128 } 129 d, s := n.aDelete(x) 130 if d == nil { 131 return nil 132 } 133 t.root = s 134 t.size-- 135 return d.data 136} 137 138func (t *T) DeleteMin() (int32, interface{}) { 139 n := t.root 140 if n == nil { 141 return NOT_KEY32, nil 142 } 143 d, s := n.aDeleteMin() 144 if d == nil { 145 return NOT_KEY32, nil 146 } 147 t.root = s 148 t.size-- 149 return d.key, d.data 150} 151 152func (t *T) DeleteMax() (int32, interface{}) { 153 n := t.root 154 if n == nil { 155 return NOT_KEY32, nil 156 } 157 d, s := n.aDeleteMax() 158 if d == nil { 159 return NOT_KEY32, nil 160 } 161 t.root = s 162 t.size-- 163 return d.key, d.data 164} 165 166func (t *T) Size() int { 167 return t.size 168} 169 170// Intersection returns the intersection of t and u, where the result 171// data for any common keys is given by f(t's data, u's data) -- f need 172// not be symmetric. If f returns nil, then the key and data are not 173// added to the result. If f itself is nil, then whatever value was 174// already present in the smaller set is used. 175func (t *T) Intersection(u *T, f func(x, y interface{}) interface{}) *T { 176 if t.Size() == 0 || u.Size() == 0 { 177 return &T{} 178 } 179 180 // For faster execution and less allocation, prefer t smaller, iterate over t. 181 if t.Size() <= u.Size() { 182 v := t.Copy() 183 for it := t.Iterator(); !it.Done(); { 184 k, d := it.Next() 185 e := u.Find(k) 186 if e == nil { 187 v.Delete(k) 188 continue 189 } 190 if f == nil { 191 continue 192 } 193 if c := f(d, e); c != d { 194 if c == nil { 195 v.Delete(k) 196 } else { 197 v.Insert(k, c) 198 } 199 } 200 } 201 return v 202 } 203 v := u.Copy() 204 for it := u.Iterator(); !it.Done(); { 205 k, e := it.Next() 206 d := t.Find(k) 207 if d == nil { 208 v.Delete(k) 209 continue 210 } 211 if f == nil { 212 continue 213 } 214 if c := f(d, e); c != d { 215 if c == nil { 216 v.Delete(k) 217 } else { 218 v.Insert(k, c) 219 } 220 } 221 } 222 223 return v 224} 225 226// Union returns the union of t and u, where the result data for any common keys 227// is given by f(t's data, u's data) -- f need not be symmetric. If f returns nil, 228// then the key and data are not added to the result. If f itself is nil, then 229// whatever value was already present in the larger set is used. 230func (t *T) Union(u *T, f func(x, y interface{}) interface{}) *T { 231 if t.Size() == 0 { 232 return u 233 } 234 if u.Size() == 0 { 235 return t 236 } 237 238 if t.Size() >= u.Size() { 239 v := t.Copy() 240 for it := u.Iterator(); !it.Done(); { 241 k, e := it.Next() 242 d := t.Find(k) 243 if d == nil { 244 v.Insert(k, e) 245 continue 246 } 247 if f == nil { 248 continue 249 } 250 if c := f(d, e); c != d { 251 if c == nil { 252 v.Delete(k) 253 } else { 254 v.Insert(k, c) 255 } 256 } 257 } 258 return v 259 } 260 261 v := u.Copy() 262 for it := t.Iterator(); !it.Done(); { 263 k, d := it.Next() 264 e := u.Find(k) 265 if e == nil { 266 v.Insert(k, d) 267 continue 268 } 269 if f == nil { 270 continue 271 } 272 if c := f(d, e); c != d { 273 if c == nil { 274 v.Delete(k) 275 } else { 276 v.Insert(k, c) 277 } 278 } 279 } 280 return v 281} 282 283// Difference returns the difference of t and u, subject to the result 284// of f applied to data corresponding to equal keys. If f returns nil 285// (or if f is nil) then the key+data are excluded, as usual. If f 286// returns not-nil, then that key+data pair is inserted. instead. 287func (t *T) Difference(u *T, f func(x, y interface{}) interface{}) *T { 288 if t.Size() == 0 { 289 return &T{} 290 } 291 if u.Size() == 0 { 292 return t 293 } 294 v := t.Copy() 295 for it := t.Iterator(); !it.Done(); { 296 k, d := it.Next() 297 e := u.Find(k) 298 if e != nil { 299 if f == nil { 300 v.Delete(k) 301 continue 302 } 303 c := f(d, e) 304 if c == nil { 305 v.Delete(k) 306 continue 307 } 308 if c != d { 309 v.Insert(k, c) 310 } 311 } 312 } 313 return v 314} 315 316func (t *T) Iterator() Iterator { 317 return Iterator{it: t.root.iterator()} 318} 319 320func (t *T) Equals(u *T) bool { 321 if t == u { 322 return true 323 } 324 if t.Size() != u.Size() { 325 return false 326 } 327 return t.root.equals(u.root) 328} 329 330func (t *T) String() string { 331 var b strings.Builder 332 first := true 333 for it := t.Iterator(); !it.Done(); { 334 k, v := it.Next() 335 if first { 336 first = false 337 } else { 338 b.WriteString("; ") 339 } 340 b.WriteString(strconv.FormatInt(int64(k), 10)) 341 b.WriteString(":") 342 fmt.Fprint(&b, v) 343 } 344 return b.String() 345} 346 347func (t *node32) equals(u *node32) bool { 348 if t == u { 349 return true 350 } 351 it, iu := t.iterator(), u.iterator() 352 for !it.done() && !iu.done() { 353 nt := it.next() 354 nu := iu.next() 355 if nt == nu { 356 continue 357 } 358 if nt.key != nu.key { 359 return false 360 } 361 if nt.data != nu.data { 362 return false 363 } 364 } 365 return it.done() == iu.done() 366} 367 368func (t *T) Equiv(u *T, eqv func(x, y interface{}) bool) bool { 369 if t == u { 370 return true 371 } 372 if t.Size() != u.Size() { 373 return false 374 } 375 return t.root.equiv(u.root, eqv) 376} 377 378func (t *node32) equiv(u *node32, eqv func(x, y interface{}) bool) bool { 379 if t == u { 380 return true 381 } 382 it, iu := t.iterator(), u.iterator() 383 for !it.done() && !iu.done() { 384 nt := it.next() 385 nu := iu.next() 386 if nt == nu { 387 continue 388 } 389 if nt.key != nu.key { 390 return false 391 } 392 if !eqv(nt.data, nu.data) { 393 return false 394 } 395 } 396 return it.done() == iu.done() 397} 398 399type iterator struct { 400 parents []*node32 401} 402 403type Iterator struct { 404 it iterator 405} 406 407func (it *Iterator) Next() (int32, interface{}) { 408 x := it.it.next() 409 if x == nil { 410 return NOT_KEY32, nil 411 } 412 return x.key, x.data 413} 414 415func (it *Iterator) Done() bool { 416 return len(it.it.parents) == 0 417} 418 419func (t *node32) iterator() iterator { 420 if t == nil { 421 return iterator{} 422 } 423 it := iterator{parents: make([]*node32, 0, int(t.height()))} 424 it.leftmost(t) 425 return it 426} 427 428func (it *iterator) leftmost(t *node32) { 429 for t != nil { 430 it.parents = append(it.parents, t) 431 t = t.left 432 } 433} 434 435func (it *iterator) done() bool { 436 return len(it.parents) == 0 437} 438 439func (it *iterator) next() *node32 { 440 l := len(it.parents) 441 if l == 0 { 442 return nil 443 } 444 x := it.parents[l-1] // return value 445 if x.right != nil { 446 it.leftmost(x.right) 447 return x 448 } 449 // discard visited top of parents 450 l-- 451 it.parents = it.parents[:l] 452 y := x // y is known visited/returned 453 for l > 0 && y == it.parents[l-1].right { 454 y = it.parents[l-1] 455 l-- 456 it.parents = it.parents[:l] 457 } 458 459 return x 460} 461 462// Min returns the minimum element of t. 463// If t is empty, then (NOT_KEY32, nil) is returned. 464func (t *T) Min() (k int32, d interface{}) { 465 return t.root.min().nilOrKeyAndData() 466} 467 468// Max returns the maximum element of t. 469// If t is empty, then (NOT_KEY32, nil) is returned. 470func (t *T) Max() (k int32, d interface{}) { 471 return t.root.max().nilOrKeyAndData() 472} 473 474// Glb returns the greatest-lower-bound-exclusive of x and the associated 475// data. If x has no glb in the tree, then (NOT_KEY32, nil) is returned. 476func (t *T) Glb(x int32) (k int32, d interface{}) { 477 return t.root.glb(x, false).nilOrKeyAndData() 478} 479 480// GlbEq returns the greatest-lower-bound-inclusive of x and the associated 481// data. If x has no glbEQ in the tree, then (NOT_KEY32, nil) is returned. 482func (t *T) GlbEq(x int32) (k int32, d interface{}) { 483 return t.root.glb(x, true).nilOrKeyAndData() 484} 485 486// Lub returns the least-upper-bound-exclusive of x and the associated 487// data. If x has no lub in the tree, then (NOT_KEY32, nil) is returned. 488func (t *T) Lub(x int32) (k int32, d interface{}) { 489 return t.root.lub(x, false).nilOrKeyAndData() 490} 491 492// LubEq returns the least-upper-bound-inclusive of x and the associated 493// data. If x has no lubEq in the tree, then (NOT_KEY32, nil) is returned. 494func (t *T) LubEq(x int32) (k int32, d interface{}) { 495 return t.root.lub(x, true).nilOrKeyAndData() 496} 497 498func (t *node32) isLeaf() bool { 499 return t.left == nil && t.right == nil && t.height_ == LEAF_HEIGHT 500} 501 502func (t *node32) visitInOrder(f func(int32, interface{})) { 503 if t.left != nil { 504 t.left.visitInOrder(f) 505 } 506 f(t.key, t.data) 507 if t.right != nil { 508 t.right.visitInOrder(f) 509 } 510} 511 512func (t *node32) find(key int32) *node32 { 513 for t != nil { 514 if key < t.key { 515 t = t.left 516 } else if key > t.key { 517 t = t.right 518 } else { 519 return t 520 } 521 } 522 return nil 523} 524 525func (t *node32) min() *node32 { 526 if t == nil { 527 return t 528 } 529 for t.left != nil { 530 t = t.left 531 } 532 return t 533} 534 535func (t *node32) max() *node32 { 536 if t == nil { 537 return t 538 } 539 for t.right != nil { 540 t = t.right 541 } 542 return t 543} 544 545func (t *node32) glb(key int32, allow_eq bool) *node32 { 546 var best *node32 = nil 547 for t != nil { 548 if key <= t.key { 549 if allow_eq && key == t.key { 550 return t 551 } 552 // t is too big, glb is to left. 553 t = t.left 554 } else { 555 // t is a lower bound, record it and seek a better one. 556 best = t 557 t = t.right 558 } 559 } 560 return best 561} 562 563func (t *node32) lub(key int32, allow_eq bool) *node32 { 564 var best *node32 = nil 565 for t != nil { 566 if key >= t.key { 567 if allow_eq && key == t.key { 568 return t 569 } 570 // t is too small, lub is to right. 571 t = t.right 572 } else { 573 // t is an upper bound, record it and seek a better one. 574 best = t 575 t = t.left 576 } 577 } 578 return best 579} 580 581func (t *node32) aInsert(x int32) (newroot, newnode, oldnode *node32) { 582 // oldnode default of nil is good, others should be assigned. 583 if x == t.key { 584 oldnode = t 585 newt := *t 586 newnode = &newt 587 newroot = newnode 588 return 589 } 590 if x < t.key { 591 if t.left == nil { 592 t = t.copy() 593 n := makeNode(x) 594 t.left = n 595 newnode = n 596 newroot = t 597 t.height_ = 2 // was balanced w/ 0, sibling is height 0 or 1 598 return 599 } 600 var new_l *node32 601 new_l, newnode, oldnode = t.left.aInsert(x) 602 t = t.copy() 603 t.left = new_l 604 if new_l.height() > 1+t.right.height() { 605 newroot = t.aLeftIsHigh(newnode) 606 } else { 607 t.height_ = 1 + max(t.left.height(), t.right.height()) 608 newroot = t 609 } 610 } else { // x > t.key 611 if t.right == nil { 612 t = t.copy() 613 n := makeNode(x) 614 t.right = n 615 newnode = n 616 newroot = t 617 t.height_ = 2 // was balanced w/ 0, sibling is height 0 or 1 618 return 619 } 620 var new_r *node32 621 new_r, newnode, oldnode = t.right.aInsert(x) 622 t = t.copy() 623 t.right = new_r 624 if new_r.height() > 1+t.left.height() { 625 newroot = t.aRightIsHigh(newnode) 626 } else { 627 t.height_ = 1 + max(t.left.height(), t.right.height()) 628 newroot = t 629 } 630 } 631 return 632} 633 634func (t *node32) aDelete(key int32) (deleted, newSubTree *node32) { 635 if t == nil { 636 return nil, nil 637 } 638 639 if key < t.key { 640 oh := t.left.height() 641 d, tleft := t.left.aDelete(key) 642 if tleft == t.left { 643 return d, t 644 } 645 return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft) 646 } else if key > t.key { 647 oh := t.right.height() 648 d, tright := t.right.aDelete(key) 649 if tright == t.right { 650 return d, t 651 } 652 return d, t.copy().aRebalanceAfterRightDeletion(oh, tright) 653 } 654 655 if t.height() == LEAF_HEIGHT { 656 return t, nil 657 } 658 659 // Interior delete by removing left.Max or right.Min, 660 // then swapping contents 661 if t.left.height() > t.right.height() { 662 oh := t.left.height() 663 d, tleft := t.left.aDeleteMax() 664 r := t 665 t = t.copy() 666 t.data, t.key = d.data, d.key 667 return r, t.aRebalanceAfterLeftDeletion(oh, tleft) 668 } 669 670 oh := t.right.height() 671 d, tright := t.right.aDeleteMin() 672 r := t 673 t = t.copy() 674 t.data, t.key = d.data, d.key 675 return r, t.aRebalanceAfterRightDeletion(oh, tright) 676} 677 678func (t *node32) aDeleteMin() (deleted, newSubTree *node32) { 679 if t == nil { 680 return nil, nil 681 } 682 if t.left == nil { // leaf or left-most 683 return t, t.right 684 } 685 oh := t.left.height() 686 d, tleft := t.left.aDeleteMin() 687 if tleft == t.left { 688 return d, t 689 } 690 return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft) 691} 692 693func (t *node32) aDeleteMax() (deleted, newSubTree *node32) { 694 if t == nil { 695 return nil, nil 696 } 697 698 if t.right == nil { // leaf or right-most 699 return t, t.left 700 } 701 702 oh := t.right.height() 703 d, tright := t.right.aDeleteMax() 704 if tright == t.right { 705 return d, t 706 } 707 return d, t.copy().aRebalanceAfterRightDeletion(oh, tright) 708} 709 710func (t *node32) aRebalanceAfterLeftDeletion(oldLeftHeight int8, tleft *node32) *node32 { 711 t.left = tleft 712 713 if oldLeftHeight == tleft.height() || oldLeftHeight == t.right.height() { 714 // this node is still balanced and its height is unchanged 715 return t 716 } 717 718 if oldLeftHeight > t.right.height() { 719 // left was larger 720 t.height_-- 721 return t 722 } 723 724 // left height fell by 1 and it was already less than right height 725 t.right = t.right.copy() 726 return t.aRightIsHigh(nil) 727} 728 729func (t *node32) aRebalanceAfterRightDeletion(oldRightHeight int8, tright *node32) *node32 { 730 t.right = tright 731 732 if oldRightHeight == tright.height() || oldRightHeight == t.left.height() { 733 // this node is still balanced and its height is unchanged 734 return t 735 } 736 737 if oldRightHeight > t.left.height() { 738 // left was larger 739 t.height_-- 740 return t 741 } 742 743 // right height fell by 1 and it was already less than left height 744 t.left = t.left.copy() 745 return t.aLeftIsHigh(nil) 746} 747 748// aRightIsHigh does rotations necessary to fix a high right child 749// assume that t and t.right are already fresh copies. 750func (t *node32) aRightIsHigh(newnode *node32) *node32 { 751 right := t.right 752 if right.right.height() < right.left.height() { 753 // double rotation 754 if newnode != right.left { 755 right.left = right.left.copy() 756 } 757 t.right = right.leftToRoot() 758 } 759 t = t.rightToRoot() 760 return t 761} 762 763// aLeftIsHigh does rotations necessary to fix a high left child 764// assume that t and t.left are already fresh copies. 765func (t *node32) aLeftIsHigh(newnode *node32) *node32 { 766 left := t.left 767 if left.left.height() < left.right.height() { 768 // double rotation 769 if newnode != left.right { 770 left.right = left.right.copy() 771 } 772 t.left = left.rightToRoot() 773 } 774 t = t.leftToRoot() 775 return t 776} 777 778// rightToRoot does that rotation, modifying t and t.right in the process. 779func (t *node32) rightToRoot() *node32 { 780 // this 781 // left right 782 // rl rr 783 // 784 // becomes 785 // 786 // right 787 // this rr 788 // left rl 789 // 790 right := t.right 791 rl := right.left 792 right.left = t 793 // parent's child ptr fixed in caller 794 t.right = rl 795 t.height_ = 1 + max(rl.height(), t.left.height()) 796 right.height_ = 1 + max(t.height(), right.right.height()) 797 return right 798} 799 800// leftToRoot does that rotation, modifying t and t.left in the process. 801func (t *node32) leftToRoot() *node32 { 802 // this 803 // left right 804 // ll lr 805 // 806 // becomes 807 // 808 // left 809 // ll this 810 // lr right 811 // 812 left := t.left 813 lr := left.right 814 left.right = t 815 // parent's child ptr fixed in caller 816 t.left = lr 817 t.height_ = 1 + max(lr.height(), t.right.height()) 818 left.height_ = 1 + max(t.height(), left.left.height()) 819 return left 820} 821 822func max(a, b int8) int8 { 823 if a > b { 824 return a 825 } 826 return b 827} 828 829func (t *node32) copy() *node32 { 830 u := *t 831 return &u 832} 833