1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package abt
6
7import (
8	"fmt"
9	"strconv"
10	"strings"
11)
12
13const (
14	LEAF_HEIGHT = 1
15	ZERO_HEIGHT = 0
16	NOT_KEY32   = int32(-0x80000000)
17)
18
19// T is the exported applicative balanced tree data type.
20// A T can be used as a value; updates to one copy of the value
21// do not change other copies.
22type T struct {
23	root *node32
24	size int
25}
26
27// node32 is the internal tree node data type
28type node32 struct {
29	// Standard conventions hold for left = smaller, right = larger
30	left, right *node32
31	data        interface{}
32	key         int32
33	height_     int8
34}
35
36func makeNode(key int32) *node32 {
37	return &node32{key: key, height_: LEAF_HEIGHT}
38}
39
40// IsEmpty returns true iff t is empty.
41func (t *T) IsEmpty() bool {
42	return t.root == nil
43}
44
45// IsSingle returns true iff t is a singleton (leaf).
46func (t *T) IsSingle() bool {
47	return t.root != nil && t.root.isLeaf()
48}
49
50// VisitInOrder applies f to the key and data pairs in t,
51// with keys ordered from smallest to largest.
52func (t *T) VisitInOrder(f func(int32, interface{})) {
53	if t.root == nil {
54		return
55	}
56	t.root.visitInOrder(f)
57}
58
59func (n *node32) nilOrData() interface{} {
60	if n == nil {
61		return nil
62	}
63	return n.data
64}
65
66func (n *node32) nilOrKeyAndData() (k int32, d interface{}) {
67	if n == nil {
68		k = NOT_KEY32
69		d = nil
70	} else {
71		k = n.key
72		d = n.data
73	}
74	return
75}
76
77func (n *node32) height() int8 {
78	if n == nil {
79		return 0
80	}
81	return n.height_
82}
83
84// Find returns the data associated with x in the tree, or
85// nil if x is not in the tree.
86func (t *T) Find(x int32) interface{} {
87	return t.root.find(x).nilOrData()
88}
89
90// Insert either adds x to the tree if x was not previously
91// a key in the tree, or updates the data for x in the tree if
92// x was already a key in the tree.  The previous data associated
93// with x is returned, and is nil if x was not previously a
94// key in the tree.
95func (t *T) Insert(x int32, data interface{}) interface{} {
96	if x == NOT_KEY32 {
97		panic("Cannot use sentinel value -0x80000000 as key")
98	}
99	n := t.root
100	var newroot *node32
101	var o *node32
102	if n == nil {
103		n = makeNode(x)
104		newroot = n
105	} else {
106		newroot, n, o = n.aInsert(x)
107	}
108	var r interface{}
109	if o != nil {
110		r = o.data
111	} else {
112		t.size++
113	}
114	n.data = data
115	t.root = newroot
116	return r
117}
118
119func (t *T) Copy() *T {
120	u := *t
121	return &u
122}
123
124func (t *T) Delete(x int32) interface{} {
125	n := t.root
126	if n == nil {
127		return nil
128	}
129	d, s := n.aDelete(x)
130	if d == nil {
131		return nil
132	}
133	t.root = s
134	t.size--
135	return d.data
136}
137
138func (t *T) DeleteMin() (int32, interface{}) {
139	n := t.root
140	if n == nil {
141		return NOT_KEY32, nil
142	}
143	d, s := n.aDeleteMin()
144	if d == nil {
145		return NOT_KEY32, nil
146	}
147	t.root = s
148	t.size--
149	return d.key, d.data
150}
151
152func (t *T) DeleteMax() (int32, interface{}) {
153	n := t.root
154	if n == nil {
155		return NOT_KEY32, nil
156	}
157	d, s := n.aDeleteMax()
158	if d == nil {
159		return NOT_KEY32, nil
160	}
161	t.root = s
162	t.size--
163	return d.key, d.data
164}
165
166func (t *T) Size() int {
167	return t.size
168}
169
170// Intersection returns the intersection of t and u, where the result
171// data for any common keys is given by f(t's data, u's data) -- f need
172// not be symmetric.  If f returns nil, then the key and data are not
173// added to the result.  If f itself is nil, then whatever value was
174// already present in the smaller set is used.
175func (t *T) Intersection(u *T, f func(x, y interface{}) interface{}) *T {
176	if t.Size() == 0 || u.Size() == 0 {
177		return &T{}
178	}
179
180	// For faster execution and less allocation, prefer t smaller, iterate over t.
181	if t.Size() <= u.Size() {
182		v := t.Copy()
183		for it := t.Iterator(); !it.Done(); {
184			k, d := it.Next()
185			e := u.Find(k)
186			if e == nil {
187				v.Delete(k)
188				continue
189			}
190			if f == nil {
191				continue
192			}
193			if c := f(d, e); c != d {
194				if c == nil {
195					v.Delete(k)
196				} else {
197					v.Insert(k, c)
198				}
199			}
200		}
201		return v
202	}
203	v := u.Copy()
204	for it := u.Iterator(); !it.Done(); {
205		k, e := it.Next()
206		d := t.Find(k)
207		if d == nil {
208			v.Delete(k)
209			continue
210		}
211		if f == nil {
212			continue
213		}
214		if c := f(d, e); c != d {
215			if c == nil {
216				v.Delete(k)
217			} else {
218				v.Insert(k, c)
219			}
220		}
221	}
222
223	return v
224}
225
226// Union returns the union of t and u, where the result data for any common keys
227// is given by f(t's data, u's data) -- f need not be symmetric.  If f returns nil,
228// then the key and data are not added to the result.  If f itself is nil, then
229// whatever value was already present in the larger set is used.
230func (t *T) Union(u *T, f func(x, y interface{}) interface{}) *T {
231	if t.Size() == 0 {
232		return u
233	}
234	if u.Size() == 0 {
235		return t
236	}
237
238	if t.Size() >= u.Size() {
239		v := t.Copy()
240		for it := u.Iterator(); !it.Done(); {
241			k, e := it.Next()
242			d := t.Find(k)
243			if d == nil {
244				v.Insert(k, e)
245				continue
246			}
247			if f == nil {
248				continue
249			}
250			if c := f(d, e); c != d {
251				if c == nil {
252					v.Delete(k)
253				} else {
254					v.Insert(k, c)
255				}
256			}
257		}
258		return v
259	}
260
261	v := u.Copy()
262	for it := t.Iterator(); !it.Done(); {
263		k, d := it.Next()
264		e := u.Find(k)
265		if e == nil {
266			v.Insert(k, d)
267			continue
268		}
269		if f == nil {
270			continue
271		}
272		if c := f(d, e); c != d {
273			if c == nil {
274				v.Delete(k)
275			} else {
276				v.Insert(k, c)
277			}
278		}
279	}
280	return v
281}
282
283// Difference returns the difference of t and u, subject to the result
284// of f applied to data corresponding to equal keys.  If f returns nil
285// (or if f is nil) then the key+data are excluded, as usual.  If f
286// returns not-nil, then that key+data pair is inserted. instead.
287func (t *T) Difference(u *T, f func(x, y interface{}) interface{}) *T {
288	if t.Size() == 0 {
289		return &T{}
290	}
291	if u.Size() == 0 {
292		return t
293	}
294	v := t.Copy()
295	for it := t.Iterator(); !it.Done(); {
296		k, d := it.Next()
297		e := u.Find(k)
298		if e != nil {
299			if f == nil {
300				v.Delete(k)
301				continue
302			}
303			c := f(d, e)
304			if c == nil {
305				v.Delete(k)
306				continue
307			}
308			if c != d {
309				v.Insert(k, c)
310			}
311		}
312	}
313	return v
314}
315
316func (t *T) Iterator() Iterator {
317	return Iterator{it: t.root.iterator()}
318}
319
320func (t *T) Equals(u *T) bool {
321	if t == u {
322		return true
323	}
324	if t.Size() != u.Size() {
325		return false
326	}
327	return t.root.equals(u.root)
328}
329
330func (t *T) String() string {
331	var b strings.Builder
332	first := true
333	for it := t.Iterator(); !it.Done(); {
334		k, v := it.Next()
335		if first {
336			first = false
337		} else {
338			b.WriteString("; ")
339		}
340		b.WriteString(strconv.FormatInt(int64(k), 10))
341		b.WriteString(":")
342		fmt.Fprint(&b, v)
343	}
344	return b.String()
345}
346
347func (t *node32) equals(u *node32) bool {
348	if t == u {
349		return true
350	}
351	it, iu := t.iterator(), u.iterator()
352	for !it.done() && !iu.done() {
353		nt := it.next()
354		nu := iu.next()
355		if nt == nu {
356			continue
357		}
358		if nt.key != nu.key {
359			return false
360		}
361		if nt.data != nu.data {
362			return false
363		}
364	}
365	return it.done() == iu.done()
366}
367
368func (t *T) Equiv(u *T, eqv func(x, y interface{}) bool) bool {
369	if t == u {
370		return true
371	}
372	if t.Size() != u.Size() {
373		return false
374	}
375	return t.root.equiv(u.root, eqv)
376}
377
378func (t *node32) equiv(u *node32, eqv func(x, y interface{}) bool) bool {
379	if t == u {
380		return true
381	}
382	it, iu := t.iterator(), u.iterator()
383	for !it.done() && !iu.done() {
384		nt := it.next()
385		nu := iu.next()
386		if nt == nu {
387			continue
388		}
389		if nt.key != nu.key {
390			return false
391		}
392		if !eqv(nt.data, nu.data) {
393			return false
394		}
395	}
396	return it.done() == iu.done()
397}
398
399type iterator struct {
400	parents []*node32
401}
402
403type Iterator struct {
404	it iterator
405}
406
407func (it *Iterator) Next() (int32, interface{}) {
408	x := it.it.next()
409	if x == nil {
410		return NOT_KEY32, nil
411	}
412	return x.key, x.data
413}
414
415func (it *Iterator) Done() bool {
416	return len(it.it.parents) == 0
417}
418
419func (t *node32) iterator() iterator {
420	if t == nil {
421		return iterator{}
422	}
423	it := iterator{parents: make([]*node32, 0, int(t.height()))}
424	it.leftmost(t)
425	return it
426}
427
428func (it *iterator) leftmost(t *node32) {
429	for t != nil {
430		it.parents = append(it.parents, t)
431		t = t.left
432	}
433}
434
435func (it *iterator) done() bool {
436	return len(it.parents) == 0
437}
438
439func (it *iterator) next() *node32 {
440	l := len(it.parents)
441	if l == 0 {
442		return nil
443	}
444	x := it.parents[l-1] // return value
445	if x.right != nil {
446		it.leftmost(x.right)
447		return x
448	}
449	// discard visited top of parents
450	l--
451	it.parents = it.parents[:l]
452	y := x // y is known visited/returned
453	for l > 0 && y == it.parents[l-1].right {
454		y = it.parents[l-1]
455		l--
456		it.parents = it.parents[:l]
457	}
458
459	return x
460}
461
462// Min returns the minimum element of t.
463// If t is empty, then (NOT_KEY32, nil) is returned.
464func (t *T) Min() (k int32, d interface{}) {
465	return t.root.min().nilOrKeyAndData()
466}
467
468// Max returns the maximum element of t.
469// If t is empty, then (NOT_KEY32, nil) is returned.
470func (t *T) Max() (k int32, d interface{}) {
471	return t.root.max().nilOrKeyAndData()
472}
473
474// Glb returns the greatest-lower-bound-exclusive of x and the associated
475// data.  If x has no glb in the tree, then (NOT_KEY32, nil) is returned.
476func (t *T) Glb(x int32) (k int32, d interface{}) {
477	return t.root.glb(x, false).nilOrKeyAndData()
478}
479
480// GlbEq returns the greatest-lower-bound-inclusive of x and the associated
481// data.  If x has no glbEQ in the tree, then (NOT_KEY32, nil) is returned.
482func (t *T) GlbEq(x int32) (k int32, d interface{}) {
483	return t.root.glb(x, true).nilOrKeyAndData()
484}
485
486// Lub returns the least-upper-bound-exclusive of x and the associated
487// data.  If x has no lub in the tree, then (NOT_KEY32, nil) is returned.
488func (t *T) Lub(x int32) (k int32, d interface{}) {
489	return t.root.lub(x, false).nilOrKeyAndData()
490}
491
492// LubEq returns the least-upper-bound-inclusive of x and the associated
493// data.  If x has no lubEq in the tree, then (NOT_KEY32, nil) is returned.
494func (t *T) LubEq(x int32) (k int32, d interface{}) {
495	return t.root.lub(x, true).nilOrKeyAndData()
496}
497
498func (t *node32) isLeaf() bool {
499	return t.left == nil && t.right == nil && t.height_ == LEAF_HEIGHT
500}
501
502func (t *node32) visitInOrder(f func(int32, interface{})) {
503	if t.left != nil {
504		t.left.visitInOrder(f)
505	}
506	f(t.key, t.data)
507	if t.right != nil {
508		t.right.visitInOrder(f)
509	}
510}
511
512func (t *node32) find(key int32) *node32 {
513	for t != nil {
514		if key < t.key {
515			t = t.left
516		} else if key > t.key {
517			t = t.right
518		} else {
519			return t
520		}
521	}
522	return nil
523}
524
525func (t *node32) min() *node32 {
526	if t == nil {
527		return t
528	}
529	for t.left != nil {
530		t = t.left
531	}
532	return t
533}
534
535func (t *node32) max() *node32 {
536	if t == nil {
537		return t
538	}
539	for t.right != nil {
540		t = t.right
541	}
542	return t
543}
544
545func (t *node32) glb(key int32, allow_eq bool) *node32 {
546	var best *node32 = nil
547	for t != nil {
548		if key <= t.key {
549			if allow_eq && key == t.key {
550				return t
551			}
552			// t is too big, glb is to left.
553			t = t.left
554		} else {
555			// t is a lower bound, record it and seek a better one.
556			best = t
557			t = t.right
558		}
559	}
560	return best
561}
562
563func (t *node32) lub(key int32, allow_eq bool) *node32 {
564	var best *node32 = nil
565	for t != nil {
566		if key >= t.key {
567			if allow_eq && key == t.key {
568				return t
569			}
570			// t is too small, lub is to right.
571			t = t.right
572		} else {
573			// t is an upper bound, record it and seek a better one.
574			best = t
575			t = t.left
576		}
577	}
578	return best
579}
580
581func (t *node32) aInsert(x int32) (newroot, newnode, oldnode *node32) {
582	// oldnode default of nil is good, others should be assigned.
583	if x == t.key {
584		oldnode = t
585		newt := *t
586		newnode = &newt
587		newroot = newnode
588		return
589	}
590	if x < t.key {
591		if t.left == nil {
592			t = t.copy()
593			n := makeNode(x)
594			t.left = n
595			newnode = n
596			newroot = t
597			t.height_ = 2 // was balanced w/ 0, sibling is height 0 or 1
598			return
599		}
600		var new_l *node32
601		new_l, newnode, oldnode = t.left.aInsert(x)
602		t = t.copy()
603		t.left = new_l
604		if new_l.height() > 1+t.right.height() {
605			newroot = t.aLeftIsHigh(newnode)
606		} else {
607			t.height_ = 1 + max(t.left.height(), t.right.height())
608			newroot = t
609		}
610	} else { // x > t.key
611		if t.right == nil {
612			t = t.copy()
613			n := makeNode(x)
614			t.right = n
615			newnode = n
616			newroot = t
617			t.height_ = 2 // was balanced w/ 0, sibling is height 0 or 1
618			return
619		}
620		var new_r *node32
621		new_r, newnode, oldnode = t.right.aInsert(x)
622		t = t.copy()
623		t.right = new_r
624		if new_r.height() > 1+t.left.height() {
625			newroot = t.aRightIsHigh(newnode)
626		} else {
627			t.height_ = 1 + max(t.left.height(), t.right.height())
628			newroot = t
629		}
630	}
631	return
632}
633
634func (t *node32) aDelete(key int32) (deleted, newSubTree *node32) {
635	if t == nil {
636		return nil, nil
637	}
638
639	if key < t.key {
640		oh := t.left.height()
641		d, tleft := t.left.aDelete(key)
642		if tleft == t.left {
643			return d, t
644		}
645		return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft)
646	} else if key > t.key {
647		oh := t.right.height()
648		d, tright := t.right.aDelete(key)
649		if tright == t.right {
650			return d, t
651		}
652		return d, t.copy().aRebalanceAfterRightDeletion(oh, tright)
653	}
654
655	if t.height() == LEAF_HEIGHT {
656		return t, nil
657	}
658
659	// Interior delete by removing left.Max or right.Min,
660	// then swapping contents
661	if t.left.height() > t.right.height() {
662		oh := t.left.height()
663		d, tleft := t.left.aDeleteMax()
664		r := t
665		t = t.copy()
666		t.data, t.key = d.data, d.key
667		return r, t.aRebalanceAfterLeftDeletion(oh, tleft)
668	}
669
670	oh := t.right.height()
671	d, tright := t.right.aDeleteMin()
672	r := t
673	t = t.copy()
674	t.data, t.key = d.data, d.key
675	return r, t.aRebalanceAfterRightDeletion(oh, tright)
676}
677
678func (t *node32) aDeleteMin() (deleted, newSubTree *node32) {
679	if t == nil {
680		return nil, nil
681	}
682	if t.left == nil { // leaf or left-most
683		return t, t.right
684	}
685	oh := t.left.height()
686	d, tleft := t.left.aDeleteMin()
687	if tleft == t.left {
688		return d, t
689	}
690	return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft)
691}
692
693func (t *node32) aDeleteMax() (deleted, newSubTree *node32) {
694	if t == nil {
695		return nil, nil
696	}
697
698	if t.right == nil { // leaf or right-most
699		return t, t.left
700	}
701
702	oh := t.right.height()
703	d, tright := t.right.aDeleteMax()
704	if tright == t.right {
705		return d, t
706	}
707	return d, t.copy().aRebalanceAfterRightDeletion(oh, tright)
708}
709
710func (t *node32) aRebalanceAfterLeftDeletion(oldLeftHeight int8, tleft *node32) *node32 {
711	t.left = tleft
712
713	if oldLeftHeight == tleft.height() || oldLeftHeight == t.right.height() {
714		// this node is still balanced and its height is unchanged
715		return t
716	}
717
718	if oldLeftHeight > t.right.height() {
719		// left was larger
720		t.height_--
721		return t
722	}
723
724	// left height fell by 1 and it was already less than right height
725	t.right = t.right.copy()
726	return t.aRightIsHigh(nil)
727}
728
729func (t *node32) aRebalanceAfterRightDeletion(oldRightHeight int8, tright *node32) *node32 {
730	t.right = tright
731
732	if oldRightHeight == tright.height() || oldRightHeight == t.left.height() {
733		// this node is still balanced and its height is unchanged
734		return t
735	}
736
737	if oldRightHeight > t.left.height() {
738		// left was larger
739		t.height_--
740		return t
741	}
742
743	// right height fell by 1 and it was already less than left height
744	t.left = t.left.copy()
745	return t.aLeftIsHigh(nil)
746}
747
748// aRightIsHigh does rotations necessary to fix a high right child
749// assume that t and t.right are already fresh copies.
750func (t *node32) aRightIsHigh(newnode *node32) *node32 {
751	right := t.right
752	if right.right.height() < right.left.height() {
753		// double rotation
754		if newnode != right.left {
755			right.left = right.left.copy()
756		}
757		t.right = right.leftToRoot()
758	}
759	t = t.rightToRoot()
760	return t
761}
762
763// aLeftIsHigh does rotations necessary to fix a high left child
764// assume that t and t.left are already fresh copies.
765func (t *node32) aLeftIsHigh(newnode *node32) *node32 {
766	left := t.left
767	if left.left.height() < left.right.height() {
768		// double rotation
769		if newnode != left.right {
770			left.right = left.right.copy()
771		}
772		t.left = left.rightToRoot()
773	}
774	t = t.leftToRoot()
775	return t
776}
777
778// rightToRoot does that rotation, modifying t and t.right in the process.
779func (t *node32) rightToRoot() *node32 {
780	//    this
781	// left  right
782	//      rl   rr
783	//
784	// becomes
785	//
786	//       right
787	//    this   rr
788	// left  rl
789	//
790	right := t.right
791	rl := right.left
792	right.left = t
793	// parent's child ptr fixed in caller
794	t.right = rl
795	t.height_ = 1 + max(rl.height(), t.left.height())
796	right.height_ = 1 + max(t.height(), right.right.height())
797	return right
798}
799
800// leftToRoot does that rotation, modifying t and t.left in the process.
801func (t *node32) leftToRoot() *node32 {
802	//     this
803	//  left  right
804	// ll  lr
805	//
806	// becomes
807	//
808	//    left
809	//   ll  this
810	//      lr  right
811	//
812	left := t.left
813	lr := left.right
814	left.right = t
815	// parent's child ptr fixed in caller
816	t.left = lr
817	t.height_ = 1 + max(lr.height(), t.right.height())
818	left.height_ = 1 + max(t.height(), left.left.height())
819	return left
820}
821
822func max(a, b int8) int8 {
823	if a > b {
824		return a
825	}
826	return b
827}
828
829func (t *node32) copy() *node32 {
830	u := *t
831	return &u
832}
833