xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset20.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This file exports ONNX ops for opset 20.
3
4Note [ONNX Operators that are added/updated in opset 20]
5
6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
8New operators:
9    AffineGrid
10    ConstantOfShape
11    DFT
12    Gelu
13    GridSample
14    ImageDecoder
15    IsInf
16    IsNaN
17    ReduceMax
18    ReduceMin
19    RegexFullMatch
20    StringConcat
21    StringSplit
22"""
23
24import functools
25
26import torch.nn.functional as F
27from torch import _C
28from torch.onnx import symbolic_helper
29from torch.onnx._internal import jit_utils, registration
30
31
32# EDITING THIS FILE? READ THIS FIRST!
33# see Note [Edit Symbolic Files] in symbolic_helper.py
34
35__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
36
37
38def convert_grid_sample_mode(mode_s):
39    return (
40        "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
41    )
42
43
44_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
45
46
47@_onnx_symbolic("aten::grid_sampler")
48@symbolic_helper.parse_args("v", "v", "i", "i", "b")
49def _grid_sampler(
50    g: jit_utils.GraphContext,
51    input: _C.Value,
52    grid: _C.Value,
53    mode_enum: int,
54    padding_mode_enum: int,
55    align_corners: bool,
56):
57    mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum]  # type: ignore[call-arg, index]
58    # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
59    mode_s = convert_grid_sample_mode(mode_s)
60    padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[  # type: ignore[call-arg, index]
61        padding_mode_enum  # type: ignore[index]
62    ]
63    return g.op(
64        "GridSample",
65        input,
66        grid,
67        align_corners_i=int(align_corners),
68        mode_s=mode_s,
69        padding_mode_s=padding_mode_s,
70    )
71
72
73@_onnx_symbolic("aten::affine_grid_generator")
74@symbolic_helper.parse_args("v", "v", "b")
75def _affine_grid_generator(
76    g: jit_utils.GraphContext,
77    theta: _C.Value,
78    size: _C.Value,
79    align_corners: bool,
80):
81    return g.op(
82        "AffineGrid",
83        theta,
84        size,
85        align_corners_i=int(align_corners),
86    )
87
88
89@_onnx_symbolic("aten::gelu")
90@symbolic_helper.parse_args("v", "s")
91def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
92    return g.op("Gelu", self, approximate_s=approximate)
93