xref: /aosp_15_r20/external/pytorch/test/fx/test_common_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: fx"]
2
3import itertools
4
5import torch
6from torch.fx.experimental.proxy_tensor import make_fx
7from torch.fx.graph_module import GraphModule
8from torch.fx.passes.dialect.common.cse_pass import CSEPass
9from torch.testing._internal.common_utils import (
10    instantiate_parametrized_tests,
11    parametrize,
12    run_tests,
13    TestCase,
14)
15
16
17def FactoryFunctionCall(x, device):
18    y = torch.full(x.shape, 3, device=device)
19    z = torch.add(y, x)
20    return z
21
22
23def TorchTensorCall(x):
24    y = torch.tensor(3)
25    return x + y
26
27
28def TakeList(x):
29    z = torch.cat([x, x])
30    return z
31
32
33def ReturnList(x):
34    a = torch.arange(10).reshape(5, 2)
35    z = torch.split(a, [1, 4])
36    return z
37
38
39def Mutation(x):
40    y = x + 2
41    y.add_(1)
42    return x + y
43
44
45def MutationInput(x):
46    x.add_(1)
47    y = x + 2
48    return x + y
49
50
51def MutationFactory(x, device):
52    y = torch.full(x.shape, 3, device=device)
53    y.add_(1)
54    return x + y
55
56
57def MutationTorchTensorCall(x):
58    y = torch.tensor(3)
59    y.add_(1)
60    return x + y
61
62
63def MutationMetadata(x):
64    x.resize_(2)
65    return x
66
67
68Passes = [CSEPass]
69Test_Cases = [
70    TakeList,
71    ReturnList,
72    Mutation,
73    MutationInput,
74    MutationMetadata,
75    MutationTorchTensorCall,
76]
77Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
78Devices = ["cpu"]
79if torch.cuda.is_available():
80    Devices.append("cuda")
81
82
83def name_fn(common_pass, f, device):
84    """Names parameterized test cases."""
85    return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
86
87
88@instantiate_parametrized_tests
89class TestCommonPass(TestCase):
90    @parametrize(
91        "common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
92    )
93    def test_correctness(self, common_pass, f, device):
94        inp = torch.randn(10, device=device)
95
96        traced_m = make_fx(f)(inp)
97        P = common_pass()
98
99        res = P(traced_m)
100        modified_m = res.graph_module
101        assert isinstance(modified_m, GraphModule)
102
103        inp_copy = inp.clone()
104        expected = f(inp)
105        result = modified_m(inp_copy)
106
107        self.assertEqual(result, expected)
108
109    @parametrize(
110        "common_pass,f,device",
111        itertools.product(Passes, Factory_Test_Cases, Devices),
112        name_fn,
113    )
114    def test_correctness_factory(self, common_pass, f, device):
115        inp = torch.randn(10, device=device)
116        traced_m = make_fx(f)(inp, device)
117        P = common_pass()
118
119        res = P(traced_m)
120        modified_m = res.graph_module
121        assert isinstance(modified_m, GraphModule)
122
123        inp_copy = inp.clone()
124        expected = f(inp, device)
125        result = modified_m(inp_copy, device)
126
127        self.assertEqual(result, expected)
128
129
130if __name__ == "__main__":
131    run_tests()
132