1import sys 2 3import torch 4from torch.testing._internal.inductor_utils import GPU_TYPE 5 6 7def first_arg(x, y): 8 return x[y] 9 10 11def second_arg(x, y): 12 return x[:, y] 13 14 15def same_pm_one(x, y): 16 return x[y + 1, y - 1] 17 18 19def same_pp_one(x, y): 20 return x[y + 1, y + 1] 21 22 23def store(x, y, z): 24 x[y + 1, y + 1] = z 25 26 27def upper1(x): 28 b = torch.arange(4, device=x.device) 29 return x[b] 30 31 32def lower1(x): 33 b = x.new_full((), -4, dtype=torch.int64) 34 return x[b] 35 36 37def upper2(x): 38 b = x.new_full((), 4, dtype=torch.int64) 39 return x[b] 40 41 42def lower2(x): 43 b = x.new_zeros((), dtype=torch.int64) 44 return x[b - 4] 45 46 47if __name__ == "__main__": 48 fns = [ 49 name 50 for name, obj in locals().items() 51 if callable(obj) and obj.__module__ == __name__ 52 ] 53 54 _, fn_name, dims, dyn_shape, one_size = sys.argv 55 assert fn_name in fns 56 assert one_size in ("True", "False") 57 one_size = one_size == "True" 58 assert dims in ("2", "3") 59 shape_x = [3, 2, 4] if dims == "3" else [3, 2] 60 if one_size: 61 assert ( 62 fn_name == "first_arg" 63 ), "only first_arg can be tested for a special case of 1-size tensor" 64 shape_x[0] = 1 65 assert dyn_shape in ("True", "False") 66 dynamic_shapes = dyn_shape == "True" 67 68 x = torch.randn(shape_x, device=GPU_TYPE) 69 y = torch.arange(4, device=GPU_TYPE) 70 fn = vars()[fn_name] 71 fn = torch.compile(dynamic=dynamic_shapes)(fn) 72 if fn_name == "store": 73 shape = (y.numel(),) + x.shape[2:] 74 z = torch.randn(shape, device=GPU_TYPE) 75 fn(x, y, z) 76 elif fn_name in ("upper1", "upper2", "lower1", "lower2"): 77 fn(x) 78 else: 79 fn(x, y) 80