xref: /aosp_15_r20/external/clang/utils/ABITest/Enumeration.py (revision 67e74705e28f6214e480b399dd47ea732279e315)
1*67e74705SXin Li"""Utilities for enumeration of finite and countably infinite sets.
2*67e74705SXin Li"""
3*67e74705SXin Li###
4*67e74705SXin Li# Countable iteration
5*67e74705SXin Li
6*67e74705SXin Li# Simplifies some calculations
7*67e74705SXin Liclass Aleph0(int):
8*67e74705SXin Li    _singleton = None
9*67e74705SXin Li    def __new__(type):
10*67e74705SXin Li        if type._singleton is None:
11*67e74705SXin Li            type._singleton = int.__new__(type)
12*67e74705SXin Li        return type._singleton
13*67e74705SXin Li    def __repr__(self): return '<aleph0>'
14*67e74705SXin Li    def __str__(self): return 'inf'
15*67e74705SXin Li
16*67e74705SXin Li    def __cmp__(self, b):
17*67e74705SXin Li        return 1
18*67e74705SXin Li
19*67e74705SXin Li    def __sub__(self, b):
20*67e74705SXin Li        raise ValueError,"Cannot subtract aleph0"
21*67e74705SXin Li    __rsub__ = __sub__
22*67e74705SXin Li
23*67e74705SXin Li    def __add__(self, b):
24*67e74705SXin Li        return self
25*67e74705SXin Li    __radd__ = __add__
26*67e74705SXin Li
27*67e74705SXin Li    def __mul__(self, b):
28*67e74705SXin Li        if b == 0: return b
29*67e74705SXin Li        return self
30*67e74705SXin Li    __rmul__ = __mul__
31*67e74705SXin Li
32*67e74705SXin Li    def __floordiv__(self, b):
33*67e74705SXin Li        if b == 0: raise ZeroDivisionError
34*67e74705SXin Li        return self
35*67e74705SXin Li    __rfloordiv__ = __floordiv__
36*67e74705SXin Li    __truediv__ = __floordiv__
37*67e74705SXin Li    __rtuediv__ = __floordiv__
38*67e74705SXin Li    __div__ = __floordiv__
39*67e74705SXin Li    __rdiv__ = __floordiv__
40*67e74705SXin Li
41*67e74705SXin Li    def __pow__(self, b):
42*67e74705SXin Li        if b == 0: return 1
43*67e74705SXin Li        return self
44*67e74705SXin Lialeph0 = Aleph0()
45*67e74705SXin Li
46*67e74705SXin Lidef base(line):
47*67e74705SXin Li    return line*(line+1)//2
48*67e74705SXin Li
49*67e74705SXin Lidef pairToN((x,y)):
50*67e74705SXin Li    line,index = x+y,y
51*67e74705SXin Li    return base(line)+index
52*67e74705SXin Li
53*67e74705SXin Lidef getNthPairInfo(N):
54*67e74705SXin Li    # Avoid various singularities
55*67e74705SXin Li    if N==0:
56*67e74705SXin Li        return (0,0)
57*67e74705SXin Li
58*67e74705SXin Li    # Gallop to find bounds for line
59*67e74705SXin Li    line = 1
60*67e74705SXin Li    next = 2
61*67e74705SXin Li    while base(next)<=N:
62*67e74705SXin Li        line = next
63*67e74705SXin Li        next = line << 1
64*67e74705SXin Li
65*67e74705SXin Li    # Binary search for starting line
66*67e74705SXin Li    lo = line
67*67e74705SXin Li    hi = line<<1
68*67e74705SXin Li    while lo + 1 != hi:
69*67e74705SXin Li        #assert base(lo) <= N < base(hi)
70*67e74705SXin Li        mid = (lo + hi)>>1
71*67e74705SXin Li        if base(mid)<=N:
72*67e74705SXin Li            lo = mid
73*67e74705SXin Li        else:
74*67e74705SXin Li            hi = mid
75*67e74705SXin Li
76*67e74705SXin Li    line = lo
77*67e74705SXin Li    return line, N - base(line)
78*67e74705SXin Li
79*67e74705SXin Lidef getNthPair(N):
80*67e74705SXin Li    line,index = getNthPairInfo(N)
81*67e74705SXin Li    return (line - index, index)
82*67e74705SXin Li
83*67e74705SXin Lidef getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
84*67e74705SXin Li    """getNthPairBounded(N, W, H) -> (x, y)
85*67e74705SXin Li
86*67e74705SXin Li    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
87*67e74705SXin Li
88*67e74705SXin Li    if W <= 0 or H <= 0:
89*67e74705SXin Li        raise ValueError,"Invalid bounds"
90*67e74705SXin Li    elif N >= W*H:
91*67e74705SXin Li        raise ValueError,"Invalid input (out of bounds)"
92*67e74705SXin Li
93*67e74705SXin Li    # Simple case...
94*67e74705SXin Li    if W is aleph0 and H is aleph0:
95*67e74705SXin Li        return getNthPair(N)
96*67e74705SXin Li
97*67e74705SXin Li    # Otherwise simplify by assuming W < H
98*67e74705SXin Li    if H < W:
99*67e74705SXin Li        x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
100*67e74705SXin Li        return y,x
101*67e74705SXin Li
102*67e74705SXin Li    if useDivmod:
103*67e74705SXin Li        return N%W,N//W
104*67e74705SXin Li    else:
105*67e74705SXin Li        # Conceptually we want to slide a diagonal line across a
106*67e74705SXin Li        # rectangle. This gives more interesting results for large
107*67e74705SXin Li        # bounds than using divmod.
108*67e74705SXin Li
109*67e74705SXin Li        # If in lower left, just return as usual
110*67e74705SXin Li        cornerSize = base(W)
111*67e74705SXin Li        if N < cornerSize:
112*67e74705SXin Li            return getNthPair(N)
113*67e74705SXin Li
114*67e74705SXin Li        # Otherwise if in upper right, subtract from corner
115*67e74705SXin Li        if H is not aleph0:
116*67e74705SXin Li            M = W*H - N - 1
117*67e74705SXin Li            if M < cornerSize:
118*67e74705SXin Li                x,y = getNthPair(M)
119*67e74705SXin Li                return (W-1-x,H-1-y)
120*67e74705SXin Li
121*67e74705SXin Li        # Otherwise, compile line and index from number of times we
122*67e74705SXin Li        # wrap.
123*67e74705SXin Li        N = N - cornerSize
124*67e74705SXin Li        index,offset = N%W,N//W
125*67e74705SXin Li        # p = (W-1, 1+offset) + (-1,1)*index
126*67e74705SXin Li        return (W-1-index, 1+offset+index)
127*67e74705SXin Lidef getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
128*67e74705SXin Li    x,y = GNP(N,W,H,useDivmod)
129*67e74705SXin Li    assert 0 <= x < W and 0 <= y < H
130*67e74705SXin Li    return x,y
131*67e74705SXin Li
132*67e74705SXin Lidef getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
133*67e74705SXin Li    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
134*67e74705SXin Li
135*67e74705SXin Li    Return the N-th W-tuple, where for 0 <= x_i < H."""
136*67e74705SXin Li
137*67e74705SXin Li    if useLeftToRight:
138*67e74705SXin Li        elts = [None]*W
139*67e74705SXin Li        for i in range(W):
140*67e74705SXin Li            elts[i],N = getNthPairBounded(N, H)
141*67e74705SXin Li        return tuple(elts)
142*67e74705SXin Li    else:
143*67e74705SXin Li        if W==0:
144*67e74705SXin Li            return ()
145*67e74705SXin Li        elif W==1:
146*67e74705SXin Li            return (N,)
147*67e74705SXin Li        elif W==2:
148*67e74705SXin Li            return getNthPairBounded(N, H, H)
149*67e74705SXin Li        else:
150*67e74705SXin Li            LW,RW = W//2, W - (W//2)
151*67e74705SXin Li            L,R = getNthPairBounded(N, H**LW, H**RW)
152*67e74705SXin Li            return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) +
153*67e74705SXin Li                    getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
154*67e74705SXin Lidef getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
155*67e74705SXin Li    t = GNT(N,W,H,useLeftToRight)
156*67e74705SXin Li    assert len(t) == W
157*67e74705SXin Li    for i in t:
158*67e74705SXin Li        assert i < H
159*67e74705SXin Li    return t
160*67e74705SXin Li
161*67e74705SXin Lidef getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
162*67e74705SXin Li    """getNthTuple(N, maxSize, maxElement) -> x
163*67e74705SXin Li
164*67e74705SXin Li    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
165*67e74705SXin Li    y < maxElement."""
166*67e74705SXin Li
167*67e74705SXin Li    # All zero sized tuples are isomorphic, don't ya know.
168*67e74705SXin Li    if N == 0:
169*67e74705SXin Li        return ()
170*67e74705SXin Li    N -= 1
171*67e74705SXin Li    if maxElement is not aleph0:
172*67e74705SXin Li        if maxSize is aleph0:
173*67e74705SXin Li            raise NotImplementedError,'Max element size without max size unhandled'
174*67e74705SXin Li        bounds = [maxElement**i for i in range(1, maxSize+1)]
175*67e74705SXin Li        S,M = getNthPairVariableBounds(N, bounds)
176*67e74705SXin Li    else:
177*67e74705SXin Li        S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
178*67e74705SXin Li    return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
179*67e74705SXin Lidef getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0,
180*67e74705SXin Li                       useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
181*67e74705SXin Li    # FIXME: maxsize is inclusive
182*67e74705SXin Li    t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
183*67e74705SXin Li    assert len(t) <= maxSize
184*67e74705SXin Li    for i in t:
185*67e74705SXin Li        assert i < maxElement
186*67e74705SXin Li    return t
187*67e74705SXin Li
188*67e74705SXin Lidef getNthPairVariableBounds(N, bounds):
189*67e74705SXin Li    """getNthPairVariableBounds(N, bounds) -> (x, y)
190*67e74705SXin Li
191*67e74705SXin Li    Given a finite list of bounds (which may be finite or aleph0),
192*67e74705SXin Li    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
193*67e74705SXin Li    bounds[x]."""
194*67e74705SXin Li
195*67e74705SXin Li    if not bounds:
196*67e74705SXin Li        raise ValueError,"Invalid bounds"
197*67e74705SXin Li    if not (0 <= N < sum(bounds)):
198*67e74705SXin Li        raise ValueError,"Invalid input (out of bounds)"
199*67e74705SXin Li
200*67e74705SXin Li    level = 0
201*67e74705SXin Li    active = range(len(bounds))
202*67e74705SXin Li    active.sort(key=lambda i: bounds[i])
203*67e74705SXin Li    prevLevel = 0
204*67e74705SXin Li    for i,index in enumerate(active):
205*67e74705SXin Li        level = bounds[index]
206*67e74705SXin Li        W = len(active) - i
207*67e74705SXin Li        if level is aleph0:
208*67e74705SXin Li            H = aleph0
209*67e74705SXin Li        else:
210*67e74705SXin Li            H = level - prevLevel
211*67e74705SXin Li        levelSize = W*H
212*67e74705SXin Li        if N<levelSize: # Found the level
213*67e74705SXin Li            idelta,delta = getNthPairBounded(N, W, H)
214*67e74705SXin Li            return active[i+idelta],prevLevel+delta
215*67e74705SXin Li        else:
216*67e74705SXin Li            N -= levelSize
217*67e74705SXin Li            prevLevel = level
218*67e74705SXin Li    else:
219*67e74705SXin Li        raise RuntimError,"Unexpected loop completion"
220*67e74705SXin Li
221*67e74705SXin Lidef getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
222*67e74705SXin Li    x,y = GNVP(N,bounds)
223*67e74705SXin Li    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
224*67e74705SXin Li    return (x,y)
225*67e74705SXin Li
226*67e74705SXin Li###
227*67e74705SXin Li
228*67e74705SXin Lidef testPairs():
229*67e74705SXin Li    W = 3
230*67e74705SXin Li    H = 6
231*67e74705SXin Li    a = [['  ' for x in range(10)] for y in range(10)]
232*67e74705SXin Li    b = [['  ' for x in range(10)] for y in range(10)]
233*67e74705SXin Li    for i in range(min(W*H,40)):
234*67e74705SXin Li        x,y = getNthPairBounded(i,W,H)
235*67e74705SXin Li        x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
236*67e74705SXin Li        print i,(x,y),(x2,y2)
237*67e74705SXin Li        a[y][x] = '%2d'%i
238*67e74705SXin Li        b[y2][x2] = '%2d'%i
239*67e74705SXin Li
240*67e74705SXin Li    print '-- a --'
241*67e74705SXin Li    for ln in a[::-1]:
242*67e74705SXin Li        if ''.join(ln).strip():
243*67e74705SXin Li            print '  '.join(ln)
244*67e74705SXin Li    print '-- b --'
245*67e74705SXin Li    for ln in b[::-1]:
246*67e74705SXin Li        if ''.join(ln).strip():
247*67e74705SXin Li            print '  '.join(ln)
248*67e74705SXin Li
249*67e74705SXin Lidef testPairsVB():
250*67e74705SXin Li    bounds = [2,2,4,aleph0,5,aleph0]
251*67e74705SXin Li    a = [['  ' for x in range(15)] for y in range(15)]
252*67e74705SXin Li    b = [['  ' for x in range(15)] for y in range(15)]
253*67e74705SXin Li    for i in range(min(sum(bounds),40)):
254*67e74705SXin Li        x,y = getNthPairVariableBounds(i, bounds)
255*67e74705SXin Li        print i,(x,y)
256*67e74705SXin Li        a[y][x] = '%2d'%i
257*67e74705SXin Li
258*67e74705SXin Li    print '-- a --'
259*67e74705SXin Li    for ln in a[::-1]:
260*67e74705SXin Li        if ''.join(ln).strip():
261*67e74705SXin Li            print '  '.join(ln)
262*67e74705SXin Li
263*67e74705SXin Li###
264*67e74705SXin Li
265*67e74705SXin Li# Toggle to use checked versions of enumeration routines.
266*67e74705SXin Liif False:
267*67e74705SXin Li    getNthPairVariableBounds = getNthPairVariableBoundsChecked
268*67e74705SXin Li    getNthPairBounded = getNthPairBoundedChecked
269*67e74705SXin Li    getNthNTuple = getNthNTupleChecked
270*67e74705SXin Li    getNthTuple = getNthTupleChecked
271*67e74705SXin Li
272*67e74705SXin Liif __name__ == '__main__':
273*67e74705SXin Li    testPairs()
274*67e74705SXin Li
275*67e74705SXin Li    testPairsVB()
276*67e74705SXin Li
277