1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bitvec
6
7import (
8	"math/bits"
9
10	"cmd/compile/internal/base"
11)
12
13const (
14	wordBits  = 32
15	wordMask  = wordBits - 1
16	wordShift = 5
17)
18
19// A BitVec is a bit vector.
20type BitVec struct {
21	N int32    // number of bits in vector
22	B []uint32 // words holding bits
23}
24
25func New(n int32) BitVec {
26	nword := (n + wordBits - 1) / wordBits
27	return BitVec{n, make([]uint32, nword)}
28}
29
30type Bulk struct {
31	words []uint32
32	nbit  int32
33	nword int32
34}
35
36func NewBulk(nbit int32, count int32) Bulk {
37	nword := (nbit + wordBits - 1) / wordBits
38	size := int64(nword) * int64(count)
39	if int64(int32(size*4)) != size*4 {
40		base.Fatalf("NewBulk too big: nbit=%d count=%d nword=%d size=%d", nbit, count, nword, size)
41	}
42	return Bulk{
43		words: make([]uint32, size),
44		nbit:  nbit,
45		nword: nword,
46	}
47}
48
49func (b *Bulk) Next() BitVec {
50	out := BitVec{b.nbit, b.words[:b.nword]}
51	b.words = b.words[b.nword:]
52	return out
53}
54
55func (bv1 BitVec) Eq(bv2 BitVec) bool {
56	if bv1.N != bv2.N {
57		base.Fatalf("bvequal: lengths %d and %d are not equal", bv1.N, bv2.N)
58	}
59	for i, x := range bv1.B {
60		if x != bv2.B[i] {
61			return false
62		}
63	}
64	return true
65}
66
67func (dst BitVec) Copy(src BitVec) {
68	copy(dst.B, src.B)
69}
70
71func (bv BitVec) Get(i int32) bool {
72	if i < 0 || i >= bv.N {
73		base.Fatalf("bvget: index %d is out of bounds with length %d\n", i, bv.N)
74	}
75	mask := uint32(1 << uint(i%wordBits))
76	return bv.B[i>>wordShift]&mask != 0
77}
78
79func (bv BitVec) Set(i int32) {
80	if i < 0 || i >= bv.N {
81		base.Fatalf("bvset: index %d is out of bounds with length %d\n", i, bv.N)
82	}
83	mask := uint32(1 << uint(i%wordBits))
84	bv.B[i/wordBits] |= mask
85}
86
87func (bv BitVec) Unset(i int32) {
88	if i < 0 || i >= bv.N {
89		base.Fatalf("bvunset: index %d is out of bounds with length %d\n", i, bv.N)
90	}
91	mask := uint32(1 << uint(i%wordBits))
92	bv.B[i/wordBits] &^= mask
93}
94
95// bvnext returns the smallest index >= i for which bvget(bv, i) == 1.
96// If there is no such index, bvnext returns -1.
97func (bv BitVec) Next(i int32) int32 {
98	if i >= bv.N {
99		return -1
100	}
101
102	// Jump i ahead to next word with bits.
103	if bv.B[i>>wordShift]>>uint(i&wordMask) == 0 {
104		i &^= wordMask
105		i += wordBits
106		for i < bv.N && bv.B[i>>wordShift] == 0 {
107			i += wordBits
108		}
109	}
110
111	if i >= bv.N {
112		return -1
113	}
114
115	// Find 1 bit.
116	w := bv.B[i>>wordShift] >> uint(i&wordMask)
117	i += int32(bits.TrailingZeros32(w))
118
119	return i
120}
121
122func (bv BitVec) IsEmpty() bool {
123	for _, x := range bv.B {
124		if x != 0 {
125			return false
126		}
127	}
128	return true
129}
130
131func (bv BitVec) Count() int {
132	n := 0
133	for _, x := range bv.B {
134		n += bits.OnesCount32(x)
135	}
136	return n
137}
138
139func (bv BitVec) Not() {
140	for i, x := range bv.B {
141		bv.B[i] = ^x
142	}
143	if bv.N%wordBits != 0 {
144		bv.B[len(bv.B)-1] &= 1<<uint(bv.N%wordBits) - 1 // clear bits past N in the last word
145	}
146}
147
148// union
149func (dst BitVec) Or(src1, src2 BitVec) {
150	if len(src1.B) == 0 {
151		return
152	}
153	_, _ = dst.B[len(src1.B)-1], src2.B[len(src1.B)-1] // hoist bounds checks out of the loop
154
155	for i, x := range src1.B {
156		dst.B[i] = x | src2.B[i]
157	}
158}
159
160// intersection
161func (dst BitVec) And(src1, src2 BitVec) {
162	if len(src1.B) == 0 {
163		return
164	}
165	_, _ = dst.B[len(src1.B)-1], src2.B[len(src1.B)-1] // hoist bounds checks out of the loop
166
167	for i, x := range src1.B {
168		dst.B[i] = x & src2.B[i]
169	}
170}
171
172// difference
173func (dst BitVec) AndNot(src1, src2 BitVec) {
174	if len(src1.B) == 0 {
175		return
176	}
177	_, _ = dst.B[len(src1.B)-1], src2.B[len(src1.B)-1] // hoist bounds checks out of the loop
178
179	for i, x := range src1.B {
180		dst.B[i] = x &^ src2.B[i]
181	}
182}
183
184func (bv BitVec) String() string {
185	s := make([]byte, 2+bv.N)
186	copy(s, "#*")
187	for i := int32(0); i < bv.N; i++ {
188		ch := byte('0')
189		if bv.Get(i) {
190			ch = '1'
191		}
192		s[2+i] = ch
193	}
194	return string(s)
195}
196
197func (bv BitVec) Clear() {
198	for i := range bv.B {
199		bv.B[i] = 0
200	}
201}
202