xref: /aosp_15_r20/external/pytorch/test/onnx/test_fx_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2import pytorch_test_common
3
4import torch
5import torch._dynamo
6import torch.fx
7from torch.onnx._internal.fx.passes import _utils as pass_utils
8from torch.testing._internal import common_utils
9
10
11class TestFxPasses(common_utils.TestCase):
12    def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
13        def func(x, y, z):
14            return x + y + z
15
16        x = torch.randn(3)
17        y = torch.randn(3)
18        z = torch.randn(3)
19        gm, _ = torch._dynamo.export(func)(x, y, z)
20        torch._dynamo.reset()
21
22        # Purposely name the nodes in a way that will cause a recursive collision later.
23        # See :func:`set_node_name` for name collision renaming logic.
24        base_name = "tensor"
25        nodes = list(gm.graph.nodes)
26        for i, node in enumerate(nodes[1:]):
27            if i == 0:
28                node.name = base_name
29            else:
30                node.name = f"{base_name}.{i}"
31
32        # Run `set_node_name` and verify that the names are correct.
33        name_to_node = {node.name: node for node in gm.graph.nodes}
34        pass_utils.set_node_name(nodes[0], base_name, name_to_node)
35        assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
36        assert len({node.name for node in nodes}) == len(
37            nodes
38        ), f"Expected all names to be unique, got {nodes}"
39
40    def test_set_node_name_succeeds_when_no_name_collisions(self):
41        def func(x, y, z):
42            return x + y + z
43
44        x = torch.randn(3)
45        y = torch.randn(3)
46        z = torch.randn(3)
47        gm, _ = torch._dynamo.export(func)(x, y, z)
48        torch._dynamo.reset()
49
50        # Run `set_node_name` and verify that the names are correct.
51        new_name = "some_tensor"
52        nodes = list(gm.graph.nodes)
53        name_to_node = {node.name: node for node in nodes}
54        pass_utils.set_node_name(nodes[1], new_name, name_to_node)
55        assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
56        assert len({node.name for node in nodes}) == len(
57            nodes
58        ), f"Expected all names to be unique, got {nodes}"
59
60    def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
61        @torch.library.custom_op(
62            "mylibrary::foo_op", device_types="cpu", mutates_args=()
63        )
64        def foo_op(x: torch.Tensor) -> torch.Tensor:
65            return x + 1
66
67        @torch.library.custom_op(
68            "mylibrary::bar_op", device_types="cpu", mutates_args=()
69        )
70        def bar_op(x: torch.Tensor) -> torch.Tensor:
71            return x + 2
72
73        @foo_op.register_fake
74        def _(x):
75            return torch.empty_like(x)
76
77        @bar_op.register_fake
78        def _(x):
79            return torch.empty_like(x)
80
81        def func(x, y, z):
82            return foo_op(x) + bar_op(y) + z
83
84        x = torch.randn(3)
85        y = torch.randn(3)
86        z = torch.randn(3)
87        with self.assertRaises(torch.onnx.OnnxExporterError) as ctx:
88            torch.onnx.dynamo_export(func, x, y, z)
89        inner_exception = ctx.exception.__cause__
90        self.assertRegex(
91            str(inner_exception),
92            r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op",
93        )
94
95        torch._dynamo.reset()
96
97
98@common_utils.instantiate_parametrized_tests
99class TestModularizePass(common_utils.TestCase):
100    @pytorch_test_common.xfail(
101        error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found",
102        reason="optimizer",
103    )
104    @common_utils.parametrize(
105        "is_exported_program",
106        [
107            common_utils.subtest(
108                True,
109                name="exported_program",
110            ),
111            common_utils.subtest(
112                False,
113                name="nn_module",
114            ),
115        ],
116    )
117    def test_modularize_pass_succeeds_when_submodule_output_is_unused(
118        self, is_exported_program
119    ):
120        # This is an ill-formed model, but exporter must not crash.
121        # It is illegal for submodule to have zero output. For modularization pass it can happen
122        # when the submodule output is unused, so no inner node is connected to any outer
123        # nodes.
124        # However, this also means the entire submodule should be erased by DCE. Hence
125        # it should never occur.
126        #
127        # Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768
128        class TestModule(torch.nn.Module):
129            def __init__(self) -> None:
130                super().__init__()
131                self.unused_relu = torch.nn.ReLU()
132                self.used_gelu = torch.nn.GELU()
133
134            def forward(self, x, y):
135                result = self.used_gelu(x + y)
136                unused_relu_result = self.unused_relu(x)
137                return result
138
139        if is_exported_program:
140            model = torch.export.export(
141                TestModule(), args=(torch.randn(3), torch.randn(3))
142            )
143        else:
144            model = TestModule()
145
146        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
147        model_proto = onnx_program.model_proto
148        function_proto_names = [function.name for function in model_proto.functions]
149        self.assertIn(
150            "torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names
151        )
152        self.assertFalse(any("ReLU" in name for name in function_proto_names))
153
154    @pytorch_test_common.xfail(
155        error_message="'torch_nn_modules_activation_ReLU_relu_1' not found",
156        reason="optimizer",
157    )
158    @common_utils.parametrize(
159        "is_exported_program",
160        [
161            common_utils.subtest(
162                True,
163                name="exported_program",
164            ),
165            common_utils.subtest(
166                False,
167                name="nn_module",
168            ),
169        ],
170    )
171    def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times(
172        self, is_exported_program
173    ):
174        class TestModule(torch.nn.Module):
175            def __init__(self) -> None:
176                super().__init__()
177                self.relu = torch.nn.ReLU()
178
179            def forward(self, x, y):
180                out = x + y
181                out = self.relu(out)
182                out = out + x
183                out = self.relu(out)
184                return out
185
186        if is_exported_program:
187            model = torch.export.export(
188                TestModule(), args=(torch.randn(3), torch.randn(3))
189            )
190        else:
191            model = TestModule()
192
193        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
194        model_proto = onnx_program.model_proto
195        function_proto_names = [function.name for function in model_proto.functions]
196        self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names)
197        self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names)
198
199    @pytorch_test_common.xfail(
200        error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found",
201        reason="optimizer",
202    )
203    @common_utils.parametrize(
204        "is_exported_program",
205        [
206            common_utils.subtest(
207                True,
208                name="exported_program",
209            ),
210            common_utils.subtest(
211                False,
212                name="nn_module",
213            ),
214        ],
215    )
216    def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers(
217        self, is_exported_program
218    ):
219        # Minified repro from basic_gnn_edgecnn.
220        class InnerModule(torch.nn.Module):
221            def __init__(self) -> None:
222                super().__init__()
223                self.relu = torch.nn.ReLU()
224
225            def forward(self, x):
226                return self.relu(x)
227
228        class TestModule(torch.nn.Module):
229            def __init__(self) -> None:
230                super().__init__()
231                self.inner_module = InnerModule()
232
233            def forward(self, x, y):
234                out = x + y
235                out = self.inner_module(out)
236                out = out + x
237                out = self.inner_module.relu(out)
238                return out
239
240        if is_exported_program:
241            model = torch.export.export(
242                TestModule(), args=(torch.randn(3), torch.randn(3))
243            )
244        else:
245            model = TestModule()
246
247        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
248        model_proto = onnx_program.model_proto
249        function_proto_names = [function.name for function in model_proto.functions]
250        self.assertIn(
251            "torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names
252        )
253        self.assertIn(
254            "torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names
255        )
256        # local module qualified name is unstable in test environment depending on different test
257        # invocation methods.
258        self.assertTrue(
259            any("InnerModule_inner_module_1" in name for name in function_proto_names)
260        )
261
262
263if __name__ == "__main__":
264    common_utils.run_tests()
265