xref: /aosp_15_r20/external/llvm/utils/shuffle_fuzz.py (revision 9880d6810fe72a1726cb53787c6711e909410d58)
1*9880d681SAndroid Build Coastguard Worker#!/usr/bin/env python
2*9880d681SAndroid Build Coastguard Worker
3*9880d681SAndroid Build Coastguard Worker"""A shuffle vector fuzz tester.
4*9880d681SAndroid Build Coastguard Worker
5*9880d681SAndroid Build Coastguard WorkerThis is a python program to fuzz test the LLVM shufflevector instruction. It
6*9880d681SAndroid Build Coastguard Workergenerates a function with a random sequnece of shufflevectors, maintaining the
7*9880d681SAndroid Build Coastguard Workerelement mapping accumulated across the function. It then generates a main
8*9880d681SAndroid Build Coastguard Workerfunction which calls it with a different value in each element and checks that
9*9880d681SAndroid Build Coastguard Workerthe result matches the expected mapping.
10*9880d681SAndroid Build Coastguard Worker
11*9880d681SAndroid Build Coastguard WorkerTake the output IR printed to stdout, compile it to an executable using whatever
12*9880d681SAndroid Build Coastguard Workerset of transforms you want to test, and run the program. If it crashes, it found
13*9880d681SAndroid Build Coastguard Workera bug.
14*9880d681SAndroid Build Coastguard Worker"""
15*9880d681SAndroid Build Coastguard Worker
16*9880d681SAndroid Build Coastguard Workerimport argparse
17*9880d681SAndroid Build Coastguard Workerimport itertools
18*9880d681SAndroid Build Coastguard Workerimport random
19*9880d681SAndroid Build Coastguard Workerimport sys
20*9880d681SAndroid Build Coastguard Workerimport uuid
21*9880d681SAndroid Build Coastguard Worker
22*9880d681SAndroid Build Coastguard Workerdef main():
23*9880d681SAndroid Build Coastguard Worker  element_types=['i8', 'i16', 'i32', 'i64', 'f32', 'f64']
24*9880d681SAndroid Build Coastguard Worker  parser = argparse.ArgumentParser(description=__doc__)
25*9880d681SAndroid Build Coastguard Worker  parser.add_argument('-v', '--verbose', action='store_true',
26*9880d681SAndroid Build Coastguard Worker                      help='Show verbose output')
27*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--seed', default=str(uuid.uuid4()),
28*9880d681SAndroid Build Coastguard Worker                      help='A string used to seed the RNG')
29*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--max-shuffle-height', type=int, default=16,
30*9880d681SAndroid Build Coastguard Worker                      help='Specify a fixed height of shuffle tree to test')
31*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--no-blends', dest='blends', action='store_false',
32*9880d681SAndroid Build Coastguard Worker                      help='Include blends of two input vectors')
33*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--fixed-bit-width', type=int, choices=[128, 256],
34*9880d681SAndroid Build Coastguard Worker                      help='Specify a fixed bit width of vector to test')
35*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--fixed-element-type', choices=element_types,
36*9880d681SAndroid Build Coastguard Worker                      help='Specify a fixed element type to test')
37*9880d681SAndroid Build Coastguard Worker  parser.add_argument('--triple',
38*9880d681SAndroid Build Coastguard Worker                      help='Specify a triple string to include in the IR')
39*9880d681SAndroid Build Coastguard Worker  args = parser.parse_args()
40*9880d681SAndroid Build Coastguard Worker
41*9880d681SAndroid Build Coastguard Worker  random.seed(args.seed)
42*9880d681SAndroid Build Coastguard Worker
43*9880d681SAndroid Build Coastguard Worker  if args.fixed_element_type is not None:
44*9880d681SAndroid Build Coastguard Worker    element_types=[args.fixed_element_type]
45*9880d681SAndroid Build Coastguard Worker
46*9880d681SAndroid Build Coastguard Worker  if args.fixed_bit_width is not None:
47*9880d681SAndroid Build Coastguard Worker    if args.fixed_bit_width == 128:
48*9880d681SAndroid Build Coastguard Worker      width_map={'i64': 2, 'i32': 4, 'i16': 8, 'i8': 16, 'f64': 2, 'f32': 4}
49*9880d681SAndroid Build Coastguard Worker      (width, element_type) = random.choice(
50*9880d681SAndroid Build Coastguard Worker          [(width_map[t], t) for t in element_types])
51*9880d681SAndroid Build Coastguard Worker    elif args.fixed_bit_width == 256:
52*9880d681SAndroid Build Coastguard Worker      width_map={'i64': 4, 'i32': 8, 'i16': 16, 'i8': 32, 'f64': 4, 'f32': 8}
53*9880d681SAndroid Build Coastguard Worker      (width, element_type) = random.choice(
54*9880d681SAndroid Build Coastguard Worker          [(width_map[t], t) for t in element_types])
55*9880d681SAndroid Build Coastguard Worker    else:
56*9880d681SAndroid Build Coastguard Worker      sys.exit(1) # Checked above by argument parsing.
57*9880d681SAndroid Build Coastguard Worker  else:
58*9880d681SAndroid Build Coastguard Worker    width = random.choice([2, 4, 8, 16, 32, 64])
59*9880d681SAndroid Build Coastguard Worker    element_type = random.choice(element_types)
60*9880d681SAndroid Build Coastguard Worker
61*9880d681SAndroid Build Coastguard Worker  element_modulus = {
62*9880d681SAndroid Build Coastguard Worker      'i8': 1 << 8, 'i16': 1 << 16, 'i32': 1 << 32, 'i64': 1 << 64,
63*9880d681SAndroid Build Coastguard Worker      'f32': 1 << 32, 'f64': 1 << 64}[element_type]
64*9880d681SAndroid Build Coastguard Worker
65*9880d681SAndroid Build Coastguard Worker  shuffle_range = (2 * width) if args.blends else width
66*9880d681SAndroid Build Coastguard Worker
67*9880d681SAndroid Build Coastguard Worker  # Because undef (-1) saturates and is indistinguishable when testing the
68*9880d681SAndroid Build Coastguard Worker  # correctness of a shuffle, we want to bias our fuzz toward having a decent
69*9880d681SAndroid Build Coastguard Worker  # mixture of non-undef lanes in the end. With a deep shuffle tree, the
70*9880d681SAndroid Build Coastguard Worker  # probabilies aren't good so we need to bias things. The math here is that if
71*9880d681SAndroid Build Coastguard Worker  # we uniformly select between -1 and the other inputs, each element of the
72*9880d681SAndroid Build Coastguard Worker  # result will have the following probability of being undef:
73*9880d681SAndroid Build Coastguard Worker  #
74*9880d681SAndroid Build Coastguard Worker  #   1 - (shuffle_range/(shuffle_range+1))^max_shuffle_height
75*9880d681SAndroid Build Coastguard Worker  #
76*9880d681SAndroid Build Coastguard Worker  # More generally, for any probability P of selecting a defined element in
77*9880d681SAndroid Build Coastguard Worker  # a single shuffle, the end result is:
78*9880d681SAndroid Build Coastguard Worker  #
79*9880d681SAndroid Build Coastguard Worker  #   1 - P^max_shuffle_height
80*9880d681SAndroid Build Coastguard Worker  #
81*9880d681SAndroid Build Coastguard Worker  # The power of the shuffle height is the real problem, as we want:
82*9880d681SAndroid Build Coastguard Worker  #
83*9880d681SAndroid Build Coastguard Worker  #   1 - shuffle_range/(shuffle_range+1)
84*9880d681SAndroid Build Coastguard Worker  #
85*9880d681SAndroid Build Coastguard Worker  # So we bias the selection of undef at any given node based on the tree
86*9880d681SAndroid Build Coastguard Worker  # height. Below, let 'A' be 'len(shuffle_range)', 'C' be 'max_shuffle_height',
87*9880d681SAndroid Build Coastguard Worker  # and 'B' be the bias we use to compensate for
88*9880d681SAndroid Build Coastguard Worker  # C '((A+1)*A^(1/C))/(A*(A+1)^(1/C))':
89*9880d681SAndroid Build Coastguard Worker  #
90*9880d681SAndroid Build Coastguard Worker  #   1 - (B * A)/(A + 1)^C = 1 - A/(A + 1)
91*9880d681SAndroid Build Coastguard Worker  #
92*9880d681SAndroid Build Coastguard Worker  # So at each node we use:
93*9880d681SAndroid Build Coastguard Worker  #
94*9880d681SAndroid Build Coastguard Worker  #   1 - (B * A)/(A + 1)
95*9880d681SAndroid Build Coastguard Worker  # = 1 - ((A + 1) * A * A^(1/C))/(A * (A + 1) * (A + 1)^(1/C))
96*9880d681SAndroid Build Coastguard Worker  # = 1 - ((A + 1) * A^((C + 1)/C))/(A * (A + 1)^((C + 1)/C))
97*9880d681SAndroid Build Coastguard Worker  #
98*9880d681SAndroid Build Coastguard Worker  # This is the formula we use to select undef lanes in the shuffle.
99*9880d681SAndroid Build Coastguard Worker  A = float(shuffle_range)
100*9880d681SAndroid Build Coastguard Worker  C = float(args.max_shuffle_height)
101*9880d681SAndroid Build Coastguard Worker  undef_prob = 1.0 - (((A + 1.0) * pow(A, (C + 1.0)/C)) /
102*9880d681SAndroid Build Coastguard Worker                      (A * pow(A + 1.0, (C + 1.0)/C)))
103*9880d681SAndroid Build Coastguard Worker
104*9880d681SAndroid Build Coastguard Worker  shuffle_tree = [[[-1 if random.random() <= undef_prob
105*9880d681SAndroid Build Coastguard Worker                       else random.choice(range(shuffle_range))
106*9880d681SAndroid Build Coastguard Worker                    for _ in itertools.repeat(None, width)]
107*9880d681SAndroid Build Coastguard Worker                   for _ in itertools.repeat(None, args.max_shuffle_height - i)]
108*9880d681SAndroid Build Coastguard Worker                  for i in xrange(args.max_shuffle_height)]
109*9880d681SAndroid Build Coastguard Worker
110*9880d681SAndroid Build Coastguard Worker  if args.verbose:
111*9880d681SAndroid Build Coastguard Worker    # Print out the shuffle sequence in a compact form.
112*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, ('Testing shuffle sequence "%s" (v%d%s):' %
113*9880d681SAndroid Build Coastguard Worker                         (args.seed, width, element_type))
114*9880d681SAndroid Build Coastguard Worker    for i, shuffles in enumerate(shuffle_tree):
115*9880d681SAndroid Build Coastguard Worker      print >>sys.stderr, '  tree level %d:' % (i,)
116*9880d681SAndroid Build Coastguard Worker      for j, s in enumerate(shuffles):
117*9880d681SAndroid Build Coastguard Worker        print >>sys.stderr, '    shuffle %d: %s' % (j, s)
118*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, ''
119*9880d681SAndroid Build Coastguard Worker
120*9880d681SAndroid Build Coastguard Worker  # Symbolically evaluate the shuffle tree.
121*9880d681SAndroid Build Coastguard Worker  inputs = [[int(j % element_modulus)
122*9880d681SAndroid Build Coastguard Worker             for j in xrange(i * width + 1, (i + 1) * width + 1)]
123*9880d681SAndroid Build Coastguard Worker            for i in xrange(args.max_shuffle_height + 1)]
124*9880d681SAndroid Build Coastguard Worker  results = inputs
125*9880d681SAndroid Build Coastguard Worker  for shuffles in shuffle_tree:
126*9880d681SAndroid Build Coastguard Worker    results = [[((results[i] if j < width else results[i + 1])[j % width]
127*9880d681SAndroid Build Coastguard Worker                 if j != -1 else -1)
128*9880d681SAndroid Build Coastguard Worker                for j in s]
129*9880d681SAndroid Build Coastguard Worker               for i, s in enumerate(shuffles)]
130*9880d681SAndroid Build Coastguard Worker  if len(results) != 1:
131*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, 'ERROR: Bad results: %s' % (results,)
132*9880d681SAndroid Build Coastguard Worker    sys.exit(1)
133*9880d681SAndroid Build Coastguard Worker  result = results[0]
134*9880d681SAndroid Build Coastguard Worker
135*9880d681SAndroid Build Coastguard Worker  if args.verbose:
136*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, 'Which transforms:'
137*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, '  from: %s' % (inputs,)
138*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, '  into: %s' % (result,)
139*9880d681SAndroid Build Coastguard Worker    print >>sys.stderr, ''
140*9880d681SAndroid Build Coastguard Worker
141*9880d681SAndroid Build Coastguard Worker  # The IR uses silly names for floating point types. We also need a same-size
142*9880d681SAndroid Build Coastguard Worker  # integer type.
143*9880d681SAndroid Build Coastguard Worker  integral_element_type = element_type
144*9880d681SAndroid Build Coastguard Worker  if element_type == 'f32':
145*9880d681SAndroid Build Coastguard Worker    integral_element_type = 'i32'
146*9880d681SAndroid Build Coastguard Worker    element_type = 'float'
147*9880d681SAndroid Build Coastguard Worker  elif element_type == 'f64':
148*9880d681SAndroid Build Coastguard Worker    integral_element_type = 'i64'
149*9880d681SAndroid Build Coastguard Worker    element_type = 'double'
150*9880d681SAndroid Build Coastguard Worker
151*9880d681SAndroid Build Coastguard Worker  # Now we need to generate IR for the shuffle function.
152*9880d681SAndroid Build Coastguard Worker  subst = {'N': width, 'T': element_type, 'IT': integral_element_type}
153*9880d681SAndroid Build Coastguard Worker  print """
154*9880d681SAndroid Build Coastguard Workerdefine internal fastcc <%(N)d x %(T)s> @test(%(arguments)s) noinline nounwind {
155*9880d681SAndroid Build Coastguard Workerentry:""" % dict(subst,
156*9880d681SAndroid Build Coastguard Worker                 arguments=', '.join(
157*9880d681SAndroid Build Coastguard Worker                     ['<%(N)d x %(T)s> %%s.0.%(i)d' % dict(subst, i=i)
158*9880d681SAndroid Build Coastguard Worker                      for i in xrange(args.max_shuffle_height + 1)]))
159*9880d681SAndroid Build Coastguard Worker
160*9880d681SAndroid Build Coastguard Worker  for i, shuffles in enumerate(shuffle_tree):
161*9880d681SAndroid Build Coastguard Worker   for j, s in enumerate(shuffles):
162*9880d681SAndroid Build Coastguard Worker    print """
163*9880d681SAndroid Build Coastguard Worker  %%s.%(next_i)d.%(j)d = shufflevector <%(N)d x %(T)s> %%s.%(i)d.%(j)d, <%(N)d x %(T)s> %%s.%(i)d.%(next_j)d, <%(N)d x i32> <%(S)s>
164*9880d681SAndroid Build Coastguard Worker""".strip('\n') % dict(subst, i=i, next_i=i + 1, j=j, next_j=j + 1,
165*9880d681SAndroid Build Coastguard Worker                       S=', '.join(['i32 ' + (str(si) if si != -1 else 'undef')
166*9880d681SAndroid Build Coastguard Worker                                    for si in s]))
167*9880d681SAndroid Build Coastguard Worker
168*9880d681SAndroid Build Coastguard Worker  print """
169*9880d681SAndroid Build Coastguard Worker  ret <%(N)d x %(T)s> %%s.%(i)d.0
170*9880d681SAndroid Build Coastguard Worker}
171*9880d681SAndroid Build Coastguard Worker""" % dict(subst, i=len(shuffle_tree))
172*9880d681SAndroid Build Coastguard Worker
173*9880d681SAndroid Build Coastguard Worker  # Generate some string constants that we can use to report errors.
174*9880d681SAndroid Build Coastguard Worker  for i, r in enumerate(result):
175*9880d681SAndroid Build Coastguard Worker    if r != -1:
176*9880d681SAndroid Build Coastguard Worker      s = ('FAIL(%(seed)s): lane %(lane)d, expected %(result)d, found %%d\n\\0A' %
177*9880d681SAndroid Build Coastguard Worker           {'seed': args.seed, 'lane': i, 'result': r})
178*9880d681SAndroid Build Coastguard Worker      s += ''.join(['\\00' for _ in itertools.repeat(None, 128 - len(s) + 2)])
179*9880d681SAndroid Build Coastguard Worker      print """
180*9880d681SAndroid Build Coastguard Worker@error.%(i)d = private unnamed_addr global [128 x i8] c"%(s)s"
181*9880d681SAndroid Build Coastguard Worker""".strip() % {'i': i, 's': s}
182*9880d681SAndroid Build Coastguard Worker
183*9880d681SAndroid Build Coastguard Worker  # Define a wrapper function which is marked 'optnone' to prevent
184*9880d681SAndroid Build Coastguard Worker  # interprocedural optimizations from deleting the test.
185*9880d681SAndroid Build Coastguard Worker  print """
186*9880d681SAndroid Build Coastguard Workerdefine internal fastcc <%(N)d x %(T)s> @test_wrapper(%(arguments)s) optnone noinline {
187*9880d681SAndroid Build Coastguard Worker  %%result = call fastcc <%(N)d x %(T)s> @test(%(arguments)s)
188*9880d681SAndroid Build Coastguard Worker  ret <%(N)d x %(T)s> %%result
189*9880d681SAndroid Build Coastguard Worker}
190*9880d681SAndroid Build Coastguard Worker""" % dict(subst,
191*9880d681SAndroid Build Coastguard Worker           arguments=', '.join(['<%(N)d x %(T)s> %%s.%(i)d' % dict(subst, i=i)
192*9880d681SAndroid Build Coastguard Worker                                for i in xrange(args.max_shuffle_height + 1)]))
193*9880d681SAndroid Build Coastguard Worker
194*9880d681SAndroid Build Coastguard Worker  # Finally, generate a main function which will trap if any lanes are mapped
195*9880d681SAndroid Build Coastguard Worker  # incorrectly (in an observable way).
196*9880d681SAndroid Build Coastguard Worker  print """
197*9880d681SAndroid Build Coastguard Workerdefine i32 @main() {
198*9880d681SAndroid Build Coastguard Workerentry:
199*9880d681SAndroid Build Coastguard Worker  ; Create a scratch space to print error messages.
200*9880d681SAndroid Build Coastguard Worker  %%str = alloca [128 x i8]
201*9880d681SAndroid Build Coastguard Worker  %%str.ptr = getelementptr inbounds [128 x i8], [128 x i8]* %%str, i32 0, i32 0
202*9880d681SAndroid Build Coastguard Worker
203*9880d681SAndroid Build Coastguard Worker  ; Build the input vector and call the test function.
204*9880d681SAndroid Build Coastguard Worker  %%v = call fastcc <%(N)d x %(T)s> @test_wrapper(%(inputs)s)
205*9880d681SAndroid Build Coastguard Worker  ; We need to cast this back to an integer type vector to easily check the
206*9880d681SAndroid Build Coastguard Worker  ; result.
207*9880d681SAndroid Build Coastguard Worker  %%v.cast = bitcast <%(N)d x %(T)s> %%v to <%(N)d x %(IT)s>
208*9880d681SAndroid Build Coastguard Worker  br label %%test.0
209*9880d681SAndroid Build Coastguard Worker""" % dict(subst,
210*9880d681SAndroid Build Coastguard Worker           inputs=', '.join(
211*9880d681SAndroid Build Coastguard Worker               [('<%(N)d x %(T)s> bitcast '
212*9880d681SAndroid Build Coastguard Worker                 '(<%(N)d x %(IT)s> <%(input)s> to <%(N)d x %(T)s>)' %
213*9880d681SAndroid Build Coastguard Worker                 dict(subst, input=', '.join(['%(IT)s %(i)d' % dict(subst, i=i)
214*9880d681SAndroid Build Coastguard Worker                                              for i in input])))
215*9880d681SAndroid Build Coastguard Worker                for input in inputs]))
216*9880d681SAndroid Build Coastguard Worker
217*9880d681SAndroid Build Coastguard Worker  # Test that each non-undef result lane contains the expected value.
218*9880d681SAndroid Build Coastguard Worker  for i, r in enumerate(result):
219*9880d681SAndroid Build Coastguard Worker    if r == -1:
220*9880d681SAndroid Build Coastguard Worker      print """
221*9880d681SAndroid Build Coastguard Workertest.%(i)d:
222*9880d681SAndroid Build Coastguard Worker  ; Skip this lane, its value is undef.
223*9880d681SAndroid Build Coastguard Worker  br label %%test.%(next_i)d
224*9880d681SAndroid Build Coastguard Worker""" % dict(subst, i=i, next_i=i + 1)
225*9880d681SAndroid Build Coastguard Worker    else:
226*9880d681SAndroid Build Coastguard Worker      print """
227*9880d681SAndroid Build Coastguard Workertest.%(i)d:
228*9880d681SAndroid Build Coastguard Worker  %%v.%(i)d = extractelement <%(N)d x %(IT)s> %%v.cast, i32 %(i)d
229*9880d681SAndroid Build Coastguard Worker  %%cmp.%(i)d = icmp ne %(IT)s %%v.%(i)d, %(r)d
230*9880d681SAndroid Build Coastguard Worker  br i1 %%cmp.%(i)d, label %%die.%(i)d, label %%test.%(next_i)d
231*9880d681SAndroid Build Coastguard Worker
232*9880d681SAndroid Build Coastguard Workerdie.%(i)d:
233*9880d681SAndroid Build Coastguard Worker  ; Capture the actual value and print an error message.
234*9880d681SAndroid Build Coastguard Worker  %%tmp.%(i)d = zext %(IT)s %%v.%(i)d to i2048
235*9880d681SAndroid Build Coastguard Worker  %%bad.%(i)d = trunc i2048 %%tmp.%(i)d to i32
236*9880d681SAndroid Build Coastguard Worker  call i32 (i8*, i8*, ...) @sprintf(i8* %%str.ptr, i8* getelementptr inbounds ([128 x i8], [128 x i8]* @error.%(i)d, i32 0, i32 0), i32 %%bad.%(i)d)
237*9880d681SAndroid Build Coastguard Worker  %%length.%(i)d = call i32 @strlen(i8* %%str.ptr)
238*9880d681SAndroid Build Coastguard Worker  call i32 @write(i32 2, i8* %%str.ptr, i32 %%length.%(i)d)
239*9880d681SAndroid Build Coastguard Worker  call void @llvm.trap()
240*9880d681SAndroid Build Coastguard Worker  unreachable
241*9880d681SAndroid Build Coastguard Worker""" % dict(subst, i=i, next_i=i + 1, r=r)
242*9880d681SAndroid Build Coastguard Worker
243*9880d681SAndroid Build Coastguard Worker  print """
244*9880d681SAndroid Build Coastguard Workertest.%d:
245*9880d681SAndroid Build Coastguard Worker  ret i32 0
246*9880d681SAndroid Build Coastguard Worker}
247*9880d681SAndroid Build Coastguard Worker
248*9880d681SAndroid Build Coastguard Workerdeclare i32 @strlen(i8*)
249*9880d681SAndroid Build Coastguard Workerdeclare i32 @write(i32, i8*, i32)
250*9880d681SAndroid Build Coastguard Workerdeclare i32 @sprintf(i8*, i8*, ...)
251*9880d681SAndroid Build Coastguard Workerdeclare void @llvm.trap() noreturn nounwind
252*9880d681SAndroid Build Coastguard Worker""" % (len(result),)
253*9880d681SAndroid Build Coastguard Worker
254*9880d681SAndroid Build Coastguard Workerif __name__ == '__main__':
255*9880d681SAndroid Build Coastguard Worker  main()
256