xref: /aosp_15_r20/external/pytorch/test/onnx/test_onnx_opset.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import io
4import itertools
5
6import onnx
7
8import pytorch_test_common
9
10import torch
11import torch.onnx
12from torch.nn import Module
13from torch.onnx import producer_name, producer_version
14from torch.onnx._globals import GLOBALS
15from torch.testing._internal import common_utils
16
17
18def check_onnx_opset_operator(
19    model, ops, opset_version=GLOBALS.export_onnx_opset_version
20):
21    # check_onnx_components
22    assert (
23        model.producer_name == producer_name
24        and model.producer_version == producer_version
25        and model.opset_import[0].version == opset_version
26    )
27
28    # check the schema with the onnx checker
29    onnx.checker.check_model(model)
30
31    # check target type and attributes
32    graph = model.graph
33    # ops should contain an object for each node
34    # in graph.node, in the right order.
35    # At least the op_name should be specified,
36    # but the op's attributes can optionally be
37    # specified as well
38    assert len(ops) == len(graph.node)
39    for i in range(0, len(ops)):
40        assert graph.node[i].op_type == ops[i]["op_name"]
41        if "attributes" in ops[i]:
42            attributes = ops[i]["attributes"]
43            assert len(attributes) == len(graph.node[i].attribute)
44            for j in range(0, len(attributes)):
45                for attribute_field in attributes[j].keys():
46                    assert attributes[j][attribute_field] == getattr(
47                        graph.node[i].attribute[j], attribute_field
48                    )
49
50
51def check_onnx_opsets_operator(
52    module,
53    x,
54    ops,
55    opset_versions,
56    training=torch.onnx.TrainingMode.EVAL,
57    input_names=None,
58    dynamic_axes=None,
59):
60    for opset_version in opset_versions:
61        f = io.BytesIO()
62        torch.onnx.export(
63            module,
64            x,
65            f,
66            opset_version=opset_version,
67            training=training,
68            input_names=input_names,
69            dynamic_axes=dynamic_axes,
70        )
71        model = onnx.load(io.BytesIO(f.getvalue()))
72        check_onnx_opset_operator(model, ops[opset_version], opset_version)
73
74
75class TestONNXOpset(pytorch_test_common.ExportTestCase):
76    def test_opset_fallback(self):
77        class MyModule(Module):
78            def forward(self, x):
79                return torch.isnan(x)
80
81        ops = [{"op_name": "IsNaN"}]
82        ops = {9: ops, 10: ops}
83        x = torch.tensor([1.0, float("nan"), 2.0])
84        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
85
86    def test_topk(self):
87        class MyModule(Module):
88            def forward(self, x):
89                return torch.topk(x, 3)
90
91        ops_9 = [
92            {
93                "op_name": "TopK",
94                "attributes": [
95                    {"name": "axis", "i": -1, "type": 2},
96                    {"name": "k", "i": 3, "type": 2},
97                ],
98            }
99        ]
100        ops_10 = [
101            {"op_name": "Constant"},
102            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
103        ]
104        ops = {9: ops_9, 10: ops_10}
105        x = torch.arange(1.0, 6.0, requires_grad=True)
106        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
107
108        # test with dynamic k
109        class MyModuleDynamic(torch.jit.ScriptModule):
110            @torch.jit.script_method
111            def forward(self, input, k):
112                return torch.topk(input, k)
113
114        ops_10 = [
115            {"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
116            {"op_name": "Reshape"},
117            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
118        ]
119        ops = {10: ops_10}
120        x = torch.arange(1.0, 6.0, requires_grad=True)
121        k = torch.tensor(3)
122        module = MyModuleDynamic()
123        check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])
124
125    def test_maxpool(self):
126        module = torch.nn.MaxPool1d(2, stride=1)
127
128        ops_9 = [
129            {
130                "op_name": "MaxPool",
131                "attributes": [
132                    {"name": "kernel_shape", "ints": [2], "type": 7},
133                    {"name": "pads", "ints": [0, 0], "type": 7},
134                    {"name": "strides", "ints": [1], "type": 7},
135                ],
136            }
137        ]
138        ops_10 = [
139            {
140                "op_name": "MaxPool",
141                "attributes": [
142                    {"name": "ceil_mode", "i": 0, "type": 2},
143                    {"name": "dilations", "ints": [1], "type": 7},
144                    {"name": "kernel_shape", "ints": [2], "type": 7},
145                    {"name": "pads", "ints": [0, 0], "type": 7},
146                    {"name": "strides", "ints": [1], "type": 7},
147                ],
148            }
149        ]
150        ops = {9: ops_9, 10: ops_10}
151        x = torch.randn(20, 16, 50)
152        check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])
153
154        # add test with dilations
155        module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
156
157        ops_10 = [
158            {
159                "op_name": "MaxPool",
160                "attributes": [
161                    {"name": "ceil_mode", "i": 0, "type": 2},
162                    {"name": "dilations", "ints": [2], "type": 7},
163                    {"name": "kernel_shape", "ints": [2], "type": 7},
164                    {"name": "pads", "ints": [0, 0], "type": 7},
165                    {"name": "strides", "ints": [1], "type": 7},
166                ],
167            }
168        ]
169        ops = {10: ops_10}
170        x = torch.randn(20, 16, 50)
171        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
172
173    def test_upsample(self):
174        class MyModule(Module):
175            def forward(self, x):
176                size = [v * 2 for v in x.size()[2:]]
177                size = [int(i) for i in size]
178                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
179
180        module = MyModule()
181        ops8 = [
182            {
183                "op_name": "Upsample",
184                "attributes": [
185                    {"name": "mode", "s": (b"nearest"), "type": 3},
186                    {"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
187                ],
188            }
189        ]
190        ops9 = [
191            {"op_name": "Constant"},
192            {
193                "op_name": "Upsample",
194                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
195            },
196        ]
197        ops = {8: ops8, 9: ops9}
198        x = torch.randn(2, 2, 2, 2)
199        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
200
201    def test_cast_constant(self):
202        class MyModule(Module):
203            def forward(self, x):
204                return x - 1
205
206        module = MyModule()
207        ops_8 = [
208            {"op_name": "Constant"},
209            {"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
210            {"op_name": "Sub"},
211        ]
212        ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
213        ops = {8: ops_8, 9: ops_9}
214        x = torch.ones(5, 6, dtype=torch.long)
215        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
216
217    def test_slice(self):
218        class MyModule(Module):
219            def forward(self, x):
220                return x[0:1]
221
222        ops_9 = [
223            {
224                "op_name": "Slice",
225                "attributes": [
226                    {"name": "axes", "ints": [0], "type": 7},
227                    {"name": "ends", "ints": [1], "type": 7},
228                    {"name": "starts", "ints": [0], "type": 7},
229                ],
230            }
231        ]
232        ops_10 = [
233            {"op_name": "Constant"},
234            {"op_name": "Constant"},
235            {"op_name": "Constant"},
236            {"op_name": "Constant"},
237            {"op_name": "Slice", "attributes": []},
238        ]
239        ops = {9: ops_9, 10: ops_10}
240        x = torch.randn(3)
241        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
242
243        class DynamicSliceModel(torch.jit.ScriptModule):
244            @torch.jit.script_method
245            def forward(self, x):
246                return x[1 : x.size(0)]
247
248        module = DynamicSliceModel()
249        x = torch.rand(1, 2)
250        ops_10 = [
251            {"op_name": "Shape"},
252            {"op_name": "Constant"},
253            {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
254            {"op_name": "Constant"},
255            {"op_name": "Constant"},
256            {
257                "op_name": "Unsqueeze",
258                "attributes": [{"name": "axes", "i": 0, "type": 7}],
259            },
260            {"op_name": "Constant"},
261            {"op_name": "Slice", "attributes": []},
262        ]
263        ops = {10: ops_10}
264        check_onnx_opsets_operator(
265            module,
266            x,
267            ops,
268            opset_versions=[10],
269            input_names=["x"],
270            dynamic_axes={"x": [0, 1]},
271        )
272
273        ops_10 = [
274            {"op_name": "Constant"},
275            {"op_name": "Constant"},
276            {"op_name": "Constant"},
277            {"op_name": "Constant"},
278            {"op_name": "Slice", "attributes": []},
279        ]
280        ops = {10: ops_10}
281        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
282
283    def test_flip(self):
284        class MyModule(Module):
285            def forward(self, x):
286                return torch.flip(x, dims=[0])
287
288        ops_10 = [
289            {"op_name": "Constant"},
290            {"op_name": "Constant"},
291            {"op_name": "Constant"},
292            {"op_name": "Constant"},
293            {"op_name": "Slice", "attributes": []},
294        ]
295        ops = {10: ops_10}
296        import numpy
297
298        x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
299        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
300
301    def test_dropout(self):
302        class MyModule(Module):
303            def __init__(self) -> None:
304                super().__init__()
305                self.dropout = torch.nn.Dropout(0.5)
306
307            def forward(self, x):
308                return self.dropout(x)
309
310        x = torch.randn(1, 2, 3)
311
312        # we should only export the onnx Dropout op in training mode; test both modes
313
314        # test training mode
315        ops = [
316            {
317                "op_name": "Dropout",
318                "attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
319            }
320        ]
321        ops = {9: ops, 10: ops}
322        check_onnx_opsets_operator(
323            MyModule(),
324            x,
325            ops,
326            opset_versions=[9, 10],
327            training=torch.onnx.TrainingMode.TRAINING,
328        )
329
330        # test eval mode
331        ops = [{"op_name": "Identity"}]
332        ops = {9: ops, 10: ops}
333        check_onnx_opsets_operator(
334            MyModule(),
335            x,
336            ops,
337            opset_versions=[9, 10],
338            training=torch.onnx.TrainingMode.EVAL,
339        )
340
341    def test_full(self):
342        class MyModule(Module):
343            def forward(self, x):
344                return torch.full((3, 4), x)
345
346        ops = [
347            {"op_name": "Constant"},
348            {"op_name": "ConstantOfShape"},
349            {"op_name": "Add"},
350        ]
351        ops = {9: ops, 10: ops}
352        x = torch.tensor(12.0)
353        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
354
355    def test_interpolate(self):
356        class MyModel(torch.nn.Module):
357            def forward(self, x):
358                size = [v * 2 for v in x.size()[2:]]
359                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
360
361        ops_9 = [
362            {"op_name": "Shape"},
363            {"op_name": "Constant"},
364            {"op_name": "Gather"},
365            {"op_name": "Shape"},
366            {"op_name": "Constant"},
367            {"op_name": "Gather"},
368            {"op_name": "Constant"},
369            {"op_name": "Mul"},
370            {"op_name": "Constant"},
371            {"op_name": "Mul"},
372            {"op_name": "Unsqueeze"},
373            {"op_name": "Unsqueeze"},
374            {"op_name": "Concat"},
375            {"op_name": "Cast"},
376            {"op_name": "Shape"},
377            {"op_name": "Slice"},
378            {"op_name": "Cast"},
379            {"op_name": "Div"},
380            {"op_name": "Constant"},
381            {"op_name": "Concat"},
382            {
383                "op_name": "Upsample",
384                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
385            },
386        ]
387        ops_10 = [
388            {"op_name": "Shape"},
389            {"op_name": "Constant"},
390            {"op_name": "Gather"},
391            {"op_name": "Shape"},
392            {"op_name": "Constant"},
393            {"op_name": "Gather"},
394            {"op_name": "Constant"},
395            {"op_name": "Mul"},
396            {"op_name": "Constant"},
397            {"op_name": "Mul"},
398            {"op_name": "Unsqueeze"},
399            {"op_name": "Unsqueeze"},
400            {"op_name": "Concat"},
401            {"op_name": "Cast"},
402            {"op_name": "Shape"},
403            {"op_name": "Constant"},
404            {"op_name": "Constant"},
405            {"op_name": "Constant"},
406            {"op_name": "Slice"},
407            {"op_name": "Cast"},
408            {"op_name": "Div"},
409            {"op_name": "Constant"},
410            {"op_name": "Concat"},
411            {
412                "op_name": "Resize",
413                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
414            },
415        ]
416
417        ops = {9: ops_9, 10: ops_10}
418        x = torch.randn(1, 2, 3, 4, requires_grad=True)
419        check_onnx_opsets_operator(
420            MyModel(),
421            x,
422            ops,
423            opset_versions=[9, 10],
424            input_names=["x"],
425            dynamic_axes={"x": [0, 1, 2, 3]},
426        )
427
428        ops_9 = [
429            {"op_name": "Constant"},
430            {"op_name": "Shape"},
431            {"op_name": "Slice"},
432            {"op_name": "Cast"},
433            {"op_name": "Div"},
434            {"op_name": "Constant"},
435            {"op_name": "Concat"},
436            {
437                "op_name": "Upsample",
438                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
439            },
440        ]
441        ops_10 = [
442            {"op_name": "Constant"},
443            {"op_name": "Shape"},
444            {"op_name": "Constant"},
445            {"op_name": "Constant"},
446            {"op_name": "Constant"},
447            {"op_name": "Slice"},
448            {"op_name": "Cast"},
449            {"op_name": "Div"},
450            {"op_name": "Constant"},
451            {"op_name": "Concat"},
452            {"op_name": "Resize"},
453        ]
454
455        ops = {9: ops_9, 10: ops_10}
456        x = torch.randn(1, 2, 3, 4, requires_grad=True)
457        check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])
458
459        class MyDynamicModel(torch.nn.Module):
460            def forward(self, x):
461                size = [v * 2 for v in x.size()[2:]]
462                # work around for now: turn the dynamic sizes into constant
463                size = [int(i) for i in size]
464                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
465
466        ops_9 = [
467            {"op_name": "Constant"},
468            {
469                "op_name": "Upsample",
470                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
471            },
472        ]
473        ops_10 = [
474            {"op_name": "Constant"},
475            {
476                "op_name": "Resize",
477                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
478            },
479        ]
480        ops = {9: ops_9, 10: ops_10}
481        x = torch.randn(20, 16, 50)
482        check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
483
484    def test_affine_grid(self):
485        class MyModule(Module):
486            def __init__(self, align_corners):
487                super().__init__()
488                self.align_corners = align_corners
489
490            def forward(self, theta, size):
491                return torch.nn.functional.affine_grid(
492                    theta, size, align_corners=self.align_corners
493                )
494
495        opset_version = 20
496        ops_2d = {
497            opset_version: [
498                {"op_name": "Constant"},
499                {"op_name": "Unsqueeze"},
500                {"op_name": "Constant"},
501                {"op_name": "Unsqueeze"},
502                {"op_name": "Constant"},
503                {"op_name": "Unsqueeze"},
504                {"op_name": "Constant"},
505                {"op_name": "Unsqueeze"},
506                {"op_name": "Concat"},
507                {"op_name": "AffineGrid"},
508            ]
509        }
510
511        ops_3d = {
512            opset_version: [
513                {"op_name": "Constant"},
514                {"op_name": "Unsqueeze"},
515                {"op_name": "Constant"},
516                {"op_name": "Unsqueeze"},
517                {"op_name": "Constant"},
518                {"op_name": "Unsqueeze"},
519                {"op_name": "Constant"},
520                {"op_name": "Unsqueeze"},
521                {"op_name": "Constant"},
522                {"op_name": "Unsqueeze"},
523                {"op_name": "Concat"},
524                {"op_name": "AffineGrid"},
525            ]
526        }
527        # 2D affine
528        theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
529        size_2d = torch.Size([1, 1, 2, 2])
530        # 3D affine
531        theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
532        size_3d = torch.Size([1, 1, 2, 2, 2])
533
534        for inputs, align_corners in itertools.product(
535            ((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
536            (True, False),
537        ):
538            theta, size, ops = inputs
539            args = (
540                theta,
541                size,
542            )
543            check_onnx_opsets_operator(
544                MyModule(align_corners=align_corners),
545                args,
546                ops,
547                opset_versions=[opset_version],
548                training=torch.onnx.TrainingMode.TRAINING,
549            )
550            check_onnx_opsets_operator(
551                MyModule(align_corners=align_corners),
552                args,
553                ops,
554                opset_versions=[opset_version],
555                training=torch.onnx.TrainingMode.EVAL,
556            )
557
558    def test_grid_sample(self):
559        class MyModule(torch.nn.Module):
560            def __init__(self, mode, padding_mode, align_corners):
561                super().__init__()
562                self.mode = mode
563                self.padding_mode = padding_mode
564                self.align_corners = align_corners
565
566            def forward(self, x, grid):
567                return torch.nn.functional.grid_sample(
568                    x,
569                    grid,
570                    mode=self.mode,
571                    padding_mode=self.padding_mode,
572                    align_corners=self.align_corners,
573                )
574
575        for mode, padding_mode, align_corners, opset_version in itertools.product(
576            ("bilinear", "nearest", "bicubic"),
577            ("zeros", "border", "reflection"),
578            (True, False),
579            (16, 20),
580        ):
581
582            def test_eval_and_training(
583                ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
584            ):
585                args = (
586                    torch.randn(*x_shape),  # x
587                    torch.randn(grid),  # grid,
588                )
589                check_onnx_opsets_operator(
590                    MyModule(
591                        mode=mode,
592                        padding_mode=padding_mode,
593                        align_corners=align_corners,
594                    ),
595                    args,
596                    ops,
597                    opset_versions=[opset_version],
598                    training=torch.onnx.TrainingMode.TRAINING,
599                )
600                check_onnx_opsets_operator(
601                    MyModule(
602                        mode=mode,
603                        padding_mode=padding_mode,
604                        align_corners=align_corners,
605                    ),
606                    args,
607                    ops,
608                    opset_versions=[opset_version],
609                    training=torch.onnx.TrainingMode.EVAL,
610                )
611
612            ops = {opset_version: [{"op_name": "GridSample"}]}
613            # mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode
614            n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
615            test_eval_and_training(
616                ops,
617                opset_version,
618                mode,
619                padding_mode,
620                align_corners,
621                (n, c, h_in, w_in),
622                (n, h_out, w_out, 2),
623            )
624            if opset_version == 20 and mode != "bicubic":
625                test_eval_and_training(
626                    ops,
627                    opset_version,
628                    mode,
629                    padding_mode,
630                    align_corners,
631                    (n, c, d_in, h_in, w_in),
632                    (n, d_out, h_out, w_out, 3),
633                )
634
635    def test_flatten(self):
636        class MyModule(Module):
637            def forward(self, x):
638                return torch.flatten(x)
639
640        module = MyModule()
641
642        ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
643        ops_1d = [{"op_name": "Identity"}]
644        for shape in ([], [3]):
645            x = torch.randn(shape)
646            for opset_version in [9, 10]:
647                ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
648                check_onnx_opsets_operator(
649                    module, x, ops, opset_versions=[opset_version]
650                )
651
652
653if __name__ == "__main__":
654    common_utils.run_tests()
655