xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset16.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This file exports ONNX ops for opset 16.
3
4Note [ONNX Operators that are added/updated in opset 16]
5
6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
8New operators:
9    GridSample https://github.com/onnx/onnx/pull/3557
10
11Updated operators:
12    Identity
13    If
14    LeakyRelu
15    Loop
16    PRelu
17    RoiAlign
18    Scan
19    ScatterElements
20    ScatterND
21    Where
22    GreaterOrEqual
23    LessOrEqual
24"""
25
26# EDITING THIS FILE? READ THIS FIRST!
27# see Note [Edit Symbolic Files] in README.md
28
29import functools
30
31import torch
32from torch.nn.functional import (
33    GRID_SAMPLE_INTERPOLATION_MODES,
34    GRID_SAMPLE_PADDING_MODES,
35)
36from torch.onnx import _type_utils, errors, symbolic_helper, utils
37from torch.onnx._internal import jit_utils, registration
38
39
40_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
41
42
43# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
44# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
45@_onnx_symbolic("aten::grid_sampler")
46@symbolic_helper.parse_args("v", "v", "i", "i", "b")
47def grid_sampler(
48    g: jit_utils.GraphContext,
49    input,
50    grid,
51    mode_enum,
52    padding_mode_enum,
53    align_corners,
54):
55    # Check the input and grid tensor rank beforehand.
56    if symbolic_helper._get_tensor_rank(input) == 5:
57        return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
58    mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum]  # type: ignore[call-arg]
59    padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[  # type: ignore[call-arg]
60        padding_mode_enum
61    ]
62    return g.op(
63        "GridSample",
64        input,
65        grid,
66        align_corners_i=int(align_corners),
67        mode_s=mode_s,
68        padding_mode_s=padding_mode_s,
69    )
70
71
72@_onnx_symbolic("aten::scatter_add")
73@symbolic_helper.parse_args("v", "i", "v", "v")
74def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
75    src_type = _type_utils.JitScalarType.from_value(
76        src, _type_utils.JitScalarType.UNDEFINED
77    )
78    src_sizes = symbolic_helper._get_tensor_sizes(src)
79    index_sizes = symbolic_helper._get_tensor_sizes(index)
80
81    if len(src_sizes) != len(index_sizes):
82        return symbolic_helper._unimplemented(
83            "scatter_add",
84            f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
85        )
86
87    # PyTorch only allows index shape <= src shape, so we can only consider
88    # taking index as subset size to src, like PyTorch does. When sizes for src
89    # and index are not matched or there are dynamic axes, we take index shape to
90    # slice src to accommodate.
91    if src_sizes != index_sizes or None in index_sizes:
92        adjusted_shape = g.op("Shape", index)
93        starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
94        src = g.op("Slice", src, starts, adjusted_shape)
95
96    src = symbolic_helper._maybe_get_scalar(src)
97    if symbolic_helper._is_value(src):
98        return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
99    else:
100        # Check if scalar "src" has same type as self (PyTorch allows different
101        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
102        if _type_utils.JitScalarType.from_value(self) != src_type:
103            src = g.op(
104                "Cast",
105                src,
106                to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
107            )
108
109        return g.op(
110            "ScatterElements",
111            self,
112            index,
113            src,
114            axis_i=dim,
115            reduction_s="add",
116        )
117
118
119@_onnx_symbolic("aten::scatter_reduce")
120@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
121def scatter_reduce(
122    g: jit_utils.GraphContext,
123    self: torch._C.Value,
124    dim: int,
125    index: torch._C.Value,
126    src: torch._C.Value,
127    reduce: str,
128    include_self: bool,
129):
130    if reduce == "mean":
131        raise errors.OnnxExporterError(
132            "ONNX does not support mean reduction for scatter_reduce"
133        )
134    if not include_self:
135        raise errors.OnnxExporterError(
136            "ONNX does not support include_self=False for scatter_reduce"
137        )
138
139    reduce_mode = {  # convert torch string name to onnx string name
140        "mean": "none",  # 'mean' doesn't support in ONNX 1.14 definition
141        "sum": "add",
142        "prod": "mul",
143        "amin": "min",
144        "amax": "max",
145    }
146    onnx_reduce = reduce_mode[reduce]
147
148    self_rank = g.op("Size", g.op("Shape", self))
149
150    # if self_rank == 0:  # assert (index_rank == 0 and rank_src == 0)
151    self_rank_is_zero = g.op(
152        "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
153    )
154    if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
155        g, "If", self_rank_is_zero, n_blocks=2, outputs=3
156    )
157    neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
158
159    self_reshape = if_context.op("Reshape", self, neg_1)
160    utils._add_output_to_block(if_context.block, self_reshape)
161    index_reshape = if_context.op("Reshape", index, neg_1)
162    utils._add_output_to_block(if_context.block, index_reshape)
163    src_reshape = if_context.op("Reshape", src, neg_1)
164    utils._add_output_to_block(if_context.block, src_reshape)
165
166    self_identity = else_context.op("Identity", self)
167    utils._add_output_to_block(else_context.block, self_identity)
168    index_identitye = else_context.op("Identity", index)
169    utils._add_output_to_block(else_context.block, index_identitye)
170    src_identity = else_context.op("Identity", src)
171    utils._add_output_to_block(else_context.block, src_identity)
172
173    result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
174
175    # if self_rank == 0:
176    if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
177        g, "If", self_rank_is_zero, n_blocks=2, outputs=1
178    )
179    result_squeezed = if_context.op("Squeeze", result)
180    utils._add_output_to_block(if_context.block, result_squeezed)
181    result_identity = else_context.op("Identity", result)
182    utils._add_output_to_block(else_context.block, result_identity)
183    result_final = if_op.node().output()
184
185    return result_final
186