xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset8.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Note [ONNX operators that are added/updated from opset 8 to opset 9]
4~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5New operators:
6    Compress
7    ConstantOfShape
8    EyeLike
9    MaxUnpool
10    OneHot
11    Sinh
12    Cosh
13    Asinh
14    Acosh
15    Atanh
16    Shrink
17    IsNaN
18    Sign
19    Erf
20    Scatter
21    Where
22    NonZero
23    TfIdfVectorizer
24    MeanVarianceNormalization
25
26Updated operators:
27    BatchNormalization: removed spatial attribute.
28    Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
29    Cast: more data types{string} supported.
30    Upsample: moved scales from attribute to input.
31    Scan
32"""
33
34import functools
35import warnings
36
37import torch
38from torch._C import _onnx as _C_onnx
39from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
40from torch.onnx._internal import jit_utils, registration
41
42
43_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
44
45block_listed_operators = (
46    "nonzero",
47    "where",
48    "scatter",
49    "scatter_add",
50    "erf",
51    "sign",
52    "isnan",
53    "gather",
54    "arange",
55    "masked_fill",
56    "index_fill",
57    "index_copy",
58    "repeat_interleave",
59    "any",
60    "all",
61)
62
63for block_listed_op in block_listed_operators:
64    _onnx_symbolic(f"aten::{block_listed_op}")(
65        symbolic_helper._block_list_in_opset(block_listed_op)
66    )
67
68
69@_onnx_symbolic(
70    "aten::upsample_nearest1d",
71    decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
72)
73@_onnx_symbolic(
74    "aten::upsample_nearest2d",
75    decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
76)
77@_onnx_symbolic(
78    "aten::upsample_nearest3d",
79    decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
80)
81@_onnx_symbolic(
82    "aten::upsample_linear1d",
83    decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
84)
85@_onnx_symbolic(
86    "aten::upsample_bilinear2d",
87    decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
88)
89@_onnx_symbolic(
90    "aten::upsample_trilinear3d",
91    decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
92)
93def _interpolate(name, dim, interpolate_mode):
94    def symbolic_fn(g, input, output_size, *args):
95        scales, align_corners = symbolic_helper._get_interpolate_attributes(
96            g, interpolate_mode, args
97        )
98        symbolic_helper._interpolate_warning(interpolate_mode)
99        align_corners = symbolic_helper._maybe_get_scalar(align_corners)
100        if align_corners:
101            return symbolic_helper._unimplemented(name, "align_corners == True", input)
102        output_size = symbolic_helper._maybe_get_const(output_size, "is")
103        if symbolic_helper._is_value(output_size):
104            return symbolic_helper._unimplemented(
105                name, "torch._C.Value (output_size) indexing"
106            )
107        if scales is None:
108            scales = [
109                1.0
110                if i < 2
111                else float(output_size[-(dim - i)])
112                / float(input.type().sizes()[-(dim - i)])
113                for i in range(0, dim)
114            ]
115        return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
116
117    return symbolic_fn
118
119
120@_onnx_symbolic("aten::__interpolate")
121def __interpolate(
122    g: jit_utils.GraphContext,
123    input,
124    size,
125    scale_factor,
126    mode,
127    align_corners,
128    recompute_scale_factor,
129    antialias,
130):
131    align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
132    if not symbolic_helper._is_none(align_corners) and align_corners:
133        return symbolic_helper._unimplemented("interpolate", "align_corners == True")
134
135    if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
136        scale_factor
137    ):
138        return symbolic_helper._unimplemented(
139            "interpolate", "dynamic scales in opset 8"
140        )
141
142    if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
143        return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
144
145    scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
146        g, input, size, scale_factor, mode, align_corners
147    )
148    return g.op("Upsample", input, mode_s=mode, scales_f=scales)
149
150
151# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
152#       issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
153#       is lost after casting.
154def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
155    floating_scalar_types = {
156        _type_utils.JitScalarType.HALF,
157        _type_utils.JitScalarType.FLOAT,
158        _type_utils.JitScalarType.DOUBLE,
159    }
160    old_type = None
161    # Cast the input tensor to Float if its scalarType is known and is not floating number.
162    # If casting is performed, return the old scalarType, otherwise return None.
163    arg0_type = _type_utils.JitScalarType.from_value(
164        args[0], _type_utils.JitScalarType.UNDEFINED
165    )
166    if arg0_type != _type_utils.JitScalarType.UNDEFINED:
167        old_type = arg0_type
168        if old_type not in floating_scalar_types:
169            old_type = old_type.scalar_name()  # type: ignore[assignment]
170            args = tuple(
171                g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
172                for arg in args
173            )
174        else:
175            return (None,) + args
176    else:
177        warnings.warn(
178            "Only floating datatype is supported for these operators: "
179            "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
180            "the onnx model to be incorrect, if inputs have integer datatypes."
181        )
182    return (old_type,) + args
183
184
185def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
186    if to_type is None:
187        return input
188    return getattr(opset9, f"_cast_{to_type}")(g, input, False)
189
190
191def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
192    other = symbolic_helper._maybe_get_scalar(other)
193    other = symbolic_helper._if_scalar_type_as(other, input)
194    _, input, other = _try_cast_integer_to_float(g, input, other)
195    return g.op(op_name, input, other)
196
197
198# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
199#       integer input type not supported in opset8. Cast to float if possible.
200@_onnx_symbolic("aten::gt")
201def gt(g: jit_utils.GraphContext, input, other):
202    return _comparison_operator(g, input, other, "Greater")
203
204
205@_onnx_symbolic("aten::lt")
206def lt(g: jit_utils.GraphContext, input, other):
207    return _comparison_operator(g, input, other, "Less")
208
209
210@_onnx_symbolic("aten::bmm")
211def bmm(g: jit_utils.GraphContext, self, other):
212    if symbolic_helper._try_get_scalar_type(self):
213        old_type, self, other = _try_cast_integer_to_float(g, self, other)
214        return _cast_to_type(g, g.op("MatMul", self, other), old_type)
215    else:
216        return g.op("MatMul", self, other)
217
218
219@_onnx_symbolic("aten::matmul")
220def matmul(g: jit_utils.GraphContext, self, other):
221    return bmm(g, self, other)
222
223
224@_onnx_symbolic("aten::prelu")
225def prelu(g: jit_utils.GraphContext, self, weight):
226    self_rank = symbolic_helper._get_tensor_rank(self)
227    weight_sizes = symbolic_helper._get_tensor_sizes(weight)
228    if self_rank is not None and self_rank > 2:
229        weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
230    elif self_rank == 0 and weight_sizes == [1]:
231        # self and weight are both scalar but weight has rank == 1, squeeze weight.
232        weight = symbolic_helper._squeeze_helper(g, weight, [0])
233    if symbolic_helper._try_get_scalar_type(self):
234        old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
235        return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
236    else:
237        return g.op("PRelu", self, weight)
238
239
240@_onnx_symbolic("aten::mm")
241def mm(g: jit_utils.GraphContext, self, other):
242    # Create a dummy C tensor. Only needed for API purposes, the value is
243    # since beta = 0
244    scalar_type = symbolic_helper._try_get_scalar_type(self, other)
245    if scalar_type is None:
246        raise errors.SymbolicValueError(
247            "mm can only operate on tensors with known types", self
248        )
249    zero_constant = g.op(
250        "Constant",
251        value_t=torch.tensor([0], dtype=scalar_type.dtype()),
252    )
253
254    if symbolic_helper._try_get_scalar_type(self):
255        old_type, self, other, zero_constant = _try_cast_integer_to_float(
256            g, self, other, zero_constant
257        )
258        return _cast_to_type(
259            g,
260            g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
261            old_type,
262        )
263    return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
264
265
266@_onnx_symbolic("aten::addmm")
267@symbolic_helper.parse_args("v", "v", "v", "t", "t")
268def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
269    if symbolic_helper._try_get_scalar_type(self):
270        old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
271        return _cast_to_type(
272            g,
273            g.op(
274                "Gemm",
275                mat1,
276                mat2,
277                self,
278                beta_f=symbolic_helper._scalar(beta),
279                alpha_f=symbolic_helper._scalar(alpha),
280            ),
281            old_type,
282        )
283    else:
284        return g.op(
285            "Gemm",
286            mat1,
287            mat2,
288            self,
289            beta_f=symbolic_helper._scalar(beta),
290            alpha_f=symbolic_helper._scalar(alpha),
291        )
292
293
294@_onnx_symbolic("aten::flatten")
295def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
296    start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
297    end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
298
299    dim = input.type().dim()
300    if end_dim_i < 0:
301        end_dim_i = dim + end_dim_i
302    # use ONNX's Flatten operator for cases where the output shape is 2D
303    if start_dim_i == 1 and end_dim_i == dim - 1:
304        if symbolic_helper._try_get_scalar_type(input):
305            old_type, input = _try_cast_integer_to_float(g, input)
306            return _cast_to_type(
307                g, g.op("Flatten", input, axis_i=start_dim_i), old_type
308            )
309        else:
310            return g.op("Flatten", input, axis_i=start_dim_i)
311    if start_dim_i == 0 and end_dim_i == dim - 2:
312        if symbolic_helper._try_get_scalar_type(input):
313            old_type, input = _try_cast_integer_to_float(g, input)
314            return _cast_to_type(
315                g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
316            )
317        else:
318            return g.op("Flatten", input, axis_i=end_dim_i + 1)
319
320    return opset9.flatten(g, input, start_dim, end_dim)
321
322
323def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
324    if dtype is None:
325        scalar_type = _type_utils.JitScalarType.FLOAT
326    else:
327        scalar_type = _type_utils.JitScalarType(dtype)
328    if not scalar_type.dtype().is_floating_point:
329        result = g.op(
330            "ConstantFill",
331            sizes,
332            dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
333            input_as_shape_i=1,
334            value_f=const_value,
335        )
336        return g.op("Cast", result, to_i=scalar_type.onnx_type())
337    else:
338        return g.op(
339            "ConstantFill",
340            sizes,
341            dtype_i=scalar_type.onnx_type(),
342            input_as_shape_i=1,
343            value_f=const_value,
344        )
345
346
347@_onnx_symbolic("aten::empty")
348@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
349def empty(
350    g: jit_utils.GraphContext,
351    sizes,
352    dtype,
353    layout,
354    device,
355    pin_memory=False,
356    memory_format=None,
357):
358    return zeros(g, sizes, dtype, layout, device, pin_memory)
359
360
361@_onnx_symbolic("aten::empty_like")
362@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
363def empty_like(
364    g: jit_utils.GraphContext,
365    input,
366    dtype,
367    layout,
368    device,
369    pin_memory=False,
370    memory_format=None,
371):
372    return zeros_like(g, input, dtype, layout, device, pin_memory)
373
374
375@_onnx_symbolic("aten::zeros")
376@symbolic_helper.parse_args("v", "i", "v", "v", "v")
377def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
378    # NOTE: no way to set device and layout in ONNX, so we ignore it
379    return _constant_fill(g, sizes, dtype, 0)
380
381
382@_onnx_symbolic("aten::zeros_like")
383@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
384def zeros_like(
385    g: jit_utils.GraphContext,
386    input,
387    dtype,
388    layout,
389    device,
390    pin_memory=False,
391    memory_format=None,
392):
393    shape = g.op("Shape", input)
394    return _constant_fill(g, shape, dtype, 0)
395
396
397@_onnx_symbolic("aten::ones")
398@symbolic_helper.parse_args("v", "i", "v", "v", "v")
399def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
400    return _constant_fill(g, sizes, dtype, 1)
401
402
403@_onnx_symbolic("aten::ones_like")
404@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
405def ones_like(
406    g: jit_utils.GraphContext,
407    input,
408    dtype,
409    layout,
410    device,
411    pin_memory=False,
412    memory_format=None,
413):
414    shape = g.op("Shape", input)
415    return _constant_fill(g, shape, dtype, 1)
416
417
418@_onnx_symbolic("aten::full")
419def full(
420    g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
421):
422    const_value = symbolic_helper._maybe_get_const(value, "t")
423    if symbolic_helper._is_value(const_value):
424        tmp = zeros(g, sizes, dtype, layout, device)
425        return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
426    else:
427        dtype = symbolic_helper._get_const(dtype, "i", "dtype")
428        return _constant_fill(g, sizes, dtype, const_value)
429
430
431@_onnx_symbolic("aten::full_like")
432@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
433def full_like(
434    g: jit_utils.GraphContext,
435    input,
436    fill_value,
437    dtype,
438    layout,
439    device,
440    pin_memory=False,
441    memory_format=None,
442):
443    shape = g.op("Shape", input)
444    return _constant_fill(g, shape, dtype, fill_value)
445
446
447@_onnx_symbolic("aten::repeat")
448def repeat(g: jit_utils.GraphContext, self, repeats):
449    if not symbolic_helper._is_value(repeats):
450        repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
451    if symbolic_helper._is_packed_list(repeats):
452        repeat_size_len = len(symbolic_helper._unpack_list(repeats))
453    else:
454        const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
455        repeat_size_len = len(const_repeats)
456    if self.isCompleteTensor():
457        sizes = self.type().sizes()
458        diff_dims = repeat_size_len - len(sizes)
459        if diff_dims > 0:
460            self = opset9.view(
461                g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
462            )
463    return g.op("Tile", self, repeats)
464