xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset9.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# mypy: disable-error-code=arg-type
4"""This file exports ONNX ops for opset 9.
5
6Opset 9 is supported by ONNX release 1.4.1
7release on 01/23/19
8"""
9
10from __future__ import annotations
11
12import builtins
13import functools
14import math
15import sys
16import warnings
17from typing import Callable, Sequence, TYPE_CHECKING
18
19import torch
20import torch._C._onnx as _C_onnx
21import torch.nn.modules.utils
22import torch.onnx
23from torch import _C
24
25# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
26from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
27from torch.onnx._globals import GLOBALS
28from torch.onnx._internal import jit_utils, registration
29
30
31if TYPE_CHECKING:
32    from torch.types import Number
33
34# EDITING THIS FILE? READ THIS FIRST!
35# see Note [Edit Symbolic Files] in README.md
36
37__all__ = [
38    "abs",
39    "acos",
40    "add",
41    "addcmul",
42    "addmm",
43    "alias",
44    "amax",
45    "amin",
46    "aminmax",
47    "arange",
48    "argmax",
49    "argmin",
50    "as_strided",
51    "as_tensor",
52    "asin",
53    "atan",
54    "atan2",
55    "baddbmm",
56    "batch_norm",
57    "bernoulli",
58    "bitwise_not",
59    "bitwise_or",
60    "bmm",
61    "broadcast_tensors",
62    "broadcast_to",
63    "bucketize",
64    "cat",
65    "cdist",
66    "ceil",
67    "clamp_max",
68    "clamp_min",
69    "clamp",
70    "clone",
71    "constant_pad_nd",
72    "contiguous",
73    "conv_tbc",
74    "conv_transpose1d",
75    "conv_transpose2d",
76    "conv_transpose3d",
77    "conv1d",
78    "conv2d",
79    "conv3d",
80    "convert_element_type",
81    "convolution",
82    "cos",
83    "cosine_similarity",
84    "cross",
85    "cumsum",
86    "detach",
87    "dim",
88    "div",
89    "dot",
90    "dropout",
91    "elu",
92    "embedding_bag",
93    "embedding",
94    "empty_like",
95    "empty",
96    "eq",
97    "erf",
98    "exp",
99    "expand_as",
100    "expand",
101    "eye",
102    "fill",
103    "flatten",
104    "floor_divide",
105    "floor",
106    "floordiv",
107    "frobenius_norm",
108    "full_like",
109    "full",
110    "gather",
111    "ge",
112    "gelu",
113    "get_pool_ceil_padding",
114    "glu",
115    "group_norm",
116    "gt",
117    "hann_window",
118    "hardshrink",
119    "hardsigmoid",
120    "hardswish",
121    "hardtanh",
122    "index_add",
123    "index_copy",
124    "index_fill",
125    "index_put",
126    "index_select",
127    "index",
128    "instance_norm",
129    "is_floating_point",
130    "is_pinned",
131    "isnan",
132    "item",
133    "kl_div",
134    "layer_norm",
135    "le",
136    "leaky_relu",
137    "lerp",
138    "lift",
139    "linalg_cross",
140    "linalg_matrix_norm",
141    "linalg_norm",
142    "linalg_vector_norm",
143    "linear",
144    "linspace",
145    "log_sigmoid",
146    "log_softmax",
147    "log",
148    "log10",
149    "log1p",
150    "log2",
151    "logical_and",
152    "logical_not",
153    "logical_or",
154    "logical_xor",
155    "logit",
156    "logsumexp",
157    "lstm_cell",
158    "lstm",
159    "lt",
160    "masked_fill",
161    "masked_fill_",
162    "matmul",
163    "max_pool1d_with_indices",
164    "max_pool2d_with_indices",
165    "max_pool3d_with_indices",
166    "max",
167    "maximum",
168    "meshgrid",
169    "min",
170    "minimum",
171    "mish",
172    "mm",
173    "movedim",
174    "mse_loss",
175    "mul",
176    "multinomial",
177    "mv",
178    "narrow",
179    "native_layer_norm",
180    "ne",
181    "neg",
182    "new_empty",
183    "new_full",
184    "new_ones",
185    "new_zeros",
186    "nonzero_numpy",
187    "nonzero",
188    "norm",
189    "numel",
190    "numpy_T",
191    "one_hot",
192    "ones_like",
193    "ones",
194    "onnx_placeholder",
195    "pad",
196    "pairwise_distance",
197    "permute",
198    "pixel_shuffle",
199    "pixel_unshuffle",
200    "pow",
201    "prelu",
202    "prim_constant_chunk",
203    "prim_constant_split",
204    "prim_constant",
205    "prim_data",
206    "prim_device",
207    "prim_dtype",
208    "prim_if",
209    "prim_layout",
210    "prim_list_construct",
211    "prim_list_unpack",
212    "prim_loop",
213    "prim_max",
214    "prim_min",
215    "prim_shape",
216    "prim_tolist",
217    "prim_tuple_construct",
218    "prim_type",
219    "prim_unchecked_cast",
220    "prim_uninitialized",
221    "rand_like",
222    "rand",
223    "randint_like",
224    "randint",
225    "randn_like",
226    "randn",
227    "reciprocal",
228    "reflection_pad",
229    "relu",
230    "relu6",
231    "remainder",
232    "repeat_interleave",
233    "repeat",
234    "replication_pad",
235    "reshape_as",
236    "reshape",
237    "roll",
238    "rrelu",
239    "rsqrt",
240    "rsub",
241    "scalar_tensor",
242    "scatter_add",
243    "scatter",
244    "select",
245    "selu",
246    "sigmoid",
247    "sign",
248    "silu",
249    "sin",
250    "size",
251    "slice",
252    "softmax",
253    "softplus",
254    "softshrink",
255    "sort",
256    "split_with_sizes",
257    "split",
258    "sqrt",
259    "square",
260    "squeeze",
261    "stack",
262    "std_mean",
263    "std",
264    "sub",
265    "t",
266    "take",
267    "tan",
268    "tanh",
269    "tanhshrink",
270    "tensor",
271    "threshold",
272    "to",
273    "topk",
274    "transpose",
275    "true_divide",
276    "type_as",
277    "unbind",
278    "unfold",
279    "unsafe_chunk",
280    "unsafe_split_with_sizes",
281    "unsafe_split",
282    "unsqueeze",
283    "unsupported_complex_operators",
284    "noop_complex_operators",
285    "unused",
286    "var_mean",
287    "var",
288    "view_as",
289    "view",
290    "where",
291    "wrap_logical_op_with_cast_to",
292    "wrap_logical_op_with_negation",
293    "zeros_like",
294    "zeros",
295    "zero",
296]
297
298
299_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
300
301
302def _export(name: str):
303    """Exports the function in the current global namespace."""
304
305    def wrapper(func):
306        globals()[name] = func
307        __all__.append(name)
308        return func
309
310    return wrapper
311
312
313def unused(g):
314    """Represents "missing" optional inputs."""
315    n = g.op("prim::Constant")
316    n.setType(_C.OptionalType.ofTensor())
317    return n
318
319
320@_onnx_symbolic("aten::_shape_as_tensor")
321def _shape_as_tensor(g: jit_utils.GraphContext, input):
322    return g.op("Shape", input)
323
324
325@_onnx_symbolic("aten::_reshape_from_tensor")
326def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape):
327    if isinstance(shape, list):
328        shape = g.op("Concat", *shape, axis_i=0)
329    return reshape(g, input, shape)
330
331
332@_onnx_symbolic("aten::reshape")
333@symbolic_helper.quantized_args(True)
334def reshape(g: jit_utils.GraphContext, self, shape):
335    return symbolic_helper._reshape_helper(g, self, shape)
336
337
338@_onnx_symbolic("aten::reshape_as")
339@symbolic_helper.quantized_args(True)
340def reshape_as(g: jit_utils.GraphContext, self, other):
341    shape = g.op("Shape", other)
342    return reshape(g, self, shape)
343
344
345@_onnx_symbolic("aten::add")
346def add(g: jit_utils.GraphContext, self, other, alpha=None):
347    """
348    This function takes the add function and returns the corresponding ONNX operator.
349
350    This function is not meant to be called directly by the user.
351
352    Args:
353        g (GraphContext): The graph context.
354        self (Tensor): The first operand.
355        other (Tensor): The second operand.
356        alpha (float, optional): The scaling factor for the second operand. Defaults to None.
357
358    Returns:
359        ONNX operator.
360    """
361    if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
362        return symbolic_helper._onnx_opset_unsupported_detailed(
363            "Add", 9, 11, "Add between list of tensors not supported", self
364        )
365    if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
366        other = g.op("Mul", other, alpha)
367    return g.op("Add", self, other)
368
369
370@_onnx_symbolic("aten::sub")
371def sub(g: jit_utils.GraphContext, self, other, alpha=None):
372    """
373    Consumes sub function and returns the corresponding ONNX operator.
374
375    This function is not meant to be called directly by the user.
376
377    Args:
378        g (GraphContext): The graph context.
379        self (Tensor): The first operand.
380        other (Tensor): The second operand.
381        alpha (Optional[Tensor]): A scaling factor to apply to the second operand.
382            If `alpha` is not provided, it defaults to 1.
383
384    Returns:
385        ONNX operator
386    """
387    if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
388        other = g.op("Mul", other, alpha)
389    return g.op("Sub", self, other)
390
391
392@_onnx_symbolic("aten::rsub")
393def rsub(g: jit_utils.GraphContext, self, other, alpha=None):
394    return sub(g, other, self, alpha=alpha)
395
396
397@_onnx_symbolic("aten::mul")
398def mul(g: jit_utils.GraphContext, self, other):
399    if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other):
400        # ONNX Mul doesn't support Boolean, so use And as an equivalent operator.
401        return g.op("And", self, other)
402    else:
403        return g.op("Mul", self, other)
404
405
406@_onnx_symbolic("aten::div")
407def div(g: jit_utils.GraphContext, self, other, *args):
408    if len(args) == 0:
409        return true_divide(g, self, other)
410    else:
411        return _div_rounding_mode(g, self, other, *args)
412
413
414@_onnx_symbolic("aten::addcmul")
415@symbolic_helper.parse_args("v", "v", "v", "f")
416def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0):
417    value_tens = g.op("Constant", value_t=torch.tensor([value]))
418    return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
419
420
421@symbolic_helper.parse_args("v", "v", "s")
422def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
423    if rounding_mode is None:
424        return true_divide(g, self, other)
425    elif rounding_mode == "floor":
426        return _floor_divide(g, self, other)
427    elif rounding_mode == "trunc":
428        return _trunc_divide(g, self, other)
429    else:
430        raise errors.SymbolicValueError(
431            f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"',
432            self,
433        )
434
435
436def _trunc_divide(g: jit_utils.GraphContext, self, other):
437    out = g.op("Div", self, other)
438    # the correct operation is truncate, which is not supported in ONNX,
439    # we cannot call floor since it will behave differently for negative numbers
440    # (eg. -0.1 should become -0 )
441    # - if scalar_type information are not available, assume that
442    # we need to call floor (treat as float)
443    out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64)
444
445    # Matching PyTorch's behavior:
446    # - if self is fp the output's type is self's type
447    # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT
448    # - self is not fp and other is not fp, the output's type is self's output type
449    # - the output type defaults to Float
450    scalar_type = _type_utils.JitScalarType.from_value(
451        self, _type_utils.JitScalarType.UNDEFINED
452    )
453    if scalar_type != _type_utils.JitScalarType.UNDEFINED:
454        if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other):
455            out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
456        else:
457            out = g.op(
458                "Cast",
459                out,
460                to_i=scalar_type.onnx_type(),
461            )
462    else:
463        out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
464    return out
465
466
467def _floor_divide(g: jit_utils.GraphContext, self, other):
468    if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
469        out = true_divide(g, self, other)
470        return g.op("Floor", out)
471    else:
472        # Integer division does trunction rounding
473        div = g.op("Div", self, other)
474        # Division is negative if: self < 0 != other < 0
475        zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
476        negative = g.op(
477            "Xor",
478            symbolic_helper._lt_helper(g, self, zero),
479            symbolic_helper._lt_helper(g, other, zero),
480        )
481
482        # For negative numbers with self % other != 0, subtract 1 to round down instead of up
483        mod = g.op("Sub", self, g.op("Mul", div, other))
484        fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
485
486        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
487        fixup = g.op("Mul", fixup_mask, one)
488        return g.op("Sub", div, fixup)
489
490
491@_onnx_symbolic("aten::floor_divide")
492def floor_divide(g: jit_utils.GraphContext, self, other):
493    # Deprecated behavior, floor_divide actually truncates
494    return _trunc_divide(g, self, other)
495
496
497@_onnx_symbolic("aten::floordiv")
498def floordiv(g: jit_utils.GraphContext, self, other):
499    return floor_divide(g, self, other)
500
501
502@_onnx_symbolic("aten::true_divide")
503def true_divide(g: jit_utils.GraphContext, self, other):
504    """Division where both inputs are cast to floating types
505
506    If both inputs are floating, performs div as usual
507    If only one input is a floating type, the other input is cast to its type
508    If neither input is a floating type, both inputs are cast to the default scalar type
509    """
510
511    # Case 1: either values are floating
512    # Performs div as usual.
513    # Implicit casting will be handled in scalar type analysis pass.
514    if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
515        return g.op("Div", self, other)
516
517    # Case 2: neither is floating
518    # Casts both inputs to the default scalar type
519    scalar_type = torch.get_default_dtype()
520    onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT
521    assert scalar_type is torch.float or scalar_type is torch.double
522    if torch.get_default_dtype() is torch.double:
523        onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE
524
525    self = g.op("Cast", self, to_i=onnx_scalar_type)
526    other = g.op("Cast", other, to_i=onnx_scalar_type)
527    return g.op("Div", self, other)
528
529
530@_onnx_symbolic("aten::reciprocal")
531def reciprocal(g: jit_utils.GraphContext, self):
532    # torch.reciprocal implicitly casts to float, so we do the same.
533    if not symbolic_helper._is_fp(self):
534        self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
535    return g.op("Reciprocal", self)
536
537
538@_onnx_symbolic("aten::cat")
539@symbolic_helper.parse_args("v", "i")
540def cat(g: jit_utils.GraphContext, tensor_list, dim):
541    """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension.
542
543    Parameters:
544        g (jit_utils.GraphContext): Graph context.
545        tensor_list (List[torch.Tensor]): List of tensors to concatenate.
546        dim (int): Dimension along which to concatenate the tensors.
547
548    Returns:
549        ONNX graph node representing the concatenated tensor.
550    """
551    tensors = symbolic_helper._unpack_list(tensor_list)
552    # torch.cat ignores empty tensors such as `torch.Tensor([])`
553    # These needs to be removed as input from ONNX's concat too, otherwise shape inference
554    # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
555    nonempty_tensors = []
556    for t in tensors:
557        if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
558            t, 0
559        ):
560            continue
561        nonempty_tensors.append(t)
562    assert len(nonempty_tensors) > 0
563    assert all(
564        symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
565        or symbolic_helper._get_tensor_rank(t) is None
566        or symbolic_helper._get_tensor_rank(t)
567        == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
568        for t in nonempty_tensors
569    )
570    tensor_list.node().removeAllInputs()
571    for t in nonempty_tensors:
572        tensor_list.node().addInput(t)
573
574    tensors = symbolic_helper._unpack_list(tensor_list)
575    return g.op("Concat", *tensors, axis_i=dim)
576
577
578@_onnx_symbolic("aten::stack")
579@symbolic_helper.parse_args("v", "i")
580def stack(g: jit_utils.GraphContext, tensor_list, dim):
581    unsqueezed = [
582        symbolic_helper._unsqueeze_helper(g, t, [dim])
583        for t in symbolic_helper._unpack_list(tensor_list)
584    ]
585    return g.op("Concat", *unsqueezed, axis_i=dim)
586
587
588@_onnx_symbolic("aten::list")
589def _list(g: jit_utils.GraphContext, self):
590    return self
591
592
593@_onnx_symbolic("aten::mm")
594def mm(g: jit_utils.GraphContext, self, other):
595    # Create a dummy C tensor. Only needed for API purposes, the value is
596    # since beta = 0
597    C = g.op("Constant", value_t=torch.tensor([1]))
598    return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
599
600
601@_onnx_symbolic("aten::bmm")
602def bmm(g: jit_utils.GraphContext, self, other):
603    return g.op("MatMul", self, other)
604
605
606@_onnx_symbolic("aten::matmul")
607def matmul(g: jit_utils.GraphContext, self, other):
608    return g.op("MatMul", self, other)
609
610
611@_onnx_symbolic("aten::addmm")
612@symbolic_helper.parse_args("v", "v", "v", "t", "t")
613def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
614    scalar_type = None
615    self_scalar_type = symbolic_helper._try_get_scalar_type(self)
616    mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1)
617    mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2)
618    if self_scalar_type is not None:
619        scalar_type = self_scalar_type
620    elif mat1_scalar_type is not None:
621        scalar_type = mat1_scalar_type
622    elif mat2_scalar_type is not None:
623        scalar_type = mat2_scalar_type
624
625    mat1_rank = symbolic_helper._get_tensor_rank(mat1)
626    mat2_rank = symbolic_helper._get_tensor_rank(mat2)
627
628    def is_not_none_nor(v, u):
629        return v is not None and v != u
630
631    if scalar_type is not None and (
632        is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2)
633    ):
634        res1 = g.op("MatMul", mat1, mat2)
635        res2 = self
636
637        alpha = symbolic_helper._scalar(alpha)
638        beta = symbolic_helper._scalar(beta)
639
640        if alpha != 1:
641            alpha = g.op(
642                "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype())
643            )
644            res1 = g.op("Mul", res1, alpha)
645        if beta != 1:
646            beta = g.op(
647                "Constant",
648                value_t=torch.tensor(
649                    symbolic_helper._scalar(beta), dtype=scalar_type.dtype()
650                ),
651            )
652            res2 = g.op("Mul", res2, beta)
653
654        return g.op("Add", res1, res2)
655
656    return g.op(
657        "Gemm",
658        mat1,
659        mat2,
660        self,
661        beta_f=symbolic_helper._scalar(beta),
662        alpha_f=symbolic_helper._scalar(alpha),
663    )
664
665
666@_onnx_symbolic("aten::neg")
667def neg(g: jit_utils.GraphContext, self):
668    return g.op("Neg", self)
669
670
671@_onnx_symbolic("aten::sqrt")
672def sqrt(g: jit_utils.GraphContext, self):
673    if _type_utils.JitScalarType.from_value(
674        self, _type_utils.JitScalarType.UNDEFINED
675    ) in {
676        _type_utils.JitScalarType.UINT8,
677        _type_utils.JitScalarType.INT8,
678        _type_utils.JitScalarType.INT16,
679        _type_utils.JitScalarType.INT,
680        _type_utils.JitScalarType.INT64,
681    }:
682        # torch converts all int inputs to sqrt to float
683        self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
684
685    return g.op("Sqrt", self)
686
687
688@_onnx_symbolic("aten::rsqrt")
689def rsqrt(g: jit_utils.GraphContext, self):
690    return g.op(
691        "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self)
692    )
693
694
695@_onnx_symbolic("aten::tanh")
696# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp
697@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128)
698def tanh(g: jit_utils.GraphContext, self):
699    return g.op("Tanh", self)
700
701
702@_onnx_symbolic("aten::sin")
703def sin(g: jit_utils.GraphContext, self):
704    return g.op("Sin", self)
705
706
707@_onnx_symbolic("aten::cos")
708def cos(g: jit_utils.GraphContext, self):
709    return g.op("Cos", self)
710
711
712@_onnx_symbolic("aten::tan")
713def tan(g: jit_utils.GraphContext, self):
714    return g.op("Tan", self)
715
716
717@_onnx_symbolic("aten::asin")
718def asin(g: jit_utils.GraphContext, self):
719    return g.op("Asin", self)
720
721
722@_onnx_symbolic("aten::acos")
723def acos(g: jit_utils.GraphContext, self):
724    return g.op("Acos", self)
725
726
727@_onnx_symbolic("aten::atan")
728def atan(g: jit_utils.GraphContext, self):
729    return g.op("Atan", self)
730
731
732@_onnx_symbolic("aten::atan2")
733def atan2(g: jit_utils.GraphContext, self, other):
734    # self is y, and other is x on coordinate
735    slope = g.op("Div", self, other)
736    atan = g.op("Atan", slope)
737    const_zero = g.op("Constant", value_t=torch.tensor(0))
738    const_pi = g.op("Constant", value_t=torch.tensor(math.pi))
739
740    condition_second_or_third_quadrant = g.op("Greater", self, const_zero)
741    second_third_quadrant = g.op(
742        "Where",
743        condition_second_or_third_quadrant,
744        g.op("Add", atan, const_pi),
745        g.op("Sub", atan, const_pi),
746    )
747
748    condition_14_or_23_quadrant = g.op("Less", other, const_zero)
749    result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan)
750
751    return result
752
753
754@_onnx_symbolic("aten::sigmoid")
755# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
756@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
757def sigmoid(g: jit_utils.GraphContext, self):
758    """Converts the corresponding PyTorch function into ONNX operators.
759
760    It is not meant to be called directly by a user.
761
762    Args:
763        g (jit_utils.GraphContext): Graph context.
764        self (Tensor): the input tensor.
765    Returns:
766        ONNX operator
767    """
768    return g.op("Sigmoid", self)
769
770
771@_onnx_symbolic("aten::sign")
772def sign(g: jit_utils.GraphContext, self):
773    return g.op("Sign", self)
774
775
776@symbolic_helper.quantized_args(True)
777def _slice(g: jit_utils.GraphContext, input, axes, starts, ends):
778    assert len(starts) == len(ends)
779    if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX:
780        return input
781    return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
782
783
784@_onnx_symbolic(
785    "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")]
786)
787@_onnx_symbolic(
788    "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
789)
790# torch.prod does not support multidimensional "dim"
791@_onnx_symbolic(
792    "aten::prod",
793    decorate=[
794        symbolic_helper._apply_params(
795            "ReduceProd", "prod", allow_multi_dim_support=False
796        )
797    ],
798)
799def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
800    return symbolic_helper._reduce_with_dtype_helper(
801        onnx_op, name, allow_multi_dim_support
802    )
803
804
805@_onnx_symbolic("aten::cumsum")
806@symbolic_helper.parse_args("v", "i", "none")
807def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
808    symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
809
810
811@_onnx_symbolic("aten::_sample_dirichlet")
812def _sample_dirichlet(g: jit_utils.GraphContext, self, generator):
813    return symbolic_helper._onnx_unsupported("_sample_dirichlet", self)
814
815
816@_onnx_symbolic("aten::_standard_gamma")
817def _standard_gamma(g: jit_utils.GraphContext, self, generator):
818    return symbolic_helper._onnx_unsupported("_standard_gamma", self)
819
820
821@_onnx_symbolic("aten::t")
822def t(g: jit_utils.GraphContext, self):
823    rank = symbolic_helper._get_tensor_rank(self)
824    if rank is None or rank < 2:
825        # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior
826        # clearly and onnxruntime fails on these cases. So we add an Identity node to
827        # mirror the behavior of eager mode.
828        return g.op("Identity", self)
829    return g.op("Transpose", self, perm_i=(1, 0))
830
831
832@_onnx_symbolic("aten::numpy_T")
833@symbolic_helper.quantized_args(True)
834def numpy_T(g: jit_utils.GraphContext, input):
835    ndim = symbolic_helper._get_tensor_rank(input)
836    assert ndim is not None
837    perm = list(reversed(range(0, ndim)))
838    return g.op("Transpose", input, perm_i=perm)
839
840
841@_onnx_symbolic("aten::expand")
842@symbolic_helper.quantized_args(True)
843def expand(g: jit_utils.GraphContext, self, size, implicit):
844    """Implement the expand function for a pytorch tensor in ONNX according to specified `size`"""
845    size = symbolic_helper._maybe_get_const(size, "is")
846    if not symbolic_helper._is_value(size):
847        size = g.op("Constant", value_t=torch.LongTensor(size))
848    elif symbolic_helper._is_packed_list(size):
849        # Expand with -1 dim value means dim is unchanged.
850        # Since onnx::expand supports two-way broadcasting,
851        # -1 dim value can be exported to onnx as 1
852        size = symbolic_helper._reshape_helper(
853            g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
854        )
855    dtype = _type_utils.JitScalarType.INT64
856    ones = ones_like(g, size, dtype)
857    neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
858    size = where(g, g.op("Equal", size, neg_ones), ones, size)
859    return g.op("Expand", self, size)
860
861
862@_onnx_symbolic("aten::broadcast_to")
863@symbolic_helper.quantized_args(True)
864def broadcast_to(g: jit_utils.GraphContext, self, size):
865    size = symbolic_helper._maybe_get_const(size, "is")
866    if not symbolic_helper._is_value(size):
867        size = g.op("Constant", value_t=torch.LongTensor(size))
868    elif symbolic_helper._is_packed_list(size):
869        # Expand with -1 dim value means dim is unchanged.
870        # Since onnx::expand supports two-way broadcasting,
871        # -1 dim value can be exported to onnx as 1
872        size = symbolic_helper._reshape_helper(
873            g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
874        )
875    dtype = _type_utils.JitScalarType.INT64
876    ones = ones_like(g, size, dtype)
877    neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
878    size = where(g, g.op("Equal", size, neg_ones), ones, size)
879    return g.op("Expand", self, size)
880
881
882@_onnx_symbolic("aten::expand_as")
883@symbolic_helper.quantized_args(True, True)
884def expand_as(g: jit_utils.GraphContext, self, other):
885    self_t = symbolic_helper._maybe_get_const(self, "t")
886    if isinstance(self_t, torch.Tensor):
887        orig_type = self_t.dtype
888        self_t = self_t.to(torch.double)
889        dims = []
890        for d in range(self_t.dim()):
891            if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t):
892                dims.append(d)
893                self = g.op(
894                    "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type)
895                )
896
897    shape = g.op("Shape", other)
898    return g.op("Expand", self, shape)
899
900
901@_onnx_symbolic("aten::embedding")
902@symbolic_helper.quantized_args(True)
903@symbolic_helper.parse_args("v", "v", "i", "b", "v")
904def embedding(
905    g: jit_utils.GraphContext,
906    weight,
907    indices,
908    padding_idx,
909    scale_grad_by_freq,
910    sparse,
911):
912    if scale_grad_by_freq and GLOBALS.export_training:
913        raise errors.SymbolicValueError(
914            "Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
915            "for training mode. ONNX does not support scaling the gradients.",
916            weight,
917        )
918    if padding_idx >= 0 and GLOBALS.export_training:
919        warnings.warn(
920            "Warning: ONNX export of embedding with padding_idx >= 0 "
921            "for training mode. "
922            "ONNX does not support not updating the embedding vector at padding_idx during training."
923        )
924
925    return g.op("Gather", weight, indices)
926
927
928@_onnx_symbolic("aten::embedding_bag")
929@symbolic_helper.quantized_args(True)
930@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
931def embedding_bag(
932    g: jit_utils.GraphContext,
933    embedding_matrix,
934    indices,
935    offsets,
936    scale_grad_by_freq,
937    mode,
938    sparse,
939    per_sample_weights,
940    include_last_offset,
941    padding_idx,
942):
943    if not symbolic_helper._is_none(per_sample_weights):
944        return symbolic_helper._onnx_unsupported(
945            "embedding_bag with per_sample_weights"
946        )
947
948    return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix)
949
950
951@_onnx_symbolic("aten::size")
952@symbolic_helper.quantized_args(True, quantize_output=False)
953def size(g: jit_utils.GraphContext, self, dim=None):
954    if dim is None:
955        return g.op("Shape", self)
956    if symbolic_helper._maybe_get_const(dim, "i") < 0:
957        rank = symbolic_helper._get_tensor_rank(self)
958        if rank is not None:
959            dim = symbolic_helper._maybe_get_const(dim, "i") + rank
960            dim = g.op("Constant", value_t=torch.tensor(dim))
961    return symbolic_helper._size_helper(g, self, dim)
962
963
964@_onnx_symbolic("aten::transpose")
965@symbolic_helper.quantized_args(True)
966@symbolic_helper.parse_args("v", "i", "i")
967def transpose(g: jit_utils.GraphContext, self, dim0, dim1):
968    if dim0 == dim1:  # micro-optimization
969        return self
970
971    # NB: Transpose in ONNX is actually a Permute
972    rank = symbolic_helper._get_tensor_rank(self)
973    if rank is not None:
974        axes = list(range(rank))
975        axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
976        return g.op("Transpose", self, perm_i=axes)
977    else:
978        raise errors.SymbolicValueError(
979            "Unsupported: ONNX export of transpose for tensor of unknown rank.",
980            self,
981        )
982
983
984@_onnx_symbolic("aten::permute")
985@symbolic_helper.parse_args("v", "is")
986def permute(g: jit_utils.GraphContext, self, dims):
987    if dims == list(range(0, len(dims))):
988        return self
989    return g.op("Transpose", self, perm_i=dims)
990
991
992@_onnx_symbolic("aten::view")
993@symbolic_helper.quantized_args(True)
994def view(g: jit_utils.GraphContext, self, size):
995    return reshape(g, self, size)
996
997
998@_onnx_symbolic("aten::view_as")
999def view_as(g: jit_utils.GraphContext, self, other):
1000    shape = g.op("Shape", other)
1001    return reshape(g, self, shape)
1002
1003
1004@_onnx_symbolic("aten::unsafe_chunk")
1005@symbolic_helper.parse_args("v", "i", "i", "i")
1006def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
1007    if _outputs is None:
1008        return symbolic_helper._onnx_opset_unsupported_detailed(
1009            "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self
1010        )
1011    size = symbolic_helper._get_tensor_dim_size(self, dim)
1012    if size is None:
1013        return symbolic_helper._unimplemented(
1014            "unsafe_chunk", "unknown dimension size", self
1015        )
1016    split_size = (size + chunks - 1) // chunks
1017    splits = [split_size] * (size // split_size)
1018    leftover = size % split_size
1019    if leftover:
1020        splits.append(leftover)
1021    return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
1022
1023
1024@_onnx_symbolic("aten::split")
1025@symbolic_helper.parse_args("v", "v", "i", "i")
1026def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
1027    if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
1028        return symbolic_helper._onnx_opset_unsupported_detailed(
1029            "split", 9, 11, "Dynamic number of outputs not supported", self
1030        )
1031    split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
1032    if split_val.dim() > 0:
1033        return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs)
1034    split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
1035
1036    size = symbolic_helper._get_tensor_dim_size(self, dim)
1037    if size is None:
1038        if _outputs is not None:
1039            size = split_size * _outputs
1040        else:
1041            return symbolic_helper._onnx_opset_unsupported_detailed(
1042                "split", 9, 11, "Unknown dimension size not supported", self
1043            )
1044    splits = [split_size] * (size // split_size)
1045    leftover = size % split_size
1046    if leftover:
1047        splits.append(leftover)
1048    return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
1049
1050
1051@_onnx_symbolic("aten::unsafe_split")
1052def unsafe_split(
1053    g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
1054):
1055    return split(g, self, split_size_or_sizes, dim, _outputs)
1056
1057
1058@_onnx_symbolic("aten::split_with_sizes")
1059@symbolic_helper.parse_args("v", "is", "i", "i")
1060def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
1061    if not symbolic_helper._is_split_static(split_sizes, _outputs):
1062        return symbolic_helper._onnx_opset_unsupported_detailed(
1063            "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self
1064        )
1065    return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs)
1066
1067
1068@_onnx_symbolic("aten::unsafe_split_with_sizes")
1069def unsafe_split_with_sizes(
1070    g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
1071):
1072    return split_with_sizes(g, self, split_sizes, dim, _outputs)
1073
1074
1075@_onnx_symbolic("aten::unbind")
1076@symbolic_helper.parse_args("v", "i", "i")
1077def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
1078    if _outputs is None:
1079        return symbolic_helper._onnx_opset_unsupported_detailed(
1080            "unbind", 9, 11, "Dynamic number of outputs not supported", self
1081        )
1082
1083    outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs)
1084    outputs = [outputs] if _outputs == 1 else outputs
1085    squeezed_outputs = [
1086        symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs
1087    ]
1088    return squeezed_outputs
1089
1090
1091@_onnx_symbolic("aten::select")
1092@symbolic_helper.quantized_args(True)
1093@symbolic_helper.parse_args("v", "i", "v")
1094def select(g: jit_utils.GraphContext, self, dim, index):
1095    """Implement the select functionality for a pytorch tensor in ONNX.
1096
1097    Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor.
1098    """
1099    index = symbolic_helper._maybe_get_scalar(index)
1100    if (not symbolic_helper._is_value(index)) and (index < 0):
1101        if index == -1:
1102            end_index = _constants.INT64_MAX
1103        else:
1104            end_index = index + 1
1105        slice_node = symbolic_helper._slice_helper(
1106            g, self, axes=[dim], starts=[index], ends=[end_index]
1107        )
1108        return symbolic_helper._squeeze_helper(g, slice_node, [dim])
1109    else:
1110        # FIXME(justinchuby): can index be an int and not a value?
1111        return g.op("Gather", self, index, axis_i=dim)
1112
1113
1114@_onnx_symbolic("aten::square")
1115def square(g: jit_utils.GraphContext, self):
1116    return g.op("Mul", self, self)
1117
1118
1119@_onnx_symbolic("aten::squeeze")
1120def squeeze(g: jit_utils.GraphContext, self, dim=None):
1121    if dim is None:
1122        return g.op("Squeeze", self)
1123
1124    squeeze_dim = symbolic_helper._get_const(dim, "i", "dim")
1125    # Handle negative dims
1126    if squeeze_dim < 0:
1127        rank = symbolic_helper._get_tensor_rank(self)
1128        if rank is not None:
1129            warnings.warn(
1130                "ONNX export squeeze with negative axis "
1131                + str(squeeze_dim)
1132                + " might cause the onnx model to be incorrect. "
1133                + "Negative axis is not supported in ONNX. "
1134                + "Axis is converted to "
1135                + str(squeeze_dim + rank)
1136                + " based on input shape at export time. "
1137                + "Passing an tensor of different rank in execution will be incorrect."
1138            )
1139            squeeze_dim += rank
1140        else:
1141            return symbolic_helper._unimplemented(
1142                "squeeze", "negative axis with unknown input rank", self
1143            )
1144
1145    dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim)
1146    if dim_size is None:
1147        warnings.warn(
1148            "This model contains a squeeze operation on dimension "
1149            + str(squeeze_dim)
1150            + " on an input "
1151            + "with unknown shape. Note that if the size of dimension "
1152            + str(squeeze_dim)
1153            + " of the input "
1154            + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on "
1155            + "non-singleton dimensions, it is recommended to export this model using opset "
1156            + "version 11 or higher."
1157        )
1158        return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
1159    if dim_size > 1:
1160        warnings.warn(
1161            "This model contains a squeeze operation on dimension "
1162            + str(squeeze_dim)
1163            + ". The size of "
1164            + "this dimension in the given input is "
1165            + str(dim_size)
1166            + ". The model will "
1167            + "be exported without the squeeze node. If the model is intended to be used with dynamic "
1168            + "input shapes, please use opset version 11 to "
1169            + "export the model."
1170        )
1171        return self
1172
1173    warnings.warn(
1174        "This model contains a squeeze operation on dimension "
1175        + str(squeeze_dim)
1176        + ". If the model is "
1177        + "intended to be used with dynamic input shapes, please use opset version 11 to export the model."
1178    )
1179    return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
1180
1181
1182@_onnx_symbolic("aten::prelu")
1183def prelu(g: jit_utils.GraphContext, self, weight):
1184    self_rank = symbolic_helper._get_tensor_rank(self)
1185    weight_sizes = symbolic_helper._get_tensor_sizes(weight)
1186    weight_rank = len(weight_sizes)
1187    if self_rank is not None:
1188        if self_rank > 2:
1189            # make weight unidirectional broadcastable
1190            weight = symbolic_helper._unsqueeze_helper(
1191                g, weight, list(range(1, self_rank - 1))
1192            )
1193        elif self_rank == 0 and weight_sizes == [1]:
1194            # self and weight are both scalar but weight has rank == 1, squeeze weight.
1195            weight = symbolic_helper._squeeze_helper(g, weight, [0])
1196            weight_rank = 0
1197
1198    if self_rank is not None and weight_rank is not None:
1199        assert (
1200            self_rank >= weight_rank
1201        ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}"
1202    return g.op("PRelu", self, weight)
1203
1204
1205@_onnx_symbolic("aten::silu")
1206def silu(g: jit_utils.GraphContext, input):
1207    return g.op("Mul", input, g.op("Sigmoid", input))
1208
1209
1210@_onnx_symbolic("aten::mish")
1211def mish(g: jit_utils.GraphContext, input):
1212    return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
1213
1214
1215@_onnx_symbolic("aten::relu")
1216@symbolic_helper.quantized_args(True)
1217def relu(g: jit_utils.GraphContext, input):
1218    return symbolic_helper._op_with_optional_float_cast(
1219        g, "Relu", input, opset_before=14
1220    )
1221
1222
1223@_onnx_symbolic("aten::relu6")
1224@symbolic_helper.quantized_args(True)
1225def relu6(g: jit_utils.GraphContext, input):
1226    return clamp(g, input, 0, 6)
1227
1228
1229@_onnx_symbolic("aten::ceil")
1230def ceil(g: jit_utils.GraphContext, input):
1231    return g.op("Ceil", input)
1232
1233
1234@_onnx_symbolic("aten::floor")
1235def floor(g: jit_utils.GraphContext, input):
1236    return g.op("Floor", input)
1237
1238
1239@_onnx_symbolic("aten::len")
1240def _len(g: jit_utils.GraphContext, self):
1241    sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
1242    return symbolic_helper._squeeze_helper(g, sz_0, [0])
1243
1244
1245@_onnx_symbolic("aten::threshold")
1246@symbolic_helper.parse_args("v", "t", "t")
1247def threshold(g: jit_utils.GraphContext, self, threshold, value):
1248    # See Note [Export inplace]
1249    if symbolic_helper._scalar(threshold) != 0:
1250        return symbolic_helper._unimplemented("threshold", "non-zero threshold", self)
1251    if symbolic_helper._scalar(value) != 0:
1252        return symbolic_helper._unimplemented("threshold", "non-zero value", self)
1253    return g.op("Relu", self)
1254
1255
1256@_onnx_symbolic("aten::leaky_relu")
1257@symbolic_helper.quantized_args(True)
1258@symbolic_helper.parse_args("v", "f", "b")
1259def leaky_relu(
1260    g: jit_utils.GraphContext,
1261    input: _C.Value,
1262    negative_slope: float,
1263    inplace: bool = False,
1264):
1265    # See Note [Export inplace]
1266    return g.op("LeakyRelu", input, alpha_f=negative_slope)
1267
1268
1269@_onnx_symbolic("aten::glu")
1270@symbolic_helper.parse_args("v", "i")
1271def glu(g: jit_utils.GraphContext, input, dim):
1272    dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
1273    if dim_size is not None:
1274        assert dim_size % 2 == 0
1275
1276    first, second = g.op("Split", input, axis_i=dim, outputs=2)
1277    return g.op("Mul", first, g.op("Sigmoid", second))
1278
1279
1280@_onnx_symbolic("aten::softmax")
1281@symbolic_helper.parse_args("v", "i", "none")
1282def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
1283    # Softmax does normalization at vector level.
1284    # PyTorch and ONNX use different strategies to split the input tensor into vectors.
1285    # Thus dim and axis have different meanings.
1286    # PyTorch slices the input tensor into vectors along the `dim`-th dimension.
1287    # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
1288    # If input is a 2 x 3 tensor:
1289    # input = [[1.0, 1.0, 1.0],
1290    #          [1.0, 1,0, 1,0]]
1291    # with dim = 0, the result is:
1292    # result = [[0.5, 0.5, 0.5],
1293    #           [0.5, 0.5, 0.5]]
1294    # with axis = 0, the result is:
1295    # result = [[0.167, 0.167, 0.167],
1296    #           [0.167, 0.167, 0.167]]
1297    # So only when dim and axis both equal to ndim - 1 (the last dimension),
1298    # their semantics are equivalent.
1299    # So use softmax when dim and axis both equal to ndim - 1,
1300    # otherwise transpose the input to put the vectors to be normalized to the last dimension.
1301    # When input rank is not known at export time we compute softmax using a subgraph
1302    # with other operators
1303    input_dim = symbolic_helper._get_tensor_rank(input)
1304    if input_dim is not None:
1305        # TODO: remove this as onnx opset 11 spec allows negative axes
1306        if dim < 0:
1307            dim = input_dim + dim
1308
1309        is_transpose_required = input_dim != dim + 1
1310
1311        if is_transpose_required:
1312            axes = list(range(input_dim))
1313            axes[dim], axes[-1] = axes[-1], axes[dim]
1314            input = g.op("Transpose", input, perm_i=axes)
1315            dim = input_dim - 1
1316
1317        softmax = g.op("Softmax", input, axis_i=dim)
1318        if dtype and dtype.node().kind() != "prim::Constant":
1319            parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
1320            softmax = g.op(
1321                "Cast",
1322                softmax,
1323                to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(),
1324            )
1325
1326        if is_transpose_required:
1327            softmax = g.op("Transpose", softmax, perm_i=axes)  # type: ignore[possibly-undefined]
1328        return softmax
1329
1330    # Apply max normalization.
1331    input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1))
1332
1333    exp = g.op("Exp", input)
1334    sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim])
1335    softmax = g.op("Div", exp, sum)
1336    if dtype and dtype.node().kind() != "prim::Constant":
1337        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
1338        softmax = g.op(
1339            "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
1340        )
1341    return softmax
1342
1343
1344@_onnx_symbolic("aten::softplus")
1345def softplus(g: jit_utils.GraphContext, self, beta, threshold):
1346    beta_const = symbolic_helper._maybe_get_const(beta, "f")
1347    if beta_const != 1:
1348        return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta)
1349    return g.op("Softplus", self)
1350
1351
1352@_onnx_symbolic("aten::get_pool_ceil_padding")
1353def get_pool_ceil_padding(input, kernel_size, stride, padding):
1354    # TODO(justinchuby): Looks like this op is deprecated in torch
1355    sizes = symbolic_helper._get_tensor_sizes(input)
1356    dim = sizes[-len(padding) :] if sizes is not None else None
1357    if dim is None or any(i is None for i in dim):
1358        return symbolic_helper._unimplemented(
1359            "get_pool_ceil_padding", "input size not accessible", input
1360        )
1361    ceiled_output_dim = [
1362        int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i])))
1363        + 1
1364        for i in range(0, len(padding))
1365    ]
1366    # ensure last pooling starts inside
1367    ceiled_output_dim = [
1368        (
1369            ceiled_output_dim[i] - 1
1370            if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
1371            else ceiled_output_dim[i]
1372        )
1373        for i in range(0, len(ceiled_output_dim))
1374    ]
1375    padding_ceil = [
1376        (
1377            0
1378            if (stride[i] == 1)
1379            else (
1380                kernel_size[i]
1381                - (
1382                    dim[i]
1383                    + 2 * padding[i]
1384                    - ((ceiled_output_dim[i] - 1) * stride[i] + 1)
1385                )
1386            )
1387        )
1388        for i in range(0, len(padding))
1389    ]
1390    # ensure padding is not > kernel_size
1391    padding_ceil = [
1392        (
1393            (
1394                int(padding_ceil[i])
1395                if padding_ceil[i] < kernel_size[i] - 1
1396                else int(kernel_size[i] - 1)
1397            )
1398            if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
1399            else int(padding_ceil[i])
1400        )
1401        for i in range(0, len(padding_ceil))
1402    ]
1403    return padding_ceil
1404
1405
1406@_onnx_symbolic(
1407    "aten::max_pool1d",
1408    decorate=[
1409        symbolic_helper._apply_params(
1410            "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
1411        ),
1412        _export("max_pool1d"),
1413    ],
1414)
1415@_onnx_symbolic(
1416    "aten::max_pool2d",
1417    decorate=[
1418        symbolic_helper._apply_params(
1419            "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
1420        ),
1421        _export("max_pool2d"),
1422    ],
1423)
1424@_onnx_symbolic(
1425    "aten::max_pool3d",
1426    decorate=[
1427        symbolic_helper._apply_params(
1428            "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
1429        ),
1430        _export("max_pool3d"),
1431    ],
1432)
1433def _max_pool(name, tuple_fn, ndims, return_indices):
1434    @symbolic_helper.quantized_args(True, False, False, False, False, False)
1435    @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
1436    def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
1437        if set(tuple_fn(dilation)) != {1}:
1438            return symbolic_helper._unimplemented(name, "dilation", input)
1439        if not stride:
1440            stride = kernel_size
1441        padding = tuple(tuple_fn(padding))
1442        if ceil_mode:
1443            padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
1444            padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
1445        else:
1446            padding = padding * 2
1447        kwargs = {
1448            "kernel_shape_i": tuple_fn(kernel_size),
1449            "pads_i": padding,
1450            "strides_i": tuple_fn(stride),
1451        }
1452        # easy but hacky way to get flattened indices values
1453        # to be used to convert the indices values to non-flattened.
1454        # In ONNX the indices are computed as a flatten 1-D tensor,
1455        # so the values in indices are in [0, N x C x D1 x ... x Dn).
1456        # To convert the indices to the same format used by Pytorch,
1457        # we first execute a maxpool with a kernel and stride of 1 on the same input.
1458        # This will result in a tensor of indices in which each index will have it's own value.
1459        # Using this tensor as a reference, we extract the first index of each axis and subtract
1460        # it from each index of this axis in the indices to convert.
1461        # This step will result in a tensor were each dimension has values of indices within
1462        # the dimension it is in.
1463        # For more information :
1464        # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
1465        if return_indices:
1466            r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
1467            _, flattened_indices = g.op(
1468                "MaxPool",
1469                input,
1470                outputs=2,
1471                kernel_shape_i=[1 for _ in range(ndims)],
1472                strides_i=[1 for _ in range(ndims)],
1473            )
1474            # convert indices to have non-flattened indices values
1475            s = symbolic_helper._slice_helper(
1476                g,
1477                flattened_indices,
1478                axes=[2 + i for i in range(ndims)],
1479                starts=list(tuple_fn(0)),
1480                ends=list(tuple_fn(1)),
1481            )
1482            indices = sub(g, indices, s)
1483            return r, indices
1484        else:
1485            r = g.op("MaxPool", input, outputs=1, **kwargs)
1486            return r
1487
1488    return symbolic_fn
1489
1490
1491max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")(
1492    _max_pool(
1493        "max_pool1d_with_indices",
1494        torch.nn.modules.utils._single,
1495        1,
1496        return_indices=True,
1497    )
1498)
1499max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")(
1500    _max_pool(
1501        "max_pool2d_with_indices",
1502        torch.nn.modules.utils._pair,
1503        2,
1504        return_indices=True,
1505    )
1506)
1507max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")(
1508    _max_pool(
1509        "max_pool3d_with_indices",
1510        torch.nn.modules.utils._triple,
1511        3,
1512        return_indices=True,
1513    )
1514)
1515
1516
1517@_onnx_symbolic(
1518    "aten::avg_pool1d",
1519    decorate=[
1520        symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single),
1521        _export("avg_pool1d"),
1522    ],
1523)
1524@_onnx_symbolic(
1525    "aten::avg_pool2d",
1526    decorate=[
1527        symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair),
1528        _export("avg_pool2d"),
1529    ],
1530)
1531@_onnx_symbolic(
1532    "aten::avg_pool3d",
1533    decorate=[
1534        symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple),
1535        _export("avg_pool3d"),
1536    ],
1537)
1538def _avg_pool(name, tuple_fn):
1539    @symbolic_helper.quantized_args(True)
1540    @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
1541    def symbolic_fn(
1542        g,
1543        input: _C.Value,
1544        kernel_size: Sequence[int],
1545        stride: Sequence[int],
1546        padding: int | Sequence[int],
1547        ceil_mode: int,
1548        count_include_pad: int,
1549        divisor_override=None,
1550    ):
1551        if not stride:
1552            stride = kernel_size
1553        padding = symbolic_helper._avgpool_helper(
1554            tuple_fn, padding, kernel_size, stride, divisor_override, name
1555        )
1556        assert isinstance(padding, tuple)
1557        adjusted_padding = padding
1558        # Although onnx::AvgPool provides count_include_pad,
1559        # The corner case of Average Pooling with ceil_mode on
1560        # PyTorch allows sliding window go off bound, which leads to
1561        # this accommodation.
1562        # More detail on https://github.com/pytorch/pytorch/issues/57178
1563        if count_include_pad:
1564            input = symbolic_helper._op_with_optional_float_cast(
1565                g,
1566                "Pad",
1567                input,
1568                pads_i=((0,) * 2 + padding) * 2,
1569                mode_s="constant",
1570                value_f=0.0,
1571                opset_before=11,
1572            )
1573            adjusted_padding = (0,) * len(padding)
1574        if ceil_mode:
1575            padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
1576            adjusted_padding = adjusted_padding + tuple(
1577                a + b for (a, b) in zip(padding_ceil, adjusted_padding)
1578            )
1579        else:
1580            adjusted_padding = adjusted_padding * 2
1581        output = g.op(
1582            "AveragePool",
1583            input,
1584            kernel_shape_i=tuple_fn(kernel_size),
1585            strides_i=tuple_fn(stride),
1586            pads_i=adjusted_padding,
1587        )
1588        return output
1589
1590    return symbolic_fn
1591
1592
1593@_onnx_symbolic(
1594    "aten::adaptive_avg_pool1d",
1595    decorate=[
1596        symbolic_helper._apply_params(
1597            "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
1598        ),
1599        _export("adaptive_avg_pool1d"),
1600    ],
1601)
1602@_onnx_symbolic(
1603    "aten::adaptive_avg_pool2d",
1604    decorate=[
1605        symbolic_helper._apply_params(
1606            "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
1607        ),
1608        _export("adaptive_avg_pool2d"),
1609    ],
1610)
1611@_onnx_symbolic(
1612    "aten::adaptive_avg_pool3d",
1613    decorate=[
1614        symbolic_helper._apply_params(
1615            "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
1616        ),
1617        _export("adaptive_avg_pool3d"),
1618    ],
1619)
1620@_onnx_symbolic(
1621    "aten::adaptive_max_pool1d",
1622    decorate=[
1623        symbolic_helper._apply_params(
1624            "adaptive_max_pool1d",
1625            "MaxPool",
1626            torch.nn.modules.utils._single,
1627            max_pool1d_with_indices,
1628        ),
1629        _export("adaptive_max_pool1d"),
1630    ],
1631)
1632@_onnx_symbolic(
1633    "aten::adaptive_max_pool2d",
1634    decorate=[
1635        symbolic_helper._apply_params(
1636            "adaptive_max_pool2d",
1637            "MaxPool",
1638            torch.nn.modules.utils._pair,
1639            max_pool2d_with_indices,
1640        ),
1641        _export("adaptive_max_pool2d"),
1642    ],
1643)
1644@_onnx_symbolic(
1645    "aten::adaptive_max_pool3d",
1646    decorate=[
1647        symbolic_helper._apply_params(
1648            "adaptive_max_pool3d",
1649            "MaxPool",
1650            torch.nn.modules.utils._triple,
1651            max_pool3d_with_indices,
1652        ),
1653        _export("adaptive_max_pool3d"),
1654    ],
1655)
1656def _adaptive_pool(name, type, tuple_fn, fn=None):
1657    @symbolic_helper.quantized_args(True, False)
1658    def symbolic_fn(g, input, output_size):
1659        # _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
1660        # by executing a GlobalPool.
1661        # It is also supported for cases where the output size is a factor of the input size.
1662        # For these cases the stride and kernel size are uniform along all the indices of
1663        # the same dimension, which makes it possible to export it to ONNX.
1664        # for MaxPool, GlobalMaxPool does not return indices,
1665        # so we try using max_poolxd_with_indices, and if it is not possible
1666        # (input is not a complete tensor or output size not factor of input size)
1667        # then we call GlobalAveragePool and return None for the indices
1668        output_size_value = output_size
1669        try:
1670            output_size = symbolic_helper._parse_arg(output_size, "is")
1671        except Exception:
1672            # FIXME(justinchuby): Avoid catching Exception.
1673            # Catch a more specific exception instead.
1674            return symbolic_helper._onnx_unsupported(
1675                "adaptive pooling, since output_size is not constant.", input
1676            )
1677        if output_size == [1] * len(output_size) and type == "AveragePool":
1678            return g.op("GlobalAveragePool", input)
1679        sizes = symbolic_helper._get_tensor_sizes(input)
1680        try:
1681            dim = sizes[2:]
1682        except Exception:
1683            # FIXME(justinchuby): Avoid catching Exception.
1684            # Catch a more specific exception instead.
1685            dim = None
1686        if dim is None or any(i is None for i in dim):
1687            if output_size == [1] * len(output_size):
1688                return g.op("GlobalMaxPool", input), None
1689            return symbolic_helper._unimplemented(
1690                name, "input size not accessible", input
1691            )
1692        # verify if output size % input size = 0 for all dim
1693        mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
1694        if mod != [0] * len(mod):
1695            if output_size == [1] * len(output_size):
1696                return g.op("GlobalMaxPool", input), None
1697            return symbolic_helper._unimplemented(
1698                name, "output size that are not factor of input size", output_size_value
1699            )
1700        k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
1701        # call max_poolxd_with_indices to get indices in the output
1702        if type == "MaxPool":
1703            return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
1704        output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k))
1705        return output
1706
1707    return symbolic_fn
1708
1709
1710def _prepare_onnx_paddings(dim: int, pad):
1711    """Generate paddings in ONNX order based on pad in pytorch.
1712    Args:
1713        dim: the dimension of the tensor.
1714        pad: the paddings in pytorch.
1715            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
1716    """
1717    # The desired order of paddings is
1718    # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
1719    # n is the dimension of input.
1720    # assume zero-dimensions in the beginning
1721    paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
1722    # reverse order and collate first beginnings and then ends
1723    paddings = paddings[-2::-2] + paddings[-1::-2]
1724    return paddings
1725
1726
1727def _convert_padding_node(input):
1728    padding = symbolic_helper._maybe_get_const(input, "is")
1729    if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding):
1730        input_list = symbolic_helper._unpack_list(padding)
1731        try:
1732            padding = [
1733                symbolic_helper._get_const(v, "i", "padding") for v in input_list
1734            ]
1735        except Exception:
1736            # FIXME(justinchuby): Avoid catching Exception.
1737            # Catch a more specific exception instead.
1738            return symbolic_helper._onnx_opset_unsupported_detailed(
1739                "Pad", 9, 11, "The sizes of the padding must be constant", input
1740            )
1741    return padding
1742
1743
1744@_onnx_symbolic("aten::constant_pad_nd")
1745def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value):
1746    mode = "constant"
1747    try:
1748        value = symbolic_helper._get_const(value, "f", "value")
1749    except Exception:
1750        # FIXME(justinchuby): Avoid catching Exception.
1751        # Catch a more specific exception instead.
1752        return symbolic_helper._onnx_opset_unsupported_detailed(
1753            "Pad", 9, 11, "The value for the padding must be constant", value
1754        )
1755
1756    padding = _convert_padding_node(padding)
1757    paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
1758    return symbolic_helper._op_with_optional_float_cast(
1759        g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
1760    )
1761
1762
1763def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value):
1764    padding = _convert_padding_node(pad)
1765    assert len(padding) % 2 == 0
1766    ndim = len(padding) // 2
1767
1768    cur = input
1769    for idx in range(ndim):
1770        pad_r = padding[-(2 * idx + 1)]
1771        pad_l = padding[-(2 * idx + 2)]
1772        tensors = []
1773        if pad_l > 0:
1774            left = symbolic_helper._slice_helper(
1775                g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX]
1776            )
1777            tensors.append(left)
1778
1779        if pad_l < 0 or pad_r < 0:
1780            start = builtins.max(0, -pad_l)
1781            end = -(builtins.max(0, -pad_r))
1782            middle = symbolic_helper._slice_helper(
1783                g,
1784                cur,
1785                axes=[2 + idx],
1786                starts=[start],
1787                ends=[end],
1788            )
1789            tensors.append(middle)
1790        else:
1791            tensors.append(cur)
1792
1793        if pad_r > 0:
1794            right = symbolic_helper._slice_helper(
1795                g, cur, axes=[2 + idx], starts=[0], ends=[pad_r]
1796            )
1797            tensors.append(right)
1798
1799        cur = g.op("Concat", *tensors, axis_i=(2 + idx))
1800
1801    return cur
1802
1803
1804@_onnx_symbolic("aten::reflection_pad1d")
1805@_onnx_symbolic("aten::reflection_pad2d")
1806@_onnx_symbolic("aten::reflection_pad3d")
1807def reflection_pad(g: jit_utils.GraphContext, input, padding):
1808    mode = "reflect"
1809    padding = _convert_padding_node(padding)
1810    paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
1811    return symbolic_helper._op_with_optional_float_cast(
1812        g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
1813    )
1814
1815
1816@_onnx_symbolic("aten::replication_pad1d")
1817@_onnx_symbolic("aten::replication_pad2d")
1818@_onnx_symbolic("aten::replication_pad3d")
1819def replication_pad(g: jit_utils.GraphContext, input, padding):
1820    mode = "edge"
1821    padding = _convert_padding_node(padding)
1822    paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
1823    return symbolic_helper._op_with_optional_float_cast(
1824        g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
1825    )
1826
1827
1828@_onnx_symbolic("aten::pad")
1829def pad(
1830    g: jit_utils.GraphContext,
1831    input: _C.Value,
1832    pad: _C.Value,
1833    mode: _C.Value,
1834    value: _C.Value,
1835):
1836    mode = symbolic_helper._parse_arg(mode, "s")
1837    if mode == "replicate":
1838        return replication_pad(g, input, pad)
1839    elif mode == "reflect":
1840        return reflection_pad(g, input, pad)
1841    elif mode == "constant":
1842        return constant_pad_nd(g, input, pad, value)
1843    elif mode == "circular":
1844        return _pad_circular(g, input, pad)
1845    else:
1846        raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
1847
1848
1849@_onnx_symbolic(
1850    "aten::upsample_nearest1d",
1851    decorate=[
1852        symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"),
1853        _export("upsample_nearest1d"),
1854    ],
1855)
1856@_onnx_symbolic(
1857    "aten::upsample_nearest2d",
1858    decorate=[
1859        symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"),
1860        _export("upsample_nearest2d"),
1861    ],
1862)
1863@_onnx_symbolic(
1864    "aten::upsample_nearest3d",
1865    decorate=[
1866        symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"),
1867        _export("upsample_nearest3d"),
1868    ],
1869)
1870@_onnx_symbolic(
1871    "aten::upsample_linear1d",
1872    decorate=[
1873        symbolic_helper._apply_params("upsample_linear1d", 3, "linear"),
1874        _export("upsample_linear1d"),
1875    ],
1876)
1877@_onnx_symbolic(
1878    "aten::upsample_bilinear2d",
1879    decorate=[
1880        symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"),
1881        _export("upsample_bilinear2d"),
1882    ],
1883)
1884@_onnx_symbolic(
1885    "aten::upsample_trilinear3d",
1886    decorate=[
1887        symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"),
1888        _export("upsample_trilinear3d"),
1889    ],
1890)
1891def _interpolate(name: str, dim: int, interpolate_mode: str):
1892    def symbolic_fn(g, input, output_size, *args):
1893        scales, align_corners = symbolic_helper._get_interpolate_attributes(
1894            g, interpolate_mode, args
1895        )
1896        symbolic_helper._interpolate_warning(interpolate_mode)
1897        align_corners = symbolic_helper._maybe_get_scalar(align_corners)
1898        if align_corners:
1899            return symbolic_helper._unimplemented(name, "align_corners == True", input)
1900        if scales is None:
1901            scales = symbolic_helper._interpolate_size_to_scales(
1902                g, input, output_size, dim
1903            )
1904        return g.op("Upsample", input, scales, mode_s=interpolate_mode)
1905
1906    return symbolic_fn
1907
1908
1909@_onnx_symbolic("aten::__interpolate")
1910def __interpolate(
1911    g: jit_utils.GraphContext,
1912    input,
1913    size,
1914    scale_factor,
1915    mode,
1916    align_corners,
1917    recompute_scale_factor,
1918    antialias,
1919):
1920    scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
1921        g, input, size, scale_factor, mode, align_corners
1922    )
1923    return g.op("Upsample", input, scales, mode_s=mode)
1924
1925
1926@_onnx_symbolic("aten::bitwise_not")
1927def bitwise_not(g: jit_utils.GraphContext, input):
1928    if not symbolic_helper._is_bool(input):
1929        raise errors.SymbolicValueError(
1930            "ONNX export does NOT support exporting bitwise Not "
1931            "for non-boolean input values",
1932            input,
1933        )
1934    return g.op("Not", input)
1935
1936
1937@_onnx_symbolic("aten::bitwise_or")
1938def bitwise_or(g, self, other):
1939    if not symbolic_helper._is_bool(self):
1940        raise errors.SymbolicValueError(
1941            "ONNX export does NOT support exporting bitwise OR "
1942            "for non-boolean input values. self: ",
1943            self,
1944        )
1945    if not symbolic_helper._is_bool(other):
1946        raise errors.SymbolicValueError(
1947            "ONNX export does NOT support exporting bitwise OR "
1948            "for non-boolean input values. other: ",
1949            other,
1950        )
1951    return g.op("Or", self, other)
1952
1953
1954def wrap_logical_op_with_cast_to(to_type):
1955    def decorator(fn):
1956        @functools.wraps(fn)
1957        def wrap_with_cast(g, input, other):
1958            to_cast_func = globals()[f"_cast_{to_type}"]
1959            return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
1960
1961        return wrap_with_cast
1962
1963    return decorator
1964
1965
1966def wrap_logical_op_with_negation(func: Callable) -> Callable:
1967    @functools.wraps(func)
1968    def wrap_with_not(g, input, other):
1969        return g.op("Not", func(g, input, other))
1970
1971    return wrap_with_not
1972
1973
1974@_onnx_symbolic("aten::__not_")
1975def __not_(g: jit_utils.GraphContext, self):
1976    if not symbolic_helper._is_bool(self):
1977        raise errors.SymbolicValueError(
1978            "ONNX export does NOT support exporting bitwise Not "
1979            "for non-boolean input values",
1980            self,
1981        )
1982    return g.op("Not", self)
1983
1984
1985@_onnx_symbolic("aten::eq")
1986@symbolic_helper.quantized_args(True, True)
1987def eq(g: jit_utils.GraphContext, self, other):
1988    if isinstance(self.type(), _C.DeviceObjType) and isinstance(
1989        other.type(), _C.DeviceObjType
1990    ):
1991        # ONNX doesn't have devices, so consider them all to be equal.
1992        # The no-op check for equality will get constant-folded.
1993        return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool))
1994    self_node = self.node()
1995    other_node = other.node()
1996    if self_node.kind() == other_node.kind() == "onnx::Constant":
1997        if self_node.kindOf("value") == other_node.kindOf("value") == "s":
1998            # Exporting strings to ONNX is not supported.
1999            # If both strings are constant, we can compare them directly.
2000            # The no-op check for equality will get constant-folded.
2001            return g.op(
2002                "Constant",
2003                value_t=torch.tensor(
2004                    self_node.s("value") == other_node.s("value"),
2005                    dtype=torch.bool,
2006                ),
2007            )
2008
2009    return g.op("Equal", self, other)
2010
2011
2012@_onnx_symbolic("aten::ne")
2013@symbolic_helper.quantized_args(True, True)
2014@wrap_logical_op_with_negation
2015def ne(g: jit_utils.GraphContext, self, other):
2016    return eq(g, self, other)
2017
2018
2019@_onnx_symbolic("aten::gt")
2020@symbolic_helper.quantized_args(True, True)
2021def gt(g: jit_utils.GraphContext, input, other):
2022    return _gt_impl(g, input, other)
2023
2024
2025def _gt_impl(g: jit_utils.GraphContext, input, other):
2026    if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
2027        input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
2028        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
2029    return g.op("Greater", input, other)
2030
2031
2032@_onnx_symbolic("aten::lt")
2033@symbolic_helper.quantized_args(True, True)
2034def lt(g: jit_utils.GraphContext, input, other):
2035    return _lt_impl(g, input, other)
2036
2037
2038def _lt_impl(g: jit_utils.GraphContext, input, other):
2039    if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
2040        input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
2041        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
2042    return g.op("Less", input, other)
2043
2044
2045@_onnx_symbolic("aten::ge")
2046@symbolic_helper.quantized_args(True, True)
2047@wrap_logical_op_with_negation
2048def ge(g: jit_utils.GraphContext, input, other):
2049    return _lt_impl(g, input, other)
2050
2051
2052@_onnx_symbolic("aten::le")
2053@symbolic_helper.quantized_args(True, True)
2054@wrap_logical_op_with_negation
2055def le(g: jit_utils.GraphContext, input, other):
2056    return _gt_impl(g, input, other)
2057
2058
2059@_onnx_symbolic("aten::__and_")
2060def __and_(g: jit_utils.GraphContext, input, other):
2061    if not symbolic_helper._is_bool(input):
2062        raise errors.SymbolicValueError(
2063            "ONNX export does NOT support exporting bitwise AND "
2064            "for non-boolean input values",
2065            input,
2066        )
2067    if not symbolic_helper._is_bool(other):
2068        raise errors.SymbolicValueError(
2069            "ONNX export does NOT support exporting bitwise AND "
2070            "for non-boolean input values",
2071            other,
2072        )
2073    return g.op("And", input, other)
2074
2075
2076@_onnx_symbolic("aten::__or_")
2077def __or_(g: jit_utils.GraphContext, input, other):
2078    if not symbolic_helper._is_bool(input):
2079        raise errors.SymbolicValueError(
2080            "ONNX export does NOT support exporting bitwise OR "
2081            "for non-boolean input values",
2082            input,
2083        )
2084    if not symbolic_helper._is_bool(other):
2085        raise errors.SymbolicValueError(
2086            "ONNX export does NOT support exporting bitwise OR "
2087            "for non-boolean input values",
2088            other,
2089        )
2090    return g.op("Or", input, other)
2091
2092
2093@_onnx_symbolic("aten::__xor_")
2094def __xor_(g: jit_utils.GraphContext, input, other):
2095    if not symbolic_helper._is_bool(input):
2096        raise errors.SymbolicValueError(
2097            "ONNX export does NOT support exporting bitwise XOR "
2098            "for non-boolean input values",
2099            input,
2100        )
2101    if not symbolic_helper._is_bool(other):
2102        raise errors.SymbolicValueError(
2103            "ONNX export does NOT support exporting bitwise XOR "
2104            "for non-boolean input values",
2105            other,
2106        )
2107    return g.op("Xor", input, other)
2108
2109
2110@_onnx_symbolic("aten::logical_and")
2111@wrap_logical_op_with_cast_to("Bool")
2112def logical_and(g: jit_utils.GraphContext, input, other):
2113    return g.op("And", input, other)
2114
2115
2116@_onnx_symbolic("aten::logical_or")
2117@wrap_logical_op_with_cast_to("Bool")
2118def logical_or(g: jit_utils.GraphContext, input, other):
2119    return g.op("Or", input, other)
2120
2121
2122@_onnx_symbolic("aten::logical_xor")
2123@wrap_logical_op_with_cast_to("Bool")
2124def logical_xor(g: jit_utils.GraphContext, input, other):
2125    return g.op("Xor", input, other)
2126
2127
2128@_onnx_symbolic("aten::logical_not")
2129def logical_not(g: jit_utils.GraphContext, input):
2130    return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL))
2131
2132
2133@_onnx_symbolic("aten::__rshift_")
2134def __rshift_(g: jit_utils.GraphContext, self, other):
2135    # make sure to cast other to self's type
2136    # (when self is long, make sure that other is not float)
2137    self_scalar_type = _type_utils.JitScalarType.from_value(self)
2138    if (
2139        _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
2140        != self_scalar_type
2141    ):
2142        other = g.op(
2143            "Cast",
2144            other,
2145            to_i=self_scalar_type.onnx_type(),
2146        )
2147
2148    two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
2149    # exponent (same type as self) has to be float or double in onnx::Pow
2150    if not symbolic_helper._is_fp(self):
2151        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
2152    two_pow = g.op("Pow", two, other)
2153    two_pow = g.op(
2154        "Cast",
2155        two_pow,
2156        to_i=self_scalar_type.onnx_type(),
2157    )
2158    rshift = g.op("Div", self, two_pow)
2159    return rshift
2160
2161
2162@_onnx_symbolic("aten::__lshift_")
2163def __lshift_(g: jit_utils.GraphContext, self, other):
2164    # make sure to cast other to self's type
2165    # (when self is long, make sure that other is not float)
2166    self_scalar_type = _type_utils.JitScalarType.from_value(self)
2167    if (
2168        _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
2169        != self_scalar_type
2170    ):
2171        other = g.op(
2172            "Cast",
2173            other,
2174            to_i=self_scalar_type.onnx_type(),
2175        )
2176
2177    two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
2178    # exponent (same type as self) has to be float or double in onnx::Pow
2179    if not symbolic_helper._is_fp(self):
2180        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
2181    two_pow = g.op("Pow", two, other)
2182    two_pow = g.op(
2183        "Cast",
2184        two_pow,
2185        to_i=self_scalar_type.onnx_type(),
2186    )
2187    lshift = g.op("Mul", self, two_pow)
2188    return lshift
2189
2190
2191@_onnx_symbolic("aten::where")
2192@symbolic_helper.parse_args("v", "v", "v", "i")
2193def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
2194    # Assumes that torch.where's first argument takes only Bool and Byte tensors.
2195    if not symbolic_helper._is_bool(condition):
2196        condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
2197    if self is None:
2198        condition = nonzero(g, condition)
2199        return symbolic_helper._unbind_helper(
2200            g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
2201        )
2202    return g.op("Where", condition, self, other)
2203
2204
2205@_onnx_symbolic("aten::log_softmax")
2206@symbolic_helper.parse_args("v", "i", "none")
2207def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
2208    # PyTorch dim and ONNX axis have different meanings.
2209    # See Softmax comment for details.
2210    # TODO: remove this as onnx opset 11 spec allows negative axes
2211    input_dim = symbolic_helper._get_tensor_rank(input)
2212    if input_dim is None:
2213        return symbolic_helper._unimplemented(
2214            "dim",
2215            "ONNX and PyTorch use different strategies to split the input. "
2216            "Input rank must be known at export time.",
2217        )
2218    if dim < 0:
2219        dim = input_dim + dim
2220    is_transpose_required = input_dim != dim + 1
2221    # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
2222    if is_transpose_required:
2223        axes = list(range(input_dim))
2224        axes[dim], axes[-1] = axes[-1], axes[dim]
2225        input = g.op("Transpose", input, perm_i=axes)
2226        dim = input_dim - 1
2227    return_op = g.op("LogSoftmax", input, axis_i=dim)
2228    if dtype and dtype.node().kind() != "prim::Constant":
2229        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
2230        return_op = g.op(
2231            "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
2232        )
2233    if is_transpose_required:
2234        return_op = g.op("Transpose", return_op, perm_i=axes)  # type: ignore[possibly-undefined]
2235    return return_op
2236
2237
2238@_onnx_symbolic("aten::_log_softmax")
2239@symbolic_helper.parse_args("v", "i", "i")
2240def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float):
2241    if (
2242        half_to_float
2243        and _type_utils.JitScalarType.from_value(
2244            input, _type_utils.JitScalarType.UNDEFINED
2245        )
2246        == _type_utils.JitScalarType.HALF
2247    ):
2248        input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
2249    return log_softmax(g, input, dim)
2250
2251
2252@_onnx_symbolic("aten::_convolution")
2253@symbolic_helper.parse_args(
2254    "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i"
2255)
2256def _convolution(
2257    g: jit_utils.GraphContext,
2258    input,
2259    weight,
2260    bias,
2261    stride,
2262    padding,
2263    dilation,
2264    transposed,
2265    output_padding,
2266    groups,
2267    benchmark,
2268    deterministic,
2269    cudnn_enabled,
2270    allow_tf32=None,
2271):
2272    weight_size = symbolic_helper._get_tensor_sizes(weight)
2273    try:
2274        kernel_shape = weight_size[2:]
2275    except Exception:
2276        # FIXME(justinchuby): Avoid catching Exception.
2277        # Catch a more specific exception instead.
2278        kernel_shape = None
2279
2280    if kernel_shape is None or any(i is None for i in kernel_shape):
2281        raise errors.SymbolicValueError(
2282            "Unsupported: ONNX export of convolution for kernel of unknown shape.",
2283            input,
2284        )
2285
2286    args = [input, weight]
2287    # ONNX only supports 1D bias
2288    if (
2289        not symbolic_helper._is_none(bias)
2290        and symbolic_helper._get_tensor_rank(bias) == 1
2291    ):
2292        args.append(bias)
2293
2294    kwargs = {
2295        "kernel_shape_i": weight_size[2:],
2296        "strides_i": stride,
2297        # NB: ONNX supports asymmetric padding, whereas PyTorch supports only
2298        # symmetric padding
2299        "pads_i": padding + padding,
2300        "dilations_i": dilation,
2301        "group_i": groups,
2302    }
2303
2304    if any(o != 0 for o in output_padding):
2305        # ONNX supports both output_shape and output_padding. they are equivalent expressive.
2306        # output_padding is more straightforward, so we use it here.
2307        # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
2308        assert transposed
2309        assert len(stride) == len(output_padding)
2310        kwargs["output_padding_i"] = output_padding
2311
2312    n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
2313
2314    if (
2315        not symbolic_helper._is_none(bias)
2316        and symbolic_helper._get_tensor_rank(bias) != 1
2317    ):
2318        return g.op("Add", n, bias)
2319    else:
2320        return n
2321
2322
2323@_onnx_symbolic("aten::_convolution_mode")
2324@symbolic_helper.parse_args(
2325    "v",
2326    "v",
2327    "v",
2328    "is",
2329    "s",
2330    "is",
2331    "i",
2332)
2333def _convolution_mode(
2334    g: jit_utils.GraphContext,
2335    input,
2336    weight,
2337    bias,
2338    stride,
2339    padding,
2340    dilation,
2341    groups,
2342):
2343    weight_size = symbolic_helper._get_tensor_sizes(weight)
2344    try:
2345        kernel_shape = weight_size[2:]
2346    except Exception:
2347        # FIXME(justinchuby): Avoid catching Exception.
2348        # Catch a more specific exception instead.
2349        kernel_shape = None
2350
2351    if kernel_shape is None or any(i is None for i in kernel_shape):
2352        raise errors.SymbolicValueError(
2353            "Unsupported: ONNX export of convolution for kernel of unknown shape.",
2354            input,
2355        )
2356
2357    args = [input, weight]
2358    # ONNX only supports 1D bias
2359    if (
2360        not symbolic_helper._is_none(bias)
2361        and symbolic_helper._get_tensor_rank(bias) == 1
2362    ):
2363        args.append(bias)
2364
2365    if padding == "valid":
2366        padding = "VALID"
2367    elif padding == "same":
2368        padding = "SAME_UPPER"
2369    kwargs = {
2370        "kernel_shape_i": weight_size[2:],
2371        "strides_i": stride,
2372        "auto_pad_s": padding,
2373        "dilations_i": dilation,
2374        "group_i": groups,
2375    }
2376
2377    n = g.op("Conv", *args, **kwargs)
2378
2379    if (
2380        not symbolic_helper._is_none(bias)
2381        and symbolic_helper._get_tensor_rank(bias) != 1
2382    ):
2383        return g.op("Add", n, bias)
2384    else:
2385        return n
2386
2387
2388@_onnx_symbolic("aten::convolution")
2389@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i")
2390def convolution(
2391    g: jit_utils.GraphContext,
2392    input,
2393    weight,
2394    bias,
2395    stride,
2396    padding,
2397    dilation,
2398    transposed,
2399    output_padding,
2400    groups,
2401):
2402    return _convolution(
2403        g,
2404        input,
2405        weight,
2406        bias,
2407        stride,
2408        padding,
2409        dilation,
2410        transposed,
2411        output_padding,
2412        groups,
2413        None,
2414        None,
2415        None,
2416        None,
2417    )
2418
2419
2420@_onnx_symbolic("aten::conv1d")
2421@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
2422def conv1d(
2423    g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
2424):
2425    str_padding = symbolic_helper._parse_arg(padding, "s")
2426    if str_padding in ["valid", "same"]:
2427        return _convolution_mode(
2428            g,
2429            input,
2430            weight,
2431            bias,
2432            stride,
2433            str_padding,
2434            dilation,
2435            groups,
2436        )
2437    else:
2438        padding = symbolic_helper._parse_arg(padding, "is")
2439        return _convolution(
2440            g,
2441            input,
2442            weight,
2443            bias,
2444            stride,
2445            padding,
2446            dilation,
2447            False,
2448            (),
2449            groups,
2450            None,
2451            None,
2452            None,
2453            None,
2454        )
2455
2456
2457@_onnx_symbolic("aten::conv2d")
2458@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
2459def conv2d(
2460    g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
2461):
2462    str_padding = symbolic_helper._parse_arg(padding, "s")
2463    if str_padding in ["valid", "same"]:
2464        return _convolution_mode(
2465            g,
2466            input,
2467            weight,
2468            bias,
2469            stride,
2470            str_padding,
2471            dilation,
2472            groups,
2473        )
2474    else:
2475        padding = symbolic_helper._parse_arg(padding, "is")
2476        return _convolution(
2477            g,
2478            input,
2479            weight,
2480            bias,
2481            stride,
2482            padding,
2483            dilation,
2484            False,
2485            (),
2486            groups,
2487            None,
2488            None,
2489            None,
2490            None,
2491        )
2492
2493
2494@_onnx_symbolic("aten::conv3d")
2495@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
2496def conv3d(
2497    g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
2498):
2499    str_padding = symbolic_helper._parse_arg(padding, "s")
2500    if str_padding in ["valid", "same"]:
2501        return _convolution_mode(
2502            g,
2503            input,
2504            weight,
2505            bias,
2506            stride,
2507            str_padding,
2508            dilation,
2509            groups,
2510        )
2511    else:
2512        padding = symbolic_helper._parse_arg(padding, "is")
2513        return _convolution(
2514            g,
2515            input,
2516            weight,
2517            bias,
2518            stride,
2519            padding,
2520            dilation,
2521            False,
2522            (),
2523            groups,
2524            None,
2525            None,
2526            None,
2527            None,
2528        )
2529
2530
2531@_onnx_symbolic("aten::conv_transpose1d")
2532@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
2533def conv_transpose1d(
2534    g: jit_utils.GraphContext,
2535    input,
2536    weight,
2537    bias,
2538    stride,
2539    padding,
2540    output_padding,
2541    groups,
2542    dilation,
2543):
2544    return _convolution(
2545        g,
2546        input,
2547        weight,
2548        bias,
2549        stride,
2550        padding,
2551        dilation,
2552        True,
2553        output_padding,
2554        groups,
2555        None,
2556        None,
2557        None,
2558        None,
2559    )
2560
2561
2562@_onnx_symbolic("aten::conv_transpose2d")
2563@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
2564def conv_transpose2d(
2565    g: jit_utils.GraphContext,
2566    input,
2567    weight,
2568    bias,
2569    stride,
2570    padding,
2571    output_padding,
2572    groups,
2573    dilation,
2574):
2575    return _convolution(
2576        g,
2577        input,
2578        weight,
2579        bias,
2580        stride,
2581        padding,
2582        dilation,
2583        True,
2584        output_padding,
2585        groups,
2586        None,
2587        None,
2588        None,
2589        None,
2590    )
2591
2592
2593@_onnx_symbolic("aten::conv_transpose3d")
2594@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
2595def conv_transpose3d(
2596    g: jit_utils.GraphContext,
2597    input,
2598    weight,
2599    bias,
2600    stride,
2601    padding,
2602    output_padding,
2603    groups,
2604    dilation,
2605):
2606    return _convolution(
2607        g,
2608        input,
2609        weight,
2610        bias,
2611        stride,
2612        padding,
2613        dilation,
2614        True,
2615        output_padding,
2616        groups,
2617        None,
2618        None,
2619        None,
2620        None,
2621    )
2622
2623
2624@_onnx_symbolic("aten::batch_norm")
2625@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
2626def batch_norm(
2627    g: jit_utils.GraphContext,
2628    input,
2629    weight,
2630    bias,
2631    running_mean,
2632    running_var,
2633    training,
2634    momentum,
2635    eps,
2636    cudnn_enabled,
2637):
2638    symbolic_helper.check_training_mode(training, "batch_norm")
2639
2640    if (
2641        torch.is_autocast_enabled()
2642        and not symbolic_helper.args_have_same_dtype(
2643            [input, weight, bias, running_mean, running_var]
2644        )
2645        and GLOBALS.export_onnx_opset_version < 15
2646    ):
2647        return symbolic_helper._onnx_opset_unsupported_detailed(
2648            "BatchNormalization",
2649            9,
2650            15,
2651            "All input tensors must have the same `dtype`."
2652            " Turn off Autocast or export using opset version 15.",
2653            input,
2654        )
2655
2656    weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
2657        g, input, weight, bias, running_mean, running_var
2658    )
2659    out = g.op(
2660        "BatchNormalization",
2661        input,
2662        weight,
2663        bias,
2664        running_mean,
2665        running_var,
2666        epsilon_f=eps,
2667        momentum_f=1 - momentum,
2668        outputs=1 if not training else 5,
2669    )
2670    if not training:
2671        return out
2672    else:
2673        res, new_running_mean, new_running_var, saved_mean, saved_var = out
2674        new_running_mean.setType(running_mean.type())
2675        new_running_var.setType(running_var.type())
2676        saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
2677        saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
2678        return res
2679
2680
2681@_onnx_symbolic("aten::native_layer_norm")
2682@symbolic_helper.quantized_args(True, False, False, False)
2683@symbolic_helper.parse_args("v", "is", "v", "v", "f")
2684def native_layer_norm(
2685    g: jit_utils.GraphContext,
2686    input: _C.Value,
2687    normalized_shape: Sequence[int],
2688    weight: _C.Value,
2689    bias: _C.Value,
2690    eps: float,
2691) -> tuple[_C.Value, _C.Value, _C.Value]:
2692    axes = [-i for i in range(len(normalized_shape), 0, -1)]
2693
2694    two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
2695    eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
2696
2697    if g.opset < 18:
2698        mean = g.op("ReduceMean", input, axes_i=axes)
2699    else:
2700        mean = g.op(
2701            "ReduceMean",
2702            input,
2703            g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
2704        )
2705
2706    numerator = sub(g, input, mean)
2707
2708    # Cast it to eps dtype to avoid precision loss
2709    is_type_half = (
2710        _type_utils.JitScalarType.from_value(numerator)
2711        == _type_utils.JitScalarType.HALF
2712    )
2713    if is_type_half:
2714        eps_dtype = _type_utils.JitScalarType.from_value(eps_cst)
2715        numerator = g.op(
2716            "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type()
2717        )
2718
2719    # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
2720    if g.opset < 18:
2721        variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
2722    else:
2723        variance = g.op(
2724            "ReduceMean",
2725            pow(g, numerator, two_cst),
2726            g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
2727        )
2728
2729    denominator = sqrt(g, g.op("Add", variance, eps_cst))
2730    normalized = g.op("Div", numerator, denominator)
2731
2732    # Cast back to input type as eps related ops are all done
2733    if is_type_half:
2734        input_dtype = _type_utils.JitScalarType.from_value(input)
2735        normalized = g.op(
2736            "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
2737        )
2738
2739    if not (weight is None or symbolic_helper._is_none(weight)):
2740        normalized = mul(g, normalized, weight)
2741    if not (bias is None or symbolic_helper._is_none(bias)):
2742        normalized = add(g, normalized, bias)
2743
2744    # rdenominator := 1 / sqrt(variance + eps)
2745    # According to aten::native_layer_norm, rdenominator should have the same dtype as input,
2746    # mean and normalized, so we need to Cast it back
2747    if is_type_half:
2748        denominator = g.op(
2749            "Cast",
2750            denominator,
2751            to_i=_type_utils.JitScalarType(input_dtype).onnx_type(),  # type: ignore[possibly-undefined]
2752        )
2753        rdenominator = g.op("Reciprocal", denominator)
2754    else:
2755        rdenominator = reciprocal(g, denominator)
2756
2757    return normalized, mean, rdenominator
2758
2759
2760@_onnx_symbolic("aten::layer_norm")
2761@symbolic_helper.quantized_args(True, False, False, False)
2762@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b")
2763def layer_norm(
2764    g: jit_utils.GraphContext,
2765    input: _C.Value,
2766    normalized_shape: Sequence[int],
2767    weight: _C.Value,
2768    bias: _C.Value,
2769    eps: float,
2770    cudnn_enable: bool,
2771) -> _C.Value:
2772    normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps)
2773    return normalized
2774
2775
2776@_onnx_symbolic("aten::instance_norm")
2777@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b")
2778def instance_norm(
2779    g: jit_utils.GraphContext,
2780    input,
2781    weight,
2782    bias,
2783    running_mean,
2784    running_var,
2785    use_input_stats: bool,
2786    momentum: Number,
2787    eps: Number,
2788    cudnn_enabled: bool,
2789):
2790    symbolic_helper.check_training_mode(use_input_stats, "instance_norm")
2791    channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
2792    if weight is None or symbolic_helper._is_none(weight):
2793        if channel_size is None:
2794            raise errors.SymbolicValueError(
2795                "Unsupported: ONNX export of instance_norm for unknown channel size.",
2796                input,
2797            )
2798        weight_value = torch.tensor(
2799            [1.0] * channel_size,
2800            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
2801        )
2802        weight = g.op("Constant", value_t=weight_value)
2803    if bias is None or symbolic_helper._is_none(bias):
2804        if channel_size is None:
2805            raise errors.SymbolicValueError(
2806                "Unsupported: ONNX export of instance_norm for unknown channel size.",
2807                input,
2808            )
2809        bias_value = torch.tensor(
2810            [0.0] * channel_size,
2811            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
2812        )
2813        bias = g.op("Constant", value_t=bias_value)
2814    if (
2815        running_mean is None
2816        or symbolic_helper._is_none(running_mean)
2817        or running_var is None
2818        or symbolic_helper._is_none(running_var)
2819    ):
2820        return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
2821    else:
2822        input_size = symbolic_helper._get_tensor_sizes(input)
2823        # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm.
2824        # For more information instance_norm():
2825        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542
2826        input_size_reshape = input_size.copy()
2827        n = input_size[0]
2828        if n is None:
2829            raise errors.SymbolicValueError(
2830                "Unsupported: ONNX export of instance_norm training for unknown "
2831                "batch size.",
2832                input,
2833            )
2834        c = input_size[1]
2835        input_size_reshape[0] = 1
2836        input_size_reshape[1] = n * c
2837        weight_ = repeat(
2838            g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
2839        )
2840        bias_ = repeat(
2841            g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
2842        )
2843        running_mean_ = repeat(
2844            g,
2845            running_mean,
2846            g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
2847        )
2848        running_var_ = repeat(
2849            g,
2850            running_var,
2851            g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
2852        )
2853        input_reshaped = g.op(
2854            "Reshape",
2855            input,
2856            g.op("Constant", value_t=torch.LongTensor(input_size_reshape)),
2857        )
2858        out = batch_norm(
2859            g,
2860            input_reshaped,
2861            weight_,
2862            bias_,
2863            running_mean_,
2864            running_var_,
2865            use_input_stats,
2866            momentum,
2867            eps,
2868            cudnn_enabled,
2869        )
2870        return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
2871
2872
2873@_onnx_symbolic("aten::unfold")
2874@symbolic_helper.parse_args("v", "i", "i", "i")
2875def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
2876    sizes = symbolic_helper._get_tensor_sizes(input)
2877    # FIXME(justinchuby): Get rid of the try catch here to improve readability
2878    try:
2879        sizedim = sizes[dimension]
2880    except Exception:
2881        # FIXME(justinchuby): Avoid catching Exception.
2882        # Catch a more specific exception instead.
2883        sizedim = None
2884    if sizedim is not None:
2885        low_indices = range(0, sizedim, step)
2886        hi_indices = range(size, sizedim + 1, step)
2887        stack = [
2888            symbolic_helper._slice_helper(
2889                g, input, axes=[dimension], starts=[low], ends=[hi]
2890            )
2891            for low, hi in zip(low_indices, hi_indices)
2892        ]
2893        ndim = len(sizes)
2894        perm = list(range(0, ndim))
2895        perm.append(perm.pop(dimension))
2896        unsqueeze = [
2897            symbolic_helper._unsqueeze_helper(
2898                g, g.op("Transpose", t, perm_i=perm), [dimension]
2899            )
2900            for t in stack
2901        ]
2902        return g.op("Concat", *unsqueeze, axis_i=dimension)
2903    else:
2904        return symbolic_helper._unimplemented(
2905            "Unfold", "input size not accessible", input
2906        )
2907
2908
2909@_onnx_symbolic("aten::elu")
2910@symbolic_helper.quantized_args(True)
2911@symbolic_helper.parse_args("v", "t", "t", "t")
2912def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale):
2913    if scale and scale != 1.0:
2914        return symbolic_helper._unimplemented(
2915            "scale", "does not support scale in Elu", scale
2916        )
2917    if input_scale and input_scale != 1.0:
2918        return symbolic_helper._unimplemented(
2919            "input_scale", "does not support input_scale in Elu", input_scale
2920        )
2921    # See Note [Export inplace]
2922    return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha))
2923
2924
2925@_onnx_symbolic("aten::selu")
2926@symbolic_helper.quantized_args(True)
2927def selu(g: jit_utils.GraphContext, input):
2928    return g.op("Selu", input)
2929
2930
2931@_onnx_symbolic("aten::index_select")
2932@symbolic_helper.parse_args("v", "i", "v")
2933def index_select(g: jit_utils.GraphContext, self, dim, index):
2934    # In case of a scalar index, index_select returns a tensor with the same rank as the input.
2935    # To match this behavior in ONNX, we make index a 1D tensor so that the following gather
2936    # also produces a tensor with the same rank as the input.
2937    return symbolic_helper._select_helper(g, self, dim, index)
2938
2939
2940@_onnx_symbolic("aten::index_put")
2941def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate):
2942    if symbolic_helper._is_packed_list(indices_list_value):
2943        indices_list = symbolic_helper._unpack_list(indices_list_value)
2944    else:
2945        indices_list = [indices_list_value]
2946
2947    accumulate = symbolic_helper._parse_arg(accumulate, "b")
2948
2949    if len(indices_list) == 0:
2950        if accumulate:
2951            return add(g, self, values)
2952        return values
2953    symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self)
2954
2955
2956@_onnx_symbolic("aten::index_fill")
2957def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
2958    dim_value = symbolic_helper._parse_arg(dim, "i")
2959    expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
2960        g, self, dim, index
2961    )
2962    value = symbolic_helper._maybe_get_scalar(value)
2963    value = symbolic_helper._if_scalar_type_as(value, self)
2964    expanded_value = expand(g, value, expanded_index_shape, None)
2965
2966    return scatter(g, self, dim, expanded_index, expanded_value)
2967
2968
2969@_onnx_symbolic("aten::index_copy")
2970def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
2971    dim_value = symbolic_helper._parse_arg(dim, "i")
2972    expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
2973        g, self, dim, index
2974    )
2975    return scatter(g, self, dim, expanded_index, source)
2976
2977
2978@_onnx_symbolic("aten::bucketize")
2979@symbolic_helper.parse_args("v", "v", "b", "b")
2980def bucketize(
2981    g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False
2982):
2983    out_type = _C_onnx.TensorProtoDataType.INT64
2984    if out_int32:
2985        out_type = _C_onnx.TensorProtoDataType.INT32
2986    # A tensor expanded_boundaries is created such that it
2987    # contains a copy of boundaries for each element of self.
2988    new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0)
2989    # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops
2990    # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
2991    tensor_rank = symbolic_helper._get_tensor_rank(self)
2992    assert tensor_rank is not None
2993    unsqueeze_axes = list(range(1, tensor_rank + 1))
2994    expanded_boundaries = expand(
2995        g,
2996        symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes),
2997        new_shape,
2998        None,
2999    )
3000    # Compare each element of self to boundaries to get a tensor
3001    # with leading 1s and trailing 0s.
3002    # e.g., 4 > [1, 3, 4] = [1, 1, 0]
3003    # The index of the last 1 is the bucket where the element should go.
3004    if right:
3005        cond = ge(g, self, expanded_boundaries)
3006    else:
3007        cond = gt(g, self, expanded_boundaries)
3008    cond_out = g.op("Cast", cond, to_i=out_type)
3009    # Sum to get the number of 1s corresponding to each element,
3010    # which is the same as the bucket index.
3011    # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2
3012    return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0)
3013
3014
3015@_onnx_symbolic("aten::type_as")
3016def type_as(g: jit_utils.GraphContext, self, other):
3017    self_dtype = symbolic_helper._try_get_scalar_type(self)
3018    other_dtype = symbolic_helper._try_get_scalar_type(other)
3019    if self_dtype == other_dtype and self_dtype is not None:
3020        return self
3021    if other_dtype is not None:
3022        return g.op(
3023            "Cast",
3024            self,
3025            to_i=other_dtype.onnx_type(),
3026        )
3027
3028    raise errors.SymbolicValueError(
3029        "Unsupported: ONNX export of type_as for tensor "
3030        "of unknown dtype. Please check if the dtype of the "
3031        "parameter passed to the type_as function is correct.",
3032        other,
3033    )
3034
3035
3036@_onnx_symbolic("aten::cosine_similarity")
3037@symbolic_helper.parse_args("v", "v", "i", "f")
3038def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps):
3039    cross = symbolic_helper._reducesum_helper(
3040        g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
3041    )
3042    x1_l2 = symbolic_helper._reducesum_helper(
3043        g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0
3044    )
3045    x2_l2 = symbolic_helper._reducesum_helper(
3046        g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0
3047    )
3048    div_tens = max(
3049        g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps]))
3050    )
3051    return div(g, cross, div_tens)
3052
3053
3054@_onnx_symbolic("aten::pairwise_distance")
3055def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim):
3056    if not symbolic_helper._is_value(eps):
3057        eps = g.op("Constant", value_t=torch.tensor([eps]))
3058    inv_p = div(
3059        g,
3060        g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)),
3061        add(g, p, eps),
3062    )
3063    summation = symbolic_helper._reducesum_helper(
3064        g,
3065        pow(g, sub(g, input1, input2), p),
3066        axes_i=[-1],
3067        keepdims_i=symbolic_helper._parse_arg(keepdim, "i"),
3068    )
3069    return pow(g, summation, inv_p)
3070
3071
3072@_onnx_symbolic("aten::clone")
3073# ignore clone operators that are inserted by PyTorch autograd
3074def clone(g: jit_utils.GraphContext, input, unused_memory_format):
3075    return input
3076
3077
3078@_onnx_symbolic("aten::abs")
3079def abs(g: jit_utils.GraphContext, self):
3080    return g.op("Abs", self)
3081
3082
3083@_onnx_symbolic("aten::log")
3084def log(g: jit_utils.GraphContext, self):
3085    return g.op("Log", self)
3086
3087
3088@_onnx_symbolic("aten::log1p")
3089def log1p(g: jit_utils.GraphContext, self):
3090    return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self))
3091
3092
3093@_onnx_symbolic("aten::log10")
3094def log10(g: jit_utils.GraphContext, self):
3095    _ln10 = 2.30258509299404568401
3096    return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10])))
3097
3098
3099@_onnx_symbolic("aten::pow")
3100def pow(g: jit_utils.GraphContext, self, exponent):
3101    f_dtype = _type_utils.JitScalarType.from_value(self)
3102    if not symbolic_helper._is_fp(self):
3103        f_dtype = _type_utils.JitScalarType.FLOAT
3104        self = g.op("Cast", self, to_i=f_dtype.onnx_type())
3105    if not symbolic_helper._is_fp(exponent):
3106        exponent = g.op(
3107            "Cast",
3108            exponent,
3109            to_i=f_dtype.onnx_type(),
3110        )
3111    pow = g.op("Pow", self, exponent)
3112    return pow
3113
3114
3115@_onnx_symbolic("aten::clamp")
3116def clamp(g: jit_utils.GraphContext, self, min, max):
3117    # min or max may be None that we need to dispatch to
3118    # Clip separately, as ONNX does not have None syntax
3119    if symbolic_helper._is_none(min):
3120        return clamp_max(g, self, max)
3121    elif symbolic_helper._is_none(max):
3122        return clamp_min(g, self, min)
3123    else:
3124        if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
3125            return symbolic_helper._op_with_optional_float_cast(
3126                g,
3127                "Clip",
3128                self,
3129                min_f=symbolic_helper._parse_arg(min, "f"),
3130                max_f=symbolic_helper._parse_arg(max, "f"),
3131                opset_before=12,
3132            )
3133        else:
3134            return clamp_max(g, clamp_min(g, self, min), max)
3135
3136
3137@_onnx_symbolic("aten::clamp_min")
3138@symbolic_helper.parse_args("v", "v")
3139def clamp_min(g: jit_utils.GraphContext, self, min):
3140    if symbolic_helper._is_constant(min):
3141        return symbolic_helper._op_with_optional_float_cast(
3142            g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
3143        )
3144    else:
3145        dtype = _type_utils.JitScalarType.from_value(self)
3146        min = g.op("Cast", min, to_i=dtype.onnx_type())
3147        return symbolic_helper._op_with_optional_float_cast(
3148            g, "Max", self, min, opset_before=12
3149        )
3150
3151
3152@_onnx_symbolic("aten::clamp_max")
3153@symbolic_helper.parse_args("v", "v")
3154def clamp_max(g: jit_utils.GraphContext, self, max):
3155    if symbolic_helper._is_constant(max):
3156        return symbolic_helper._op_with_optional_float_cast(
3157            g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
3158        )
3159    else:
3160        dtype = _type_utils.JitScalarType.from_value(self)
3161        max = g.op("Cast", max, to_i=dtype.onnx_type())
3162        return symbolic_helper._op_with_optional_float_cast(
3163            g, "Min", self, max, opset_before=12
3164        )
3165
3166
3167@_onnx_symbolic("aten::max")
3168# torch.max (same for torch.min) actually has two interfaces smashed together:
3169# torch.max(x, dim, keepdim) and torch.max(x, y)
3170# TODO(justinchuby): Support multiple quantized args in output
3171def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
3172    return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
3173
3174
3175@_onnx_symbolic("aten::maximum")
3176@symbolic_helper.quantized_args(True, True)
3177def maximum(g: jit_utils.GraphContext, input, other):
3178    return max(g, input, dim_or_y=other)
3179
3180
3181@_onnx_symbolic("aten::min")
3182# TODO(justinchuby): Support multiple quantized args in output
3183def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
3184    return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
3185
3186
3187@_onnx_symbolic("aten::minimum")
3188@symbolic_helper.quantized_args(True, True)
3189def minimum(g: jit_utils.GraphContext, input, other):
3190    return min(g, input, dim_or_y=other)
3191
3192
3193@_onnx_symbolic("aten::amax")
3194@symbolic_helper.quantized_args(True)
3195@symbolic_helper.parse_args("v", "is", "i")
3196def amax(g: jit_utils.GraphContext, self, dim, keepdim):
3197    return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim)
3198
3199
3200@_onnx_symbolic("aten::amin")
3201@symbolic_helper.quantized_args(True)
3202@symbolic_helper.parse_args("v", "is", "i")
3203def amin(g: jit_utils.GraphContext, self, dim, keepdim):
3204    return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim)
3205
3206
3207@_onnx_symbolic("aten::aminmax")
3208@symbolic_helper.quantized_args(True)
3209@symbolic_helper.parse_args("v", "v", "i")
3210def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
3211    reduce_kwargs = {"keepdims_i": keepdim}
3212    if not symbolic_helper._is_none(dim):
3213        dim = symbolic_helper._get_const(dim, "i", "dim")
3214        reduce_kwargs["axes_i"] = [dim]
3215
3216    return g.op("ReduceMin", self, **reduce_kwargs), g.op(
3217        "ReduceMax", self, **reduce_kwargs
3218    )
3219
3220
3221@_onnx_symbolic("aten::exp")
3222def exp(g: jit_utils.GraphContext, self):
3223    return g.op("Exp", self)
3224
3225
3226@_onnx_symbolic("aten::dropout_")
3227@_onnx_symbolic("aten::dropout")
3228@symbolic_helper.parse_args("v", "f", "i")
3229def dropout(g: jit_utils.GraphContext, input, p, train):
3230    symbolic_helper.check_training_mode(train, "dropout")
3231    # if train is False, dropout is no-op
3232    if not train:
3233        return input
3234    r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
3235    return r
3236
3237
3238@_onnx_symbolic(
3239    "aten::alpha_dropout_",
3240    decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")],
3241)  # See Note [Export inplace]
3242@_onnx_symbolic(
3243    "aten::feature_alpha_dropout_",
3244    decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")],
3245)
3246@_onnx_symbolic(
3247    "aten::feature_dropout_",
3248    decorate=[symbolic_helper._apply_params("aten::feature_dropout_")],
3249)
3250@_onnx_symbolic(
3251    "aten::feature_alpha_dropout",
3252    decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")],
3253)
3254@_onnx_symbolic(
3255    "aten::alpha_dropout",
3256    decorate=[symbolic_helper._apply_params("aten::alpha_dropout")],
3257)
3258@_onnx_symbolic(
3259    "aten::feature_dropout",
3260    decorate=[symbolic_helper._apply_params("aten::feature_dropout")],
3261)
3262def _unsupported_dropout(name: str):
3263    @symbolic_helper.parse_args("v", "none", "b")
3264    def feature_dropout(g, input, p, train):
3265        # NB: In inference mode, FeatureDropout is exported as an identity op.
3266        if train:
3267            return symbolic_helper._unimplemented(name, "training mode", input)
3268        return input
3269
3270    return feature_dropout
3271
3272
3273@_onnx_symbolic("aten::norm")
3274@symbolic_helper.parse_args("v", "t", "is", "i", "v")
3275def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
3276    if p == 1:
3277        f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1")
3278    elif p == 2:
3279        f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2")
3280    else:
3281        raise errors.SymbolicValueError(
3282            "ONNX export only p-norms with p of 1 or 2", self
3283        )
3284    result = f(g, self, dim=dim, keepdim=keepdim)
3285    if dtype is not None:
3286        dtype = symbolic_helper._get_const(dtype, "i", "dtype")
3287        result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3288    return result
3289
3290
3291@_onnx_symbolic("aten::conv_tbc")
3292@symbolic_helper.parse_args("v", "v", "v", "i")
3293def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
3294    # input must have 3 dimensions, see:
3295    # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
3296    # input = (time, batch, in_channels)
3297    # weight = (kernel_width, in_channels, out_channels)
3298    # bias = (out_channels,)
3299    input = g.op("Transpose", input, perm_i=[1, 2, 0])
3300    weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
3301    conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
3302    return g.op("Transpose", conv, perm_i=[2, 0, 1])
3303
3304
3305@_onnx_symbolic("aten::_unique")
3306@symbolic_helper.parse_args("v", "i", "i")
3307def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
3308    return symbolic_helper._onnx_unsupported("_unique", input)
3309
3310
3311@_onnx_symbolic("aten::_unique2")
3312@symbolic_helper.parse_args("v", "i", "i", "i")
3313def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
3314    symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
3315
3316
3317@_onnx_symbolic("aten::_cast_Byte")
3318@_deprecation.deprecated(
3319    "2.0",
3320    "the future",
3321    "Avoid using this function and create a Cast node instead",
3322)
3323def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
3324    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
3325
3326
3327@_onnx_symbolic("aten::_cast_Char")
3328@_deprecation.deprecated(
3329    "2.0",
3330    "the future",
3331    "Avoid using this function and create a Cast node instead",
3332)
3333def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
3334    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
3335
3336
3337@_onnx_symbolic("aten::_cast_Short")
3338@_deprecation.deprecated(
3339    "2.0",
3340    "the future",
3341    "Avoid using this function and create a Cast node instead",
3342)
3343def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
3344    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
3345
3346
3347@_onnx_symbolic("aten::_cast_Int")
3348@_deprecation.deprecated(
3349    "2.0",
3350    "the future",
3351    "Avoid using this function and create a Cast node instead",
3352)
3353def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
3354    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
3355
3356
3357@_onnx_symbolic("aten::_cast_Long")
3358@_deprecation.deprecated(
3359    "2.0",
3360    "the future",
3361    "Avoid using this function and create a Cast node instead",
3362)
3363def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
3364    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
3365
3366
3367@_onnx_symbolic("aten::_cast_Half")
3368@_deprecation.deprecated(
3369    "2.0",
3370    "the future",
3371    "Avoid using this function and create a Cast node instead",
3372)
3373def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
3374    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
3375
3376
3377@_onnx_symbolic("aten::_cast_Float")
3378@_deprecation.deprecated(
3379    "2.0",
3380    "the future",
3381    "Avoid using this function and create a Cast node instead",
3382)
3383def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
3384    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
3385
3386
3387@_onnx_symbolic("aten::_cast_Double")
3388@_deprecation.deprecated(
3389    "2.0",
3390    "the future",
3391    "Avoid using this function and create a Cast node instead",
3392)
3393def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
3394    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
3395
3396
3397@_onnx_symbolic("aten::_cast_Bool")
3398@_deprecation.deprecated(
3399    "2.0",
3400    "the future",
3401    "Avoid using this function and create a Cast node instead",
3402)
3403def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
3404    return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
3405
3406
3407@_onnx_symbolic("aten::empty")
3408@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
3409def empty(
3410    g: jit_utils.GraphContext,
3411    sizes,
3412    dtype,
3413    layout,
3414    device,
3415    pin_memory=False,
3416    memory_format=None,
3417):
3418    return zeros(g, sizes, dtype, layout, device, pin_memory)
3419
3420
3421@_onnx_symbolic("aten::empty_like")
3422@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
3423def empty_like(
3424    g: jit_utils.GraphContext,
3425    input,
3426    dtype=None,
3427    layout=None,
3428    device=None,
3429    pin_memory=False,
3430    memory_format=None,
3431):
3432    return zeros_like(g, input, dtype, layout, device, pin_memory)
3433
3434
3435@_onnx_symbolic("aten::new_empty")
3436def new_empty(
3437    g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
3438):
3439    self_dtype = symbolic_helper._try_get_scalar_type(self)
3440    if symbolic_helper._is_none(dtype) and self_dtype is not None:
3441        dtype = self_dtype
3442    return empty(g, sizes, dtype, layout, device, pin_memory)
3443
3444
3445@_onnx_symbolic("aten::scalar_tensor")
3446def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options):
3447    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
3448    if dtype is None:
3449        dtype = _type_utils.JitScalarType.FLOAT
3450    scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3451    return scalar
3452
3453
3454@_onnx_symbolic("aten::tensor")
3455def tensor(
3456    g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False
3457):
3458    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
3459    if symbolic_helper._is_packed_list(data):
3460        if dtype is None:
3461            dtype = _type_utils.JitScalarType.from_value(
3462                symbolic_helper._unpack_list(data)[0]
3463            )
3464        input_list = []
3465        for t in symbolic_helper._unpack_list(data):
3466            shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
3467            t = symbolic_helper._reshape_helper(g, t, shape_reference)
3468            t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3469            input_list.append(t)
3470        return g.op("Concat", *input_list, axis_i=0)
3471    else:
3472        if dtype is None:
3473            dtype = _type_utils.JitScalarType.from_value(data)
3474        if symbolic_helper._is_list(data) and (
3475            symbolic_helper._is_tensor_list(data)
3476            or symbolic_helper._is_scalar_list(data)
3477        ):
3478            data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1)
3479    return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3480
3481
3482@_onnx_symbolic("aten::as_tensor")
3483def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None):
3484    return tensor(g, data, dtype, device)
3485
3486
3487@_onnx_symbolic("aten::zeros")
3488@symbolic_helper.parse_args("v", "i", "v", "v", "v")
3489def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
3490    # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
3491    if dtype is None:
3492        scalar_type = _type_utils.JitScalarType.FLOAT
3493    else:
3494        scalar_type = _type_utils.JitScalarType(dtype)
3495    sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
3496    if isinstance(sizes_, list) and len(sizes_) == 0:
3497        sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
3498    return g.op(
3499        "ConstantOfShape",
3500        sizes,
3501        value_t=torch.tensor([0], dtype=scalar_type.dtype()),
3502    )
3503
3504
3505@_onnx_symbolic("aten::zeros_like")
3506@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
3507def zeros_like(
3508    g: jit_utils.GraphContext,
3509    input,
3510    dtype=None,
3511    layout=None,
3512    device=None,
3513    pin_memory=False,
3514    memory_format=None,
3515):
3516    shape = g.op("Shape", input)
3517    if symbolic_helper._is_none(dtype):
3518        scalar_type = _type_utils.JitScalarType.from_value(
3519            input, _type_utils.JitScalarType.FLOAT
3520        )
3521    else:
3522        scalar_type = _type_utils.JitScalarType(dtype)
3523    return g.op(
3524        "ConstantOfShape",
3525        shape,
3526        value_t=torch.tensor([0], dtype=scalar_type.dtype()),
3527    )
3528
3529
3530@_onnx_symbolic("aten::new_zeros")
3531def new_zeros(
3532    g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
3533):
3534    self_dtype = symbolic_helper._try_get_scalar_type(self)
3535
3536    if symbolic_helper._is_none(dtype) and self_dtype is not None:
3537        dtype = self_dtype
3538    return zeros(g, sizes, dtype, layout, device, pin_memory)
3539
3540
3541@_onnx_symbolic("aten::zero")
3542def zero(g: jit_utils.GraphContext, self):
3543    self_dtype = symbolic_helper._try_get_scalar_type(self)
3544    return zeros_like(g, self, self_dtype)
3545
3546
3547@_onnx_symbolic("aten::ones")
3548@symbolic_helper.parse_args("v", "i", "v", "v", "v")
3549def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
3550    if dtype is None:
3551        scalar_type = _type_utils.JitScalarType.FLOAT
3552    else:
3553        scalar_type = _type_utils.JitScalarType(dtype)
3554    sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
3555    if isinstance(sizes_, list) and len(sizes_) == 0:
3556        sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
3557    return g.op(
3558        "ConstantOfShape",
3559        sizes,
3560        value_t=torch.tensor([1], dtype=scalar_type.dtype()),
3561    )
3562
3563
3564@_onnx_symbolic("aten::ones_like")
3565@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
3566def ones_like(
3567    g: jit_utils.GraphContext,
3568    input,
3569    dtype=None,
3570    layout=None,
3571    device=None,
3572    pin_memory=False,
3573    memory_format=None,
3574):
3575    shape = g.op("Shape", input)
3576    if symbolic_helper._is_none(dtype):
3577        scalar_type = _type_utils.JitScalarType.from_value(
3578            input, _type_utils.JitScalarType.FLOAT
3579        )
3580    else:
3581        scalar_type = _type_utils.JitScalarType(dtype)
3582    return g.op(
3583        "ConstantOfShape",
3584        shape,
3585        value_t=torch.tensor([1], dtype=scalar_type.dtype()),
3586    )
3587
3588
3589@_onnx_symbolic("aten::new_ones")
3590def new_ones(
3591    g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
3592):
3593    self_dtype = symbolic_helper._try_get_scalar_type(self)
3594    if symbolic_helper._is_none(dtype) and self_dtype is not None:
3595        dtype = self_dtype
3596    return ones(g, sizes, dtype, layout, device, pin_memory)
3597
3598
3599@_onnx_symbolic("aten::full")
3600def full(
3601    g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
3602):
3603    const_value = symbolic_helper._maybe_get_const(value, "t")
3604    if symbolic_helper._is_value(const_value):
3605        dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype
3606        tmp = zeros(g, sizes, dtype, layout, device)
3607        return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
3608    else:
3609        dtype = symbolic_helper._get_const(dtype, "i", "dtype")
3610        if dtype is None:
3611            scalar_type = _type_utils.JitScalarType.FLOAT
3612        else:
3613            scalar_type = _type_utils.JitScalarType(dtype)
3614        sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
3615        if isinstance(sizes_, list) and len(sizes_) == 0:
3616            sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
3617        return g.op(
3618            "ConstantOfShape",
3619            sizes,
3620            value_t=const_value.view(1).to(scalar_type.dtype()),
3621        )
3622
3623
3624@_onnx_symbolic("aten::full_like")
3625def full_like(
3626    g: jit_utils.GraphContext,
3627    input,
3628    fill_value,
3629    dtype=None,
3630    layout=None,
3631    device=None,
3632    pin_memory=False,
3633    memory_format=None,
3634):
3635    fill_value = symbolic_helper._maybe_get_const(fill_value, "f")
3636    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
3637    if dtype is None:
3638        scalar_type = _type_utils.JitScalarType.from_value(
3639            input, _type_utils.JitScalarType.FLOAT
3640        )
3641    else:
3642        scalar_type = _type_utils.JitScalarType(dtype)
3643    if symbolic_helper._is_value(fill_value):
3644        tmp = zeros_like(g, input, dtype, layout, device)
3645        fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type())
3646        return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1)))
3647    else:
3648        shape = g.op("Shape", input)
3649        return g.op(
3650            "ConstantOfShape",
3651            shape,
3652            value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()),
3653        )
3654
3655
3656@_onnx_symbolic("aten::new_full")
3657def new_full(
3658    g: jit_utils.GraphContext,
3659    self,
3660    size,
3661    fill_value,
3662    dtype,
3663    layout,
3664    device,
3665    pin_memory=False,
3666):
3667    self_dtype = symbolic_helper._try_get_scalar_type(self)
3668    if symbolic_helper._is_none(dtype) and self_dtype is not None:
3669        dtype = self_dtype
3670    return full(g, size, fill_value, dtype, layout, device, pin_memory)
3671
3672
3673@_onnx_symbolic("aten::eye")
3674def eye(g: jit_utils.GraphContext, *args):
3675    if len(args) == 5:
3676        # aten::eye(n, dtype, layout, device, pin_memory)
3677        n, dtype, layout, device, pin_memory = args
3678        dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
3679        shape = g.op("Concat", dim_size, dim_size, axis_i=0)
3680        tensor = zeros(g, shape, dtype, layout, device)
3681        return g.op("EyeLike", tensor)
3682    if len(args) == 6:
3683        # aten::eye(n, m, dtype, layout, device, pin_memory)
3684        n, m, dtype, layout, device, pin_memory = args
3685        shape = g.op(
3686            "Concat",
3687            symbolic_helper._unsqueeze_helper(g, n, [0]),
3688            symbolic_helper._unsqueeze_helper(g, m, [0]),
3689            axis_i=0,
3690        )
3691        tensor = zeros(g, shape, dtype, layout, device)
3692        return g.op("EyeLike", tensor)
3693
3694    return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments")
3695
3696
3697@_onnx_symbolic("aten::slice")
3698def slice(g: jit_utils.GraphContext, self, *args):
3699    if len(args) == 4:
3700        # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
3701        dim, start, end, step = args
3702        step = symbolic_helper._parse_arg(step, "i")
3703        if step != 1:
3704            raise errors.SymbolicValueError("step!=1 is currently not supported", self)
3705        is_start_none = start.node().kind() == "prim::Constant" and isinstance(
3706            start.type(), _C.NoneType
3707        )
3708        is_end_none = end.node().kind() == "prim::Constant" and isinstance(
3709            end.type(), _C.NoneType
3710        )
3711        is_start_onnx_const = start.node().kind() == "onnx::Constant"
3712        is_end_onnx_const = end.node().kind() == "onnx::Constant"
3713        if (
3714            ((not is_start_none) and (not is_start_onnx_const))
3715            or ((not is_end_none) and (not is_end_onnx_const))
3716            or dim.node().kind() != "onnx::Constant"
3717        ):
3718            if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
3719                raise errors.SymbolicValueError(
3720                    "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
3721                    "is a deprecated experimental op. Please use statically allocated "
3722                    "variables or export to a higher opset version.",
3723                    self,
3724                )
3725            else:
3726                start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0])
3727                end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0])
3728                dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0])
3729                return g.op(
3730                    "DynamicSlice",
3731                    self,
3732                    start_unsqueezed,
3733                    end_unsqueezed,
3734                    dim_unsqueezed,
3735                )
3736        else:
3737            start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
3738            end = (
3739                _constants.INT64_MAX
3740                if is_end_none
3741                else symbolic_helper._parse_arg(end, "i")
3742            )
3743            dim = symbolic_helper._parse_arg(dim, "i")
3744            return symbolic_helper._slice_helper(
3745                g, self, axes=[dim], starts=[start], ends=[end]
3746            )
3747    elif len(args) == 3:
3748        # aten::slice(t[] l, int start, int end, int step) -> t[]
3749        start, end, step = args
3750        dim = 0
3751        is_start_none = start.node().kind() == "prim::Constant" and isinstance(
3752            start.type(), _C.NoneType
3753        )
3754        is_end_none = end.node().kind() == "prim::Constant" and isinstance(
3755            end.type(), _C.NoneType
3756        )
3757        start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
3758        end = (
3759            _constants.INT64_MAX
3760            if is_end_none
3761            else symbolic_helper._parse_arg(end, "i")
3762        )
3763        return symbolic_helper._slice_helper(
3764            g, self, axes=[dim], starts=[start], ends=[end]
3765        )
3766
3767    return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments")
3768
3769
3770@_onnx_symbolic("aten::hardtanh")
3771@symbolic_helper.quantized_args(True)
3772@symbolic_helper.parse_args("v", "f", "f")
3773def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
3774    return symbolic_helper._op_with_optional_float_cast(
3775        g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
3776    )
3777
3778
3779@_onnx_symbolic("aten::hardswish")
3780@symbolic_helper.quantized_args(True)
3781@symbolic_helper.parse_args("v")
3782def hardswish(g: jit_utils.GraphContext, self):
3783    hs = hardsigmoid(g, self)
3784    return g.op("Mul", self, hs)
3785
3786
3787@_onnx_symbolic("aten::hardsigmoid")
3788# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
3789@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
3790@symbolic_helper.parse_args("v")
3791def hardsigmoid(g: jit_utils.GraphContext, self):
3792    # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
3793    # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
3794    return g.op("HardSigmoid", self, alpha_f=1 / 6)
3795
3796
3797@_onnx_symbolic("aten::tanhshrink")
3798@symbolic_helper.parse_args("v")
3799def tanhshrink(g: jit_utils.GraphContext, self):
3800    return g.op("Sub", self, tanh(g, self))
3801
3802
3803@_onnx_symbolic("aten::hardshrink")
3804@symbolic_helper.parse_args("v", "f")
3805def hardshrink(g: jit_utils.GraphContext, self, lambd):
3806    scalar_type = _type_utils.JitScalarType.from_value(
3807        self, _type_utils.JitScalarType.FLOAT
3808    )
3809    lambd_op = g.op(
3810        "Constant",
3811        value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
3812    )
3813    cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op)))
3814    return g.op(
3815        "Where",
3816        cond,
3817        self,
3818        g.op(
3819            "Constant",
3820            value_t=torch.tensor(0, dtype=scalar_type.dtype()),
3821        ),
3822    )
3823
3824
3825@_onnx_symbolic("aten::softshrink")
3826@symbolic_helper.parse_args("v", "f")
3827def softshrink(g: jit_utils.GraphContext, self, lambd):
3828    scalar_type = _type_utils.JitScalarType.from_value(
3829        self, _type_utils.JitScalarType.FLOAT
3830    )
3831    lambd_op = g.op(
3832        "Constant",
3833        value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
3834    )
3835    gt_cond = gt(g, self, lambd_op)
3836    gt_out = g.op(
3837        "Where",
3838        gt_cond,
3839        sub(g, self, lambd_op),
3840        g.op(
3841            "Constant",
3842            value_t=torch.tensor(0, dtype=scalar_type.dtype()),
3843        ),
3844    )
3845    lt_cond = lt(g, self, neg(g, lambd_op))
3846    lt_out = g.op(
3847        "Where",
3848        lt_cond,
3849        add(g, self, lambd_op),
3850        g.op(
3851            "Constant",
3852            value_t=torch.tensor(0, dtype=scalar_type.dtype()),
3853        ),
3854    )
3855    return add(g, gt_out, lt_out)
3856
3857
3858@_onnx_symbolic("aten::alias")
3859def alias(g: jit_utils.GraphContext, self):
3860    return self
3861
3862
3863@_onnx_symbolic("aten::unsqueeze")
3864@symbolic_helper.parse_args("v", "i")
3865def unsqueeze(g: jit_utils.GraphContext, self, dim):
3866    """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`"""
3867    # Handle negative dim
3868    if dim < 0:
3869        rank = symbolic_helper._get_tensor_rank(self)
3870        if rank is not None:
3871            warnings.warn(
3872                "ONNX export unsqueeze with negative axis "
3873                + str(dim)
3874                + " might cause the onnx model to be incorrect. "
3875                + "Negative axis is not supported in ONNX. "
3876                + "Axis is converted to "
3877                + str(dim + rank + 1)
3878                + " based on input shape at export time. "
3879                + "Passing an tensor of different rank in execution will be incorrect."
3880            )
3881            dim = dim + rank + 1
3882        else:
3883            return symbolic_helper._unimplemented(
3884                "unsqueeze", "negative axis with unknown input rank", self
3885            )
3886
3887    return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim])
3888
3889
3890@_onnx_symbolic("aten::sort")
3891# TODO(justinchuby): Support multiple quantized args in output
3892@symbolic_helper.parse_args("v", "i", "i", "none")
3893def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
3894    if out is not None:
3895        symbolic_helper._unimplemented(
3896            "Sort", "Out parameter is not supported for sort", self
3897        )
3898    self_sizes = symbolic_helper._get_tensor_sizes(self)
3899    try:
3900        dim_size = self_sizes[dim]
3901    except Exception:
3902        # FIXME(justinchuby): Avoid catching Exception.
3903        # Catch a more specific exception instead.
3904        dim_size = None
3905
3906    if dim_size is None:
3907        return symbolic_helper._unimplemented("Sort", "input size not accessible", self)
3908
3909    return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2)
3910
3911
3912@_onnx_symbolic("aten::numel")
3913def numel(g: jit_utils.GraphContext, self):
3914    return symbolic_helper._numel_helper(g, self)
3915
3916
3917@_onnx_symbolic("aten::topk")
3918# TODO(justinchuby): Support multiple quantized args in output
3919@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none")
3920def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
3921    if out is not None:
3922        symbolic_helper._unimplemented(
3923            "TopK", "Out parameter is not supported for topk", self
3924        )
3925    if not largest:
3926        symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self)
3927
3928    return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
3929
3930
3931@_onnx_symbolic("prim::convert_element_type")
3932def convert_element_type(g: jit_utils.GraphContext, self, *args):
3933    dtype = symbolic_helper._get_const(args[0], "i", "dtype")
3934    return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3935
3936
3937@_onnx_symbolic("aten::to")
3938def to(g: jit_utils.GraphContext, self, *args):
3939    def is_aten_to_device_only(args):
3940        if len(args) == 4:
3941            # aten::to(Tensor, Device, bool, bool, memory_format)
3942            return (
3943                args[0].node().kind() == "prim::device"
3944                or args[0].type().isSubtypeOf(_C.ListType.ofInts())
3945                or isinstance(args[0].type(), _C.DeviceObjType)
3946            )
3947        elif len(args) == 5:
3948            # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
3949            # When dtype is None, this is a aten::to(device) call
3950            dtype = symbolic_helper._get_const(args[1], "i", "dtype")
3951            return dtype is None
3952        elif len(args) in (6, 7):
3953            # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
3954            # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
3955            # When dtype is None, this is a aten::to(device) call
3956            dtype = symbolic_helper._get_const(args[0], "i", "dtype")
3957            return dtype is None
3958        return False
3959
3960    # ONNX doesn't have a concept of a device, so we ignore device-only casts
3961    if is_aten_to_device_only(args):
3962        return self
3963
3964    if len(args) == 4:
3965        # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]()
3966        # In this case, the constant value is a tensor not int,
3967        # so symbolic_helper._maybe_get_const(args[0], 'i') would not work.
3968        dtype = args[0]
3969        if (
3970            symbolic_helper._is_value(args[0])
3971            and args[0].node().kind() == "onnx::Constant"
3972        ):
3973            tval = symbolic_helper._node_get(args[0].node(), "value")
3974            if isinstance(tval, torch.Tensor):
3975                if len(tval.shape) == 0:
3976                    tval = tval.item()
3977                    dtype = int(tval)
3978                else:
3979                    dtype = tval
3980
3981        if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor):
3982            # aten::to(Tensor, Tensor, bool, bool, memory_format)
3983            dtype = _type_utils.JitScalarType.from_value(args[0])
3984            return g.op(
3985                "Cast",
3986                self,
3987                to_i=dtype.onnx_type(),
3988            )
3989        else:
3990            # aten::to(Tensor, ScalarType, bool, bool, memory_format)
3991            # memory_format is ignored
3992            return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3993    elif len(args) == 5:
3994        # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
3995        dtype = symbolic_helper._get_const(args[1], "i", "dtype")
3996        # memory_format is ignored
3997        return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
3998    elif len(args) == 6:
3999        # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
4000        dtype = symbolic_helper._get_const(args[0], "i", "dtype")
4001        # Layout, device and memory_format are ignored
4002        return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
4003    elif len(args) == 7:
4004        # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
4005        dtype = symbolic_helper._get_const(args[0], "i", "dtype")
4006        # Layout, device and memory_format are ignored
4007        return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
4008
4009    return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self)
4010
4011
4012@_onnx_symbolic("aten::repeat")
4013def repeat(g: jit_utils.GraphContext, self, repeats):
4014    dtype = _type_utils.JitScalarType.INT64
4015    shape_ = ones_like(g, repeats, dtype)
4016    self = g.op("Expand", self, shape_)
4017    return g.op("Tile", self, repeats)
4018
4019
4020@_onnx_symbolic("aten::repeat_interleave")
4021def repeat_interleave(
4022    g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
4023):
4024    repeats_dim = symbolic_helper._get_tensor_rank(repeats)
4025    repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
4026    input_sizes = symbolic_helper._get_tensor_sizes(self)
4027    if repeats_dim is None:
4028        raise errors.SymbolicValueError(
4029            "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
4030            self,
4031        )
4032    if repeats_sizes is None:
4033        raise errors.SymbolicValueError(
4034            "Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
4035            self,
4036        )
4037    if input_sizes is None:
4038        raise errors.SymbolicValueError(
4039            "Unsupported: ONNX export of repeat_interleave for unknown input size.",
4040            self,
4041        )
4042
4043    # if dim is None flatten
4044    # By default, use the flattened input array, and return a flat output array
4045    if symbolic_helper._is_none(dim):
4046        self = symbolic_helper._reshape_helper(
4047            g, self, g.op("Constant", value_t=torch.tensor([-1]))
4048        )
4049        dim = torch.tensor(0, dtype=torch.int64)
4050    else:
4051        dim = symbolic_helper._maybe_get_scalar(dim)
4052
4053    # Handle cases where dim is negative
4054    if dim < 0:
4055        dim += len(input_sizes)
4056
4057    input_sizes_temp = input_sizes.copy()
4058    for idx, input_size in enumerate(input_sizes):
4059        if input_size is None:
4060            input_sizes[idx], input_sizes_temp[idx] = 0, -1
4061
4062    # Cases where repeats is an int or single value tensor
4063    if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
4064        if input_sizes[dim] == 0:
4065            return symbolic_helper._onnx_opset_unsupported_detailed(
4066                "repeat_interleave",
4067                9,
4068                13,
4069                "Unsupported along dimension with unknown input size",
4070                self,
4071            )
4072        return symbolic_helper._repeat_interleave_single_value_repeat_helper(
4073            g, self, repeats, dim
4074        )
4075
4076    # Cases where repeats is a 1 dim Tensor
4077    elif repeats_dim == 1:
4078        if input_sizes[dim] == 0:
4079            return symbolic_helper._onnx_opset_unsupported_detailed(
4080                "repeat_interleave",
4081                9,
4082                13,
4083                "Unsupported along dimension with unknown input size",
4084                self,
4085            )
4086        if repeats_sizes[0] is None:
4087            return symbolic_helper._onnx_opset_unsupported_detailed(
4088                "repeat_interleave",
4089                9,
4090                13,
4091                "Unsupported for cases with dynamic repeats",
4092                self,
4093            )
4094        assert (
4095            repeats_sizes[0] == input_sizes[dim]
4096        ), "repeats must have the same size as input along dim"
4097        reps = repeats_sizes[0]
4098    else:
4099        raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self)
4100
4101    final_splits = []
4102    r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0)
4103    i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim)
4104    input_sizes[dim], input_sizes_temp[dim] = -1, 1
4105    for idx, r_split in enumerate(r_splits):
4106        i_split = unsqueeze(g, i_splits[idx], dim + 1)
4107        r_concat = [
4108            g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
4109            r_split,
4110            g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
4111        ]
4112        r_concat = g.op("Concat", *r_concat, axis_i=0)
4113        i_split = expand(g, i_split, r_concat, None)
4114        i_split = symbolic_helper._reshape_helper(
4115            g,
4116            i_split,
4117            g.op("Constant", value_t=torch.LongTensor(input_sizes)),
4118            allowzero=0,
4119        )
4120        final_splits.append(i_split)
4121    return g.op("Concat", *final_splits, axis_i=dim)
4122
4123
4124@_onnx_symbolic("aten::pixel_shuffle")
4125@symbolic_helper.parse_args("v", "i")
4126def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
4127    dims = symbolic_helper._get_tensor_sizes(self)
4128    if len(dims) != 4:
4129        return symbolic_helper._unimplemented(
4130            "pixel_shuffle", "only support 4d input", self
4131        )
4132    if any(i is None for i in dims[1:]):
4133        after_view = symbolic_helper._reshape_helper(
4134            g,
4135            symbolic_helper._unsqueeze_helper(g, self, [2, 3]),
4136            g.op(
4137                "Constant",
4138                value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]),
4139            ),
4140            allowzero=0,
4141        )
4142        after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
4143        # For dynamic input shapes, two reshapes are performed
4144        reshape_h = symbolic_helper._reshape_helper(
4145            g,
4146            after_transpose,
4147            g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
4148            allowzero=0,
4149        )
4150        reshape_w = symbolic_helper._reshape_helper(
4151            g,
4152            reshape_h,
4153            g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
4154            allowzero=0,
4155        )
4156        return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5])
4157    else:
4158        output_channel = dims[1] // upscale_factor // upscale_factor
4159        after_view = symbolic_helper._reshape_helper(
4160            g,
4161            self,
4162            g.op(
4163                "Constant",
4164                value_t=torch.tensor(
4165                    [
4166                        -1,
4167                        output_channel,
4168                        upscale_factor,
4169                        upscale_factor,
4170                        dims[2],
4171                        dims[3],
4172                    ]
4173                ),
4174            ),
4175            allowzero=0,
4176        )
4177        after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
4178        return symbolic_helper._reshape_helper(
4179            g,
4180            after_transpose,
4181            g.op(
4182                "Constant",
4183                value_t=torch.tensor(
4184                    [
4185                        -1,
4186                        output_channel,
4187                        dims[2] * upscale_factor,
4188                        dims[3] * upscale_factor,
4189                    ]
4190                ),
4191            ),
4192            allowzero=0,
4193        )
4194
4195
4196@_onnx_symbolic("aten::pixel_unshuffle")
4197@symbolic_helper.parse_args("v", "i")
4198def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor):
4199    dims = symbolic_helper._get_tensor_sizes(self)
4200    if len(dims) != 4:
4201        return symbolic_helper._unimplemented(
4202            "pixel_shuffle", "only support 4d input", self
4203        )
4204    if any(i is None for i in dims[1:]):
4205        # For dynamic input shapes, two reshapes are performed
4206        reshape_h = symbolic_helper._reshape_helper(
4207            g,
4208            symbolic_helper._unsqueeze_helper(g, self, [3]),
4209            g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
4210            allowzero=0,
4211        )
4212        reshape_w = symbolic_helper._reshape_helper(
4213            g,
4214            reshape_h,
4215            g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
4216            allowzero=0,
4217        )
4218        after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
4219        final_reshape = symbolic_helper._reshape_helper(
4220            g,
4221            after_transpose,
4222            g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
4223            allowzero=0,
4224        )
4225        return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3])
4226    else:
4227        output_channel = dims[1] * downscale_factor * downscale_factor
4228        after_view = symbolic_helper._reshape_helper(
4229            g,
4230            self,
4231            g.op(
4232                "Constant",
4233                value_t=torch.tensor(
4234                    [
4235                        -1,
4236                        dims[1],
4237                        dims[2] // downscale_factor,
4238                        downscale_factor,
4239                        dims[3] // downscale_factor,
4240                        downscale_factor,
4241                    ]
4242                ),
4243            ),
4244            allowzero=0,
4245        )
4246        after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
4247        return symbolic_helper._reshape_helper(
4248            g,
4249            after_transpose,
4250            g.op(
4251                "Constant",
4252                value_t=torch.tensor(
4253                    [
4254                        -1,
4255                        output_channel,
4256                        dims[2] // downscale_factor,
4257                        dims[3] // downscale_factor,
4258                    ]
4259                ),
4260            ),
4261            allowzero=0,
4262        )
4263
4264
4265def _generic_rnn(
4266    g: jit_utils.GraphContext,
4267    variant,
4268    input,
4269    initial_states,
4270    all_weights,
4271    has_biases,
4272    num_layers,
4273    dropout,
4274    train,
4275    bidirectional,
4276    batch_first=None,
4277    batch_sizes=None,
4278):
4279    warnings.warn(
4280        "Exporting a model to ONNX with a batch_size other than 1, "
4281        + "with a variable length with "
4282        + variant
4283        + " can cause an error "
4284        + "when running the ONNX model with a different batch size. "
4285        + "Make sure to save the model with a batch size of 1, "
4286        + "or define the initial states (h0/c0) as inputs of the model. "
4287    )
4288
4289    onnxActivations = [
4290        "Relu",
4291        "Tanh",
4292        "Sigmoid",
4293        "Affine",
4294        "LeakyRelu",
4295        "ThresholdedRelu",
4296        "ScaledTanh",
4297        "HardSigmoid",
4298        "Elu",
4299        "Softsign",
4300        "Softplus",
4301    ]
4302    variantToOnnxActivationMap = dict(
4303        zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations)
4304    )
4305    weights_per_layer = 4 if has_biases else 2
4306    # this means that projections are used inside LSTM, so need to tell user that it's not supported
4307    if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * (
4308        1 + bidirectional
4309    ):
4310        return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input)
4311    assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
4312    layer_weights = [
4313        all_weights[i : i + weights_per_layer]
4314        for i in range(0, len(all_weights), weights_per_layer)
4315    ]
4316    if batch_first:
4317        # batch, seq, feat -> seq, batch, feat
4318        input = g.op("Transpose", input, perm_i=[1, 0, 2])
4319    if dropout and train:
4320        return symbolic_helper._unimplemented(
4321            "RNN/GRU/LSTM", "dropout in training mode", input
4322        )
4323
4324    if variant.startswith("RNN"):
4325        nonlinearity = variantToOnnxActivationMap[variant[4:].lower()]
4326        variant = "RNN"
4327
4328    w_hh = all_weights[1]
4329    hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1)
4330    if hidden_size is None:
4331        return symbolic_helper._unimplemented(
4332            "RNN/GRU/LSTM", "unknown hidden size", input
4333        )
4334
4335    unidirectional = not bidirectional
4336
4337    prev_output = input
4338
4339    h_outs = []
4340    if variant == "RNN" or variant == "GRU":
4341        h0 = initial_states
4342    elif variant == "LSTM":
4343        h0, c0 = initial_states
4344        c_outs = []
4345
4346    sequence_lens = unused(g) if batch_sizes is None else batch_sizes
4347
4348    if variant == "GRU":
4349        # pytorch is reset, input, hidden
4350        # onnx is    input, reset, hidden
4351        reform_permutation = [(1, 2), (0, 1), (2, 3)]
4352    elif variant == "LSTM":
4353        # pytorch is input, forget, cell, output.
4354        # onnx is    input, output, forget, cell.
4355        reform_permutation = [(0, 1), (3, 4), (1, 3)]
4356
4357    def reform_weights(g, w, n, intervals):
4358        slices = [
4359            symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n])
4360            for x, y in intervals
4361        ]
4362        return g.op("Concat", *slices, axis_i=0)
4363
4364    def transform_weights_no_bias(layer_index):
4365        weights = layer_weights[layer_index]
4366        if variant == "RNN":
4367            weight_ih, weight_hh = weights
4368        elif variant == "GRU" or variant == "LSTM":
4369            weight_ih, weight_hh = (
4370                reform_weights(g, w, hidden_size, reform_permutation) for w in weights
4371            )
4372        return tuple(
4373            symbolic_helper._unsqueeze_helper(g, x, [0])
4374            for x in (weight_ih, weight_hh)  # type: ignore[possibly-undefined]
4375        )
4376
4377    def transform_weights(layer_index):
4378        weights = layer_weights[layer_index]
4379        if variant == "RNN":
4380            weight_ih, weight_hh, bias_ih, bias_hh = weights
4381        elif variant == "GRU" or variant == "LSTM":
4382            weight_ih, weight_hh, bias_ih, bias_hh = (
4383                reform_weights(g, w, hidden_size, reform_permutation) for w in weights
4384            )
4385        bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)  # type: ignore[possibly-undefined]
4386        return tuple(
4387            symbolic_helper._unsqueeze_helper(g, x, [0])
4388            for x in (weight_ih, weight_hh, bias_concat)  # type: ignore[possibly-undefined]
4389        )
4390
4391    def retrieve_state(x, start, end):
4392        return (
4393            x
4394            if num_layers == 1
4395            else symbolic_helper._slice_helper(
4396                g, x, axes=[0], starts=[start], ends=[end]
4397            )
4398        )
4399
4400    for i in range(num_layers):
4401        if unidirectional:
4402            if weights_per_layer == 4:
4403                weight_ih, weight_hh, bias_concat = transform_weights(i)
4404            else:
4405                weight_ih, weight_hh = transform_weights_no_bias(i)
4406                bias_concat = unused(g)
4407
4408            state_indices = i, i + 1
4409        else:
4410            if weights_per_layer == 4:
4411                weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
4412                weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
4413                bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0)
4414            else:
4415                weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i)
4416                weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1)
4417                bias_concat = unused(g)
4418
4419            weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0)
4420            weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0)
4421
4422            state_indices = 2 * i, 2 * i + 2
4423
4424        inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
4425
4426        inputs.append(retrieve_state(h0, *state_indices))  # type: ignore[possibly-undefined]
4427        if variant == "LSTM":
4428            inputs.append(retrieve_state(c0, *state_indices))  # type: ignore[possibly-undefined]
4429
4430        extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
4431        if variant == "RNN":
4432            if bidirectional:
4433                activation = [nonlinearity, nonlinearity]  # type: ignore[possibly-undefined]
4434            else:
4435                activation = [nonlinearity]  # type: ignore[possibly-undefined]
4436
4437            prev_output, h_out = g.op(
4438                "RNN",
4439                *inputs,
4440                outputs=2,
4441                hidden_size_i=hidden_size,
4442                activations_s=activation,
4443                **extra_kwargs,
4444            )
4445        elif variant == "GRU":
4446            prev_output, h_out = g.op(
4447                "GRU",
4448                *inputs,
4449                outputs=2,
4450                hidden_size_i=hidden_size,
4451                linear_before_reset_i=1,
4452                **extra_kwargs,
4453            )
4454        elif variant == "LSTM":
4455            prev_output, h_out, c_out = g.op(
4456                "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs
4457            )
4458
4459        if bidirectional:
4460            # The ONNX RNN/GRU/LSTM produce an output of dimensions
4461            #   seq_len, num_directions, batch, hidden_size
4462            # We have to convert to match pytorch's expected
4463            #   seq_len, batch, num_directions * hidden_size
4464            # by first moving num_directions before hidden_size with
4465            # Transpose, and then combining it with hidden_size
4466            # with Reshape.
4467            prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3])
4468            prev_output = symbolic_helper._reshape_helper(
4469                g,
4470                prev_output,
4471                g.op("Constant", value_t=torch.LongTensor([0, 0, -1])),
4472                allowzero=0,
4473            )
4474        else:
4475            prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
4476
4477        h_outs.append(h_out)  # type: ignore[possibly-undefined]
4478        if variant == "LSTM":
4479            c_outs.append(c_out)  # type: ignore[possibly-undefined]
4480    if batch_first:
4481        # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
4482        prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
4483    h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)  # type: ignore[possibly-undefined]
4484    if variant == "RNN" or variant == "GRU":
4485        return prev_output, h_outs
4486    elif variant == "LSTM":
4487        c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)  # type: ignore[possibly-undefined]
4488        return prev_output, h_outs, c_outs
4489
4490
4491@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
4492def _lstm_full(
4493    g: jit_utils.GraphContext,
4494    input,
4495    hidden_v,
4496    weight_v,
4497    has_biases,
4498    num_layers,
4499    dropout,
4500    train,
4501    bidirectional,
4502    batch_first,
4503):
4504    hidden, weight = (
4505        symbolic_helper._unpack_list(hidden_v),
4506        symbolic_helper._unpack_list(weight_v),
4507    )
4508    return _generic_rnn(
4509        g,
4510        "LSTM",
4511        input,
4512        hidden,
4513        weight,
4514        has_biases,
4515        num_layers,
4516        dropout,
4517        train,
4518        bidirectional,
4519        batch_first,
4520    )
4521
4522
4523@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
4524def _lstm_packed(
4525    g: jit_utils.GraphContext,
4526    input,
4527    batch_sizes,
4528    hidden_v,
4529    weight_v,
4530    has_biases,
4531    num_layers,
4532    dropout,
4533    train,
4534    bidirectional,
4535):
4536    hidden, weight = (
4537        symbolic_helper._unpack_list(hidden_v),
4538        symbolic_helper._unpack_list(weight_v),
4539    )
4540    return _generic_rnn(
4541        g,
4542        "LSTM",
4543        input,
4544        hidden,
4545        weight,
4546        has_biases,
4547        num_layers,
4548        dropout,
4549        train,
4550        bidirectional,
4551        batch_sizes=batch_sizes,
4552    )
4553
4554
4555@_onnx_symbolic("aten::lstm")
4556def lstm(g: jit_utils.GraphContext, *args):
4557    if symbolic_helper._is_tensor_list(args[3]):
4558        return _lstm_packed(g, *args)
4559    else:
4560        return _lstm_full(g, *args)
4561
4562
4563@_onnx_symbolic("aten::lstm_cell")
4564def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh):
4565    input = symbolic_helper._unsqueeze_helper(g, self, [0])
4566    hidden = symbolic_helper._unpack_list(hidden)
4567    hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden]
4568    weight = (
4569        (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh)
4570    )
4571    has_biases = True if symbolic_helper._is_tensor(b_ih) else False
4572    _, h_outs, c_outs = _generic_rnn(
4573        g,
4574        "LSTM",
4575        input,
4576        hidden,
4577        weight,
4578        has_biases,
4579        num_layers=1,
4580        dropout=0,
4581        train=0,
4582        bidirectional=False,
4583        batch_first=False,
4584    )
4585    return symbolic_helper._squeeze_helper(
4586        g, h_outs, [0]
4587    ), symbolic_helper._squeeze_helper(g, c_outs, [0])
4588
4589
4590@_onnx_symbolic(
4591    "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")]
4592)
4593@_onnx_symbolic(
4594    "aten::rnn_tanh",
4595    decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")],
4596)
4597@_onnx_symbolic(
4598    "aten::rnn_relu",
4599    decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")],
4600)
4601def _one_hidden_rnn(kind: str):
4602    @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
4603    def _rnn_full(
4604        g,
4605        input,
4606        hidden,
4607        weight_v,
4608        has_biases,
4609        num_layers,
4610        dropout,
4611        train,
4612        bidirectional,
4613        batch_first,
4614    ):
4615        weight = symbolic_helper._unpack_list(weight_v)
4616        return _generic_rnn(
4617            g,
4618            kind,
4619            input,
4620            hidden,
4621            weight,
4622            has_biases,
4623            num_layers,
4624            dropout,
4625            train,
4626            bidirectional,
4627            batch_first,
4628        )
4629
4630    @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
4631    def _rnn_packed(
4632        g,
4633        input,
4634        batch_sizes,
4635        hidden,
4636        weight_v,
4637        has_biases,
4638        num_layers,
4639        dropout,
4640        train,
4641        bidirectional,
4642    ):
4643        weight = symbolic_helper._unpack_list(weight_v)
4644        return _generic_rnn(
4645            g,
4646            kind,
4647            input,
4648            hidden,
4649            weight,
4650            has_biases,
4651            num_layers,
4652            dropout,
4653            train,
4654            bidirectional,
4655            batch_sizes=batch_sizes,
4656        )
4657
4658    def symbolic(g, *args):
4659        if symbolic_helper._is_tensor_list(args[3]):
4660            return _rnn_packed(g, *args)
4661        else:
4662            return _rnn_full(g, *args)
4663
4664    return symbolic
4665
4666
4667@_onnx_symbolic("aten::_dim_arange")
4668@symbolic_helper.parse_args("v", "i")
4669def _dim_arange(g: jit_utils.GraphContext, like, dim):
4670    like_shape = g.op("Shape", like)
4671    stop = g.op(
4672        "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
4673    )
4674    # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
4675    return arange(g, stop, 4, None, None, None)
4676
4677
4678@_onnx_symbolic("aten::detach")
4679def detach(g: jit_utils.GraphContext, input):
4680    # Erase aten::detach nodes because ONNX is inference only
4681    return input
4682
4683
4684@_onnx_symbolic("aten::contiguous")
4685@symbolic_helper.parse_args("v", "i")
4686def contiguous(g: jit_utils.GraphContext, input, memory_format):
4687    if memory_format > 2:  # allower values are any, preserve and contiguous_format
4688        raise errors.SymbolicValueError(
4689            "onnx memory_format support is not implemented", input
4690        )
4691    return input
4692
4693
4694@_onnx_symbolic("aten::_pack_padded_sequence")
4695@symbolic_helper.parse_args("v", "v", "i")
4696def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first):
4697    # Currently there is no PackPadded operator in ONNX. We rely on an
4698    # optimization pass to remove this later. It is an error if all
4699    # PackPadded operators cannot be optimized out.
4700    if batch_first:
4701        input = g.op("Transpose", input, perm_i=[1, 0, 2])
4702    if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
4703        raise errors.SymbolicValueError(
4704            "'lengths' must be a Tensor for ONNX export", input
4705        )
4706    # We know it's a TensorType so this check is now safe.
4707    # It's really only necessary because those operators expand to something that
4708    # only works with int32 types in Caffe2...
4709    if (
4710        _type_utils.JitScalarType.from_value(
4711            lengths, _type_utils.JitScalarType.UNDEFINED
4712        )
4713        != _type_utils.JitScalarType.INT
4714    ):
4715        lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32)
4716    return g.op("prim::PackPadded", input, lengths, outputs=2)
4717
4718
4719@_onnx_symbolic("aten::_pad_packed_sequence")
4720@symbolic_helper.parse_args("v", "v", "i", "t", "v")
4721def _pad_packed_sequence(
4722    g: jit_utils.GraphContext,
4723    data,
4724    batch_sizes,
4725    batch_first,
4726    padding_value,
4727    total_length,
4728):
4729    # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
4730    # It is only useful/used when training using data_parallel model, so
4731    # It shouldn't be relevant for ONNX anyway
4732    data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
4733    if batch_first:
4734        data = g.op("Transpose", data, perm_i=[1, 0, 2])
4735    return data, lengths
4736
4737
4738@_onnx_symbolic("aten::randint")
4739def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options):
4740    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4741    low_i = symbolic_helper._get_const(low, "i", "low")
4742    high_i = symbolic_helper._get_const(high, "i", "high")
4743    if dtype is None:
4744        scalar_type = _type_utils.JitScalarType.INT64
4745    else:
4746        scalar_type = _type_utils.JitScalarType(dtype)
4747    if low_i is None:
4748        raise symbolic_helper._onnx_unsupported("randint", low)
4749    if high_i is None:
4750        raise symbolic_helper._onnx_unsupported("randint", high)
4751
4752    shape = symbolic_helper._maybe_get_const(shapes, "is")
4753    if symbolic_helper._is_value(shape):
4754        shape_const = g.op(
4755            "ConstantOfShape",
4756            shapes,
4757            value_t=torch.tensor([0], dtype=torch.float),
4758        )
4759        randn = g.op(
4760            "RandomUniformLike",
4761            shape_const,
4762            low_f=low_i,
4763            high_f=high_i,
4764        )
4765    else:
4766        randn = g.op(
4767            "RandomUniform",
4768            shape_i=shape,
4769            low_f=low_i,
4770            high_f=high_i,
4771        )
4772
4773    # cast to integer type
4774    int_dtype = _type_utils.JitScalarType.INT64
4775    randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
4776    if int_dtype != scalar_type:
4777        randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
4778    return randint
4779
4780
4781@_onnx_symbolic("aten::randint_like")
4782def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options):
4783    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4784    low_i = symbolic_helper._get_const(low, "i", "low")
4785    high_i = symbolic_helper._get_const(high, "i", "high")
4786    if dtype is None:
4787        scalar_type = _type_utils.JitScalarType.INT64
4788    else:
4789        scalar_type = _type_utils.JitScalarType(dtype)
4790    if low_i is None:
4791        raise symbolic_helper._onnx_unsupported("randint", low)
4792    if high_i is None:
4793        raise symbolic_helper._onnx_unsupported("randint", high)
4794
4795    randn = g.op(
4796        "RandomUniformLike",
4797        self,
4798        low_f=low_i,
4799        high_f=high_i,
4800    )
4801
4802    # cast to integer type
4803    int_dtype = _type_utils.JitScalarType.INT64
4804    randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
4805    if int_dtype != scalar_type:
4806        randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
4807    return randint
4808
4809
4810@_onnx_symbolic("aten::randn")
4811def randn(g: jit_utils.GraphContext, shapes, dtype, *options):
4812    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4813    if dtype is None:
4814        scalar_type = _type_utils.JitScalarType.FLOAT
4815    else:
4816        scalar_type = _type_utils.JitScalarType(dtype)
4817    shape = symbolic_helper._maybe_get_const(shapes, "is")
4818    if symbolic_helper._is_value(shape):
4819        shape_const = g.op(
4820            "ConstantOfShape",
4821            shapes,
4822            value_t=torch.tensor([0], dtype=torch.float),
4823        )
4824        return g.op(
4825            "RandomNormalLike",
4826            shape_const,
4827            dtype_i=scalar_type.onnx_type(),
4828        )
4829    return g.op(
4830        "RandomNormal",
4831        shape_i=shape,
4832        dtype_i=scalar_type.onnx_type(),
4833    )
4834
4835
4836@_onnx_symbolic("aten::rand")
4837def rand(g: jit_utils.GraphContext, shapes, dtype, *options):
4838    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4839    if dtype is None:
4840        scalar_type = _type_utils.JitScalarType.FLOAT
4841    else:
4842        scalar_type = _type_utils.JitScalarType(dtype)
4843    shape = symbolic_helper._maybe_get_const(shapes, "is")
4844    if symbolic_helper._is_value(shape):
4845        shape_const = g.op(
4846            "ConstantOfShape",
4847            shapes,
4848            value_t=torch.tensor([0], dtype=torch.float),
4849        )
4850        return g.op(
4851            "RandomUniformLike",
4852            shape_const,
4853            dtype_i=scalar_type.onnx_type(),
4854        )
4855    return g.op(
4856        "RandomUniform",
4857        shape_i=shape,
4858        dtype_i=scalar_type.onnx_type(),
4859    )
4860
4861
4862@_onnx_symbolic("aten::randn_like")
4863def randn_like(
4864    g: jit_utils.GraphContext,
4865    self,
4866    dtype,
4867    layout=None,
4868    device=None,
4869    pin_memory=False,
4870    memory_format=None,
4871):
4872    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4873    if dtype is None:
4874        scalar_type = _type_utils.JitScalarType.from_value(
4875            self, _type_utils.JitScalarType.FLOAT
4876        )
4877    else:
4878        scalar_type = _type_utils.JitScalarType(dtype)
4879    return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type())
4880
4881
4882@_onnx_symbolic("aten::rand_like")
4883def rand_like(
4884    g: jit_utils.GraphContext,
4885    self,
4886    dtype,
4887    layout=None,
4888    device=None,
4889    pin_memory=False,
4890    memory_format=None,
4891):
4892    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
4893    if dtype is None:
4894        dtype = _type_utils.JitScalarType.from_value(
4895            self, _type_utils.JitScalarType.FLOAT
4896        )
4897    return g.op(
4898        "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type()
4899    )
4900
4901
4902@_onnx_symbolic("aten::rrelu")
4903@symbolic_helper.parse_args("v", "f", "f", "i", "none")
4904def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator):
4905    if not training:
4906        slope = (upper + lower) / 2.0
4907        return g.op("LeakyRelu", input, alpha_f=slope)
4908    p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower)
4909    return g.op("PRelu", input, p)
4910
4911
4912@_onnx_symbolic("aten::bernoulli")
4913def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
4914    if out is not None and not symbolic_helper._is_none(out):
4915        symbolic_helper._unimplemented(
4916            "Bernoulli", "out parameter is not supported for bernoulli", input
4917        )
4918    if generator is not None and not symbolic_helper._is_none(generator):
4919        symbolic_helper._unimplemented(
4920            "Bernoulli", "generator is not supported for bernoulli", input
4921        )
4922
4923    dtype = _type_utils.JitScalarType.from_value(
4924        input, _type_utils.JitScalarType.UNDEFINED
4925    )
4926    if dtype == _type_utils.JitScalarType.UNDEFINED:
4927        return symbolic_helper._unimplemented(
4928            "Bernoulli", "input dtype not accessible", input
4929        )
4930
4931    rands = g.op(
4932        "RandomUniformLike",
4933        input,
4934        high_f=1.0,
4935        low_f=0.0,
4936        dtype_i=dtype.onnx_type(),
4937    )
4938    prob = p if p is not None and not symbolic_helper._is_none(p) else input
4939    output = g.op("Less", rands, prob)
4940    return g.op("Cast", output, to_i=dtype.onnx_type())
4941
4942
4943@_onnx_symbolic("aten::log_sigmoid")
4944@symbolic_helper.parse_args("v")
4945def log_sigmoid(g: jit_utils.GraphContext, input):
4946    p = g.op("Sigmoid", input)
4947    return g.op("Log", p)
4948
4949
4950@_onnx_symbolic("aten::erf")
4951@symbolic_helper.parse_args("v")
4952def erf(g: jit_utils.GraphContext, input):
4953    return g.op("Erf", input)
4954
4955
4956@_onnx_symbolic("aten::flatten")
4957@symbolic_helper.quantized_args(True, False, False)
4958@symbolic_helper.parse_args("v", "i", "i")
4959def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
4960    dim = symbolic_helper._get_tensor_rank(input)
4961    if dim is None:
4962        return symbolic_helper._unimplemented(
4963            "dim",
4964            "ONNX and PyTorch use different strategies to split the input. "
4965            "Input rank must be known at export time.",
4966            input,
4967        )
4968
4969    if dim == 0:
4970        return symbolic_helper._reshape_helper(g, input, [1])
4971    if dim == 1:
4972        return g.op("Identity", input)
4973    # TODO: remove this as onnx opset 11 spec allows negative axes
4974    if end_dim < 0:
4975        end_dim = dim + end_dim
4976    # use ONNX's Flatten operator for cases where the output shape is 2D
4977    if start_dim == 1 and end_dim == dim - 1:
4978        return g.op("Flatten", input, axis_i=start_dim)
4979    if start_dim == 0 and end_dim == dim - 2:
4980        return g.op("Flatten", input, axis_i=end_dim + 1)
4981
4982    return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
4983
4984
4985@_onnx_symbolic("aten::nonzero")
4986@symbolic_helper.parse_args("v")
4987def nonzero(g: jit_utils.GraphContext, input):
4988    """Emitted from `torch.nonzero(x, as_tuple=False)`"""
4989    return t(g, g.op("NonZero", input))
4990
4991
4992@_onnx_symbolic("aten::nonzero_numpy")
4993# Emitted from `torch.nonzero(x, as_tuple=True)`
4994def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
4995    return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
4996
4997
4998@_onnx_symbolic("aten::isnan")
4999@symbolic_helper.parse_args("v")
5000def isnan(g: jit_utils.GraphContext, input):
5001    output = g.op("IsNaN", input)
5002    return output
5003
5004
5005@_onnx_symbolic("aten::any")
5006def _any(g: jit_utils.GraphContext, *args):
5007    # aten::any(Tensor self)
5008    if len(args) == 1:
5009        input = args[0]
5010        dim, keepdim = None, 0
5011    # aten::any(Tensor self, int[]? dim, bool keepdim)
5012    else:
5013        input, dim, keepdim = args
5014        # Can be int list or single int
5015        dim = symbolic_helper._parse_arg(dim, "t")
5016        dim = [int(d) for d in dim.view(-1)]
5017        keepdim = symbolic_helper._parse_arg(keepdim, "i")
5018    input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
5019    input_sum = symbolic_helper._reducesum_helper(
5020        g, input, axes_i=dim, keepdims_i=keepdim
5021    )
5022    return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)))
5023
5024
5025@_onnx_symbolic("aten::all")
5026def _all(g: jit_utils.GraphContext, *args):
5027    input = g.op("Not", args[0])
5028    # aten::all(Tensor self)
5029    if len(args) == 1:
5030        return g.op("Not", _any(g, input))
5031    # aten::all(Tensor self, int[]? dim, bool keepdim)
5032    else:
5033        return g.op("Not", _any(g, input, args[1], args[2]))
5034
5035
5036@_onnx_symbolic("aten::narrow")
5037@symbolic_helper.parse_args("v", "i", "i", "i")
5038def narrow(g: jit_utils.GraphContext, input, dim, start, length):
5039    return symbolic_helper._slice_helper(
5040        g, input, axes=[dim], starts=[start], ends=[start + length]
5041    )
5042
5043
5044@_onnx_symbolic("aten::argmax")
5045@symbolic_helper.parse_args("v", "v", "b")
5046def argmax(
5047    g: jit_utils.GraphContext,
5048    input: torch._C.Value,
5049    dim: torch._C.Value,
5050    keepdim: bool,
5051):
5052    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
5053
5054
5055@_onnx_symbolic("aten::argmin")
5056@symbolic_helper.parse_args("v", "v", "b")
5057def argmin(
5058    g: jit_utils.GraphContext,
5059    input: torch._C.Value,
5060    dim: torch._C.Value,
5061    keepdim: bool,
5062):
5063    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
5064
5065
5066@_onnx_symbolic("aten::scatter")
5067@symbolic_helper.parse_args("v", "i", "v", "v")
5068def scatter(g: jit_utils.GraphContext, self, dim, index, src):
5069    src_type = _type_utils.JitScalarType.from_value(
5070        src, _type_utils.JitScalarType.UNDEFINED
5071    )
5072    src = symbolic_helper._maybe_get_scalar(src)
5073    if symbolic_helper._is_value(src):
5074        return g.op("Scatter", self, index, src, axis_i=dim)
5075    else:
5076        # Check if scalar "src" has same type as self (PyTorch allows different
5077        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
5078        self_scalar_type = _type_utils.JitScalarType.from_value(self)
5079        if self_scalar_type != src_type:
5080            src = g.op("Cast", src, to_i=self_scalar_type.onnx_type())
5081        return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim)
5082
5083
5084@_onnx_symbolic("aten::scatter_add")
5085@symbolic_helper.parse_args("v", "i", "v", "v")
5086def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
5087    scalar_type = symbolic_helper._try_get_scalar_type(self)
5088    if scalar_type is None:
5089        return symbolic_helper._unimplemented(
5090            "scatter_add", "input dtype not accessible", self
5091        )
5092    sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False)
5093    if sizes:
5094        to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype()))
5095    else:
5096        to_add = zeros_like(g, self, scalar_type)
5097    to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src)
5098    return add(g, self, to_add)
5099
5100
5101@_onnx_symbolic("aten::log2")
5102def log2(g: jit_utils.GraphContext, self):
5103    _ln2 = 0.693147180559945309
5104    return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2)))
5105
5106
5107@_onnx_symbolic("aten::is_floating_point")
5108def is_floating_point(g: jit_utils.GraphContext, self):
5109    if symbolic_helper._is_fp(self):
5110        return g.op("Constant", value_t=torch.BoolTensor([1]))
5111    return g.op("Constant", value_t=torch.BoolTensor([0]))
5112
5113
5114@_onnx_symbolic("aten::__is_")
5115def __is_(g: jit_utils.GraphContext, self, other):
5116    if symbolic_helper._is_none(other):
5117        if symbolic_helper._is_none(self):
5118            return g.op("Constant", value_t=torch.BoolTensor([1]))
5119        return g.op("Constant", value_t=torch.BoolTensor([0]))
5120    return eq(g, self, other)
5121
5122
5123@_onnx_symbolic("aten::__isnot_")
5124@wrap_logical_op_with_negation
5125def __isnot_(g: jit_utils.GraphContext, self, other):
5126    return __is_(g, self, other)
5127
5128
5129@_onnx_symbolic("aten::one_hot")
5130def one_hot(g: jit_utils.GraphContext, self, num_classes):
5131    values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
5132    # onnxruntime supports limited type combinations for OneHot.
5133    if _type_utils.JitScalarType.from_value(
5134        num_classes, _type_utils.JitScalarType.UNDEFINED
5135    ) in {
5136        _type_utils.JitScalarType.UINT8,
5137        _type_utils.JitScalarType.INT8,
5138        _type_utils.JitScalarType.INT,
5139        _type_utils.JitScalarType.INT16,
5140    }:
5141        num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64)
5142    return g.op("OneHot", self, num_classes, values, axis_i=-1)
5143
5144
5145@_onnx_symbolic("aten::gather")
5146@symbolic_helper.parse_args("v", "i", "v", "v")
5147def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
5148    if symbolic_helper._maybe_get_const(sparse_grad, "i"):
5149        return symbolic_helper._unimplemented("gather", "sparse_grad == True", self)
5150    # NOTE: This workaround is needed since GatherElement is only supported
5151    #       since opset 11, and Gather in ONNX is not the same as torch.gather.
5152    scalar_type = _type_utils.JitScalarType.from_value(self)
5153    values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
5154    depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
5155    index = g.op(
5156        "Cast",
5157        g.op("OneHot", index, depth, values, axis_i=dim),
5158        to_i=scalar_type.onnx_type(),
5159    )
5160    mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index)
5161    return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)
5162
5163
5164@symbolic_helper.parse_args("v", "is", "i", "i")
5165def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim):
5166    return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim)
5167
5168
5169@_onnx_symbolic("aten::std")
5170def std(g: jit_utils.GraphContext, input, *args):
5171    var, _ = var_mean(g, input, *args)
5172    return g.op("Sqrt", var)
5173
5174
5175@_onnx_symbolic("aten::var")
5176def var(g: jit_utils.GraphContext, input, *args):
5177    var, _ = var_mean(g, input, *args)
5178    return var
5179
5180
5181@_onnx_symbolic("aten::var_mean")
5182def var_mean(g: jit_utils.GraphContext, input, *args):
5183    if len(args) == 1:
5184        return _var_mean(g, input, None, args[0], None)
5185    else:
5186        return _var_mean(g, input, *args)
5187
5188
5189@_onnx_symbolic("aten::std_mean")
5190def std_mean(g: jit_utils.GraphContext, input, *args):
5191    var, mean = var_mean(g, input, *args)
5192    return g.op("Sqrt", var), mean
5193
5194
5195@_onnx_symbolic("aten::logsumexp")
5196@symbolic_helper.parse_args("v", "is", "i")
5197def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
5198    return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim)
5199
5200
5201@_onnx_symbolic("aten::arange")
5202def arange(g: jit_utils.GraphContext, *args):
5203    def _get_arange_dtype(dtype):
5204        dtype = symbolic_helper._maybe_get_const(dtype, "i")
5205        return dtype
5206
5207    def _float_step_convert(range_tensor):
5208        if symbolic_helper._is_fp(range_tensor):
5209            range_tensor = g.op(
5210                "Cast",
5211                g.op("Ceil", range_tensor),
5212                to_i=_type_utils.JitScalarType.INT64.onnx_type(),
5213            )
5214        return range_tensor
5215
5216    if len(args) == 2 or len(args) == 5:
5217        if len(args) == 2:
5218            # aten::arange(Scalar end, Tensor out)
5219            dtype = None
5220        else:
5221            # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
5222            dtype = _get_arange_dtype(args[1])
5223        dtype, end, start, step = symbolic_helper._arange_cast_helper(
5224            g, end=args[0], dtype=dtype
5225        )
5226        end = symbolic_helper._unsqueeze_helper(g, end, [0])
5227        range_tensor = _float_step_convert(end)
5228        arange_tensor = symbolic_helper._squeeze_helper(
5229            g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1]
5230        )
5231        return g.op(
5232            "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
5233        )
5234    elif len(args) == 4 or len(args) == 7:
5235        if len(args) == 4:
5236            # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
5237            dtype = None
5238        else:
5239            # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
5240            dtype = _get_arange_dtype(args[3])
5241        dtype, end, start, step = symbolic_helper._arange_cast_helper(
5242            g, start=args[0], end=args[1], step=args[2], dtype=dtype
5243        )
5244        step = symbolic_helper._unsqueeze_helper(g, step, [0])
5245        end = symbolic_helper._unsqueeze_helper(g, end, [0])
5246        start = symbolic_helper._unsqueeze_helper(g, start, [0])
5247        range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step))
5248        arange_tensor = symbolic_helper._squeeze_helper(
5249            g, nonzero(g, ones(g, range_tensor, None, None, None)), [1]
5250        )
5251        arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
5252        return g.op(
5253            "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
5254        )
5255    elif len(args) == 6:
5256        # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
5257        dtype = _get_arange_dtype(args[2])
5258        dtype, end, start, step = symbolic_helper._arange_cast_helper(
5259            g, start=args[0], end=args[1], dtype=dtype
5260        )
5261        end = symbolic_helper._unsqueeze_helper(g, end, [0])
5262        start = symbolic_helper._unsqueeze_helper(g, start, [0])
5263        range_tensor = _float_step_convert(g.op("Sub", end, start))
5264        arange_tensor = g.op(
5265            "Add",
5266            symbolic_helper._squeeze_helper(
5267                g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]
5268            ),
5269            start,
5270        )
5271        return g.op(
5272            "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
5273        )
5274
5275    return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments")
5276
5277
5278@_onnx_symbolic("aten::linspace")
5279def linspace(
5280    g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory
5281):
5282    range_tensor = symbolic_helper._arange_helper(g, steps, None)
5283    step = div(
5284        g,
5285        sub(g, end, start),
5286        sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))),
5287    )
5288    return add(g, mul(g, range_tensor, step), start)
5289
5290
5291@_onnx_symbolic("aten::lift")
5292def lift(g: jit_utils.GraphContext, self):
5293    # at::lift() is a no-op from the perspective of tracing for onnx
5294    return self
5295
5296
5297@_onnx_symbolic("aten::masked_fill")
5298def masked_fill(g: jit_utils.GraphContext, self, mask, value):
5299    """Implement the masked_fill functionality available for a pytorch tensor in ONNX.
5300
5301    Fills elements of the input tensor with `value` where `mask` is True.
5302    """
5303    mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
5304    value = symbolic_helper._maybe_get_scalar(value)
5305    return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self)
5306
5307
5308@_onnx_symbolic("aten::masked_fill_")
5309def masked_fill_(g: jit_utils.GraphContext, self, mask, value):
5310    return masked_fill(g, self, mask, value)
5311
5312
5313@_onnx_symbolic("aten::index")
5314def index(g: jit_utils.GraphContext, self, index):
5315    if symbolic_helper._is_packed_list(index):
5316        indices = symbolic_helper._unpack_list(index)
5317    else:
5318        indices = [index]
5319
5320    def try_mask_to_index(index):
5321        if not symbolic_helper._is_none(index) and (
5322            _type_utils.JitScalarType.from_value(
5323                index, _type_utils.JitScalarType.UNDEFINED
5324            )
5325            == _type_utils.JitScalarType.UINT8
5326            or symbolic_helper._is_bool(index)
5327        ):
5328            if g.opset < 9:
5329                raise errors.SymbolicValueError(
5330                    "Exporting masked indices are only supported after ONNX opset 9.",
5331                    self,
5332                )
5333            warnings.warn(
5334                "Exporting aten::index operator with indices of type Byte. "
5335                "Only 1-D indices are supported. In any other case, "
5336                "this will produce an incorrect ONNX graph."
5337            )
5338            index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1])
5339        return index
5340
5341    indices = [try_mask_to_index(idx) for idx in indices]
5342    if len(indices) == 1:
5343        return symbolic_helper._select_helper(
5344            g, self, 0, indices[0], apply_reshape=False
5345        )
5346    else:
5347        # Multiple tensors as indices. Each tensor could either be
5348        #   1. prim::Constant()
5349        #           representing ":" in python indexing. E.g. tensor[:, :]
5350        #   2. prim::Constant[value=...] or tensor output
5351        #           representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
5352        # For more info on advanced indexing,
5353        # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
5354
5355        # Consider a general case of
5356        #       t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
5357        # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":".
5358        # Same results can be achieved through transposing t into
5359        #       t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
5360        # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t
5361        # and process the tensor indices.
5362        #       t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
5363        #       tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
5364        # After gather, reshape and transpose back.
5365        adv_idx_indices = [
5366            i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx)
5367        ]
5368
5369        if len(adv_idx_indices) == 0:
5370            return self
5371        elif len(adv_idx_indices) == 1:
5372            return index_select(
5373                g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]
5374            )
5375        else:
5376            rank = symbolic_helper._get_tensor_rank(self)
5377            if rank is None:
5378                return symbolic_helper._unimplemented(
5379                    "aten::index",
5380                    "operator of advanced indexing on tensor of unknown rank. ",
5381                    self,
5382                )
5383            # TODO: If indexing is supported natively in ONNX in future opsets,
5384            #       update the warning to recommend exporting with higher opset version.
5385            warnings.warn(
5386                "Exporting aten::index operator of advanced indexing in opset "
5387                f"{GLOBALS.export_onnx_opset_version}"
5388                " is achieved by combination of multiple ONNX operators, "
5389                "including Reshape, Transpose, Concat, and Gather. "
5390                "If indices include negative values, the exported graph will produce incorrect results."
5391            )
5392            adv_idx_count = len(adv_idx_indices)
5393            shape_tensor = _shape_as_tensor(g, self)
5394            dim_tensor_list = [
5395                g.op(
5396                    "Gather",
5397                    shape_tensor,
5398                    g.op("Constant", value_t=torch.LongTensor([dim])),
5399                    axis_i=0,
5400                )
5401                for dim in range(rank)
5402            ]
5403
5404            self = g.op(
5405                "Transpose",
5406                self,
5407                perm_i=adv_idx_indices
5408                + [i for i in range(rank) if i not in adv_idx_indices],
5409            )
5410            self = g.op("Flatten", self, axis_i=adv_idx_count)
5411
5412            # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well.
5413            cum_adv_index = indices[adv_idx_indices[-1]]
5414            multiplier = dim_tensor_list[adv_idx_indices[-1]]
5415            for i in range(adv_idx_count - 2, -1, -1):
5416                adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier)
5417                cum_adv_index = g.op("Add", cum_adv_index, adv_index)
5418                multiplier = g.op(
5419                    "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]]
5420                )
5421
5422            # perform gather
5423            self = index_select(g, self, 0, cum_adv_index)
5424
5425            cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index)
5426            # check if all advanced indices are consecutive.
5427            # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
5428            # to understand how the subarray position is decided.
5429            if adv_idx_indices == list(
5430                range(adv_idx_indices[0], adv_idx_indices[-1] + 1)
5431            ):
5432                # unfold regular index axes
5433                folded_adv_idx_shape_list = [
5434                    g.op("Constant", value_t=torch.LongTensor([-1]))
5435                ] + [
5436                    dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices
5437                ]
5438                folded_adv_idx_shape = g.op(
5439                    "Concat", *folded_adv_idx_shape_list, axis_i=0
5440                )
5441                self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape)
5442
5443                # Transpose folded advanced indexed axis to its original location.
5444                adv_idx_permute = (
5445                    list(range(1, adv_idx_indices[0] + 1))
5446                    + [0]
5447                    + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1))
5448                )
5449                self = g.op("Transpose", self, perm_i=adv_idx_permute)
5450
5451                # unfold advanced index axes
5452                final_shape_list = (
5453                    [dim_tensor_list[i] for i in range(adv_idx_indices[0])]
5454                    + [cum_adv_index_shape_tensor]
5455                    + [
5456                        dim_tensor_list[i]
5457                        for i in range(adv_idx_indices[0], rank)
5458                        if i not in adv_idx_indices
5459                    ]
5460                )
5461                final_shape = g.op("Concat", *final_shape_list, axis_i=0)
5462            else:
5463                final_shape = g.op(
5464                    "Concat",
5465                    cum_adv_index_shape_tensor,
5466                    *[
5467                        dim_tensor_list[i]
5468                        for i in range(rank)
5469                        if i not in adv_idx_indices
5470                    ],
5471                    axis_i=0,
5472                )
5473
5474            return symbolic_helper._reshape_helper(g, self, final_shape)
5475
5476
5477@_onnx_symbolic("aten::linalg_norm")
5478@symbolic_helper.parse_args("v", "v", "is", "b", "v")
5479def linalg_norm(
5480    g: jit_utils.GraphContext,
5481    self: torch._C.Value,
5482    ord: torch._C.Value,
5483    dim: Sequence[int] | None,
5484    keepdim: bool,
5485    dtype: torch._C.Value,
5486):
5487    # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
5488    ord_value = None
5489    if dim is None:
5490        if symbolic_helper._is_none(ord):
5491            self = symbolic_helper._reshape_helper(g, self, [-1])
5492            ord = g.op("Constant", value_t=torch.LongTensor([2]))
5493        self_dim = symbolic_helper._get_tensor_rank(self)
5494        if self_dim is None:
5495            return symbolic_helper._unimplemented(
5496                "dim", "Input rank must be known at export time.", self
5497            )
5498        if self_dim == 1:
5499            ord_value = symbolic_helper._parse_arg(ord, "f")
5500        else:
5501            dim = [0, 1]
5502    else:
5503        if len(dim) == 1:
5504            if symbolic_helper._is_none(ord):
5505                ord = g.op("Constant", value_t=torch.LongTensor([2]))
5506            ord_value = symbolic_helper._parse_arg(ord, "f")
5507    if ord_value:
5508        return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype)
5509    return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
5510
5511
5512@_onnx_symbolic("aten::linalg_vector_norm")
5513@symbolic_helper.parse_args("v", "f", "is", "b", "v")
5514def linalg_vector_norm(
5515    g: jit_utils.GraphContext,
5516    self: torch._C.Value,
5517    ord: float,
5518    dim: Sequence[int] | None,
5519    keepdim: bool,
5520    dtype: torch._C.Value,
5521):
5522    return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
5523
5524
5525@_onnx_symbolic("aten::linalg_matrix_norm")
5526@symbolic_helper.parse_args("v", "v", "is", "b", "v")
5527def linalg_matrix_norm(
5528    g: jit_utils.GraphContext,
5529    self: torch._C.Value,
5530    ord: torch._C.Value,
5531    dim: list[int],
5532    keepdim: bool,
5533    dtype: torch._C.Value,
5534):
5535    # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html
5536    ord_value = symbolic_helper._parse_arg(ord, "s")
5537    if ord_value == "fro":
5538        return frobenius_norm(g, self, dim, keepdim)
5539    elif ord_value == "nuc":
5540        return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self)
5541    else:
5542        ord_value = symbolic_helper._parse_arg(ord, "f")
5543        if ord_value is None:
5544            return frobenius_norm(g, self, dim, keepdim)
5545        if ord_value == 2 or ord_value == -2:
5546            # ord = 2/-2 unimplemented due to lack of operators
5547            # used to calculate singular values
5548            return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self)
5549        # Wrap the dim vector to handle negative dim values
5550        self_dim = symbolic_helper._get_tensor_rank(self)
5551        if self_dim is None:
5552            return symbolic_helper._unimplemented(
5553                "linalg.matrix_norm", "Input rank must be known at export time.", self
5554            )
5555        # Common implementation for cases with
5556        # ord = 1/-1 and ord = inf/-inf
5557        if dim[0] < 0:
5558            dim[0] += self_dim
5559        if dim[1] < 0:
5560            dim[1] += self_dim
5561
5562        if ord_value == math.inf or ord_value == -math.inf:
5563            dim[0], dim[1] = dim[1], dim[0]
5564        if dim[1] > dim[0] and not keepdim:
5565            dim[1] -= 1
5566        sum = symbolic_helper._reducesum_helper(
5567            g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
5568        )
5569        if ord_value > 0:
5570            result, indices = max(
5571                g,
5572                sum,
5573                dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
5574                keepdim=keepdim,
5575            )
5576        else:
5577            result, indices = min(
5578                g,
5579                sum,
5580                dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
5581                keepdim=keepdim,
5582            )
5583        return result
5584
5585
5586@_onnx_symbolic("aten::linalg_cross")
5587@symbolic_helper.parse_args("v", "v", "i")
5588def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1):
5589    return cross(g, input, other, dim)
5590
5591
5592@_onnx_symbolic("aten::frobenius_norm")
5593@symbolic_helper.parse_args("v", "is", "b")
5594def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
5595    sqr = g.op("Mul", self, self)
5596    sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
5597    return g.op("Sqrt", sumsqr)
5598
5599
5600@_onnx_symbolic("aten::multinomial")
5601@symbolic_helper.parse_args("v", "i", "b", "v")
5602def multinomial(
5603    g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None
5604):
5605    if generator is not None and not symbolic_helper._is_none(generator):
5606        symbolic_helper._unimplemented(
5607            "Multinomial", "generator is not supported for multinomial", input
5608        )
5609    if not replacement and num_samples > 1:
5610        symbolic_helper._unimplemented(
5611            "Multinomial",
5612            "replacement=False when num_samples > 1 is not supported for multinomial",
5613            input,
5614        )
5615
5616    log_input = log(g, input)
5617    return g.op(
5618        "Multinomial",
5619        log_input,
5620        dtype_i=_C_onnx.TensorProtoDataType.INT64,
5621        sample_size_i=num_samples,
5622    )
5623
5624
5625@_onnx_symbolic("aten::baddbmm")
5626def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha):
5627    scalar_type = _type_utils.JitScalarType.from_value(self)
5628    batch_mul = matmul(g, batch1, batch2)
5629    mul_a = mul(
5630        g,
5631        batch_mul,
5632        g.op("Cast", alpha, to_i=scalar_type.onnx_type()),
5633    )
5634    mul_b = mul(
5635        g,
5636        self,
5637        g.op("Cast", beta, to_i=scalar_type.onnx_type()),
5638    )
5639    return add(g, mul_a, mul_b)
5640
5641
5642@_onnx_symbolic("aten::meshgrid")
5643@symbolic_helper.parse_args("v", "s")
5644def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None):
5645    if indexing is None:
5646        indexing = "ij"
5647    elif indexing not in {"ij", "xy"}:
5648        raise errors.SymbolicValueError(
5649            f"Unsupported indexing: {indexing}", tensor_list
5650        )
5651    unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list)
5652    if indexing == "xy":
5653        unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1]
5654    tensors = [
5655        symbolic_helper._reshape_helper(
5656            g, t, g.op("Constant", value_t=torch.LongTensor([-1]))
5657        )
5658        for t in unpacked_tensor_list
5659    ]
5660    tensors_shape = [g.op("Shape", t) for t in tensors]
5661    out_shape = g.op("Concat", *tensors_shape, axis_i=0)
5662    out = []
5663    for i, t in enumerate(tensors):
5664        shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(
5665            tensors
5666        )
5667        shape_i[i] = tensors_shape[i]
5668        t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
5669        out.append(g.op("Expand", t_reshaped, out_shape))
5670    if indexing == "xy":
5671        out[0], out[1] = out[1], out[0]
5672    return g.op("prim::ListConstruct", *out)
5673
5674
5675@_onnx_symbolic("aten::remainder")
5676def remainder(g: jit_utils.GraphContext, input, other):
5677    div = _floor_divide(g, input, other)
5678    quo = g.op("Mul", div, other)
5679    return g.op("Sub", input, quo)
5680
5681
5682@_onnx_symbolic("aten::gelu")
5683@symbolic_helper.parse_args("v", "s")
5684def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"):
5685    if approximate == "tanh":
5686        kBeta = math.sqrt(2 / math.pi)
5687        kKappa = 0.044715
5688
5689        beta = torch.tensor(kBeta, dtype=torch.double)
5690        kappa = torch.tensor(kKappa, dtype=torch.double)
5691        one = torch.tensor(1.0, dtype=torch.double)
5692        half = torch.tensor(0.5, dtype=torch.double)
5693
5694        self_cube = mul(g, self, mul(g, self, self))
5695        inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
5696        return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
5697    else:
5698        _sqrt2 = 1.4142135623730951
5699        erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
5700        erf_plusone = add(
5701            g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))
5702        )
5703        return mul(
5704            g,
5705            mul(g, self, erf_plusone),
5706            g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)),
5707        )
5708
5709
5710@_onnx_symbolic("aten::group_norm")
5711@symbolic_helper.quantized_args(True, False, False, False)
5712@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i")
5713def group_norm(
5714    g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled
5715):
5716    channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
5717    if channel_size is not None:
5718        assert channel_size % num_groups == 0
5719    input_rank = symbolic_helper._get_tensor_rank(input)
5720    if input_rank is None:
5721        return symbolic_helper._unimplemented("group_norm", "unknown input rank", input)
5722    # 0 in the shape list keeps dimension value unchanged.
5723    shape = [0, num_groups, -1]
5724    input_reshaped = symbolic_helper._reshape_helper(
5725        g, input, g.op("Constant", value_t=torch.LongTensor(shape))
5726    )
5727
5728    # C is always divisible by num_groups
5729    # Due to shape difference. we need to apply weight and bias after
5730    # instance norm computation and reshape
5731    weight_ = g.op(
5732        "Constant",
5733        value_t=torch.tensor(
5734            [1.0] * num_groups,
5735            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
5736        ),
5737    )
5738    bias_ = g.op(
5739        "Constant",
5740        value_t=torch.tensor(
5741            [0.0] * num_groups,
5742            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
5743        ),
5744    )
5745
5746    norm_reshaped = g.op(
5747        "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps
5748    )
5749    norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input))
5750
5751    if weight is None or weight.node().mustBeNone():
5752        weight_value = torch.tensor(
5753            [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
5754        )
5755        weight = g.op("Constant", value_t=weight_value)
5756    if bias is None or bias.node().mustBeNone():
5757        bias_value = torch.tensor(
5758            [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
5759        )
5760        bias = g.op("Constant", value_t=bias_value)
5761
5762    # Norm has shape [N, C, *] so we reshape weight and bias to [C, *]
5763    axes = list(range(1, input_rank - 1))
5764    return add(
5765        g,
5766        mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)),
5767        symbolic_helper._unsqueeze_helper(g, bias, axes),
5768    )
5769
5770
5771@_onnx_symbolic("aten::_weight_norm")
5772@symbolic_helper.parse_args("v", "v", "i")
5773def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim):
5774    rank = symbolic_helper._get_tensor_rank(weight_v)
5775    if rank is not None:
5776        # W = g * ((v) / ||v||)
5777        # Compute norm_except_dim for l2 norm. dim = None means over all dims
5778        # torch's weight_norm module sets dim = -1 if it's None.
5779        # This conflicts the logic for negative axes to access dims backwards
5780        # TODO: Might need a fix in torch group_norm module
5781        axes = list(range(rank))
5782        if dim is not None:
5783            if dim < -1:
5784                dim += rank
5785            if dim != -1:
5786                axes.remove(dim)
5787        norm_v = norm(g, weight_v, 2, axes, 1)
5788        div = g.op("Div", weight_v, norm_v)
5789        return g.op("Mul", div, weight_g)
5790    raise errors.SymbolicValueError(
5791        "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.",
5792        weight_v,
5793    )
5794
5795
5796@_onnx_symbolic("aten::dim")
5797def dim(g: jit_utils.GraphContext, self):
5798    """Implement the dim functionality available for a pytorch tensor in ONNX"""
5799    # ONNX does not support dim directly in this opset so we can use 2 ops to get the info
5800    shape = g.op("Shape", self)
5801    return g.op("Size", shape)
5802
5803
5804@_onnx_symbolic("aten::__contains_")
5805def __contains_(g: jit_utils.GraphContext, self, element):
5806    unpacked_list = symbolic_helper._unpack_list(self)
5807    if all(
5808        symbolic_helper._is_constant(x) for x in unpacked_list
5809    ) and symbolic_helper._is_constant(element):
5810        return g.op(
5811            "Constant",
5812            value_t=torch.tensor(
5813                symbolic_helper._node_get(element.node(), "value")
5814                in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list)
5815            ),
5816        )
5817
5818    raise errors.SymbolicValueError(
5819        "Unsupported: ONNX export of __contains__ for non-constant list or element.",
5820        self,
5821    )
5822
5823
5824@_onnx_symbolic("aten::__getitem_")
5825def __getitem_(g: jit_utils.GraphContext, self, i):
5826    return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
5827
5828
5829@_onnx_symbolic("aten::item")
5830def item(g: jit_utils.GraphContext, self):
5831    return self
5832
5833
5834@_onnx_symbolic("aten::take")
5835def take(g: jit_utils.GraphContext, self, index):
5836    self_flattened = symbolic_helper._reshape_helper(
5837        g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
5838    )
5839    out = index_select(g, self_flattened, 0, index)
5840    out = reshape_as(g, out, index)
5841    return out
5842
5843
5844def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target):
5845    diff_ = sub(g, target, input)
5846    exp_ = exp(g, target)
5847    output = mul(g, exp_, diff_)
5848    return output
5849
5850
5851def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target):
5852    log_ = log(g, target)
5853    diff_ = sub(g, log_, input)
5854    output_pos = mul(g, target, diff_)
5855    zeros_ = zeros_like(g, output_pos)
5856    mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
5857    output = where(g, mask_, output_pos, zeros_)
5858    return output
5859
5860
5861@_onnx_symbolic("aten::kl_div")
5862@symbolic_helper.parse_args("v", "v", "i", "b")
5863def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target):
5864    if log_target:
5865        output = _kl_div_log_target_impl(g, input, target)
5866    else:
5867        output = _kl_div_non_log_target_impl(g, input, target)
5868
5869    if reduction == 0:
5870        return output
5871    elif reduction == 1:
5872        return g.op("ReduceMean", output, keepdims_i=0)
5873    elif reduction == 2:
5874        return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
5875    else:
5876        return symbolic_helper._onnx_unsupported(
5877            "kl_div with reduction other than none, mean, or sum.", input
5878        )
5879
5880
5881@_onnx_symbolic("aten::mse_loss")
5882@symbolic_helper.parse_args("v", "v", "i")
5883def mse_loss(g: jit_utils.GraphContext, input, target, reduction):
5884    output = mul(g, sub(g, input, target), sub(g, input, target))
5885    if reduction == 0:
5886        return output
5887    elif reduction == 1:
5888        return g.op("ReduceMean", output, keepdims_i=0)
5889    elif reduction == 2:
5890        return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
5891    else:
5892        return symbolic_helper._onnx_unsupported(
5893            "mse_loss with reduction other than none, mean, or sum.", input
5894        )
5895
5896
5897@_onnx_symbolic("aten::as_strided")
5898@symbolic_helper.quantized_args(True)
5899@symbolic_helper.parse_args("v", "v", "is", "i")
5900def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None):
5901    sizes = symbolic_helper._maybe_get_const(sizes, "is")
5902    rank = len(strides)
5903    self_1d = symbolic_helper._reshape_helper(
5904        g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
5905    )
5906    ind: torch.Tensor | None
5907    if not symbolic_helper._is_value(sizes):
5908        ind = torch.tensor([0], dtype=torch.long)
5909        for i, (size, stride) in enumerate(zip(sizes, strides)):
5910            r_size = [1] * rank
5911            r_size[i] = -1
5912            ind = ind + torch.arange(size).view(r_size) * stride
5913        if offset:
5914            ind = ind + offset
5915        return g.op("Gather", self_1d, g.op("Constant", value_t=ind))
5916    else:
5917        ind = None
5918        for i, stride in enumerate(strides):
5919            r_size = [1] * rank
5920            r_size[i] = -1
5921            size = select(
5922                g,
5923                sizes,
5924                g.op("Constant", value_t=torch.tensor([0])),
5925                g.op("Constant", value_t=torch.tensor(i)),
5926            )
5927            tmp_ind = symbolic_helper._reshape_helper(
5928                g,
5929                arange(g, size, 4, None, None, None),
5930                g.op("Constant", value_t=torch.tensor(r_size)),
5931            )
5932            tmp_ind = g.op(
5933                "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride]))
5934            )
5935            if ind is None:
5936                ind = tmp_ind
5937            else:
5938                ind = g.op("Add", ind, tmp_ind)
5939        if offset:
5940            ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
5941        return g.op("Gather", self_1d, ind)
5942
5943
5944@_onnx_symbolic("aten::__derive_index")
5945def __derive_index(g: jit_utils.GraphContext, index, start, step):
5946    return g.op("Add", start, g.op("Mul", index, step))
5947
5948
5949@_onnx_symbolic("aten::__range_length")
5950# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp
5951# if (step > 0 && lo < hi) {
5952#   push(stack, 1 + (hi - 1 - lo) / step);
5953# } else if (step < 0 && lo > hi) {
5954#   push(stack, 1 + (lo - 1 - hi) / (0 - step));
5955# } else {
5956#  push(stack, 0);
5957# }
5958def __range_length(g: jit_utils.GraphContext, lo, hi, step):
5959    sub = g.op("Sub", hi, lo)
5960    div = g.op("Ceil", true_divide(g, sub, step))
5961    return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64)
5962
5963
5964@_onnx_symbolic("aten::linear")
5965def linear(g: jit_utils.GraphContext, input, weight, bias):
5966    rank = symbolic_helper._get_tensor_rank(input)
5967    weight = t(g, weight)
5968    if rank == 2 and not bias.node().mustBeNone():
5969        alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
5970        beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
5971        output = addmm(g, bias, input, weight, alpha, beta)
5972    else:
5973        output = matmul(g, input, weight)
5974        if not bias.node().mustBeNone():
5975            output = add(g, bias, output)
5976
5977    return output
5978
5979
5980@_onnx_symbolic("aten::hann_window")
5981@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v")
5982def hann_window(
5983    g: jit_utils.GraphContext,
5984    window_length,
5985    periodic=True,
5986    dtype: int | None = None,
5987    layout=None,
5988    device=None,
5989    pin_memory=None,
5990    requires_grad=False,
5991):
5992    if dtype is None:
5993        dtype_ = torch.get_default_dtype()
5994        if not dtype_ or not dtype_.is_floating_point:
5995            dtype_ = torch.float
5996        scalar_type = _type_utils.JitScalarType.from_dtype(dtype_)
5997    else:
5998        scalar_type = _type_utils.JitScalarType(dtype)
5999
6000    n_array = arange(g, window_length, 4, None, None, None)
6001    output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT)
6002    output = mul(
6003        g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output
6004    )
6005
6006    if periodic is False:
6007        window_length = sub(
6008            g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int))
6009        )
6010    output = div(g, output, window_length)
6011    output = g.op(
6012        "Cast",
6013        square(g, sin(g, output)),
6014        to_i=scalar_type.onnx_type(),
6015    )
6016
6017    return output
6018
6019
6020@_onnx_symbolic("aten::mv")
6021def mv(g: jit_utils.GraphContext, self, vec):
6022    return matmul(g, self, vec)
6023
6024
6025@_onnx_symbolic("aten::dot")
6026def dot(g: jit_utils.GraphContext, self, other):
6027    return matmul(g, self, other)
6028
6029
6030@_onnx_symbolic("aten::movedim")
6031@symbolic_helper.parse_args("v", "t", "t")
6032def movedim(g: jit_utils.GraphContext, self, source, destination):
6033    # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim
6034    source = source.view(-1)
6035    destination = destination.view(-1)
6036
6037    assert source.size() == destination.size()
6038
6039    if (source == destination).all():
6040        return self
6041
6042    self_rank = symbolic_helper._get_tensor_rank(self)
6043    assert self_rank is not None
6044
6045    perm = list(range(self_rank))
6046
6047    src_dims = perm.copy()
6048    dst_dims = perm.copy()
6049
6050    for src, dst in zip(source.tolist(), destination.tolist()):
6051        perm[dst] = src
6052        src_dims[src] = -1
6053        dst_dims[dst] = -1
6054
6055    src_dims = [dim for dim in src_dims if dim != -1]
6056    dst_dims = [dim for dim in dst_dims if dim != -1]
6057
6058    for src, dst in zip(src_dims, dst_dims):
6059        perm[dst] = src
6060
6061    return g.op("Transpose", self, perm_i=perm)
6062
6063
6064@_onnx_symbolic("aten::fill")
6065@symbolic_helper.parse_args("v", "v")
6066def fill(g: jit_utils.GraphContext, self, value):
6067    scalar_type = _type_utils.JitScalarType.from_value(
6068        self, _type_utils.JitScalarType.FLOAT
6069    )
6070    return full_like(g, self, value, scalar_type)
6071
6072
6073@_onnx_symbolic("aten::index_add")
6074def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
6075    warnings.warn(
6076        "Warning: ONNX export does not support duplicated values in 'index' field, "
6077        + "this will cause the ONNX model to be incorrect."
6078    )
6079
6080    # ONNX does not support "alpha" argument, unlike aten index_add
6081    # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context
6082    if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
6083        return symbolic_helper._unimplemented("index_add", "alpha != 1", self)
6084
6085    dim = symbolic_helper._maybe_get_const(dim, "i")
6086    if dim is None:
6087        raise errors.SymbolicValueError(
6088            "ONNX export does NOT support exporting 'index_add_()' function with "
6089            "unknown 'dim' value.",
6090            self,
6091        )
6092
6093    self_dim_rank = symbolic_helper._get_tensor_rank(self)
6094    other_dim_rank = symbolic_helper._get_tensor_rank(other)
6095
6096    if self_dim_rank is None or other_dim_rank is None:
6097        raise errors.SymbolicValueError(
6098            "ONNX export does NOT support exporting 'index_add_()' function while "
6099            "the rank of self tensor or tensor to be added is unknown.",
6100            self,
6101        )
6102
6103    if other_dim_rank != self_dim_rank:
6104        delta = self_dim_rank - other_dim_rank
6105        for i in range(delta):
6106            other = symbolic_helper._unsqueeze_helper(
6107                g, other, [symbolic_helper._get_tensor_rank(other)]
6108            )
6109
6110    other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim)
6111    self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
6112
6113    if (other_dim_size is not None) and (self_dim_size is not None):
6114        if other_dim_size > self_dim_size:
6115            raise errors.SymbolicValueError(
6116                "ONNX export does not support exporting 'index_add_()' function with "
6117                "duplicated values in 'index' parameter yet.",
6118                self,
6119            )
6120
6121    # Construct a new shape. It's almost as same as self except the size of the 'dim'
6122    # dimension is 1, so that we can expand other dimensions as expected.
6123    new_shape_axes = list(range(self_dim_rank))
6124    new_shape_starts = [0 for i in range(self_dim_rank)]
6125    new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)]
6126
6127    new_shape = symbolic_helper._slice_helper(
6128        g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends
6129    )
6130    other = expand_as(g, other, new_shape)
6131
6132    for i in range(dim):
6133        index = symbolic_helper._unsqueeze_helper(g, index, [0])
6134
6135    for i in range(self_dim_rank - dim - 1):
6136        index = symbolic_helper._unsqueeze_helper(
6137            g, index, [symbolic_helper._get_tensor_rank(index)]
6138        )
6139
6140    return scatter_add(g, self, dim, expand_as(g, index, other), other)
6141
6142
6143@_onnx_symbolic("aten::roll")
6144@symbolic_helper.parse_args("v", "is", "is")
6145def roll(g: jit_utils.GraphContext, self, shifts, dims):
6146    assert len(shifts) == len(dims)
6147
6148    result = self
6149    for i in range(len(shifts)):
6150        shapes = []
6151        shape = symbolic_helper._slice_helper(
6152            g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize]
6153        )
6154        shapes.append(shape)
6155        shape = symbolic_helper._slice_helper(
6156            g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]]
6157        )
6158        shapes.append(shape)
6159        result = g.op("Concat", *shapes, axis_i=dims[i])
6160
6161    return result
6162
6163
6164@_onnx_symbolic("aten::cross")
6165@symbolic_helper.parse_args("v", "v", "i")
6166def cross(g: jit_utils.GraphContext, input, other, dim=None):
6167    dim = symbolic_helper._get_dim_for_cross(input, dim)
6168    # If we have two tensors such that
6169    # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have
6170    # After first roll,
6171    # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e)
6172    roll_x_1 = roll(g, input, [2], [dim])
6173    roll_y_1 = roll(g, other, [1], [dim])
6174    # After second roll,
6175    # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d)
6176    roll_x_2 = roll(g, input, [1], [dim])
6177    roll_y_2 = roll(g, other, [2], [dim])
6178    # cross product is calculated as
6179    # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)]
6180    return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2))
6181
6182
6183@_onnx_symbolic("aten::cdist")
6184def cdist(
6185    g: jit_utils.GraphContext,
6186    x1,
6187    x2,
6188    p=2.0,
6189    compute_mode="use_mm_for_euclid_dist_if_necessary",
6190):
6191    # X1.shape = (B * P * D), X2.shape = (B * R * D)
6192    # In order to respect numpy style broadcasting as demonstrated in
6193    # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
6194    # we unsqueeze both input tensors
6195    # Currently we ignore the 'compute_mode' variable as we use default to
6196    # using matrix multiplication to calculate the euclidean distance
6197    rank = symbolic_helper._get_tensor_rank(x1)
6198    assert rank is not None
6199    broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1])
6200    broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2])
6201    return pairwise_distance(
6202        g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False
6203    )
6204
6205
6206@_onnx_symbolic("aten::lerp")
6207def lerp(g: jit_utils.GraphContext, self, end, weight):
6208    # Conditional for better numeric. This has been discussed in
6209    # https://github.com/pytorch/pytorch/pull/18871
6210    diff = g.op("Sub", end, self)
6211    return where(
6212        g,
6213        g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))),
6214        g.op("Add", self, g.op("Mul", weight, diff)),
6215        g.op(
6216            "Sub",
6217            end,
6218            g.op(
6219                "Mul",
6220                diff,
6221                g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight),
6222            ),
6223        ),
6224    )
6225
6226
6227@_onnx_symbolic("aten::broadcast_tensors")
6228def broadcast_tensors(g: jit_utils.GraphContext, self):
6229    all_tensors = symbolic_helper._unpack_list(self)
6230    t_with_final_shape = zeros_like(g, all_tensors[0])
6231
6232    # Add operator supports multidirectional broadcasting. So we leverage this function
6233    # to infer the final shape generated by the broadcast.
6234    for t in all_tensors:
6235        t_with_final_shape = add(g, t_with_final_shape, t)
6236
6237    t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors]
6238    return g.op("prim::ListConstruct", *t_list)
6239
6240
6241@_onnx_symbolic("aten::is_pinned")
6242def is_pinned(g: jit_utils.GraphContext, self, device=None):
6243    # Unused by ONNX.
6244    return None
6245
6246
6247@_onnx_symbolic("prim::ConstantSplit")
6248def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim):
6249    size = symbolic_helper._get_tensor_dim_size(self, dim)
6250    if size is None:
6251        return symbolic_helper._unimplemented(
6252            "prim::ConstantSplit", "unknown dimension size", self
6253        )
6254    splits = [split_size] * (size // split_size)
6255    leftover = size % split_size
6256    if leftover:
6257        splits.append(leftover)
6258    return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
6259
6260
6261# TODO: It would be better to export this as a chunk directly, as this is
6262# less sensitive to changes in input size.
6263# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
6264# method, and use the desugared version
6265@_onnx_symbolic("prim::ConstantChunk")
6266def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
6267    dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
6268    if dim_size is None:
6269        return symbolic_helper._unimplemented(
6270            "prim::ConstantChunk", "unknown dimension size", self
6271        )
6272    split_size = (dim_size + chunks - 1) // chunks
6273    return prim_constant_split(g, self, split_size, dim)
6274
6275
6276@_onnx_symbolic("prim::shape")
6277def prim_shape(g: jit_utils.GraphContext, self):
6278    return g.op("Shape", self)
6279
6280
6281@_onnx_symbolic("prim::max")
6282def prim_max(g: jit_utils.GraphContext, self, other):
6283    return symbolic_helper._op_with_optional_float_cast(
6284        g, "Max", self, other, opset_before=12
6285    )
6286
6287
6288@_onnx_symbolic("prim::min")
6289def prim_min(g: jit_utils.GraphContext, self, other=None):
6290    if not other:
6291        if symbolic_helper._is_packed_list(self):
6292            self = stack(g, self, g.op("Constant", value_t=torch.tensor([0])))
6293        return min(g, self)
6294    return min(g, self, other)
6295
6296
6297@_onnx_symbolic("prim::data")
6298def prim_data(g: jit_utils.GraphContext, self):
6299    return self
6300
6301
6302@_onnx_symbolic("prim::layout")
6303def prim_layout(g: jit_utils.GraphContext, self):
6304    # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'.
6305    # Layout class defined in 'c10/core/Layout.h'.
6306    return g.op("Constant", value_t=torch.tensor(0))
6307
6308
6309@_onnx_symbolic("prim::ListConstruct")
6310def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
6311    return None
6312
6313
6314@_onnx_symbolic("prim::ListUnpack")
6315def prim_list_unpack(
6316    g: jit_utils.GraphContext, *inputs, **kwargs
6317) -> list[_C.Value] | None:
6318    if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
6319        # Cancel the previous node if it is ListConstruct by returning its inputs
6320        # TODO(justinchuby): Use a public method in the helper module
6321        return symbolic_helper._unpack_list(inputs[0])
6322
6323    return None
6324
6325
6326@_onnx_symbolic("prim::TupleConstruct")
6327def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
6328    return None
6329
6330
6331@_onnx_symbolic("prim::Uninitialized")
6332def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs):
6333    return None
6334
6335
6336# exists to refine the type of the Value
6337# if x is an optional Tensor, unchecked_cast will cast
6338# x to Tensor, so the rest of the graph knows that x is a Tensor
6339# this doesn't do anything in runtime and is a noop in ONNX
6340@_onnx_symbolic("prim::unchecked_cast")
6341def prim_unchecked_cast(g: jit_utils.GraphContext, self):
6342    return self
6343
6344
6345@_onnx_symbolic("prim::dtype")
6346def prim_dtype(g: jit_utils.GraphContext, self):
6347    scalar_type = symbolic_helper._try_get_scalar_type(self)
6348    if scalar_type is None:
6349        scalar_type = _type_utils.JitScalarType.FLOAT
6350    # This node records a torch dtype as int
6351    return g.op("Constant", value_t=torch.tensor(scalar_type))
6352
6353
6354@_onnx_symbolic("prim::tolist")
6355def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val):
6356    """tolist is currently supported only for 1D input tensors.
6357
6358    dim_val and elem_ty_val represent dimension and type annotations
6359    that need to match dimension and type of the input tensor.
6360    """
6361    dim = symbolic_helper._maybe_get_const(dim_val, "i")
6362    if dim > 1:
6363        return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input)
6364    return input
6365
6366
6367# -----------------------------------------------------------------------------
6368# Symbolic functions that need extra context
6369# -----------------------------------------------------------------------------
6370@_onnx_symbolic("prim::device")
6371def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
6372    output_type = g.original_node.output().type()
6373    if isinstance(output_type, _C.DeviceObjType):
6374        return None
6375
6376    return symbolic_helper._unimplemented(
6377        "prim::device",
6378        f"output type should be 'DeviceObjType', not '{output_type.kind()}'",
6379        g.original_node.output(),
6380    )
6381
6382
6383@_onnx_symbolic("prim::Loop")
6384def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
6385    node = g.original_node
6386    env = g.env
6387    values_in_env = g.values_in_env
6388    params_dict = g.params_dict
6389
6390    operator_export_type = GLOBALS.operator_export_type
6391    opset_version = GLOBALS.export_onnx_opset_version
6392
6393    old_blocks = tuple(node.blocks())
6394    new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
6395        g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
6396    )
6397
6398    for old_block, new_block_context in zip(old_blocks, new_block_contexts):
6399        # Copy input metadata to subblock
6400        #
6401        #   prim::Loop(iter, cond, input_1, ..., input_n)
6402        #     block0(iter, input_1, ..., input_n)
6403        #
6404        # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`.
6405        for i, b_in in enumerate(old_block.inputs()):
6406            if i == 0 and i < len(inputs):
6407                b_in.setType(inputs[i].type())
6408            # For optional block inputs, they may switch between None not-None inside
6409            # the loop body, so if the loop input is not optional, the block input may
6410            # still need to be optional.
6411            if (
6412                i > 0
6413                and (i + 1) < len(inputs)
6414                and not isinstance(b_in.type(), _C.OptionalType)
6415            ):
6416                b_in.setType(inputs[i + 1].type())
6417        torch._C._jit_pass_onnx_block(
6418            old_block,
6419            new_block_context.block,
6420            operator_export_type,
6421            env,
6422            values_in_env,
6423            False,
6424        )
6425    fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
6426        new_node, opset_version
6427    )
6428    # Run shape type inference for Loop after subblock is converted.
6429    if GLOBALS.onnx_shape_inference:
6430        torch._C._jit_pass_onnx_node_shape_type_inference(
6431            new_node, params_dict, opset_version
6432        )
6433    return fixed_outputs
6434
6435
6436@_onnx_symbolic("prim::If")
6437def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
6438    n = g.original_node
6439    block = g.block
6440    env = g.env
6441    values_in_env = g.values_in_env
6442    params_dict = g.params_dict
6443
6444    operator_export_type = GLOBALS.operator_export_type
6445    opset_version = GLOBALS.export_onnx_opset_version
6446
6447    static_if = inputs[0].node().kind() == "onnx::Constant"
6448    if static_if:
6449        # Fold static if
6450        #
6451        # The torch IR
6452        # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu),
6453        #    %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ...
6454        # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
6455        # %21 : Long(device=cpu) = aten::eq(%20, %64)
6456        # %22 : Long(device=cpu) = prim::If(%21)
6457        #     block0():
6458        #     %23 : Long(device=cpu) = aten::is_floating_point(%input.1)
6459        #     -> (%23)
6460        #     block1():
6461        #     -> (%65)
6462        # %input.53 : Tensor, %weight : Tensor = prim::If(%22)
6463        #     block0():
6464        #     -> (%embedding_matrix.1, %input.1)
6465        #     block1():
6466        #     -> (%input.1, %embedding_matrix.1)
6467        # %26 : int[] = aten::size(%input.53)
6468        #
6469        # The converted ONNX graph
6470        # %10 : Bool(device=cpu) = onnx::Constant[value={0}]()
6471        # %14 : Bool(device=cpu) = onnx::Equal(%13, %8)
6472        # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
6473        # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1)
6474        input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist()
6475        const_value = (
6476            all(input_flag) if isinstance(input_flag, list) else bool(input_flag)
6477        )
6478        block_idx = 0 if const_value else 1
6479        current_b = list(n.blocks())[block_idx]
6480        env = torch._C._jit_pass_onnx_block(
6481            current_b,
6482            block,
6483            operator_export_type,
6484            env,
6485            values_in_env,
6486            True,
6487        )
6488        if_output_list = list(n.outputs())
6489        current_b_list = list(current_b.outputs())
6490
6491        final_b_list = []
6492        for idx in range(len(if_output_list)):
6493            if current_b_list[idx] not in env:
6494                raise errors.SymbolicValueError(
6495                    f"The sub block ATen output {current_b_list[idx]} is not in env.",
6496                    current_b_list[idx],
6497                )  # type:ignore[operator]
6498            onnx_b = env[current_b_list[idx]]
6499            final_b_list.append(onnx_b)
6500        return final_b_list
6501    else:
6502        old_blocks = tuple(n.blocks())
6503        new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
6504            g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
6505        )
6506
6507        for old_block, new_block_context in zip(old_blocks, new_block_contexts):
6508            torch._C._jit_pass_onnx_block(
6509                old_block,
6510                new_block_context.block,
6511                operator_export_type,
6512                env,
6513                values_in_env,
6514                False,
6515            )
6516        fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
6517            new_node, opset_version
6518        )
6519        # Run shape type inference for If after subblock is converted.
6520        if GLOBALS.onnx_shape_inference:
6521            torch._C._jit_pass_onnx_node_shape_type_inference(
6522                new_node, params_dict, opset_version
6523            )
6524        return fixed_outputs
6525
6526
6527@_onnx_symbolic("prim::Constant")
6528def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs):
6529    node = g.original_node
6530
6531    if node.mustBeNone():
6532        return None
6533    # This must go before checking for string values, because some device constants
6534    # have string values, but we want to keep them as unconverted Device types so
6535    # that eq() can work on them.
6536    if isinstance(node.output().type(), _C.DeviceObjType):
6537        return None
6538    if node.kindOf("value") == "t":
6539        return g.op("Constant", value_t=symbolic_helper._node_get(node, "value"))
6540    if node.kindOf("value") == "s":
6541        return g.op("Constant", value_s=symbolic_helper._node_get(node, "value"))
6542    if node.output().type().isSubtypeOf(
6543        _C.ListType.ofInts()
6544    ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()):
6545        return g.op(
6546            "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value"))
6547        )
6548    if node.output().type().isSubtypeOf(_C.ListType.ofStrings()):
6549        str_constants = [
6550            g.op("Constant", value_s=s)
6551            for s in symbolic_helper._node_get(node, "value")
6552        ]
6553        return g.op("prim::ListConstruct", *str_constants)
6554
6555    raise errors.SymbolicValueError(
6556        f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. "
6557        f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
6558        node.output(),
6559    )
6560
6561
6562@_onnx_symbolic("prim::type")
6563def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs):
6564    if device_value.node().kind() == "prim::device":
6565        device = jit_utils.get_device_from_value(device_value.node().input())
6566        if device is not None:
6567            return g.op("Constant", value_s=str(device))
6568
6569    return symbolic_helper._unimplemented(
6570        "prim::type",
6571        "Device type cannot be statically determined.",
6572        device_value,
6573    )
6574
6575
6576@_onnx_symbolic("onnx::Placeholder")
6577def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs):
6578    node = g.original_node
6579    block = g.block
6580    env = g.env
6581    values_in_env = g.values_in_env
6582
6583    return torch._C._jit_onnx_convert_pattern_from_subblock(
6584        block, node, env, values_in_env
6585    )
6586
6587
6588@_onnx_symbolic("aten::resolve_conj")
6589@_onnx_symbolic("aten::resolve_neg")
6590def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
6591    # ONNX does not have operators to *directly* manipulate real/imaginary components
6592    # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
6593    # which results in failures due to missing operators for complex numbers
6594
6595    # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op
6596    return input
6597
6598
6599@_onnx_symbolic("aten::_conj")
6600@_onnx_symbolic("aten::conj_physical")
6601def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
6602    # ONNX does not have operators to *directly* manipulate real/imaginary components
6603    # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
6604    # which results in failures due to missing operators for complex numbers
6605
6606    # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex
6607    if symbolic_helper.is_complex_value(input):
6608        # FIXME(justinchuby): report correct name for symbolic being executed
6609        return symbolic_helper._onnx_unsupported(
6610            "aten::_conj, aten::conj_physical",
6611            input,
6612        )
6613
6614    # they can safely be implemented as no-op for real numbers only
6615    return noop_complex_operators(g, input)
6616
6617
6618@_onnx_symbolic("aten::logit")
6619def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value):
6620    one = g.op("Constant", value_t=torch.tensor(1.0))
6621
6622    if not symbolic_helper._is_none(eps):
6623        eps = g.op(
6624            "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()
6625        )
6626        one_sub_eps = g.op("Sub", one, eps)
6627        self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self)
6628        temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps)
6629
6630        temporary_self_less_eps = g.op("Less", temporary_self, eps)
6631        z = g.op("Where", temporary_self_less_eps, eps, temporary_self)
6632    else:
6633        z = self
6634
6635    sub = g.op("Sub", one, z)
6636    div = g.op("Div", z, sub)
6637    return g.op("Log", div)
6638