xref: /aosp_15_r20/external/pytorch/test/onnx/test_utility_funs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import copy
4import functools
5import io
6import re
7import warnings
8from typing import Callable
9
10import onnx
11
12import parameterized
13import pytorch_test_common
14import torchvision
15from autograd_helper import CustomFunction as CustomFunction2
16from pytorch_test_common import (
17    skipIfNoCuda,
18    skipIfUnsupportedMaxOpsetVersion,
19    skipIfUnsupportedMinOpsetVersion,
20)
21
22import torch
23import torch.onnx
24import torch.utils.cpp_extension
25from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
26from torch.onnx._globals import GLOBALS
27from torch.onnx.symbolic_helper import _unpack_list, parse_args
28from torch.testing._internal import common_utils
29from torch.testing._internal.common_utils import skipIfNoLapack
30
31
32def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
33    """Remove test environment prefix added to module.
34
35    Remove prefix to normalize scope names, since different test environments add
36    prefixes with slight differences.
37
38    Example:
39
40        >>> _remove_test_environment_prefix_from_scope_name(
41        >>>     "test_utility_funs.M"
42        >>> )
43        "M"
44        >>> _remove_test_environment_prefix_from_scope_name(
45        >>>     "test_utility_funs.test_abc.<locals>.M"
46        >>> )
47        "M"
48        >>> _remove_test_environment_prefix_from_scope_name(
49        >>>     "__main__.M"
50        >>> )
51        "M"
52    """
53    prefixes_to_remove = ["test_utility_funs", "__main__"]
54    for prefix in prefixes_to_remove:
55        scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
56    return scope_name
57
58
59class _BaseTestCase(pytorch_test_common.ExportTestCase):
60    def _model_to_graph(
61        self,
62        model,
63        input,
64        do_constant_folding=True,
65        training=TrainingMode.EVAL,
66        operator_export_type=OperatorExportTypes.ONNX,
67        input_names=None,
68        dynamic_axes=None,
69    ):
70        torch.onnx.utils._setup_trace_module_map(model, False)
71        if training == torch.onnx.TrainingMode.TRAINING:
72            model.train()
73        elif training == torch.onnx.TrainingMode.EVAL:
74            model.eval()
75        utils._validate_dynamic_axes(dynamic_axes, model, None, None)
76        graph, params_dict, torch_out = utils._model_to_graph(
77            model,
78            input,
79            do_constant_folding=do_constant_folding,
80            _disable_torch_constant_prop=True,
81            operator_export_type=operator_export_type,
82            training=training,
83            input_names=input_names,
84            dynamic_axes=dynamic_axes,
85        )
86        return graph, params_dict, torch_out
87
88
89@common_utils.instantiate_parametrized_tests
90class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
91    """Unit tests for the `unconvertible_ops` function."""
92
93    def setUp(self):
94        class EinsumModule(torch.nn.Module):
95            def forward(self, x):
96                return torch.einsum("ii", x)
97
98        self.einsum_module = EinsumModule()
99
100    def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
101        x = torch.randn(4, 4)
102
103        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
104        graph, unconvertible_ops = utils.unconvertible_ops(
105            self.einsum_module, (x,), opset_version=9
106        )
107        nodes = graph.nodes()
108        self.assertEqual(next(nodes).kind(), "prim::Constant")
109        self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
110        self.assertEqual(next(nodes).kind(), "prim::Constant")
111        self.assertEqual(next(nodes).kind(), "aten::einsum")
112        self.assertEqual(unconvertible_ops, ["aten::einsum"])
113
114    @common_utils.parametrize(
115        "jit_function",
116        [
117            common_utils.subtest(
118                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
119                name="traced",
120            ),
121            common_utils.subtest(torch.jit.script, name="scripted"),
122        ],
123    )
124    def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
125        self, jit_function: Callable
126    ):
127        module = jit_function(self.einsum_module)
128        x = torch.randn(4, 4)
129
130        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
131        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
132        self.assertEqual(unconvertible_ops, ["aten::einsum"])
133
134    @common_utils.parametrize(
135        "jit_function",
136        [
137            common_utils.subtest(lambda x: x, name="nn_module"),
138            common_utils.subtest(
139                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
140                name="traced",
141            ),
142            common_utils.subtest(torch.jit.script, name="scripted"),
143        ],
144    )
145    def test_it_returns_empty_list_when_all_ops_convertible(
146        self, jit_function: Callable
147    ):
148        module = jit_function(self.einsum_module)
149        x = torch.randn(4, 4)
150
151        # Einsum is supported since opset 12
152        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
153        self.assertEqual(unconvertible_ops, [])
154
155    def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
156        class SkipConnectionModule(torch.nn.Module):
157            def forward(self, x):
158                out = x
159                out += x
160                out = torch.nn.functional.relu(out, inplace=True)
161                return out
162
163        module = SkipConnectionModule()
164        x = torch.randn(4, 4)
165        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
166        self.assertEqual(unconvertible_ops, [])
167
168
169@parameterized.parameterized_class(
170    [
171        {"opset_version": opset}
172        for opset in range(
173            _constants.ONNX_BASE_OPSET,
174            _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
175        )
176    ],
177    class_name_func=lambda cls,
178    num,
179    params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
180)
181class TestUtilityFuns(_BaseTestCase):
182    opset_version = None
183
184    def test_is_in_onnx_export(self):
185        test_self = self
186
187        class MyModule(torch.nn.Module):
188            def forward(self, x):
189                test_self.assertTrue(torch.onnx.is_in_onnx_export())
190                raise ValueError
191                return x + 1
192
193        x = torch.randn(3, 4)
194        f = io.BytesIO()
195        try:
196            torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
197        except ValueError:
198            self.assertFalse(torch.onnx.is_in_onnx_export())
199
200    def test_validate_dynamic_axes_invalid_input_output_name(self):
201        with warnings.catch_warnings(record=True) as w:
202            warnings.simplefilter("always")
203            utils._validate_dynamic_axes(
204                {"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}},
205                None,
206                ["input1", "input2"],
207                ["output"],
208            )
209            messages = [str(warning.message) for warning in w]
210        self.assertIn(
211            "Provided key invalid_name1 for dynamic axes is not a valid input/output name",
212            messages,
213        )
214        self.assertIn(
215            "Provided key invalid_name2 for dynamic axes is not a valid input/output name",
216            messages,
217        )
218        self.assertEqual(len(messages), 2)
219
220    @skipIfUnsupportedMinOpsetVersion(11)
221    def test_split_to_slice(self):
222        class SplitModule(torch.nn.Module):
223            def forward(self, x, y, t):
224                splits = (x.size(1), y.size(1))
225                out, out2 = torch.split(t, splits, dim=1)
226                return out, out2
227
228        GLOBALS.export_onnx_opset_version = self.opset_version
229        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
230        x = torch.randn(2, 3)
231        y = torch.randn(2, 4)
232        t = torch.randn(2, 7)
233        graph, _, _ = self._model_to_graph(
234            SplitModule(),
235            (x, y, t),
236            input_names=["x", "y", "t"],
237            dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
238        )
239        for node in graph.nodes():
240            self.assertNotEqual(node.kind(), "onnx::SplitToSequence")
241
242    def test_constant_fold_transpose(self):
243        class TransposeModule(torch.nn.Module):
244            def forward(self, x):
245                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
246                b = torch.transpose(a, 1, 0)
247                return b + x
248
249        GLOBALS.export_onnx_opset_version = self.opset_version
250        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
251        x = torch.ones(3, 2)
252        graph, _, __ = self._model_to_graph(
253            TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
254        )
255
256        for node in graph.nodes():
257            self.assertNotEqual(node.kind(), "onnx::Transpose")
258            self.assertNotEqual(node.kind(), "onnx::Cast")
259        self.assertEqual(len(list(graph.nodes())), 2)
260
261    @skipIfUnsupportedMaxOpsetVersion(17)
262    def test_constant_fold_reduceL2(self):
263        class ReduceModule(torch.nn.Module):
264            def forward(self, x):
265                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
266                b = torch.norm(a, p=2, dim=-2, keepdim=False)
267                return b + x
268
269        GLOBALS.export_onnx_opset_version = self.opset_version
270        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
271        x = torch.ones(2, 3)
272        graph, _, __ = self._model_to_graph(
273            ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
274        )
275
276        for node in graph.nodes():
277            self.assertNotEqual(node.kind(), "onnx::ReduceL2")
278
279    @skipIfUnsupportedMaxOpsetVersion(17)
280    def test_constant_fold_reduceL1(self):
281        class NormModule(torch.nn.Module):
282            def forward(self, x):
283                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
284                b = torch.norm(a, p=1, dim=-2)
285                return b + x
286
287        GLOBALS.export_onnx_opset_version = self.opset_version
288        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
289        x = torch.ones(2, 3)
290        graph, _, __ = self._model_to_graph(
291            NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
292        )
293
294        for node in graph.nodes():
295            self.assertNotEqual(node.kind(), "onnx::ReduceL1")
296
297    def test_constant_fold_slice(self):
298        class NarrowModule(torch.nn.Module):
299            def forward(self, x):
300                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
301                b = torch.narrow(a, 0, 0, 1)
302                return b + x
303
304        GLOBALS.export_onnx_opset_version = self.opset_version
305        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
306        x = torch.ones(1, 3)
307        graph, _, __ = self._model_to_graph(
308            NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
309        )
310
311        for node in graph.nodes():
312            self.assertNotEqual(node.kind(), "onnx::Slice")
313            self.assertNotEqual(node.kind(), "onnx::Cast")
314        self.assertEqual(len(list(graph.nodes())), 2)
315
316    def test_constant_fold_slice_index_exceeds_dim(self):
317        class SliceIndexExceedsDimModule(torch.nn.Module):
318            def forward(self, x):
319                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
320                b = a[1:10]  # index exceeds dimension
321                return b + x
322
323        GLOBALS.export_onnx_opset_version = self.opset_version
324        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
325        x = torch.ones(1, 3)
326        graph, _, __ = self._model_to_graph(
327            SliceIndexExceedsDimModule(),
328            (x,),
329            input_names=["x"],
330            dynamic_axes={"x": [0, 1]},
331        )
332
333        for node in graph.nodes():
334            self.assertNotEqual(node.kind(), "onnx::Slice")
335            self.assertNotEqual(node.kind(), "onnx::Cast")
336        self.assertEqual(len(list(graph.nodes())), 2)
337
338    def test_constant_fold_slice_negative_index(self):
339        class SliceNegativeIndexModule(torch.nn.Module):
340            def forward(self, x):
341                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
342                b = a[0:-1]  # index relative to the end
343                c = torch.select(a, dim=-1, index=-2)
344                d = torch.select(a, dim=1, index=0)
345                return b + x, c + d
346
347        GLOBALS.export_onnx_opset_version = self.opset_version
348        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
349        x = torch.ones(1, 3)
350        graph, _, __ = self._model_to_graph(
351            SliceNegativeIndexModule(),
352            (x,),
353            input_names=["x"],
354            dynamic_axes={"x": [0, 1]},
355        )
356
357        for node in graph.nodes():
358            self.assertNotEqual(node.kind(), "onnx::Slice")
359            self.assertNotEqual(node.kind(), "onnx::Cast")
360
361    def test_constant_fold_gather(self):
362        class GatherModule(torch.nn.Module):
363            def forward(self, x):
364                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
365                b = torch.select(a, dim=1, index=-2)
366                c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
367                return b + 1, c + x
368
369        GLOBALS.export_onnx_opset_version = self.opset_version
370        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
371        x = torch.ones(1, 3)
372        model = GatherModule()
373        model(x)
374        graph, _, __ = self._model_to_graph(
375            GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
376        )
377
378        for node in graph.nodes():
379            self.assertNotEqual(node.kind(), "onnx::Gather")
380
381    def test_constant_fold_unsqueeze(self):
382        class UnsqueezeModule(torch.nn.Module):
383            def forward(self, x):
384                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
385                b = torch.unsqueeze(a, -2)
386                return b + x
387
388        GLOBALS.export_onnx_opset_version = self.opset_version
389        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
390        x = torch.ones(1, 2, 3)
391        graph, _, __ = self._model_to_graph(
392            UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
393        )
394
395        for node in graph.nodes():
396            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
397            self.assertNotEqual(node.kind(), "onnx::Cast")
398        self.assertEqual(len(list(graph.nodes())), 2)
399
400    def test_constant_fold_unsqueeze_multi_axies(self):
401        class PReluModel(torch.nn.Module):
402            def __init__(self) -> None:
403                super().__init__()
404                self.prelu = torch.nn.PReLU()
405
406            def forward(self, x):
407                a = torch.randn(2, 3, 4, 5, 8, 7)
408                return self.prelu(x) + a
409
410        GLOBALS.export_onnx_opset_version = self.opset_version
411        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
412        x = torch.randn(2, 3, 4, 5, 8, 7)
413        graph, _, __ = self._model_to_graph(
414            PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
415        )
416
417        for node in graph.nodes():
418            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
419            self.assertNotEqual(node.kind(), "onnx::Cast")
420        self.assertEqual(len(list(graph.nodes())), 5)
421
422    def test_constant_fold_squeeze_without_axes(self):
423        class SqueezeModule(torch.nn.Module):
424            def forward(self, x):
425                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
426                return torch.squeeze(a) + x + torch.squeeze(a)
427
428        GLOBALS.export_onnx_opset_version = self.opset_version
429        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
430        x = torch.ones(2, 3)
431        graph, _, __ = self._model_to_graph(
432            SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
433        )
434        for node in graph.nodes():
435            self.assertNotEqual(node.kind(), "onnx::Squeeze")
436            self.assertNotEqual(node.kind(), "onnx::Cast")
437        self.assertEqual(len(list(graph.nodes())), 4)
438
439    def test_constant_fold_squeeze_with_axes(self):
440        class SqueezeAxesModule(torch.nn.Module):
441            def forward(self, x):
442                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
443                return torch.squeeze(a, dim=-3) + x
444
445        GLOBALS.export_onnx_opset_version = self.opset_version
446        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
447        x = torch.ones(2, 3)
448        graph, _, __ = self._model_to_graph(
449            SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
450        )
451
452        for node in graph.nodes():
453            self.assertNotEqual(node.kind(), "onnx::Squeeze")
454            self.assertNotEqual(node.kind(), "onnx::Cast")
455        self.assertEqual(len(list(graph.nodes())), 2)
456
457    def test_constant_fold_concat(self):
458        class ConcatModule(torch.nn.Module):
459            def forward(self, x):
460                # Why did I insert a Cast here?  There appears to be intentional
461                # behavior in ONNX constant folding where constant tensors which
462                # are not attached to any known to be foldable onnx
463                # operations don't get extracted into the initializer graph.  So
464                # without these casts, we will actually fail to pull out one of
465                # the constants, thus failing constant folding.  I think the
466                # test is wrong but I don't have time to write a more correct
467                # test (I think the right way to go about the test is to setup
468                # a predicate for what invariant graphs should hold after
469                # constant folding, and then verify this predicate holds.
470                # I think the asserts below are an attempt at this predicate,
471                # but it is not right!)
472                #
473                # More commentary at
474                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
475                a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float)
476                b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float)
477                c = torch.cat((a, b), 0)
478                d = b + c
479                return x + d
480
481        GLOBALS.export_onnx_opset_version = self.opset_version
482        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
483        x = torch.ones(2, 3)
484        graph, _, __ = self._model_to_graph(
485            ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
486        )
487
488        for node in graph.nodes():
489            self.assertNotEqual(node.kind(), "onnx::Concat")
490            self.assertNotEqual(node.kind(), "onnx::Cast")
491        self.assertEqual(len(list(graph.nodes())), 2)
492
493    def test_constant_fold_lstm(self):
494        class GruNet(torch.nn.Module):
495            def __init__(self) -> None:
496                super().__init__()
497                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
498
499            def forward(self, input, initial_state):
500                return self.mygru(input, initial_state)
501
502        GLOBALS.export_onnx_opset_version = self.opset_version
503        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
504        input = torch.randn(5, 3, 7)
505        h0 = torch.randn(1, 3, 3)
506        graph, _, __ = self._model_to_graph(
507            GruNet(),
508            (input, h0),
509            input_names=["input", "h0"],
510            dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]},
511        )
512
513        for node in graph.nodes():
514            self.assertNotEqual(node.kind(), "onnx::Slice")
515            self.assertNotEqual(node.kind(), "onnx::Concat")
516            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
517
518        if self.opset_version <= 12:
519            self.assertEqual(len(list(graph.nodes())), 3)
520        else:
521            # Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13
522            self.assertEqual(len(list(graph.nodes())), 4)
523
524    def test_constant_fold_transpose_matmul(self):
525        class MatMulNet(torch.nn.Module):
526            def __init__(self) -> None:
527                super().__init__()
528                self.B = torch.nn.Parameter(torch.ones(5, 3))
529
530            def forward(self, A):
531                return torch.matmul(A, torch.transpose(self.B, -1, -2))
532
533        GLOBALS.export_onnx_opset_version = self.opset_version
534        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
535        A = torch.randn(2, 3)
536        graph, _, __ = self._model_to_graph(
537            MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
538        )
539
540        for node in graph.nodes():
541            self.assertNotEqual(node.kind(), "onnx::Transpose")
542        self.assertEqual(len(list(graph.nodes())), 1)
543
544    def test_constant_fold_reshape(self):
545        class ReshapeModule(torch.nn.Module):
546            def __init__(
547                self,
548            ):
549                super().__init__()
550                self.weight = torch.nn.Buffer(torch.ones(5))
551
552            def forward(self, x):
553                b = self.weight.reshape(1, -1, 1, 1)
554                return x * b
555
556        GLOBALS.export_onnx_opset_version = self.opset_version
557        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
558        x = torch.randn(4, 5)
559        graph, _, __ = self._model_to_graph(
560            ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
561        )
562
563        for node in graph.nodes():
564            self.assertNotEqual(node.kind(), "onnx::Reshape")
565        self.assertEqual(len(list(graph.nodes())), 1)
566
567    def test_constant_fold_div(self):
568        class Module(torch.nn.Module):
569            def __init__(
570                self,
571            ):
572                super().__init__()
573                self.weight = torch.nn.Buffer(torch.ones(5))
574
575            def forward(self, x):
576                div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
577                return div * x
578
579        x = torch.randn(2, 5)
580        GLOBALS.export_onnx_opset_version = self.opset_version
581        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
582        graph, _, __ = self._model_to_graph(
583            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
584        )
585
586        for node in graph.nodes():
587            self.assertNotEqual(node.kind(), "onnx::Div")
588        self.assertEqual(len(list(graph.nodes())), 1)
589
590    def test_constant_fold_mul(self):
591        class Module(torch.nn.Module):
592            def __init__(
593                self,
594            ):
595                super().__init__()
596                self.weight = torch.nn.Buffer(torch.ones(5))
597
598            def forward(self, x):
599                mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
600                return mul / x
601
602        x = torch.randn(2, 5)
603        GLOBALS.export_onnx_opset_version = self.opset_version
604        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
605        graph, _, __ = self._model_to_graph(
606            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
607        )
608
609        for node in graph.nodes():
610            self.assertNotEqual(node.kind(), "onnx::Mul")
611        self.assertEqual(len(list(graph.nodes())), 1)
612
613    def test_constant_fold_add(self):
614        class Module(torch.nn.Module):
615            def __init__(
616                self,
617            ):
618                super().__init__()
619                self.weight = torch.nn.Buffer(torch.ones(5))
620
621            def forward(self, x):
622                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
623                return add - x
624
625        x = torch.randn(2, 5)
626        GLOBALS.export_onnx_opset_version = self.opset_version
627        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
628        graph, params_dict, __ = self._model_to_graph(
629            Module(),
630            (x,),
631            do_constant_folding=True,
632            operator_export_type=OperatorExportTypes.ONNX,
633            input_names=["x"],
634            dynamic_axes={"x": [0, 1]},
635        )
636        for node in graph.nodes():
637            self.assertTrue(node.kind() != "onnx::Add")
638        self.assertEqual(len(list(graph.nodes())), 1)
639        params = list(params_dict.values())
640        self.assertEqual(len(params), 1)
641        weight = params[0]
642        self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))
643
644    def test_constant_fold_sub(self):
645        class Module(torch.nn.Module):
646            def __init__(
647                self,
648            ):
649                super().__init__()
650                self.weight = torch.nn.Buffer(torch.ones(5))
651
652            def forward(self, x):
653                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
654                return sub + x
655
656        x = torch.randn(2, 5)
657        GLOBALS.export_onnx_opset_version = self.opset_version
658        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
659        graph, params_dict, __ = self._model_to_graph(
660            Module(),
661            (x,),
662            do_constant_folding=True,
663            operator_export_type=OperatorExportTypes.ONNX,
664            input_names=["x"],
665            dynamic_axes={"x": [0, 1]},
666        )
667        for node in graph.nodes():
668            self.assertNotEqual(node.kind(), "onnx::Sub")
669        self.assertEqual(len(list(graph.nodes())), 1)
670        params = list(params_dict.values())
671        self.assertEqual(len(params), 1)
672        weight = params[0]
673        self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0]))
674
675    def test_constant_fold_sqrt(self):
676        class Module(torch.nn.Module):
677            def __init__(
678                self,
679            ):
680                super().__init__()
681                self.weight = torch.nn.Buffer(torch.ones(5))
682
683            def forward(self, x):
684                sqrt = torch.sqrt(self.weight)
685                return sqrt / x
686
687        x = torch.randn(2, 5)
688        GLOBALS.export_onnx_opset_version = self.opset_version
689        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
690        graph, _, __ = self._model_to_graph(
691            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
692        )
693        for node in graph.nodes():
694            self.assertNotEqual(node.kind(), "onnx::Sqrt")
695        self.assertEqual(len(list(graph.nodes())), 1)
696
697    def test_constant_fold_shape(self):
698        class ShapeModule(torch.nn.Module):
699            def __init__(self) -> None:
700                super().__init__()
701                self.weight = torch.nn.Buffer(torch.ones(5))
702
703            def forward(self, x):
704                shape = self.weight.shape[0]
705                return x + shape
706
707        x = torch.randn(2, 5)
708        GLOBALS.export_onnx_opset_version = self.opset_version
709        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
710        graph, _, __ = self._model_to_graph(
711            ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
712        )
713        for node in graph.nodes():
714            self.assertNotEqual(node.kind(), "onnx::Shape")
715        self.assertEqual(len(list(graph.nodes())), 2)
716
717    def test_constant_fold_upsample_scale_fold_as_constant(self):
718        # upsample scale is a constant, not a model parameter,
719        # therefore should not be added as initializer after constant folding.
720        model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
721        x = torch.randn(1, 32, 224, 224)
722        f = io.BytesIO()
723        torch.onnx.export(model, x, f)
724        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
725        self.assertEqual(len(onnx_model.graph.initializer), 0)
726
727    def test_verbose(self):
728        class MyModule(torch.nn.Module):
729            def forward(self, input):
730                return torch.exp(input)
731
732        x = torch.randn(3, 4)
733
734        def is_model_stripped(f, verbose=None):
735            if verbose is None:
736                torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
737            else:
738                torch.onnx.export(
739                    MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
740                )
741            model = onnx.load(io.BytesIO(f.getvalue()))
742            model_strip = copy.copy(model)
743            onnx.helper.strip_doc_string(model_strip)
744            return model == model_strip
745
746        # test verbose=False (default)
747        self.assertTrue(is_model_stripped(io.BytesIO()))
748        # test verbose=True
749        self.assertFalse(is_model_stripped(io.BytesIO(), True))
750
751    # NB: remove this test once DataParallel can be correctly handled
752    def test_error_on_data_parallel(self):
753        model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
754        x = torch.randn(1, 2, 3, 4)
755        f = io.BytesIO()
756        with self.assertRaisesRegex(
757            ValueError,
758            "torch.nn.DataParallel is not supported by ONNX "
759            "exporter, please use 'attribute' module to "
760            "unwrap model from torch.nn.DataParallel. Try ",
761        ):
762            torch.onnx.export(model, x, f, opset_version=self.opset_version)
763
764    @skipIfUnsupportedMinOpsetVersion(11)
765    def test_sequence_dim(self):
766        class Module(torch.nn.Module):
767            def forward(self, x, y):
768                return [x, y]
769
770        model = Module()
771        # Export with scripting to keep output as Sequence type.
772        # Tracing unpacks the list.
773        script_model = torch.jit.script(model)
774        x = torch.randn(2, 3)
775
776        # Case 1: dynamic axis
777        f = io.BytesIO()
778        y = torch.randn(2, 3)
779        torch.onnx.export(
780            script_model,
781            (x, y),
782            f,
783            opset_version=self.opset_version,
784            input_names=["x", "y"],
785            dynamic_axes={"y": [1]},
786        )
787        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
788        loop_output_value_info_proto = onnx_model.graph.output[0]
789        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
790            loop_output_value_info_proto.name, 1, [2, None]
791        )
792        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
793
794        # Case 2: no dynamic axes.
795        f = io.BytesIO()
796        y = torch.randn(2, 3)
797        torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
798        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
799        loop_output_value_info_proto = onnx_model.graph.output[0]
800        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
801            loop_output_value_info_proto.name, 1, [2, 3]
802        )
803        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
804
805    def test_export_mode(self):
806        class MyModule(torch.nn.Module):
807            def forward(self, x):
808                y = x + 1
809                return y
810
811        model = MyModule()
812        x = torch.randn(10, 3, 128, 128)
813        f = io.BytesIO()
814
815        # set mode to in inference mode and export in training mode
816        model.eval()
817        old_state = model.training
818        torch.onnx.export(
819            model,
820            (x,),
821            f,
822            opset_version=self.opset_version,
823            training=torch.onnx.TrainingMode.TRAINING,
824        )
825        # verify that the model state is preserved
826        self.assertEqual(model.training, old_state)
827
828        # set mode to training mode and export in inference mode
829        model.train()
830        old_state = model.training
831        torch.onnx.export(
832            model,
833            (x,),
834            f,
835            opset_version=self.opset_version,
836            training=torch.onnx.TrainingMode.EVAL,
837        )
838        # verify that the model state is preserved
839        self.assertEqual(model.training, old_state)
840
841    def test_export_does_not_fail_on_frozen_scripted_module(self):
842        class Inner(torch.nn.Module):
843            def forward(self, x):
844                if x > 0:
845                    return x
846                else:
847                    return x * x
848
849        class Outer(torch.nn.Module):
850            def __init__(self) -> None:
851                super().__init__()
852                self.inner = torch.jit.script(Inner())
853
854            def forward(self, x):
855                return self.inner(x)
856
857        x = torch.zeros(1)
858        # Freezing is only implemented in eval mode. So we need to call eval()
859        outer_module = Outer().eval()
860        module = torch.jit.trace_module(outer_module, {"forward": (x)})
861        # jit.freeze removes the training attribute in the module
862        module = torch.jit.freeze(module)
863
864        torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
865
866    @skipIfUnsupportedMinOpsetVersion(15)
867    def test_local_function(self):
868        class N(torch.nn.Module):
869            def __init__(self, prob):
870                super().__init__()
871                self.dropout = torch.nn.Dropout(prob)
872
873            def forward(self, x):
874                return self.dropout(x)
875
876        class M(torch.nn.Module):
877            def __init__(self, num_layers):
878                super().__init__()
879                self.num_layers = num_layers
880                self.lns = torch.nn.ModuleList(
881                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
882                )
883                self.celu1 = torch.nn.CELU(1.0)
884                self.celu2 = torch.nn.CELU(2.0)
885                self.dropout = N(0.5)
886
887            def forward(self, x, y, z):
888                res1 = self.celu1(x)
889                res2 = self.celu2(y)
890                for ln in self.lns:
891                    z = ln(z)
892                return res1 + res2, self.dropout(z)
893
894        x = torch.randn(2, 3)
895        y = torch.randn(2, 3)
896        z = torch.randn(2, 3)
897
898        # Export specified modules. Test against specifying modules that won't
899        # exist in the exported model.
900        # Model export in inference mode will remove dropout node,
901        # thus the dropout module no longer exist in graph.
902        f = io.BytesIO()
903        torch.onnx.export(
904            M(3),
905            (x, y, z),
906            f,
907            opset_version=self.opset_version,
908            export_modules_as_functions={
909                torch.nn.CELU,
910                torch.nn.Dropout,
911                torch.nn.LayerNorm,
912            },
913        )
914
915        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
916
917        # Check function definition
918        funcs = onnx_model.functions
919        celu_funcs = [f for f in funcs if f.name == "CELU"]
920        self.assertEqual(len(celu_funcs), 1)
921        self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation")
922        self.assertEqual(len(celu_funcs[0].attribute), 3)
923        ln_funcs = [f for f in funcs if f.name == "LayerNorm"]
924        self.assertEqual(len(ln_funcs), 1)
925        self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization")
926        self.assertEqual(len(ln_funcs[0].attribute), 3)
927
928        # Check local function nodes
929        nodes = onnx_model.graph.node
930        celu_ns = [n for n in nodes if n.op_type == "CELU"]
931        ln_ns = [n for n in nodes if n.op_type == "LayerNorm"]
932        self.assertEqual(len(celu_ns), 2)
933        self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation")
934        self.assertEqual(len(celu_ns[0].attribute), 3)
935        self.assertEqual(len(ln_ns), 3)
936        self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization")
937        self.assertEqual(len(ln_ns[0].attribute), 3)
938
939        # Export specified modules.
940        f = io.BytesIO()
941        torch.onnx.export(
942            M(3),
943            (x, y, z),
944            f,
945            opset_version=self.opset_version,
946            export_modules_as_functions={torch.nn.CELU},
947        )
948
949        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
950        funcs = onnx_model.functions
951        self.assertEqual(len(funcs), 1)
952        self.assertEqual(funcs[0].name, "CELU")
953
954        # Export with empty specified modules. Normal export.
955        f = io.BytesIO()
956        torch.onnx.export(
957            M(3),
958            (x, y, z),
959            f,
960            opset_version=self.opset_version,
961            export_modules_as_functions=set(),
962        )
963
964        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
965        funcs = onnx_model.functions
966        self.assertEqual(len(funcs), 0)
967
968        # Export all modules. Should contain {M, CELU, LayerNorm}.
969        f = io.BytesIO()
970        torch.onnx.export(
971            M(3),
972            (x, y, z),
973            f,
974            opset_version=self.opset_version,
975            export_modules_as_functions=True,
976        )
977
978        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
979        funcs = onnx_model.functions
980        self.assertEqual(len(funcs), 3)
981
982    @skipIfUnsupportedMinOpsetVersion(15)
983    def test_local_function_overloads(self):
984        class NWithOverloads(torch.nn.Module):
985            def forward(self, x, y=None, z=None):
986                if y is None:
987                    return x + 1
988                elif z is None:
989                    return x + y
990                else:
991                    return x + y, x + z
992
993        class M(torch.nn.Module):
994            def __init__(self, num_layers):
995                super().__init__()
996                self.n = NWithOverloads()
997
998            def forward(self, x, y, z):
999                return self.n(x), self.n(x, y), self.n(x, y, z)
1000
1001        x = torch.randn(2, 3)
1002        y = torch.randn(2, 3)
1003        z = torch.randn(2, 3)
1004
1005        f = io.BytesIO()
1006        torch.onnx.export(
1007            M(3),
1008            (x, y, z),
1009            f,
1010            opset_version=self.opset_version,
1011            export_modules_as_functions={NWithOverloads},
1012        )
1013
1014        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1015        funcs = onnx_model.functions
1016        self.assertEqual(len(funcs), 3)
1017        func_names = [f.name for f in funcs]
1018        self.assertIn("NWithOverloads", func_names)
1019        self.assertIn("NWithOverloads.1", func_names)
1020        self.assertIn("NWithOverloads.2", func_names)
1021
1022    # Failing after ONNX 1.13.0
1023    @skipIfUnsupportedMaxOpsetVersion(1)
1024    def test_local_function_infer_scopes(self):
1025        class M(torch.nn.Module):
1026            def forward(self, x):
1027                # Concatenation of scalars inserts unscoped tensors in IR graph.
1028                new_tensor_shape = x.size()[:-1] + (1, 1, -1)
1029                tensor = x.view(*new_tensor_shape)
1030                return tensor
1031
1032        x = torch.randn(4, 5)
1033        f = io.BytesIO()
1034        torch.onnx.export(
1035            M(),
1036            (x,),
1037            f,
1038            export_modules_as_functions=True,
1039            opset_version=self.opset_version,
1040            do_constant_folding=False,
1041        )
1042
1043        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1044        funcs = onnx_model.functions
1045        self.assertIn("M", [f.name for f in funcs])
1046
1047    @skipIfUnsupportedMinOpsetVersion(15)
1048    def test_local_function_predefined_attributes(self):
1049        class M(torch.nn.Module):
1050            num_layers: int
1051
1052            def __init__(self, num_layers):
1053                super().__init__()
1054                self.num_layers = num_layers
1055                self.lns = torch.nn.ModuleList(
1056                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1057                )
1058
1059            def forward(self, x):
1060                for ln in self.lns:
1061                    x = ln(x)
1062                return x
1063
1064        x = torch.randn(2, 3)
1065        f = io.BytesIO()
1066        model = M(3)
1067        torch.onnx.export(
1068            model,
1069            (x,),
1070            f,
1071            export_modules_as_functions=True,
1072            opset_version=self.opset_version,
1073        )
1074
1075        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1076        funcs = onnx_model.functions
1077        m_funcs = [fn for fn in funcs if fn.name == "M"]
1078        self.assertEqual(m_funcs[0].attribute, ["num_layers"])
1079        ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"]
1080        self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"])
1081
1082        from onnx import helper
1083
1084        m_node = [n for n in onnx_model.graph.node if n.op_type == "M"]
1085        self.assertEqual(
1086            m_node[0].attribute[0],
1087            helper.make_attribute("num_layers", model.num_layers),
1088        )
1089
1090        ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"]
1091        expected_ln_attrs = [
1092            helper.make_attribute(
1093                "elementwise_affine", model.lns[0].elementwise_affine
1094            ),
1095            helper.make_attribute("eps", model.lns[0].eps),
1096        ]
1097        for ln_node in ln_nodes:
1098            self.assertIn(ln_node.attribute[0], expected_ln_attrs)
1099            self.assertIn(ln_node.attribute[1], expected_ln_attrs)
1100
1101    # This test cases checks the issue where an object does not have an attribute.
1102    # When enabling `export_modules_as_functions = True`, the exporter could return an
1103    # AttributeError. With this test case, we check that the export passes successfully
1104    # without any AttributeError exceptions.
1105    # See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that
1106    # this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`.
1107    @skipIfUnsupportedMinOpsetVersion(15)
1108    def test_local_function_subset_of_predefined_attributes(self):
1109        class M(torch.nn.Module):
1110            num_layers: int
1111
1112            def __init__(self, num_layers):
1113                super().__init__()
1114                self.embed_layer = torch.nn.Embedding.from_pretrained(
1115                    torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
1116                )
1117                self.num_layers = num_layers
1118                self.lns = torch.nn.ModuleList(
1119                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1120                )
1121
1122            def forward(self, x):
1123                e = self.embed_layer(torch.LongTensor([1]))
1124                for ln in self.lns:
1125                    x = ln(x)
1126                return x, e
1127
1128        x = torch.randn(2, 3)
1129        f = io.BytesIO()
1130        model = M(3)
1131        torch.onnx.export(
1132            model,
1133            (x,),
1134            f,
1135            export_modules_as_functions=True,
1136            opset_version=self.opset_version,
1137            verbose=True,  # Allows the test case to print `Skipping module attribute 'freeze'`
1138        )
1139
1140    def test_node_scope(self):
1141        class N(torch.nn.Module):
1142            def __init__(self) -> None:
1143                super().__init__()
1144                self.relu = torch.nn.ReLU()
1145
1146            def forward(self, x):
1147                return self.relu(x)
1148
1149        class M(torch.nn.Module):
1150            def __init__(self, num_layers):
1151                super().__init__()
1152                self.num_layers = num_layers
1153                self.lns = torch.nn.ModuleList(
1154                    [torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
1155                )
1156                self.gelu1 = torch.nn.GELU()
1157                self.gelu2 = torch.nn.GELU()
1158                self.relu = N()
1159
1160            def forward(self, x, y, z):
1161                res1 = self.gelu1(x)
1162                res2 = self.gelu2(y)
1163                for ln in self.lns:
1164                    z = ln(z)
1165                return res1 + res2, self.relu(z)
1166
1167        x = torch.randn(2, 3)
1168        y = torch.randn(2, 3)
1169        z = torch.randn(2, 3)
1170
1171        model = M(3)
1172        expected_scope_names = {
1173            "M::/torch.nn.modules.activation.GELU::gelu1",
1174            "M::/torch.nn.modules.activation.GELU::gelu2",
1175            "M::/torch.nn.modules.normalization.LayerNorm::lns.0",
1176            "M::/torch.nn.modules.normalization.LayerNorm::lns.1",
1177            "M::/torch.nn.modules.normalization.LayerNorm::lns.2",
1178            "M::/N::relu/torch.nn.modules.activation.ReLU::relu",
1179            "M::",
1180        }
1181
1182        graph, _, _ = self._model_to_graph(
1183            model, (x, y, z), input_names=[], dynamic_axes={}
1184        )
1185        for node in graph.nodes():
1186            self.assertIn(
1187                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
1188                expected_scope_names,
1189            )
1190
1191        graph, _, _ = self._model_to_graph(
1192            torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
1193        )
1194        for node in graph.nodes():
1195            self.assertIn(
1196                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
1197                expected_scope_names,
1198            )
1199
1200    def test_scope_of_constants_when_combined_by_cse_pass(self):
1201        layer_num = 3
1202
1203        class M(torch.nn.Module):
1204            def __init__(self, constant):
1205                super().__init__()
1206                self.constant = constant
1207
1208            def forward(self, x):
1209                # 'self.constant' is designed to be the same for all layers,
1210                # hence it is common sub expression.
1211                return x + self.constant
1212
1213        class N(torch.nn.Module):
1214            def __init__(self, layers: int = layer_num):
1215                super().__init__()
1216                self.layers = torch.nn.ModuleList(
1217                    [M(constant=torch.tensor(1.0)) for i in range(layers)]
1218                )
1219
1220            def forward(self, x):
1221                for layer in self.layers:
1222                    x = layer(x)
1223                return x
1224
1225        graph, _, _ = self._model_to_graph(
1226            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1227        )
1228
1229        # NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis,
1230        #       so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
1231        #       If CSE in exporter is improved later, this test needs to be updated.
1232        #       It should expect 1 constant, with same scope as root.
1233        expected_root_scope_name = "N::"
1234        expected_layer_scope_name = "M::layers"
1235        expected_constant_scope_name = [
1236            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1237            for i in range(layer_num)
1238        ]
1239
1240        constant_scope_names = []
1241        for node in graph.nodes():
1242            if node.kind() == "onnx::Constant":
1243                constant_scope_names.append(
1244                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1245                )
1246        self.assertEqual(constant_scope_names, expected_constant_scope_name)
1247
1248    def test_scope_of_nodes_when_combined_by_cse_pass(self):
1249        layer_num = 3
1250
1251        class M(torch.nn.Module):
1252            def __init__(self, constant, bias):
1253                super().__init__()
1254                self.constant = constant
1255                self.bias = bias
1256
1257            def forward(self, x):
1258                # 'constant' and 'x' is designed to be the same for all layers,
1259                # hence `x + self.constant` is common sub expression.
1260                # 'bias' is designed to be different for all layers,
1261                # hence `* self.bias` is not common sub expression.
1262                return (x + self.constant) * self.bias
1263
1264        class N(torch.nn.Module):
1265            def __init__(self, layers: int = layer_num):
1266                super().__init__()
1267
1268                self.layers = torch.nn.ModuleList(
1269                    [
1270                        M(constant=torch.tensor([1.0]), bias=torch.randn(1))
1271                        for i in range(layers)
1272                    ]
1273                )
1274
1275            def forward(self, x):
1276                y = []
1277                for layer in self.layers:
1278                    y.append(layer(x))
1279                return y[0], y[1], y[2]
1280
1281        graph, _, _ = self._model_to_graph(
1282            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1283        )
1284        expected_root_scope_name = "N::"
1285        expected_layer_scope_name = "M::layers"
1286        expected_add_scope_names = [
1287            f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
1288        ]
1289        expected_mul_scope_names = [
1290            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1291            for i in range(layer_num)
1292        ]
1293
1294        add_scope_names = []
1295        mul_scope_names = []
1296        for node in graph.nodes():
1297            if node.kind() == "onnx::Add":
1298                add_scope_names.append(
1299                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1300                )
1301            elif node.kind() == "onnx::Mul":
1302                mul_scope_names.append(
1303                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1304                )
1305        self.assertEqual(add_scope_names, expected_add_scope_names)
1306        self.assertEqual(mul_scope_names, expected_mul_scope_names)
1307
1308    def test_aten_fallthrough(self):
1309        # Test aten export of op with no symbolic
1310        class Module(torch.nn.Module):
1311            def forward(self, x):
1312                return torch.erfc(x)
1313
1314        x = torch.randn(2, 3, 4)
1315        GLOBALS.export_onnx_opset_version = self.opset_version
1316        graph, _, __ = self._model_to_graph(
1317            Module(),
1318            (x,),
1319            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1320            input_names=["x"],
1321            dynamic_axes={"x": [0, 1, 2]},
1322        )
1323        iter = graph.nodes()
1324        self.assertEqual(next(iter).kind(), "aten::erfc")
1325
1326    def test_custom_op_fallthrough(self):
1327        # Test custom op
1328        op_source = """
1329        #include <torch/script.h>
1330
1331        torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
1332          return self + other;
1333        }
1334
1335        static auto registry =
1336          torch::RegisterOperators("custom_namespace::custom_op", &custom_add);
1337        """
1338
1339        torch.utils.cpp_extension.load_inline(
1340            name="custom_add",
1341            cpp_sources=op_source,
1342            is_python_module=False,
1343            verbose=True,
1344        )
1345
1346        class FooModel(torch.nn.Module):
1347            def forward(self, input, other):
1348                # Calling custom op
1349                return torch.ops.custom_namespace.custom_op(input, other)
1350
1351        x = torch.randn(2, 3, 4, requires_grad=False)
1352        y = torch.randn(2, 3, 4, requires_grad=False)
1353        model = FooModel()
1354        graph, _, __ = self._model_to_graph(
1355            model,
1356            (x, y),
1357            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
1358            input_names=["x", "y"],
1359            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
1360        )
1361        iter = graph.nodes()
1362        self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
1363
1364    # gelu is exported as onnx::Gelu for opset >= 20
1365    @skipIfUnsupportedMaxOpsetVersion(19)
1366    def test_custom_opsets_gelu(self):
1367        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
1368
1369        def gelu(g, self, approximate):
1370            return g.op("com.microsoft::Gelu", self).setType(self.type())
1371
1372        torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9)
1373        model = torch.nn.GELU(approximate="none")
1374        x = torch.randn(3, 3)
1375        f = io.BytesIO()
1376        torch.onnx.export(
1377            model,
1378            (x,),
1379            f,
1380            opset_version=self.opset_version,
1381            custom_opsets={"com.microsoft": 1},
1382        )
1383
1384        graph = onnx.load(io.BytesIO(f.getvalue()))
1385        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1386        self.assertEqual(graph.opset_import[0].version, self.opset_version)
1387        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1388        self.assertEqual(graph.opset_import[1].version, 1)
1389
1390    # gelu is exported as onnx::Gelu for opset >= 20
1391    @skipIfUnsupportedMaxOpsetVersion(19)
1392    def test_register_aten_custom_op_symbolic(self):
1393        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
1394
1395        def gelu(g, self, approximate):
1396            return g.op("com.microsoft::Gelu", self).setType(self.type())
1397
1398        torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9)
1399        model = torch.nn.GELU(approximate="none")
1400        x = torch.randn(3, 3)
1401        f = io.BytesIO()
1402        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1403        graph = onnx.load(io.BytesIO(f.getvalue()))
1404
1405        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1406        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1407
1408    @skipIfNoLapack
1409    def test_custom_opsets_inverse(self):
1410        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
1411
1412        class CustomInverse(torch.nn.Module):
1413            def forward(self, x):
1414                return torch.inverse(x) + x
1415
1416        def linalg_inv(g, self):
1417            return g.op("com.microsoft::Inverse", self).setType(self.type())
1418
1419        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9)
1420        model = CustomInverse()
1421        x = torch.randn(2, 3, 3)
1422        f = io.BytesIO()
1423        torch.onnx.export(
1424            model,
1425            (x,),
1426            f,
1427            opset_version=self.opset_version,
1428            custom_opsets={"com.microsoft": 1},
1429        )
1430
1431        graph = onnx.load(io.BytesIO(f.getvalue()))
1432        self.assertEqual(graph.graph.node[0].op_type, "Inverse")
1433        self.assertEqual(graph.opset_import[0].version, self.opset_version)
1434        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1435        self.assertEqual(graph.opset_import[1].version, 1)
1436
1437    def test_onnx_fallthrough(self):
1438        # Test aten export of op with symbolic for aten
1439        class Module(torch.nn.Module):
1440            def forward(self, x):
1441                return torch.digamma(x)
1442
1443        x = torch.randn(100, 128)
1444        graph, _, __ = self._model_to_graph(
1445            Module(),
1446            (x,),
1447            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1448            input_names=["x"],
1449            dynamic_axes={"x": [0, 1]},
1450        )
1451        iter = graph.nodes()
1452        self.assertEqual(next(iter).kind(), "aten::digamma")
1453
1454    # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
1455    @skipIfUnsupportedMaxOpsetVersion(10)
1456    def test_prim_fallthrough(self):
1457        # Test prim op
1458        class PrimModule(torch.jit.ScriptModule):
1459            @torch.jit.script_method
1460            def forward(self, x):
1461                if isinstance(x, list):
1462                    y = x
1463                else:
1464                    y = [x]
1465                return y
1466
1467        x = torch.tensor([2])
1468        model = PrimModule()
1469        model.eval()
1470        graph, _, __ = self._model_to_graph(
1471            model,
1472            (x,),
1473            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1474            input_names=["x"],
1475            dynamic_axes={"x": [0]},
1476        )
1477        iter = graph.nodes()
1478        self.assertEqual(next(iter).kind(), "prim::ListConstruct")
1479
1480    def test_custom_layer_tuple(self):
1481        class CustomFunction(torch.autograd.Function):
1482            @staticmethod
1483            def symbolic(g, input):
1484                return g.op("CustomNamespace::Custom", input, outputs=2)
1485
1486            @staticmethod
1487            def forward(ctx, input):
1488                return input, input
1489
1490        class Custom(torch.nn.Module):
1491            def forward(self, input):
1492                return CustomFunction.apply(input)
1493
1494        model = Custom()
1495        batch = torch.FloatTensor(1, 3)
1496
1497        graph, _, _ = self._model_to_graph(
1498            model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
1499        )
1500        iter = graph.nodes()
1501        self.assertEqual(next(iter).kind(), "CustomNamespace::Custom")
1502
1503    def test_autograd_onnx_fallthrough(self):
1504        class CustomFunction(torch.autograd.Function):
1505            @staticmethod
1506            def forward(ctx, input):
1507                ctx.save_for_backward(input)
1508                return input.clamp(min=0)
1509
1510            @staticmethod
1511            def backward(ctx, grad_output):
1512                (input,) = ctx.saved_tensors
1513                grad_input = grad_output.clone()
1514                grad_input[input < 0] = 0
1515                return grad_input
1516
1517        class Custom(torch.nn.Module):
1518            def forward(self, input):
1519                return CustomFunction.apply(input)
1520
1521        model = Custom()
1522        batch = torch.FloatTensor(1, 3)
1523
1524        graph, _, _ = self._model_to_graph(
1525            model,
1526            batch,
1527            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1528            input_names=["batch"],
1529            dynamic_axes={"batch": [0, 1]},
1530        )
1531        iter = graph.nodes()
1532        self.assertEqual(next(iter).kind(), "prim::PythonOp")
1533
1534    def test_autograd_module_name(self):
1535        class CustomFunction(torch.autograd.Function):
1536            @staticmethod
1537            def forward(ctx, input):
1538                ctx.save_for_backward(input)
1539                return input.clamp(min=0)
1540
1541            @staticmethod
1542            def backward(ctx, grad_output):
1543                (input,) = ctx.saved_tensors
1544                grad_input = grad_output.clone()
1545                grad_input[input < 0] = 0
1546                return grad_input
1547
1548        class Custom(torch.nn.Module):
1549            def forward(self, input):
1550                return CustomFunction.apply(input) + CustomFunction2.apply(input)
1551
1552        model = Custom()
1553        batch = torch.FloatTensor(1, 3)
1554
1555        graph, _, _ = self._model_to_graph(
1556            model,
1557            batch,
1558            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1559            input_names=["batch"],
1560            dynamic_axes={"batch": [0, 1]},
1561        )
1562        iter = graph.nodes()
1563        autograd1 = next(iter)
1564        autograd2 = next(iter)
1565        self.assertEqual(autograd1.kind(), "prim::PythonOp")
1566        self.assertEqual(autograd2.kind(), "prim::PythonOp")
1567        self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))
1568
1569    def test_unused_initializers(self):
1570        class Model(torch.nn.Module):
1571            def __init__(self) -> None:
1572                super().__init__()
1573                self.conv2 = torch.nn.ConvTranspose2d(
1574                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
1575                )
1576                self.k_proj = torch.nn.Linear(5, 5, bias=True)
1577
1578            def forward(self, x):
1579                x = self.conv2(x)
1580                return x
1581
1582        x = torch.randn(20, 16, 50, 100)
1583        GLOBALS.export_onnx_opset_version = self.opset_version
1584        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1585        _, params_dict, __ = self._model_to_graph(
1586            Model(),
1587            (x,),
1588            do_constant_folding=False,
1589            operator_export_type=OperatorExportTypes.ONNX,
1590            input_names=["x"],
1591            dynamic_axes={"x": [0, 1, 2, 3]},
1592        )
1593
1594        self.assertEqual(len(params_dict), 2)
1595
1596    def test_scripting_param(self):
1597        class MyModule(torch.nn.Module):
1598            def __init__(self) -> None:
1599                super().__init__()
1600                self.conv = torch.nn.Conv2d(
1601                    3, 16, kernel_size=1, stride=2, padding=3, bias=True
1602                )
1603                self.bn = torch.nn.BatchNorm2d(16, affine=True)
1604
1605            def forward(self, x):
1606                x = self.conv(x)
1607                bn = self.bn(x)
1608                return bn
1609
1610        model = torch.jit.script(MyModule())
1611        x = torch.randn(10, 3, 128, 128)
1612        GLOBALS.export_onnx_opset_version = self.opset_version
1613        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1614        graph, _, __ = self._model_to_graph(
1615            model,
1616            (x,),
1617            do_constant_folding=True,
1618            operator_export_type=OperatorExportTypes.ONNX,
1619            training=torch.onnx.TrainingMode.TRAINING,
1620            input_names=["x"],
1621            dynamic_axes={"x": [0, 1, 2, 3]},
1622        )
1623
1624        graph_input_params = [param.debugName() for param in graph.inputs()]
1625        for item in dict(model.named_parameters()):
1626            self.assertIn(
1627                item,
1628                graph_input_params,
1629                "Graph parameter names does not match model parameters.",
1630            )
1631
1632    def test_fuse_conv_bn(self):
1633        class Fuse(torch.nn.Module):
1634            def __init__(self) -> None:
1635                super().__init__()
1636                self.conv = torch.nn.Conv2d(
1637                    3, 2, kernel_size=1, stride=2, padding=3, bias=True
1638                )
1639                self.bn = torch.nn.BatchNorm2d(2)
1640
1641            def forward(self, x):
1642                out = self.conv(x)
1643                return self.bn(out)
1644
1645        x = torch.randn(2, 3, 2, 2, requires_grad=True)
1646        graph, _, __ = self._model_to_graph(
1647            Fuse(),
1648            (x,),
1649            training=TrainingMode.EVAL,
1650            input_names=["x"],
1651            dynamic_axes={"x": [0, 1, 2, 3]},
1652        )
1653        for node in graph.nodes():
1654            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1655            self.assertEqual(node.kind(), "onnx::Conv")
1656
1657        self.assertEqual(len(list(graph.nodes())), 1)
1658
1659    def test_fuse_resnet18(self):
1660        model = torchvision.models.resnet18(weights=None)
1661        x = torch.randn(2, 3, 224, 224, requires_grad=True)
1662        graph, _, __ = self._model_to_graph(
1663            model,
1664            (x,),
1665            training=TrainingMode.EVAL,
1666            input_names=["x"],
1667            dynamic_axes={"x": [0, 1, 2, 3]},
1668        )
1669
1670        for node in graph.nodes():
1671            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1672
1673    def test_onnx_function_substitution_pass(self):
1674        @torch.jit.script
1675        def f(x: torch.Tensor, y: torch.Tensor):
1676            z = x - y
1677            return x + z
1678
1679        class MyModule(torch.nn.Module):
1680            def forward(self, x, y):
1681                return f(x, y)
1682
1683        input_1 = torch.tensor([11])
1684        input_2 = torch.tensor([12])
1685        GLOBALS.export_onnx_opset_version = self.opset_version
1686        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1687        graph, _, __ = self._model_to_graph(
1688            MyModule(),
1689            (input_1, input_2),
1690            do_constant_folding=True,
1691            operator_export_type=OperatorExportTypes.ONNX,
1692            input_names=["input_1", "input_2"],
1693            dynamic_axes={"input_1": [0], "input_2": [0]},
1694        )
1695        # Check that the prim::Constant node in the graph for representing the
1696        # scripted function `f` is removed and the following prim::CallFunction
1697        # is replced by inline graph, with onnx::Sub and onnx::Add nodes.
1698        for node in graph.nodes():
1699            self.assertNotEqual(node.kind(), "prim::Constant")
1700        self.assertEqual(
1701            len(list(graph.nodes())), 2
1702        )  # onnx::Sub and onnx::Add nodes only.
1703
1704    def test_onnx_value_name(self):
1705        class MyModule(torch.nn.Module):
1706            def __init__(self) -> None:
1707                super().__init__()
1708                self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
1709                self.in_bias = torch.nn.Parameter(torch.Tensor(3))
1710
1711            def forward(self, x):
1712                start = 0
1713                end = None
1714                weight = self.in_weight
1715                bias = self.in_bias
1716                weight = weight[start:end, :]
1717                if bias is not None:
1718                    bias = bias[start:end]
1719                return torch.nn.functional.linear(x, weight, bias)
1720
1721        model = MyModule()
1722        x = torch.randn(3, 3)
1723        f = io.BytesIO()
1724
1725        model.eval()
1726        torch.onnx.export(
1727            model,
1728            (x,),
1729            f,
1730            opset_version=self.opset_version,
1731            keep_initializers_as_inputs=True,
1732        )
1733        graph = onnx.load(io.BytesIO(f.getvalue()))
1734        self.assertEqual(graph.graph.input[1].name, "in_weight")
1735        self.assertEqual(graph.graph.input[2].name, "in_bias")
1736
1737    def test_onnx_node_naming(self):
1738        class MainModule(torch.nn.Module):
1739            def __init__(self) -> None:
1740                super().__init__()
1741                self._module_1 = torch.nn.Linear(10, 10)
1742                self._module_2 = torch.nn.Linear(10, 10)
1743                self._module_3 = torch.nn.Linear(10, 10)
1744                self._module_4 = torch.nn.Linear(10, 10)
1745
1746            def forward(self, x):
1747                y = self._module_1(x)
1748                z = self._module_2(y)
1749                z = self._module_3(y * z)
1750                z = self._module_4(y * z)
1751                return z
1752
1753        module = MainModule()
1754        ref_node_names = [
1755            "/_module_1/Gemm",
1756            "/_module_2/Gemm",
1757            "/_module_3/Gemm",
1758            "/_module_4/Gemm",
1759            "/Mul",
1760            "/Mul_1",
1761        ]
1762        f = io.BytesIO()
1763
1764        torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
1765        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1766        for n in onnx_model.graph.node:
1767            self.assertIn(n.name, ref_node_names)
1768
1769        torch.onnx.export(
1770            torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
1771        )
1772        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1773        for n in onnx_model.graph.node:
1774            self.assertIn(n.name, ref_node_names)
1775
1776    def _test_deduplicate_initializers(self, torchscript=False):
1777        class MyModule(torch.nn.Module):
1778            def __init__(self) -> None:
1779                super().__init__()
1780                self.layer1 = torch.nn.Linear(3, 3)
1781                self.layer2 = torch.nn.Linear(3, 3)
1782
1783                # Reusing layers.
1784                self.layer3 = self.layer1
1785
1786                # Reusing parameters.
1787                self.layer2.weight = self.layer1.weight
1788                self.layer1.bias = self.layer2.bias
1789
1790                # Parameter with different tensors equal in value.
1791                self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1792                self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1793
1794            def forward(self, x):
1795                return (
1796                    self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2
1797                )
1798
1799        model = torch.jit.script(MyModule()) if torchscript else MyModule()
1800
1801        x = torch.randn(3, 3)
1802        param_name_set = {k for k, _ in model.named_parameters()}
1803
1804        # Test training mode.
1805        model.train()
1806        f = io.BytesIO()
1807        torch.onnx.export(
1808            model,
1809            (x,),
1810            f,
1811            training=TrainingMode.TRAINING,
1812            opset_version=self.opset_version,
1813        )
1814        graph = onnx.load(io.BytesIO(f.getvalue()))
1815        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1816
1817        model.train()
1818        f = io.BytesIO()
1819        torch.onnx.export(
1820            model,
1821            (x,),
1822            f,
1823            training=TrainingMode.PRESERVE,
1824            opset_version=self.opset_version,
1825        )
1826        graph = onnx.load(io.BytesIO(f.getvalue()))
1827        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1828
1829        # Test eval mode.
1830        model.eval()
1831        f = io.BytesIO()
1832        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1833        graph = onnx.load(io.BytesIO(f.getvalue()))
1834        param_name_set.remove("param2")
1835        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1836
1837    def test_deduplicate_initializers(self):
1838        self._test_deduplicate_initializers(torchscript=False)
1839
1840    def test_deduplicate_initializers_torchscript(self):
1841        self._test_deduplicate_initializers(torchscript=True)
1842
1843    @skipIfNoCuda
1844    def test_deduplicate_initializers_diff_devices(self):
1845        class Model(torch.nn.Module):
1846            def __init__(self) -> None:
1847                super().__init__()
1848                self.w_cpu = torch.nn.Parameter(
1849                    torch.ones(3, device=torch.device("cpu"))
1850                )
1851                self.w_cuda = torch.nn.Parameter(
1852                    torch.ones(3, device=torch.device("cuda"))
1853                )
1854
1855            def forward(self, x, y):
1856                return x + self.w_cpu, y + self.w_cuda
1857
1858        x = torch.randn(3, 3, device=torch.device("cpu"))
1859        y = torch.randn(3, 3, device=torch.device("cuda"))
1860        f = io.BytesIO()
1861        torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
1862        graph = onnx.load(io.BytesIO(f.getvalue()))
1863        self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
1864
1865    def test_duplicated_output_node(self):
1866        class DuplicatedOutputNet(torch.nn.Module):
1867            def __init__(self, input_size, num_classes):
1868                super().__init__()
1869                self.fc1 = torch.nn.Linear(input_size, num_classes)
1870
1871            def forward(self, input0, input1):
1872                out1 = self.fc1(input0)
1873                out2 = self.fc1(input1)
1874                return out1, out1, out2, out1, out2
1875
1876        N, D_in, H, D_out = 64, 784, 500, 10
1877        pt_model = DuplicatedOutputNet(D_in, D_out)
1878
1879        f = io.BytesIO()
1880        x = torch.randn(N, D_in)
1881        dynamic_axes = {
1882            "input0": {0: "input0_dim0", 1: "input0_dim1"},
1883            "input1": {0: "input1_dim0", 1: "input1_dim1"},
1884            "output-0": {0: "output-0_dim0", 1: "output-0_dim1"},
1885            "output-1": {0: "output-1_dim0", 1: "output-1_dim1"},
1886            "output-2": {0: "output-2_dim0", 1: "output-2_dim1"},
1887            "output-3": {0: "output-3_dim0", 1: "output-3_dim1"},
1888            "output-4": {0: "output-4_dim0", 1: "output-4_dim1"},
1889        }
1890
1891        torch.onnx.export(
1892            pt_model,
1893            (x, x),
1894            f,
1895            input_names=["input0", "input1"],
1896            output_names=["output-0", "output-1", "output-2", "output-3", "output-4"],
1897            do_constant_folding=False,
1898            training=torch.onnx.TrainingMode.TRAINING,
1899            dynamic_axes=dynamic_axes,
1900            verbose=True,
1901            keep_initializers_as_inputs=True,
1902        )
1903
1904        graph = onnx.load(io.BytesIO(f.getvalue()))
1905        self.assertEqual(graph.graph.input[0].name, "input0")
1906        self.assertEqual(graph.graph.input[1].name, "input1")
1907        for i in range(5):
1908            self.assertEqual(graph.graph.output[i].name, f"output-{i}")
1909        self.assertEqual(graph.graph.node[0].op_type, "Gemm")
1910        self.assertEqual(graph.graph.node[1].op_type, "Identity")
1911        self.assertEqual(graph.graph.node[2].op_type, "Identity")
1912        self.assertEqual(graph.graph.node[3].op_type, "Gemm")
1913        self.assertEqual(graph.graph.node[4].op_type, "Identity")
1914
1915    def test_deduplicate_ignore_upsample_scale(self):
1916        # upsample scale is a constant, not a model parameter,
1917        # therefore should be ignored by shared weight deduplication.
1918        class Model(torch.nn.Module):
1919            def __init__(self) -> None:
1920                super().__init__()
1921                self.upsample_1 = torch.nn.Upsample(scale_factor=2)
1922                self.upsample_2 = torch.nn.Upsample(scale_factor=2)
1923
1924            def forward(self, x):
1925                return self.upsample_1(x), self.upsample_2(x)
1926
1927        f = io.BytesIO()
1928        x = torch.randn(1, 32, 224, 224)
1929        torch.onnx.export(Model(), x, f)
1930        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1931        # aten::upsample converts to onnx::resize
1932        resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
1933        self.assertEqual(len(resize_nodes), 2)
1934        for resize_node in resize_nodes:
1935            scale_node = [
1936                n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2]
1937            ]
1938            self.assertEqual(len(scale_node), 1)
1939            self.assertEqual(scale_node[0].op_type, "Constant")
1940
1941    def test_bad_symbolic_registration(self):
1942        _onnx_opset_version = 9
1943
1944        @parse_args("v")
1945        def cat(g, tensor_list, dim):
1946            tensors = _unpack_list(tensor_list)
1947            return g.op("Concat", *tensors, axis_i=dim)
1948
1949        torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version)
1950
1951        class CatModel(torch.nn.Module):
1952            def forward(self, x):
1953                return torch.cat((x, x, x), 0)
1954
1955        model = CatModel()
1956        x = torch.randn(2, 3)
1957        f = io.BytesIO()
1958        self.assertExpectedRaisesInline(
1959            AssertionError,
1960            lambda: torch.onnx.export(
1961                model, (x,), f, opset_version=_onnx_opset_version
1962            ),
1963            (
1964                "A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
1965                "'cat'. If you believe this is not due to custom symbolic implementation within your code or an external "
1966                "library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to "
1967                "report this bug."
1968            ),
1969        )
1970        torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)
1971
1972
1973if __name__ == "__main__":
1974    common_utils.run_tests()
1975