xref: /aosp_15_r20/external/pytorch/test/inductor/indirect_assert_helper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import sys
2
3import torch
4from torch.testing._internal.inductor_utils import GPU_TYPE
5
6
7def first_arg(x, y):
8    return x[y]
9
10
11def second_arg(x, y):
12    return x[:, y]
13
14
15def same_pm_one(x, y):
16    return x[y + 1, y - 1]
17
18
19def same_pp_one(x, y):
20    return x[y + 1, y + 1]
21
22
23def store(x, y, z):
24    x[y + 1, y + 1] = z
25
26
27def upper1(x):
28    b = torch.arange(4, device=x.device)
29    return x[b]
30
31
32def lower1(x):
33    b = x.new_full((), -4, dtype=torch.int64)
34    return x[b]
35
36
37def upper2(x):
38    b = x.new_full((), 4, dtype=torch.int64)
39    return x[b]
40
41
42def lower2(x):
43    b = x.new_zeros((), dtype=torch.int64)
44    return x[b - 4]
45
46
47if __name__ == "__main__":
48    fns = [
49        name
50        for name, obj in locals().items()
51        if callable(obj) and obj.__module__ == __name__
52    ]
53
54    _, fn_name, dims, dyn_shape, one_size = sys.argv
55    assert fn_name in fns
56    assert one_size in ("True", "False")
57    one_size = one_size == "True"
58    assert dims in ("2", "3")
59    shape_x = [3, 2, 4] if dims == "3" else [3, 2]
60    if one_size:
61        assert (
62            fn_name == "first_arg"
63        ), "only first_arg can be tested for a special case of 1-size tensor"
64        shape_x[0] = 1
65    assert dyn_shape in ("True", "False")
66    dynamic_shapes = dyn_shape == "True"
67
68    x = torch.randn(shape_x, device=GPU_TYPE)
69    y = torch.arange(4, device=GPU_TYPE)
70    fn = vars()[fn_name]
71    fn = torch.compile(dynamic=dynamic_shapes)(fn)
72    if fn_name == "store":
73        shape = (y.numel(),) + x.shape[2:]
74        z = torch.randn(shape, device=GPU_TYPE)
75        fn(x, y, z)
76    elif fn_name in ("upper1", "upper2", "lower1", "lower2"):
77        fn(x)
78    else:
79        fn(x, y)
80