1# mypy: allow-untyped-defs 2"""This file exports ONNX ops for opset 16. 3 4Note [ONNX Operators that are added/updated in opset 16] 5 6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set 8New operators: 9 GridSample https://github.com/onnx/onnx/pull/3557 10 11Updated operators: 12 Identity 13 If 14 LeakyRelu 15 Loop 16 PRelu 17 RoiAlign 18 Scan 19 ScatterElements 20 ScatterND 21 Where 22 GreaterOrEqual 23 LessOrEqual 24""" 25 26# EDITING THIS FILE? READ THIS FIRST! 27# see Note [Edit Symbolic Files] in README.md 28 29import functools 30 31import torch 32from torch.nn.functional import ( 33 GRID_SAMPLE_INTERPOLATION_MODES, 34 GRID_SAMPLE_PADDING_MODES, 35) 36from torch.onnx import _type_utils, errors, symbolic_helper, utils 37from torch.onnx._internal import jit_utils, registration 38 39 40_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) 41 42 43# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? 44# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. 45@_onnx_symbolic("aten::grid_sampler") 46@symbolic_helper.parse_args("v", "v", "i", "i", "b") 47def grid_sampler( 48 g: jit_utils.GraphContext, 49 input, 50 grid, 51 mode_enum, 52 padding_mode_enum, 53 align_corners, 54): 55 # Check the input and grid tensor rank beforehand. 56 if symbolic_helper._get_tensor_rank(input) == 5: 57 return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") 58 mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] 59 padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] 60 padding_mode_enum 61 ] 62 return g.op( 63 "GridSample", 64 input, 65 grid, 66 align_corners_i=int(align_corners), 67 mode_s=mode_s, 68 padding_mode_s=padding_mode_s, 69 ) 70 71 72@_onnx_symbolic("aten::scatter_add") 73@symbolic_helper.parse_args("v", "i", "v", "v") 74def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): 75 src_type = _type_utils.JitScalarType.from_value( 76 src, _type_utils.JitScalarType.UNDEFINED 77 ) 78 src_sizes = symbolic_helper._get_tensor_sizes(src) 79 index_sizes = symbolic_helper._get_tensor_sizes(index) 80 81 if len(src_sizes) != len(index_sizes): 82 return symbolic_helper._unimplemented( 83 "scatter_add", 84 f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", 85 ) 86 87 # PyTorch only allows index shape <= src shape, so we can only consider 88 # taking index as subset size to src, like PyTorch does. When sizes for src 89 # and index are not matched or there are dynamic axes, we take index shape to 90 # slice src to accommodate. 91 if src_sizes != index_sizes or None in index_sizes: 92 adjusted_shape = g.op("Shape", index) 93 starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) 94 src = g.op("Slice", src, starts, adjusted_shape) 95 96 src = symbolic_helper._maybe_get_scalar(src) 97 if symbolic_helper._is_value(src): 98 return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") 99 else: 100 # Check if scalar "src" has same type as self (PyTorch allows different 101 # type for scalar src (but not when src is tensor)). If not, insert Cast node. 102 if _type_utils.JitScalarType.from_value(self) != src_type: 103 src = g.op( 104 "Cast", 105 src, 106 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 107 ) 108 109 return g.op( 110 "ScatterElements", 111 self, 112 index, 113 src, 114 axis_i=dim, 115 reduction_s="add", 116 ) 117 118 119@_onnx_symbolic("aten::scatter_reduce") 120@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") 121def scatter_reduce( 122 g: jit_utils.GraphContext, 123 self: torch._C.Value, 124 dim: int, 125 index: torch._C.Value, 126 src: torch._C.Value, 127 reduce: str, 128 include_self: bool, 129): 130 if reduce == "mean": 131 raise errors.OnnxExporterError( 132 "ONNX does not support mean reduction for scatter_reduce" 133 ) 134 if not include_self: 135 raise errors.OnnxExporterError( 136 "ONNX does not support include_self=False for scatter_reduce" 137 ) 138 139 reduce_mode = { # convert torch string name to onnx string name 140 "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition 141 "sum": "add", 142 "prod": "mul", 143 "amin": "min", 144 "amax": "max", 145 } 146 onnx_reduce = reduce_mode[reduce] 147 148 self_rank = g.op("Size", g.op("Shape", self)) 149 150 # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) 151 self_rank_is_zero = g.op( 152 "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) 153 ) 154 if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( 155 g, "If", self_rank_is_zero, n_blocks=2, outputs=3 156 ) 157 neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) 158 159 self_reshape = if_context.op("Reshape", self, neg_1) 160 utils._add_output_to_block(if_context.block, self_reshape) 161 index_reshape = if_context.op("Reshape", index, neg_1) 162 utils._add_output_to_block(if_context.block, index_reshape) 163 src_reshape = if_context.op("Reshape", src, neg_1) 164 utils._add_output_to_block(if_context.block, src_reshape) 165 166 self_identity = else_context.op("Identity", self) 167 utils._add_output_to_block(else_context.block, self_identity) 168 index_identitye = else_context.op("Identity", index) 169 utils._add_output_to_block(else_context.block, index_identitye) 170 src_identity = else_context.op("Identity", src) 171 utils._add_output_to_block(else_context.block, src_identity) 172 173 result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) 174 175 # if self_rank == 0: 176 if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( 177 g, "If", self_rank_is_zero, n_blocks=2, outputs=1 178 ) 179 result_squeezed = if_context.op("Squeeze", result) 180 utils._add_output_to_block(if_context.block, result_squeezed) 181 result_identity = else_context.op("Identity", result) 182 utils._add_output_to_block(else_context.block, result_identity) 183 result_final = if_op.node().output() 184 185 return result_final 186