xref: /aosp_15_r20/external/pytorch/test/fx/test_fx_param_shape_control_flow.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import unittest
4
5import torch
6import torch.fx
7from torch.testing._internal.common_utils import TestCase
8
9
10class MyModuleBase(torch.nn.Module):
11    def forward(self, x):
12        matrx = self.get_mul_matrix()
13        if self.no_relu():
14            return torch.mm(x, matrx)
15        else:
16            return torch.relu(torch.mm(x, matrx))
17
18    def get_mul_matrix(self):
19        return self.param
20
21    def no_relu(self):
22        raise Exception("not implemented")  # noqa: TRY002
23
24
25class MyModuleParamShape(MyModuleBase):
26    def __init__(self, in_channels):
27        super().__init__()
28        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
29
30    def no_relu(self):
31        return self.param.shape[0] < 10
32
33
34class MyModuleParamSize(MyModuleBase):
35    def __init__(self, in_channels):
36        super().__init__()
37        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
38
39    def no_relu(self):
40        return self.param.size()[0] < 10
41
42
43class MyModuleParamDim(MyModuleBase):
44    def __init__(self, param):
45        super().__init__()
46        self.param = param
47
48    def get_mul_matrix(self):
49        return self.param[0] if (self.param.dim() == 3) else self.param
50
51    def no_relu(self):
52        return self.param.dim() == 3
53
54
55class MyModuleParamNDim(MyModuleBase):
56    def __init__(self, param):
57        super().__init__()
58        self.param = param
59
60    def get_mul_matrix(self):
61        return self.param[0] if (self.param.ndim == 3) else self.param
62
63    def no_relu(self):
64        return self.param.ndim == 3
65
66
67class MyModuleParamNumEl(MyModuleBase):
68    def __init__(self, in_channels):
69        super().__init__()
70        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
71
72    def no_relu(self):
73        return self.param.numel() < 10 * 3
74
75
76class MyModuleParamNElement(MyModuleBase):
77    def __init__(self, in_channels):
78        super().__init__()
79        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
80
81    def no_relu(self):
82        return self.param.nelement() < 10 * 3
83
84
85class TestConstParamShapeInControlFlow(TestCase):
86    def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
87        """
88        Verify one module only does a mm op while the other
89        performs both mm and relu ops in cascade
90        """
91        x = torch.randn(10, 5)
92        torch.testing.assert_close(
93            mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())
94        )
95        tracer = torch.fx.Tracer(param_shapes_constant=True)
96        traced_graph = tracer.trace(mm_only_mod)
97
98        # verify the graph module calculates the same result
99        graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
100        torch.testing.assert_close(
101            graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())
102        )
103
104        # Make a new module with different parameter shape to go down the different
105        # code path
106        x = torch.randn(10, 15)
107        torch.testing.assert_close(
108            relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
109        )
110
111        tracer2 = torch.fx.Tracer(param_shapes_constant=True)
112        traced_graph2 = tracer2.trace(relu_mod)
113
114        # verify the graph module calculates the same result
115        graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
116        torch.testing.assert_close(
117            graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
118        )
119
120        graph1_node_targets = [n.target for n in traced_graph.nodes]
121        graph2_node_targets = [n.target for n in traced_graph2.nodes]
122
123        # the second graph has an exta relu function call node
124        assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
125        assert (
126            torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
127        )
128
129    def test_param_shape_const(self):
130        mymod = MyModuleParamShape(in_channels=5)
131        mymod2 = MyModuleParamShape(in_channels=15)
132        self.verify_mm_relu_mods(mymod, mymod2)
133
134    def test_param_size_const(self):
135        mymod = MyModuleParamSize(in_channels=5)
136        mymod2 = MyModuleParamSize(in_channels=15)
137        self.verify_mm_relu_mods(mymod, mymod2)
138
139    def test_param_dim_const(self):
140        mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
141        mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
142        self.verify_mm_relu_mods(mymod, mymod2)
143
144    def test_param_ndim_const(self):
145        mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
146        mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
147        self.verify_mm_relu_mods(mymod, mymod2)
148
149    def test_param_numel_const(self):
150        mymod = MyModuleParamNumEl(in_channels=5)
151        mymod2 = MyModuleParamNumEl(in_channels=15)
152        self.verify_mm_relu_mods(mymod, mymod2)
153
154    def test_param_nelement_const(self):
155        mymod = MyModuleParamNElement(in_channels=5)
156        mymod2 = MyModuleParamNElement(in_channels=15)
157        self.verify_mm_relu_mods(mymod, mymod2)
158
159
160if __name__ == "__main__":
161    unittest.main()
162