xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_helper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5import inspect
6import math
7import sys
8import typing
9import warnings
10from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar as _TypeVar
11from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec
12
13import torch
14import torch._C._onnx as _C_onnx
15from torch import _C
16
17# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
18from torch.onnx import _constants, _type_utils, errors, utils
19from torch.onnx._globals import GLOBALS
20from torch.onnx._internal import jit_utils
21
22
23if typing.TYPE_CHECKING:
24    from torch.types import Number
25
26_T = _TypeVar("_T")
27_U = _TypeVar("_U")
28_P = _ParamSpec("_P")
29
30# ---------------------------------------------------------------------------------
31# Helper functions
32# ---------------------------------------------------------------------------------
33
34_ValueDescriptor = Literal[
35    "v",
36    "i",
37    "is",
38    "f",
39    "fs",
40    "b",
41    "s",
42    "t",
43    "none",
44]
45
46
47def _parse_arg(
48    value,
49    desc: _ValueDescriptor,
50    arg_name: str | None = None,
51    node_name: str | None = None,
52):
53    if desc == "none":
54        return value
55    if desc == "v" or not _is_value(value):
56        return value
57
58    node = value.node()
59    if node.mustBeNone():
60        return None
61    if node.kind() == "onnx::Constant":
62        node_val = _node_get(node, "value")
63        if desc == "i":
64            return int(node_val)
65        elif desc == "f":
66            return float(node_val)
67        elif desc == "b":
68            return bool(node_val)
69        elif desc == "s":
70            return str(node_val)
71        elif desc == "t":
72            return node_val
73        elif desc == "is":
74            return [int(v) for v in node_val]
75        elif desc == "fs":
76            return [float(v) for v in node_val]
77        else:
78            raise errors.SymbolicValueError(
79                f"ONNX symbolic does not understand the Constant node '{node}' "
80                f"specified with descriptor '{desc}'.",
81                value,
82            )
83    elif node.kind() == "prim::ListConstruct":
84        if desc == "is":
85            for v in node.inputs():
86                element_node = v.node()
87                if element_node.kind() != "onnx::Constant":
88                    raise errors.SymbolicValueError(
89                        f"Failed to export a node '{element_node}' "
90                        f"(in list node {node}) "
91                        f"because it is not constant. "
92                        f"Please try to make things (e.g. kernel sizes) static if possible.",
93                        value,
94                    )
95            return [int(_node_get(v.node(), "value")) for v in value.node().inputs()]
96        else:
97            raise errors.SymbolicValueError(
98                f"ONNX symbolic does not know how to unpack the ListConstruct node that "
99                f"is not a list of integers: '{node}'",
100                value,
101            )
102
103    if arg_name is None or node_name is None:
104        raise errors.SymbolicValueError(
105            f"Expected node type 'onnx::Constant', got '{node.kind()}'.",
106            value,
107        )
108
109    raise errors.SymbolicValueError(
110        "Expected node type 'onnx::Constant' "
111        f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.",
112        value,
113    )
114
115
116def _node_get(node: _C.Node, key: str):
117    """Gets attributes of a node which is polymorphic over return type."""
118    assert isinstance(node, _C.Node)
119    sel = node.kindOf(key)
120    return getattr(node, sel)(key)
121
122
123def _is_onnx_constant(value: _C.Value):
124    """Whether a Value is an ONNX constant."""
125    return value.node().kind() == "onnx::Constant"
126
127
128def _maybe_get_const(
129    value: _C.Value | torch.Tensor | Number | Sequence | None,
130    descriptor: _ValueDescriptor,
131):
132    # NOTE: prim::Constant at this stage usually means something not compatible in ONNX,
133    # otherwise it'd be converted to onnx::Constant
134    # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy
135    if isinstance(value, _C.Value) and _is_onnx_constant(value):
136        return _parse_arg(value, descriptor)
137    return value
138
139
140def _maybe_get_scalar(value):
141    value_t = _maybe_get_const(value, "t")
142    if isinstance(value_t, torch.Tensor) and value_t.shape == ():
143        return value_t
144    return value
145
146
147def _get_const(value, desc, arg_name):
148    if not _is_constant(value):
149        raise errors.SymbolicValueError(
150            f"ONNX symbolic expected a constant value of the '{arg_name}' argument, "
151            f"got '{value}'",
152            value,
153        )
154    return _parse_arg(value, desc)
155
156
157def _unpack_list(list_value: _C.Value) -> list[_C.Value]:
158    list_node = list_value.node()
159    if list_node.kind() != "prim::ListConstruct":
160        raise errors.SymbolicValueError(
161            f"ONNX symbolic expected node type prim::ListConstruct, "
162            f"got '{list_node}'.",
163            list_value,
164        )
165    return list(list_node.inputs())
166
167
168def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]:
169    tuple_node = tuple_value.node()
170    if not _is_tuple_construct(tuple_value):
171        raise errors.SymbolicValueError(
172            f"ONNX symbolic expected node type 'prim::TupleConstruct', "
173            f"got '{tuple_node.kind()}'.",
174            tuple_value,
175        )
176    return tuple(tuple_node.inputs())
177
178
179def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]:
180    """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
181    Args:
182        tuple_value: A tuple of tensor, scale, zero_point, and optionally axis.
183    Returns:
184        A tuple of tensor, scale, zero_point, and optionally axis.
185    """
186    tuple_node = tuple_value.node()
187    # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, <axis>)
188    if not _is_tuple_construct(tuple_value):
189        raise errors.SymbolicValueError(
190            f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized "
191            f"tensor. Is this likely due to missing support for quantized "
192            f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}",
193            tuple_value,
194        )
195    unpacked = tuple(tuple_node.inputs())
196    assert len(unpacked) == 3 or len(unpacked) == 4
197    return unpacked
198
199
200# Check if list_value is output from prim::ListConstruct
201# This is usually called before _unpack_list to ensure the list can be unpacked.
202def _is_packed_list(list_value: Any) -> bool:
203    return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
204
205
206def parse_args(
207    *arg_descriptors: _ValueDescriptor,
208) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]:
209    """A decorator which converts args from torch._C.Value to built-in types.
210
211    For example:
212
213    ```
214    @parse_args('v', 'i', 'fs')
215    foo(g, a, b, c):
216        assert isinstance(a, torch._C.Value)
217        assert isinstance(b, int)
218        assert isinstance(c, list)
219        assert isinstance(c[0], float)
220    ```
221
222    Args:
223        arg_descriptors: list of str, where each element is
224            a string that specifies the type to convert to. Valid descriptors:
225            "v": no conversion, keep torch._C.Value.
226            "i": int
227            "is": list of int
228            "f": float
229            "fs": list of float
230            "b": bool
231            "s": str
232            "t": torch.Tensor
233            "none": the variable is unused
234    """
235
236    def decorator(
237        fn: Callable[_Concatenate[_U, _P], _T],
238    ) -> Callable[_Concatenate[_U, _P], _T]:
239        fn._arg_descriptors = arg_descriptors  # type: ignore[attr-defined]
240
241        @functools.wraps(fn)
242        def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T:
243            # some args may be optional, so the length may be smaller
244            FILE_BUG_MSG = (
245                "If you believe this is not due to custom symbolic implementation within your code or "
246                "an external library, please file an issue at "
247                "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug."
248            )
249            assert len(arg_descriptors) >= len(args), (
250                f"A mismatch between the number of arguments ({len(args)}) and "
251                f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. "
252                f"{FILE_BUG_MSG}"
253            )
254
255            try:
256                sig = inspect.signature(fn)
257                arg_names = list(sig.parameters.keys())[1:]
258                fn_name = fn.__name__
259            except Exception:
260                # FIXME(justinchuby): Avoid catching Exception.
261                # Catch a more specific exception instead.
262                arg_names = [None] * len(args)  # type: ignore[list-item]
263                fn_name = None
264            args = [
265                _parse_arg(arg, arg_desc, arg_name, fn_name)  # type: ignore[method-assign]
266                for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)
267            ]
268            # only support _outputs in kwargs
269            assert len(kwargs) <= 1, (
270                f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single "
271                f"key/value entry. "
272                f"{FILE_BUG_MSG}"
273            )
274
275            if len(kwargs) == 1:
276                assert "_outputs" in kwargs, (
277                    f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
278                    f"'_outputs' key at '**kwargs'. "
279                    f"{FILE_BUG_MSG}"
280                )
281            return fn(g, *args, **kwargs)
282
283        return wrapper
284
285    return decorator
286
287
288def quantized_args(
289    *arg_q_descriptors: bool,
290    scale: float | None = None,
291    zero_point: int | None = None,
292    quantize_output: bool = True,
293) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
294    """A decorator which extends support for quantized version of the base operator.
295
296    Quantization is detected by examining the arguments that are annotated by
297    `arg_q_descriptors`.
298
299    If quantization is detected, the base operator symbolic function will be wrapped with
300    argument de-quantization and output quantization.
301
302    Otherwise, only the base symbolic function will be invoked.
303
304    For example:
305
306    ```
307    @quantized_args(True, False)
308    def foo(g, x, y):
309        return x + y
310    ```
311
312    is equivalent to
313
314    ```
315    def q_foo(g, x, y):
316        if is_quantized_tensor(x):
317            x = dequantize(x)
318            out = foo(g, x, y)
319            return quantize(out)
320        else:
321            return foo(g, x, y)
322    ```
323
324    Args:
325        arg_q_descriptors: A sequence of bool, where each element represents if the
326          argument is QTensor for quantized version of this operator. It defaults
327          to False for unspecified (variable length) arguments.
328        scale: Quantized output scale. If None, derive from
329          the first quantized input scale.
330        zero_point: Quantized output zero point. If None,
331          derive from the first quantized input zero point.
332        quantize_output: If True, quantize the output of the base operator. Default is True
333    """
334
335    def decorator(fn):
336        @functools.wraps(fn)
337        def wrapper(g, *args, **kwargs):
338            nonlocal scale
339            nonlocal zero_point
340            if scale is not None:
341                _scale = g.op("Constant", value_t=torch.tensor(scale))
342            else:
343                _scale = None
344            if zero_point is not None:
345                _zero_point = g.op("Constant", value_t=torch.tensor(zero_point))
346            else:
347                _zero_point = None
348
349            # Support variable length arguments by marking unspecified ones as non-quantized
350            arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
351                len(args) - len(arg_q_descriptors)
352            )
353            descriptor_args = tuple(zip(arg_q_descriptors_extended, args))
354
355            def _is_arg_quantized(descriptor, arg):
356                return descriptor and _is_value(arg) and _is_tuple_construct(arg)
357
358            # Run regular symbolic function if none of the argument is QTensor.
359            is_quantized = []
360            for descriptor, arg in descriptor_args:
361                # ListConstruct
362                if _is_packed_list(arg):
363                    for arg_input in arg.node().inputs():
364                        is_quantized.append(_is_arg_quantized(descriptor, arg_input))
365                else:
366                    is_quantized.append(_is_arg_quantized(descriptor, arg))
367
368            if not any(is_quantized):
369                return fn(g, *args, **kwargs)
370
371            # Dequantize arguments that are quantized
372            non_quantized_args = []
373            for descriptor, arg in descriptor_args:
374                if _is_arg_quantized(descriptor, arg):
375                    # Quantized arg is a tuple of (value, scale, zero_point)
376                    dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper(
377                        g, arg
378                    )
379                    non_quantized_args.append(dequantized_arg)
380                    # Set scale and zero_point to the first quantized input if not already set
381                    if _scale is None:
382                        _scale = arg_scale
383                    if _zero_point is None:
384                        _zero_point = arg_zero_point
385                # ListConstruct
386                elif _is_packed_list(arg):
387                    for arg_input in arg.node().inputs():
388                        if _is_arg_quantized(descriptor, arg_input):
389                            # Quantized arg is a tuple of (value, scale, zero_point)
390                            (
391                                dequantized_arg,
392                                arg_scale,
393                                arg_zero_point,
394                                _,
395                            ) = dequantize_helper(g, arg_input)
396                            # Set scale and zero_point to the first quantized input if not already set
397                            if _scale is None:
398                                _scale = arg_scale
399                            if _zero_point is None:
400                                _zero_point = arg_zero_point
401                            arg_input.replaceAllUsesWith(dequantized_arg)
402                    non_quantized_args.append(arg)
403                else:
404                    # Non-quantized arg
405                    non_quantized_args.append(arg)
406            # TODO(justinchuby): Only single output is supported for now. We may want to
407            # support multiple outputs in the future.
408            output = fn(g, *non_quantized_args, **kwargs)
409
410            assert _scale is not None, "Bug: Scale must be set for quantized operator"
411            assert (
412                _zero_point is not None
413            ), "Bug: Zero point must be set for quantized operator"
414
415            if quantize_output:
416                return quantize_helper(g, output, _scale, _zero_point)
417            return output
418
419        return wrapper
420
421    return decorator
422
423
424def _scalar(x: Any) -> Number | None:
425    """Convert a scalar tensor into a Python value."""
426    if isinstance(x, torch.Tensor) and x.shape == ():
427        return x.item()
428    return None
429
430
431def _if_scalar_type_as(self, tensor):
432    """
433    Convert self into the same type of tensor, as necessary.
434    We only support implicit casting for scalars, so we never
435    actually need to insert an ONNX cast operator here; just
436    fix up the scalar.
437    """
438    if isinstance(self, _C.Value):
439        return self
440
441    scalar_type = _type_utils.JitScalarType.from_value(
442        tensor, _type_utils.JitScalarType.UNDEFINED
443    )
444    if scalar_type != _type_utils.JitScalarType.UNDEFINED:
445        ty = scalar_type.scalar_name().lower()
446        return getattr(self, ty)()
447    return self
448
449
450def _is_none(x: Any) -> bool:
451    return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False)
452
453
454def _is_value(x: Any) -> bool:
455    return isinstance(x, _C.Value)
456
457
458def _is_constant(value: Any) -> bool:
459    return not _is_value(value) or value.node().kind() in {
460        "onnx::Constant",
461        "prim::Constant",
462    }
463
464
465def _is_tensor(x: _C.Value) -> bool:
466    return x.type().isSubtypeOf(_C.TensorType.get())
467
468
469# Note: _C.JitType is not exposed to Python and cannot be checked in runtime.
470def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None:
471    if isinstance(jit_type, _C.ListType):
472        return jit_type
473    return None
474
475
476def _is_list(x: _C.Value) -> bool:
477    return _as_list_type(x.type()) is not None
478
479
480def _is_tensor_list(x: _C.Value) -> bool:
481    x_type = _as_list_type(x.type())
482    if x_type is None:
483        return False
484    return isinstance(x_type.getElementType(), _C.TensorType)
485
486
487def _is_scalar_list(x: _C.Value) -> bool:
488    """Checks if x is a scalar list, for example: List[float], List[int].
489
490    Besides checking the type is ListType, we also check if the data type is
491    a valid ONNX data type.
492    """
493    x_type = _as_list_type(x.type())
494    if x_type is None:
495        return False
496    scalar_type = _type_utils.JitScalarType.from_value(x)
497    return scalar_type.onnx_compatible()
498
499
500def _is_tuple_construct(x: _C.Value) -> bool:
501    return x.node().kind() == "prim::TupleConstruct"
502
503
504def is_complex_value(x: _C.Value) -> bool:
505    assert _is_value(x)
506    return _type_utils.JitScalarType.from_value(
507        x, _type_utils.JitScalarType.UNDEFINED
508    ) in {
509        _type_utils.JitScalarType.COMPLEX32,
510        _type_utils.JitScalarType.COMPLEX64,
511        _type_utils.JitScalarType.COMPLEX128,
512    }
513
514
515def _get_tensor_rank(x: _C.Value) -> int | None:
516    if not _is_tensor(x) or x.type() is None:
517        return None
518    x_type = x.type()
519    x_type = typing.cast(_C.TensorType, x_type)
520    return x_type.dim()
521
522
523def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
524    if not _is_tensor(x) or x.type() is None:
525        return None
526    x_type = x.type()
527    x_type = typing.cast(_C.TensorType, x_type)
528    if allow_nonstatic:
529        # Each individual symbol is returned as None.
530        # e.g. [1, "a", "b"] -> [1, None, None]
531        return x_type.varyingSizes()
532    # returns None, if exists any symbol in sizes.
533    # e.g. [1, "a", "b"] -> None
534    return x_type.sizes()
535
536
537def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None:
538    sizes = _get_tensor_sizes(x)
539    return sizes[dim] if sizes else None
540
541
542def _get_dim_for_cross(x: _C.Value, dim: int | None):
543    if dim == -1:
544        tensor_rank = _get_tensor_rank(x)
545        assert tensor_rank is not None
546        return dim + tensor_rank
547    # If dim is not given, it defaults to the first dimension found with the size 3
548    if dim is None:
549        sizes = _get_tensor_sizes(x)
550        assert sizes is not None
551        for index, size in enumerate(sizes):
552            if size is not None and size == 3:
553                return index
554    return dim
555
556
557def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None:
558    # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
559    if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
560        _onnx_unsupported(f"{op}, {msg}", value)
561
562
563def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn:
564    message = (
565        f"Unsupported: ONNX export of operator {op_name}. "
566        f"Please feel free to request support or submit a pull request "
567        f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}"
568    )
569    if isinstance(value, _C.Value):
570        raise errors.SymbolicValueError(
571            message,
572            value,
573        )
574    raise errors.OnnxExporterError(message)
575
576
577def _onnx_opset_unsupported(
578    op_name: str,
579    current_opset: int,
580    supported_opset: int,
581    value: _C.Value | None = None,
582) -> NoReturn:
583    message = (
584        f"Unsupported: ONNX export of {op_name} in opset {current_opset}. "
585        f"Please try opset version {supported_opset}."
586    )
587    if isinstance(value, _C.Value):
588        raise errors.SymbolicValueError(
589            message,
590            value,
591        )
592    raise errors.OnnxExporterError(message)
593
594
595def _onnx_opset_unsupported_detailed(
596    op_name: str,
597    current_opset: int,
598    supported_opset: int,
599    reason: str,
600    value: _C.Value | None = None,
601) -> NoReturn:
602    message = (
603        f"Unsupported: ONNX export of {op_name} in "
604        f"opset {current_opset}. {reason}. Please try opset version {supported_opset}."
605    )
606    if isinstance(value, _C.Value):
607        raise errors.SymbolicValueError(
608            message,
609            value,
610        )
611    raise errors.OnnxExporterError(message)
612
613
614def _block_list_in_opset(name: str):
615    def symbolic_fn(*args, **kwargs):
616        raise errors.OnnxExporterError(
617            f"ONNX export failed on {name}, which is not implemented for opset "
618            f"{GLOBALS.export_onnx_opset_version}. "
619            "Try exporting with other opset versions."
620        )
621
622    return symbolic_fn
623
624
625def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None:
626    for arg in args:
627        scalar_type = _type_utils.JitScalarType.from_value(
628            arg, _type_utils.JitScalarType.UNDEFINED
629        )
630        if scalar_type != _type_utils.JitScalarType.UNDEFINED:
631            return scalar_type
632    return None
633
634
635def _type_promote_from_values(*args) -> _type_utils.JitScalarType:
636    undef = _type_utils.JitScalarType.UNDEFINED
637    jit_types = [_try_get_scalar_type(arg) for arg in args]
638    if len(jit_types) == 0:
639        return undef
640    if len(jit_types) == 1:
641        return jit_types[0]  # type: ignore[return-value]
642    new_dtype = jit_types[0].dtype()  # type: ignore[union-attr]
643    for t in jit_types:
644        new_dtype = torch.promote_types(new_dtype, t.dtype())  # type: ignore[union-attr]
645    return _type_utils.JitScalarType.from_dtype(new_dtype)
646
647
648def _maybe_cast_to_type(
649    g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType
650):
651    if (
652        _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED)
653        != jit_type
654    ):
655        return g.op(
656            "Cast",
657            value,
658            to_i=jit_type.onnx_type(),
659        )
660    return value
661
662
663def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True):
664    index_const = _maybe_get_scalar(index)
665    index_dim = _get_tensor_rank(index)
666    if not _is_value(index_const):
667        # Index is a constant scalar. Make it a size 1 constant tensor.
668        index = g.op("Constant", value_t=torch.LongTensor([index_const]))
669    elif index_dim is not None and apply_reshape:
670        if index_dim == 0:
671            # Index is a scalar. Reshape it to a size 1 tensor.
672            index = _reshape_helper(
673                g, index, g.op("Constant", value_t=torch.LongTensor([1]))
674            )
675
676    index_scalar_type = _type_utils.JitScalarType.from_value(
677        index, _type_utils.JitScalarType.UNDEFINED
678    )
679    if index_scalar_type not in {
680        _type_utils.JitScalarType.INT64,
681        _type_utils.JitScalarType.INT,
682    }:
683        index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64)
684    return g.op("Gather", self, index, axis_i=dim)
685
686
687def _slice_helper(
688    g: jit_utils.GraphContext,
689    input,
690    axes,
691    starts,
692    ends,
693    steps=None,
694):
695    if g.opset <= 9:
696        from torch.onnx.symbolic_opset9 import _slice as _slice9
697
698        return _slice9(g, input, axes, starts, ends)
699    else:
700        from torch.onnx.symbolic_opset10 import _slice as _slice10
701
702        return _slice10(g, input, axes, starts, ends, steps)
703
704
705def _is_fp(value) -> bool:
706    return _type_utils.JitScalarType.from_value(
707        value, _type_utils.JitScalarType.UNDEFINED
708    ) in {
709        _type_utils.JitScalarType.FLOAT,
710        _type_utils.JitScalarType.DOUBLE,
711        _type_utils.JitScalarType.HALF,
712        _type_utils.JitScalarType.BFLOAT16,
713    }
714
715
716def _is_bool(value) -> bool:
717    return _type_utils.JitScalarType.from_value(
718        value, _type_utils.JitScalarType.UNDEFINED
719    ) in {_type_utils.JitScalarType.BOOL}
720
721
722def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
723    """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
724
725    A Tensor is a considered a "wrapped number" if it is
726    auto-wrapped from a C++ or Python number type. Integer types are
727    wrapped as 0-dim int64 tensors and floating-point types are
728    wrapped as 0-dim double tensors.
729
730    The input to this function is constant value. If the data type
731    is a floating point type, it is converted to a 0-dim double
732    tensor, else it is converted to a 0-dim tensor of its original type
733    """
734    assert not isinstance(scalar, torch.Tensor)
735    if isinstance(scalar, float):
736        return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double))
737    return g.op("Constant", value_t=torch.tensor(scalar))
738
739
740def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None):
741    if out is not None:
742        _unimplemented("Sort", "Out parameter is not supported")
743    shape_ = g.op("Shape", input)
744    dim_size_ = g.op(
745        "Gather",
746        shape_,
747        g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)),
748    )
749    if g.opset <= 10:
750        if not decending:
751            _unimplemented("Sort", "Ascending is not supported")
752        return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
753    else:
754        return g.op(
755            "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2
756        )
757
758
759def _topk_helper(
760    g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None
761):
762    if out is not None:
763        _unimplemented("TopK", "Out parameter is not supported")
764    if not _is_value(k):
765        k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64))
766    else:
767        k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1])))
768        if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64:
769            k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64)
770    if g.opset <= 10:
771        if not largest:
772            _unimplemented("TopK", "Ascending is not supported")
773        return g.op("TopK", input, k, axis_i=dim, outputs=2)
774    else:
775        return g.op(
776            "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2
777        )
778
779
780def _lt_helper(g: jit_utils.GraphContext, input, other):
781    if g.opset <= 8:
782        from torch.onnx.symbolic_opset8 import lt as _lt8
783
784        return _lt8(g, input, other)
785    else:
786        from torch.onnx.symbolic_opset9 import lt as _lt9
787
788        return _lt9(g, input, other)
789
790
791def _interpolate_warning(interpolate_mode):
792    onnx_op = (
793        "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
794    )
795    warnings.warn(
796        "You are trying to export the model with "
797        + onnx_op
798        + " for ONNX opset version "
799        "" + str(GLOBALS.export_onnx_opset_version) + ". "
800        "This operator might cause results to not match the expected results by PyTorch.\n"
801        "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. "
802        "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 "
803        "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n"
804        "We recommend using opset 11 and above for models using this operator."
805    )
806
807
808def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
809    if _is_constant(axes_i[0]):
810        if g.opset >= 13:
811            axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
812            return g.op("Unsqueeze", input, axes)
813        return g.op("Unsqueeze", input, axes_i=axes_i)
814    # Tensor type
815    if g.opset < 13:
816        raise errors.SymbolicValueError(
817            "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input
818        )
819    return g.op("Unsqueeze", input, axes_i[0])
820
821
822def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
823    if _is_constant(axes_i[0]):
824        if g.opset >= 13:
825            axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
826            return g.op("Squeeze", input, axes)
827        return g.op("Squeeze", input, axes_i=axes_i)
828    # Tensor type
829    if g.opset < 13:
830        raise errors.SymbolicValueError(
831            "Opset version must be >= 13 for Squeeze with dynamic axes.", input
832        )
833    axes_t = axes_i[0]
834    axes_rank = _get_tensor_rank(axes_t)
835    assert axes_rank is not None
836    if axes_rank > 1:
837        raise errors.SymbolicValueError(
838            "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input
839        )
840    elif axes_rank == 0:
841        # The axes is a scalar. Unsqueeze it to a rank 1 tensor.
842        axes_t = _unsqueeze_helper(g, axes_t, [0])
843        return g.op("Squeeze", input, axes_t)
844    return g.op("Squeeze", input, axes_t)
845
846
847def _reducesum_helper(
848    g: jit_utils.GraphContext,
849    input,
850    axes_i=None,
851    keepdims_i=1,
852    noop_with_empty_axes_i=0,
853):
854    keepdims_i = _maybe_get_const(keepdims_i, "i")
855    if g.opset >= 13:
856        if axes_i:
857            if not _is_value(axes_i):
858                axes_i = g.op(
859                    "Constant", value_t=torch.tensor(axes_i, dtype=torch.long)
860                )
861            return g.op(
862                "ReduceSum",
863                input,
864                axes_i,
865                keepdims_i=keepdims_i,
866                noop_with_empty_axes_i=noop_with_empty_axes_i,
867            )
868        return g.op(
869            "ReduceSum",
870            input,
871            keepdims_i=keepdims_i,
872            noop_with_empty_axes_i=noop_with_empty_axes_i,
873        )
874    else:
875        return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
876
877
878def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim):
879    output_size = _maybe_get_const(output_size, "is")
880    if _is_value(output_size):
881        offset = 2
882        offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32))
883        dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT)
884        divisor = _slice_helper(
885            g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset]
886        )
887        divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT)
888        scale_dims = g.op("Div", dividend, divisor)
889        scales = g.op("Concat", offsets, scale_dims, axis_i=0)
890    else:
891        scales_constant = [
892            1.0
893            if i < 2
894            else float(output_size[-(dim - i)])
895            / float(input.type().sizes()[-(dim - i)])
896            for i in range(0, dim)
897        ]
898        scales = g.op(
899            "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32)
900        )
901    return scales
902
903
904def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
905    available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
906        scales[0]
907    )
908
909    if not available_scales:
910        return None
911
912    offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
913    scales_list = g.op(
914        "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs"))
915    )
916    scales = g.op("Concat", offsets, scales_list, axis_i=0)
917    return scales
918
919
920def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
921    if mode == "nearest":
922        align_corners = None
923        scales = args[0:]
924    else:
925        align_corners = args[0]
926        scales = args[1:]
927    scales = _interpolate_get_scales_if_available(g, scales)
928    return scales, align_corners
929
930
931def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
932    offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
933    scale_factor_rank = _get_tensor_rank(scale_factor)
934    if isinstance(scale_factor.type(), _C.ListType) or (
935        scale_factor_rank is not None and scale_factor_rank > 0
936    ):
937        return g.op("Concat", offsets, scale_factor, axis_i=0)
938    else:
939        scale_factor = _unsqueeze_helper(g, scale_factor, [0])
940        scale_factor = g.op(
941            "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT
942        )
943        scales = [scale_factor for i in range(dim - 2)]
944    scale_factor = g.op("Concat", offsets, *scales, axis_i=0)
945    return scale_factor
946
947
948def _interpolate_get_scales_and_mode(
949    g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners
950):
951    mode = _maybe_get_const(mode, "s")
952    if "linear" in mode:
953        mode = "linear"
954    if "cubic" in mode:
955        mode = "cubic"
956    _interpolate_warning(mode)
957
958    align_corners = _maybe_get_const(align_corners, "b")
959    if isinstance(align_corners, bool) and align_corners:
960        return _unimplemented("interpolate", "align_corners == True")
961
962    if not input.type().dim():
963        return _unimplemented("interpolate", "missing input shape")
964    dim = input.type().dim()
965
966    if not _is_none(scale_factor):
967        scale_factor = _interpolate_get_scales(g, scale_factor, dim)
968    elif not _is_none(size):
969        if not _is_packed_list(size):
970            is_scalar = _maybe_get_const(size, "t").dim() == 0
971            if is_scalar:
972                size = _unsqueeze_helper(g, size, [0])
973                size = [size for i in range(dim - 2)]
974                size = g.op("Concat", *size, axis_i=0)
975        scale_factor = _interpolate_size_to_scales(g, input, size, dim)
976    else:
977        return _unimplemented(
978            "interpolate", "Both size and scales are None in __interpolate"
979        )
980    return scale_factor, mode
981
982
983def _argmin_argmax_helper(
984    g: jit_utils.GraphContext,
985    input: torch._C.Value,
986    dim: torch._C.Value,
987    keepdim: bool,
988    op_name: str,
989):
990    def op_wrapper(input, axis_i, keepdims_i):
991        if g.opset >= 12:
992            return g.op(
993                op_name,
994                input,
995                axis_i=axis_i,
996                keepdims_i=keepdims_i,
997                select_last_index_i=False,
998            )
999        return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i)
1000
1001    if _is_none(dim):
1002        flattened = _reshape_helper(
1003            g, input, g.op("Constant", value_t=torch.tensor([-1]))
1004        )
1005        output = op_wrapper(flattened, axis_i=0, keepdims_i=False)
1006        if keepdim:
1007            input_shape = g.op("Shape", input)
1008            input_shape_shape = g.op("Shape", input_shape)
1009            new_shape = g.op(
1010                "ConstantOfShape",
1011                input_shape_shape,
1012                value_t=torch.tensor([1], dtype=torch.int64),
1013            )
1014            output = g.op("Reshape", output, new_shape)
1015        return output
1016
1017    dim = _parse_arg(dim, "i")
1018    return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
1019
1020
1021def _interpolate_helper(name, dim, interpolate_mode):
1022    @quantized_args(True, False, False)
1023    def symbolic_fn(g, input, output_size, *args):
1024        scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args)
1025        align_corners = _maybe_get_scalar(align_corners)
1026        coordinate_transformation_mode = (
1027            "asymmetric"
1028            if interpolate_mode == "nearest"
1029            else "align_corners"
1030            if align_corners
1031            else "half_pixel"
1032        )
1033
1034        if scales is None:
1035            input_size = g.op("Shape", input)
1036            input_size_beg = _slice_helper(
1037                g, input_size, axes=[0], ends=[2], starts=[0]
1038            )
1039            output_size = g.op(
1040                "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64
1041            )
1042            output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
1043
1044            if g.opset >= 13:
1045                empty_roi = _optional_input_placeholder_tensor(g)
1046                empty_scales = _optional_input_placeholder_tensor(g)
1047            else:
1048                empty_roi = g.op(
1049                    "Constant", value_t=torch.tensor([], dtype=torch.float32)
1050                )
1051                empty_scales = g.op(
1052                    "Constant", value_t=torch.tensor([], dtype=torch.float32)
1053                )
1054
1055            return g.op(
1056                "Resize",
1057                input,
1058                empty_roi,
1059                empty_scales,
1060                output_size,
1061                coordinate_transformation_mode_s=coordinate_transformation_mode,
1062                cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
1063                mode_s=interpolate_mode,  # nearest, linear, or cubic
1064                nearest_mode_s="floor",
1065            )  # only valid when mode="nearest"
1066        else:
1067            if g.opset >= 13:
1068                empty_roi = _optional_input_placeholder_tensor(g)
1069            else:
1070                empty_roi = g.op(
1071                    "Constant", value_t=torch.tensor([], dtype=torch.float32)
1072                )
1073
1074            return g.op(
1075                "Resize",
1076                input,
1077                empty_roi,
1078                scales,
1079                coordinate_transformation_mode_s=coordinate_transformation_mode,
1080                cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
1081                mode_s=interpolate_mode,  # nearest, linear, or cubic
1082                nearest_mode_s="floor",
1083            )  # only valid when mode="nearest"
1084
1085    return symbolic_fn
1086
1087
1088def __interpolate_helper(
1089    g: jit_utils.GraphContext,
1090    input,
1091    size,
1092    scale_factor,
1093    mode,
1094    align_corners,
1095    recompute_scale_factor,
1096):
1097    mode = _maybe_get_const(mode, "s")
1098    if "linear" in mode:
1099        mode = "linear"
1100    if "cubic" in mode:
1101        mode = "cubic"
1102    align_corners = _maybe_get_const(align_corners, "b")
1103    align_corners = False if not isinstance(align_corners, bool) else align_corners
1104    coordinate_transformation_mode = (
1105        "asymmetric"
1106        if mode == "nearest"
1107        else "align_corners"
1108        if align_corners
1109        else "half_pixel"
1110    )
1111
1112    if not _is_none(size):
1113        input_size = g.op("Shape", input)
1114        input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
1115        # in some cases size is not a packed list but size is a scalar
1116        # We need to also verify that (_maybe_get_const(size, "t").dim() == 0)
1117        # but this information is not always available. Try to get the dim,
1118        # and if not assume that it is not a scalar.
1119        try:
1120            is_scalar = not _is_packed_list(size) and (
1121                _maybe_get_const(size, "t").dim() == 0
1122            )
1123        except AttributeError:
1124            is_scalar = not _is_packed_list(size)
1125            if not is_scalar:
1126                warnings.warn(
1127                    "Cannot verify if the output_size is a scalar "
1128                    "while exporting interpolate. Assuming that it is not a scalar."
1129                )
1130
1131        if is_scalar:
1132            rank = _get_tensor_rank(input)
1133            if rank is None:
1134                return _unimplemented(
1135                    "interpolate (with a scalar output_size)",
1136                    "missing input shape (try giving an array of output_size values)",
1137                )
1138            size = _unsqueeze_helper(g, size, [0])
1139            size = [size for i in range(rank - 2)]
1140            size = g.op("Concat", *size, axis_i=0)
1141        size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64)
1142        size = g.op("Concat", input_size, size, axis_i=0)
1143
1144        if g.opset >= 13:
1145            empty_roi = _optional_input_placeholder_tensor(g)
1146            empty_scales = _optional_input_placeholder_tensor(g)
1147        else:
1148            empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
1149            empty_scales = g.op(
1150                "Constant", value_t=torch.tensor([], dtype=torch.float32)
1151            )
1152
1153        return g.op(
1154            "Resize",
1155            input,
1156            empty_roi,
1157            empty_scales,
1158            size,
1159            coordinate_transformation_mode_s=coordinate_transformation_mode,
1160            cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
1161            mode_s=mode,  # nearest, linear, or cubic
1162            nearest_mode_s="floor",
1163        )
1164    else:  # if not _is_none(scales)
1165        rank = _get_tensor_rank(input)
1166        if rank is None:
1167            return _unimplemented("interpolate (with scales)", "missing input shape")
1168
1169        if g.opset >= 13:
1170            empty_roi = _optional_input_placeholder_tensor(g)
1171        else:
1172            empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
1173
1174        scales = _interpolate_get_scales(g, scale_factor, rank)
1175        return g.op(
1176            "Resize",
1177            input,
1178            empty_roi,
1179            scales,
1180            coordinate_transformation_mode_s=coordinate_transformation_mode,
1181            cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
1182            mode_s=mode,  # nearest, linear, or cubic
1183            nearest_mode_s="floor",
1184        )  # only valid when mode="nearest"
1185
1186
1187def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
1188    if g.opset < 11:
1189        from torch.onnx.symbolic_opset9 import unbind
1190    elif g.opset <= 12:
1191        from torch.onnx.symbolic_opset11 import unbind  # type: ignore[no-redef]
1192    else:
1193        from torch.onnx.symbolic_opset13 import unbind  # type: ignore[no-redef]
1194    return unbind(g, self, dim, _outputs)
1195
1196
1197def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
1198    if g.opset <= 10:
1199        from torch.onnx.symbolic_opset9 import scatter
1200    else:
1201        # for mypy, scatter was imported two lines above
1202        from torch.onnx.symbolic_opset11 import scatter  # type: ignore[no-redef]
1203    return scatter(g, self, dim, index, src)
1204
1205
1206def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
1207    if g.opset <= 12:
1208        split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
1209    else:
1210        from torch.onnx.symbolic_opset13 import split
1211
1212        repeats = g.op("Constant", value_t=torch.tensor([1] * reps))
1213        split_out = split(g, self, repeats, dim, _outputs=reps)
1214    return split_out if reps > 1 else [split_out]
1215
1216
1217def _repeat_interleave_single_value_repeat_helper(
1218    g: jit_utils.GraphContext, self, repeats, dim
1219):
1220    from torch.onnx.symbolic_opset9 import flatten, unsqueeze
1221
1222    if not _is_tensor(repeats):
1223        repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
1224
1225    const_repeats: bool = _is_constant(repeats)
1226    reps = _maybe_get_const(repeats, "t")
1227
1228    # Convert 'repeats' to 1-d if it is 0-d.
1229    if _get_tensor_rank(repeats) == 0:
1230        repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1])))
1231
1232    # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it.
1233    unsqueezed = unsqueeze(g, self, dim + 1)
1234
1235    # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'.
1236    if const_repeats:
1237        # 'Repeats' is a constant, 'repeats_per_dim' can be a constant.
1238        onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64)  # type: ignore[arg-type]
1239        onehot[dim + 1] = reps
1240        repeats_per_dim = g.op("Constant", value_t=onehot)
1241    else:
1242        # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant.
1243        onehot = g.op(
1244            "OneHot",
1245            unsqueeze(g, dim + 1, 0),  # indices, must be >= 1-dimensional
1246            g.op(
1247                "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed))
1248            ),  # depth
1249            g.op(
1250                "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0
1251            ),  # on/off values
1252        )
1253        repeats_per_dim = flatten(g, onehot, 0, 1)
1254
1255    tiled = g.op("Tile", unsqueezed, repeats_per_dim)
1256    return flatten(g, tiled, dim, dim + 1)
1257
1258
1259def _arange_cast_helper(
1260    g: jit_utils.GraphContext, end, start=None, step=None, dtype=None
1261) -> tuple[
1262    _type_utils.JitScalarType,
1263    _C.Value | None,
1264    _C.Value | None,
1265    _C.Value | None,
1266]:
1267    def _is_all_integral(scalars):
1268        for scalar in scalars:
1269            scalar_type = _type_utils.JitScalarType.from_value(
1270                scalar, _type_utils.JitScalarType.UNDEFINED
1271            )
1272            if (
1273                scalar_type != _type_utils.JitScalarType.INT64
1274                and scalar_type != _type_utils.JitScalarType.UNDEFINED
1275            ):
1276                return False
1277        return True
1278
1279    # This logic is based on torch.arange docs. If "dtype" is provided,
1280    # infer input types from dtype. If not, then check if any of start, stop,
1281    # or step are floating point, and infer the type from get_default.
1282    # Otherwise, the dtype is inferred to be torch.int64.
1283    if dtype is None or (_is_value(dtype) and _is_none(dtype)):
1284        if _is_all_integral([start, end, step]):
1285            scalar_type = _type_utils.JitScalarType.INT64
1286        else:
1287            scalar_type = _type_utils.JitScalarType.from_dtype(
1288                torch.get_default_dtype()
1289            )
1290    else:
1291        assert isinstance(dtype, int)
1292        # TODO(justinchuby): Check if dtype is indeed a int.
1293        scalar_type = _type_utils.JitScalarType(dtype)
1294
1295    start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None
1296    end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None
1297    step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None
1298    return scalar_type, end, start, step
1299
1300
1301def _arange_helper(g: jit_utils.GraphContext, *args):
1302    if g.opset <= 10:
1303        from torch.onnx.symbolic_opset9 import arange
1304    else:
1305        from torch.onnx.symbolic_opset11 import arange  # type: ignore[no-redef]
1306    return arange(g, *args)
1307
1308
1309def _size_helper(g: jit_utils.GraphContext, self, dim):
1310    full_shape = g.op("Shape", self)
1311    from torch.onnx.symbolic_opset9 import select
1312
1313    return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
1314
1315
1316def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
1317    # 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
1318    # 2. expand index => [..., dim, ...], same shape as self except for dim.
1319    # 3. expand value as well.
1320    # 4. apply onnx::scatter.
1321
1322    from torch.onnx.symbolic_opset9 import expand
1323
1324    if g.opset <= 10:
1325        from torch.onnx.symbolic_opset9 import scatter
1326    else:
1327        # for mypy, scatter was imported two lines above
1328        from torch.onnx.symbolic_opset11 import scatter  # type: ignore[no-redef]
1329
1330    if self.type().dim() is None:
1331        return _unimplemented("index_fill", "input rank not accessible")
1332    self_dim = self.type().dim()
1333    dim_value = _parse_arg(dim, "i")
1334    if dim_value < 0:
1335        dim_value += self_dim
1336    unsqueezed_index = _unsqueeze_helper(
1337        g, index, [i for i in range(self_dim) if i != dim_value]
1338    )
1339    expanded_index_shape = scatter(
1340        g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index)
1341    )
1342    expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None)
1343    return expanded_index_shape, expanded_index
1344
1345
1346# By default, when any value in the 'shape' input is equal to zero
1347# the corresponding dimension value is copied from the input tensor dynamically.
1348# allowzero=1 indicates that if any value in the 'shape' input is set to zero,
1349# the zero value is honored, similar to NumPy.
1350# allowzero=1 is only supported for opset version >= 14.
1351def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
1352    shape = _maybe_get_const(shape, "is")
1353    if not _is_value(shape):
1354        shape = g.op("Constant", value_t=torch.LongTensor(shape))
1355    if g.opset <= 13:
1356        if allowzero == 1:
1357            _onnx_opset_unsupported(
1358                "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input
1359            )
1360        return g.op("Reshape", input, shape)
1361    else:
1362        return g.op("Reshape", input, shape, allowzero_i=allowzero)
1363
1364
1365def _batchnorm_helper(
1366    g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var
1367):
1368    from torch.onnx.symbolic_opset9 import _var_mean
1369
1370    batch_size = _get_tensor_dim_size(input, 0)
1371    channel_size = _get_tensor_dim_size(input, 1)
1372
1373    if weight is None or _is_none(weight):
1374        if channel_size is None:
1375            raise errors.SymbolicValueError(
1376                "Unsupported: ONNX export of batch_norm for unknown channel size.",
1377                input,
1378            )
1379        weight_value = torch.tensor(
1380            [1.0] * channel_size,
1381            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
1382        )
1383        weight = g.op("Constant", value_t=weight_value)
1384    if bias is None or _is_none(bias):
1385        if channel_size is None:
1386            raise errors.SymbolicValueError(
1387                "Unsupported: ONNX export of batch_norm for unknown channel size.",
1388                input,
1389            )
1390        bias_value = torch.tensor(
1391            [0.0] * channel_size,
1392            dtype=_type_utils.JitScalarType.from_value(input).dtype(),
1393        )
1394        bias = g.op("Constant", value_t=bias_value)
1395    # If track_running_stats is set to False batch statistics are instead used during evaluation time
1396    if (
1397        running_mean is None
1398        or _is_none(running_mean)
1399        or running_var is None
1400        or _is_none(running_var)
1401    ):
1402        assert batch_size is not None and channel_size is not None
1403        reshape_in = _reshape_helper(
1404            g,
1405            input,
1406            g.op(
1407                "Constant",
1408                value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64),
1409            ),
1410        )
1411        trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1])
1412        running_var, running_mean = _var_mean(
1413            g,
1414            trans_in,
1415            g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)),
1416            False,
1417            False,
1418        )
1419    return weight, bias, running_mean, running_var
1420
1421
1422def _avgpool_helper(
1423    tuple_fn: Callable[[Any], Sequence[int]],
1424    padding: int | Sequence[int],
1425    kernel_size,
1426    stride,
1427    divisor_override,
1428    name,
1429) -> tuple[int, ...]:
1430    if divisor_override and divisor_override.node().kind() != "prim::Constant":
1431        _unimplemented(name, "divisor_override")
1432    return tuple(tuple_fn(padding))
1433
1434
1435def check_training_mode(op_train_mode: int, op_name: str) -> None:
1436    """Warns the user if the model's training mode and the export mode do not agree."""
1437    if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
1438        return
1439
1440    if op_train_mode:
1441        op_mode_enum = _C_onnx.TrainingMode.TRAINING
1442    else:
1443        op_mode_enum = _C_onnx.TrainingMode.EVAL
1444    if op_mode_enum == GLOBALS.training_mode:
1445        # The modes agree. Do nothing
1446        return
1447
1448    op_mode_text = f"train={bool(op_train_mode)}"
1449    # Setting the model mode could result in op_mode != GLOBALS.training_mode
1450    # if the model is a FuncModule. In this case we warn the user of
1451    # the state and export depending on op_mode
1452    # This is to support use-cases of fixing certain layer weights
1453    # in training.
1454    warnings.warn(
1455        f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' "
1456        f"is set to {op_mode_text}. Exporting with {op_mode_text}."
1457    )
1458
1459
1460def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
1461    input_size = g.op("Shape", input)
1462    slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
1463    slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
1464    if end_dim < dim - 1:
1465        slice3 = _slice_helper(
1466            g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim]
1467        )
1468        slices = [
1469            slice1,
1470            g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
1471            slice3,
1472        ]
1473
1474    final_shape = g.op("Concat", *slices, axis_i=0)
1475    from torch.onnx.symbolic_opset9 import _reshape_from_tensor
1476
1477    return _reshape_from_tensor(g, input, final_shape)
1478
1479
1480def _is_split_static(split_size_or_sizes, _outputs):
1481    if _outputs is None:
1482        return False
1483    if (
1484        _is_value(split_size_or_sizes)
1485        and split_size_or_sizes.node().kind() != "onnx::Constant"
1486    ):
1487        return False
1488    return True
1489
1490
1491def _optional_input_placeholder_tensor(g):
1492    n = g.op("prim::Constant")
1493    n.setType(_C.OptionalType.ofTensor())
1494    return n
1495
1496
1497def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
1498    rank = _get_tensor_rank(self)
1499    if rank is not None and any(
1500        _get_tensor_dim_size(self, i) == 0 for i in range(rank)
1501    ):
1502        # If input tensor is empty, according to ONNX ReduceSum definition,
1503        # set keepdims=1 so that the resulted tensor has the same rank as the input.
1504        return g.op(op_name, self, keepdims_i=1)
1505    return g.op(op_name, self, keepdims_i=0)
1506
1507
1508def dequantize_helper(
1509    g: jit_utils.GraphContext,
1510    qtensor: _C.Value,
1511    qdtype: _C_onnx.TensorProtoDataType | None = None,
1512) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]:
1513    """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
1514
1515    Args:
1516        g: Graph, the ONNX IR graph that is under construction.
1517        qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point)
1518            for per tensor quantization, or
1519            (quantized_tensor, scale, zero_point, axis) for per channel quantization,
1520            representing the quantized tensor.
1521        qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the
1522            data type of quantized tensor. It must be either
1523            torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8.
1524    """
1525    unpacked_qtensors = _unpack_quantized_tensor(qtensor)
1526    tensor, scale, zero_point = unpacked_qtensors[:3]
1527    axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None
1528    axis_i = _get_const(axis, "i", "axis")
1529    input_qdtype = _type_utils.JitScalarType.from_value(tensor)
1530    if qdtype is None:
1531        if input_qdtype is not None:
1532            qdtype = input_qdtype.onnx_type()
1533        else:
1534            qdtype = _C_onnx.TensorProtoDataType.UINT8
1535    value = g.op("Cast", tensor, to_i=qdtype)
1536    scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
1537    zero_point = g.op("Cast", zero_point, to_i=qdtype)
1538
1539    if axis_i is not None and GLOBALS.export_onnx_opset_version < 13:
1540        _onnx_opset_unsupported_detailed(
1541            "DequantizeLinear",
1542            GLOBALS.export_onnx_opset_version,
1543            13,
1544            "Attribute axis is not supported.",
1545            qtensor,
1546        )
1547
1548    return (
1549        g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i),
1550        scale,
1551        zero_point,
1552        axis,
1553    )
1554
1555
1556def quantize_helper(
1557    g: jit_utils.GraphContext,
1558    tensor: _C.Value,
1559    scale: _C.Value,
1560    zero_point: _C.Value,
1561    axis: _C.Value | None = None,
1562) -> _C.Value:
1563    """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`.
1564
1565    Args:
1566        g: Graph, the ONNX IR graph that is under construction.
1567        tensor: torch._C.Value, representing the tensor to be quantized.
1568        scale: torch._C.Value, quantized scale.
1569        zero_point: torch._C.Value, quantized zero point.
1570        axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization.
1571            Otherwise, represents per channel quantization, along given axis.
1572
1573    Returns:
1574        A TupleConstruct storing information of the quantized tensor.
1575    """
1576    if (
1577        axis is not None
1578        and not _is_none(axis)
1579        and GLOBALS.export_onnx_opset_version < 13
1580    ):
1581        _onnx_opset_unsupported_detailed(
1582            "QuantizeLinear",
1583            GLOBALS.export_onnx_opset_version,
1584            13,
1585            "Attribute axis is not supported.",
1586            tensor,
1587        )
1588
1589    assert scale is not None
1590    if (
1591        _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
1592        != _type_utils.JitScalarType.FLOAT
1593    ):
1594        scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
1595
1596    assert zero_point is not None
1597    if _type_utils.JitScalarType.from_value(
1598        zero_point, _type_utils.JitScalarType.UNDEFINED
1599    ) not in {
1600        _type_utils.JitScalarType.UINT8,
1601        _type_utils.JitScalarType.INT8,
1602    }:
1603        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
1604    output = g.op(
1605        "QuantizeLinear",
1606        tensor,
1607        scale,
1608        zero_point,
1609        axis_i=_get_const(axis, "i", "axis"),
1610    )
1611    args = [output, scale, zero_point]
1612    if axis is not None and not _is_none(axis):
1613        args.append(axis)
1614    return g.op("prim::TupleConstruct", *args)
1615
1616
1617def requantize_bias_helper(
1618    g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None
1619):
1620    """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel.
1621    In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized.
1622    Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using
1623    regular operators.
1624    """
1625    bias_scale = g.op("Mul", weight_scale, input_scale)
1626    bias_scale_shape = g.op("Shape", bias_scale)
1627    bias_zero_point = g.op(
1628        "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int)
1629    )
1630    q_bias = g.op(
1631        "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32
1632    )
1633    axis_args = []
1634    if axis is not None and not _is_none(axis):
1635        axis_args.append(axis)
1636    return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
1637
1638
1639def args_have_same_dtype(args):
1640    assert args
1641    base_dtype = _type_utils.JitScalarType.from_value(args[0])
1642    has_same_dtype = all(
1643        _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args
1644    )
1645    return has_same_dtype
1646
1647
1648def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
1649    """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
1650    This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
1651    operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
1652    `Clip<int>(INPUT)` (opset version < 12).
1653
1654    Args:
1655        g (torch._C.Graph): graph to write the ONNX representation into.
1656        op_name (str): operator name in ONNX.
1657        *args (tuple): operands to the operator.
1658        **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
1659            indicating the smallest opset version to trigger such casting behavior and "target_float_t"
1660            (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
1661
1662    Returns:
1663        Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
1664    """
1665    opset_before = kwargs.pop("opset_before", None)
1666    target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
1667
1668    inputs = list(args)
1669    dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
1670
1671    require_cast = not _is_fp(inputs[0]) and (
1672        opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
1673    )
1674
1675    if require_cast:
1676        for input in inputs:
1677            if input.isCompleteTensor():
1678                input_scalar_type = _type_utils.JitScalarType.from_value(input)
1679                if input_scalar_type != dtype_0:
1680                    raise errors.SymbolicValueError(
1681                        f"Inputs of {op_name} must have same dtype."
1682                        f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
1683                        input,
1684                    )
1685        for i, input in enumerate(inputs):
1686            if input.isCompleteTensor() and not _is_fp(input):
1687                inputs[i] = g.op(
1688                    "Cast",
1689                    input,
1690                    to_i=target_float_t.onnx_type(),
1691                )
1692
1693    self = g.op(op_name, *inputs, **kwargs)
1694
1695    if require_cast:
1696        self = g.op("Cast", self, to_i=dtype_0.onnx_type())
1697
1698    return self
1699
1700
1701def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
1702    scalar_type = _type_utils.JitScalarType.from_value(
1703        self, _type_utils.JitScalarType.UNDEFINED
1704    )
1705    if scalar_type != _type_utils.JitScalarType.UNDEFINED:
1706        # This check only covers traced modules where dtype is present
1707        # pytorch reduce-ops cast all other integral types to int64
1708        if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64:
1709            self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
1710    return self
1711
1712
1713def _apply_params(*args, **kwargs):
1714    """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
1715
1716    def _apply(fn):
1717        return fn(*args, **kwargs)
1718
1719    return _apply
1720
1721
1722def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
1723    def symbolic(g, self, dim=None, keepdim=None):
1724        self = _maybe_cast_reduce_op_input(g, self)
1725        if dim is None or dim == ():
1726            # Dim can be 0, which will cause (not dim) == True. So we don't want to do
1727            # (not dim)
1728            # all-reduce path
1729            return _handle_reduce_dim_none(g, self, onnx_op_name)
1730        else:
1731            # dim-reduce path
1732            keepdim = _get_const(keepdim, "i", "keepdim")
1733            if g.opset < 18:
1734                desc = "is" if allow_multi_dim_support else "i"
1735                dim = _get_const(dim, desc, "dim")
1736                dim_list = dim if allow_multi_dim_support else [dim]
1737                return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
1738            else:
1739                if _is_value(dim):
1740                    axes = dim
1741                else:
1742                    if allow_multi_dim_support:
1743                        axes = g.op(
1744                            "Constant", value_t=torch.tensor(dim, dtype=torch.long)
1745                        )
1746                    else:
1747                        axes = g.op(
1748                            "Constant", value_t=torch.tensor([dim], dtype=torch.long)
1749                        )
1750                return g.op(onnx_op_name, self, axes, keepdims_i=keepdim)
1751
1752    return symbolic
1753
1754
1755def _overload_by_arg_count(fn):
1756    @functools.wraps(fn)
1757    def wrapper(g, *args):
1758        overloads = fn(g, *args)
1759        for overload in overloads:
1760            arg_descriptors = overload._arg_descriptors
1761            if len(arg_descriptors) == len(args):
1762                return overload(g, *args)
1763        return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments")
1764
1765    return wrapper
1766
1767
1768def _reduce_with_dtype_helper(
1769    onnx_op: str, name: str, allow_multi_dim_support: bool = True
1770):
1771    symbolic = _reduce_op_symbolic_helper(
1772        onnx_op, allow_multi_dim_support=allow_multi_dim_support
1773    )
1774
1775    @_overload_by_arg_count
1776    def reduce(g, *args, **kwargs):
1777        @quantized_args(True)
1778        @parse_args("v", "none")
1779        def reduce_nodim(g, self, dtype):
1780            dtype_onnx = None
1781            if dtype.node().kind() == "onnx::Constant":
1782                dtype = _get_const(dtype, "i", "dtype")
1783                dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
1784                self = g.op("Cast", self, to_i=dtype_onnx)
1785            elif dtype.node().kind() != "prim::Constant":
1786                return _unimplemented(name, "dtype", dtype)
1787            result = symbolic(g, self)
1788            if dtype_onnx is not None:
1789                result_dtype_onnx = _type_utils.JitScalarType.from_value(
1790                    result
1791                ).onnx_type()
1792                if result_dtype_onnx != dtype_onnx:
1793                    result = g.op("Cast", result, to_i=dtype_onnx)
1794            return result
1795
1796        dim_desc = "is" if allow_multi_dim_support else "i"
1797
1798        @quantized_args(True)
1799        @parse_args("v", dim_desc, "i", "none")  # type: ignore[arg-type]
1800        def reduce_dim(g, self, dim, keepdim, dtype):
1801            dtype_onnx = None
1802            if dtype.node().kind() == "onnx::Constant":
1803                dtype = _get_const(dtype, "i", "dtype")
1804                dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
1805                self = g.op("Cast", self, to_i=dtype_onnx)
1806            elif dtype.node().kind() != "prim::Constant":
1807                return _unimplemented(name, "dtype", dtype)
1808            result = symbolic(g, self, dim, keepdim)
1809            if dtype_onnx is not None:
1810                result_dtype_onnx = _type_utils.JitScalarType.from_value(
1811                    result
1812                ).onnx_type()
1813                if result_dtype_onnx != dtype_onnx:
1814                    result = g.op("Cast", result, to_i=dtype_onnx)
1815            return result
1816
1817        return reduce_nodim, reduce_dim
1818
1819    return reduce
1820
1821
1822def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
1823    # torch.max(input)
1824    if dim_or_y is None and keepdim is None:
1825        return g.op("ReduceMax", self, keepdims_i=0)
1826    # torch.max(input, other)
1827    if keepdim is None:
1828        return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
1829    # torch.max(input, dim, keepdim)
1830    else:
1831        keepdim = _get_const(keepdim, "i", "keepdim")
1832        dim = _get_const(dim_or_y, "i", "dim")
1833        if g.opset < 18:
1834            max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
1835        else:
1836            axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
1837            max = g.op("ReduceMax", self, axes, keepdims_i=keepdim)
1838        indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
1839        return max, indices
1840
1841
1842def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
1843    # torch.min(input)
1844    if dim_or_y is None and keepdim is None:
1845        return g.op("ReduceMin", self, keepdims_i=0)
1846    # torch.min(input, other)
1847    if keepdim is None:
1848        return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
1849    # torch.min(input, dim, keepdim)
1850    else:
1851        keepdim = _get_const(keepdim, "i", "keepdim")
1852        dim = _get_const(dim_or_y, "i", "dim")
1853        if g.opset < 18:
1854            min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
1855        else:
1856            axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
1857            min = g.op("ReduceMin", self, axes, keepdims_i=keepdim)
1858        indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
1859        return min, indices
1860
1861
1862def _numel_helper(g: jit_utils.GraphContext, self):
1863    shape = g.op("Shape", self)
1864    return g.op("ReduceProd", shape, keepdims_i=0)
1865
1866
1867@parse_args("v", "is", "i", "i")
1868def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim):
1869    if g.opset < 18:
1870        if dim is None:
1871            mean = g.op("ReduceMean", input, keepdims_i=0)
1872            t_mean = mean
1873            num_elements = _numel_helper(g, input)
1874        else:
1875            mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
1876            t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
1877            redudced_dims = g.op("Shape", input)
1878            # dim could contain one or multiple dimensions
1879            redudced_dims = g.op(
1880                "Gather",
1881                redudced_dims,
1882                g.op("Constant", value_t=torch.tensor(dim)),
1883                axis_i=0,
1884            )
1885            num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
1886        sub_v = g.op("Sub", input, t_mean)
1887        sqr_sub = g.op("Mul", sub_v, sub_v)
1888        keepdim_mean = 0 if dim is None else keepdim
1889        var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
1890        # Correct bias in calculating variance, by dividing it over (N - correction) instead on N
1891        if correction is None:
1892            correction = 1
1893        if correction != 0:
1894            num_elements = g.op(
1895                "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
1896            )
1897            one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
1898            mul = g.op("Mul", var, num_elements)
1899            var = g.op("Div", mul, g.op("Sub", num_elements, one))
1900        return var, mean
1901    else:
1902        axes = None
1903        if dim is None:
1904            mean = g.op("ReduceMean", input, keepdims_i=0)
1905            t_mean = mean
1906            num_elements = _numel_helper(g, input)
1907        else:
1908            axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
1909            mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim)
1910            t_mean = g.op("ReduceMean", input, axes, keepdims_i=1)
1911            redudced_dims = g.op("Shape", input)
1912            # dim could contain one or multiple dimensions
1913            redudced_dims = g.op(
1914                "Gather",
1915                redudced_dims,
1916                g.op("Constant", value_t=torch.tensor(dim)),
1917                axis_i=0,
1918            )
1919            num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
1920        sub_v = g.op("Sub", input, t_mean)
1921        sqr_sub = g.op("Mul", sub_v, sub_v)
1922        keepdim_mean = 0 if dim is None else keepdim
1923        if axes is None:
1924            var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean)
1925        else:
1926            var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean)
1927        # Correct bias in calculating variance, by dividing it over (N - correction) instead on N
1928        if correction is None:
1929            correction = 1
1930        if correction != 0:
1931            num_elements = g.op(
1932                "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
1933            )
1934            one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
1935            mul = g.op("Mul", var, num_elements)
1936            var = g.op("Div", mul, g.op("Sub", num_elements, one))
1937        return var, mean
1938
1939
1940def _embedding_bag_helper(
1941    g: jit_utils.GraphContext,
1942    embedding_matrix,
1943    indices,
1944    offsets,
1945    scale_grad_by_freq,
1946    mode,
1947    sparse,
1948    per_sample_weights,
1949    include_last_offset,
1950    padding_idx,
1951):
1952    if scale_grad_by_freq and GLOBALS.export_training:
1953        return _onnx_unsupported(
1954            "embedding_bag with scale_grad_by_freq for training mode"
1955        )
1956    if padding_idx is not None and padding_idx >= 0:
1957        raise RuntimeError("embedding_bag with padding_idx")
1958
1959    loop_condition = g.op("Constant", value_t=torch.tensor(1))
1960    loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
1961    zero = g.op("Constant", value_t=torch.tensor([0]))
1962
1963    indices_len = _unsqueeze_helper(
1964        g,
1965        _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))),
1966        [0],
1967    )
1968    if not include_last_offset:
1969        offsets = [offsets, indices_len]
1970        offsets = g.op("Concat", *offsets, axis_i=0)
1971
1972    # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
1973    # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
1974    # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
1975    offsets_starts = _slice_helper(
1976        g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
1977    )
1978    offsets_ends = _slice_helper(
1979        g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
1980    )
1981
1982    loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0)))
1983
1984    loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
1985        g, "Loop", loop_len, loop_condition, n_blocks=1
1986    )
1987    loop_block = loop_context.block
1988
1989    # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
1990    block_input_iter = utils._add_input_to_block(loop_block)
1991    cond = utils._add_input_to_block(loop_block)
1992
1993    indices_start = loop_context.op(
1994        "Gather", offsets_starts, block_input_iter, axis_i=0
1995    )
1996    indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
1997    indices_start = _unsqueeze_helper(loop_context, indices_start, [0])
1998    indices_end = _unsqueeze_helper(loop_context, indices_end, [0])
1999
2000    indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
2001    embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
2002    if not _is_none(per_sample_weights):
2003        per_sample_weights_row = loop_context.op(
2004            "Slice", per_sample_weights, indices_start, indices_end, zero
2005        )
2006        per_sample_weights_row = _unsqueeze_helper(
2007            loop_context, per_sample_weights_row, [1]
2008        )
2009        embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
2010    if mode == 0:
2011        embeddings = _reducesum_helper(
2012            loop_context, embeddings, axes_i=[0], keepdims_i=0
2013        )
2014    elif mode == 1:
2015        if loop_context.opset < 18:
2016            embeddings = loop_context.op(
2017                "ReduceMean", embeddings, axes_i=[0], keepdims_i=0
2018            )
2019        else:
2020            axes = loop_context.op(
2021                "Constant", value_t=torch.tensor([0], dtype=torch.long)
2022            )
2023            embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0)
2024    else:
2025        if loop_context.opset < 18:
2026            embeddings = loop_context.op(
2027                "ReduceMax", embeddings, axes_i=[0], keepdims_i=0
2028            )
2029        else:
2030            axes = loop_context.op(
2031                "Constant", value_t=torch.tensor([0], dtype=torch.long)
2032            )
2033            embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0)
2034
2035    cond_out = loop_context.op(
2036        "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
2037    )
2038    utils._add_output_to_block(loop_block, cond_out)
2039    utils._add_output_to_block(loop_block, embeddings)
2040
2041    # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
2042    # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
2043    return loop.node().output(), None, None, None
2044
2045
2046def _linalg_vector_norm_helper(
2047    g: jit_utils.GraphContext,
2048    self: torch._C.Value,
2049    ord: float,
2050    dim: Sequence[int] | None,
2051    keepdim: bool,
2052    dtype: torch._C.Value,
2053):
2054    axes = None
2055    # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
2056    if _is_none(dim):
2057        self = _reshape_helper(g, self, [-1])
2058        keepdim = False
2059    elif g.opset >= 18:
2060        axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
2061
2062    if ord == math.inf:
2063        if g.opset < 18:
2064            result = g.op(
2065                "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
2066            )
2067        else:
2068            if axes is None:
2069                result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim)
2070            else:
2071                result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim)
2072    elif ord == -math.inf:
2073        if g.opset < 18:
2074            result = g.op(
2075                "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
2076            )
2077        else:
2078            if axes is None:
2079                result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim)
2080            else:
2081                result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim)
2082    elif ord == 0:
2083        if g.opset < 11:
2084            return _onnx_opset_unsupported_detailed(
2085                "linalg_vector_norm", 9, 11, "ord=0 not supported", self
2086            )
2087        else:
2088            if dim is None:
2089                self = _reshape_helper(
2090                    g,
2091                    self,
2092                    g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
2093                )
2094                keepdim = False
2095
2096            cond_op = g.op(
2097                "Not",
2098                g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))),
2099            )
2100            cond_op = g.op(
2101                "Cast",
2102                cond_op,
2103                to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
2104            )
2105            return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim)
2106    elif ord == 1:
2107        if g.opset < 18:
2108            result = _reduce_op_symbolic_helper("ReduceL1")(
2109                g, self, dim=dim, keepdim=keepdim
2110            )
2111        else:
2112            if axes is None:
2113                result = _reduce_op_symbolic_helper("ReduceL1")(
2114                    g, self, keepdim=keepdim
2115                )
2116            else:
2117                result = _reduce_op_symbolic_helper("ReduceL1")(
2118                    g, self, axes, keepdim=keepdim
2119                )
2120    elif ord == 2:
2121        if g.opset < 18:
2122            result = _reduce_op_symbolic_helper("ReduceL2")(
2123                g, self, dim=dim, keepdim=keepdim
2124            )
2125        else:
2126            if axes is None:
2127                result = _reduce_op_symbolic_helper("ReduceL2")(
2128                    g, self, keepdim=keepdim
2129                )
2130            else:
2131                result = _reduce_op_symbolic_helper("ReduceL2")(
2132                    g, self, axes, keepdim=keepdim
2133                )
2134    else:
2135        ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
2136        result = _reducesum_helper(
2137            g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
2138        )
2139        result = g.op(
2140            "Pow",
2141            result,
2142            g.op(
2143                "Div",
2144                g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
2145                ord_op,
2146            ),
2147        )
2148
2149    if not _is_none(dtype):
2150        dtype = _get_const(dtype, "i", "dtype")
2151        result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type())  # type: ignore[arg-type]
2152    return result
2153
2154
2155# Deprecated. Internally use _type_utils.ScalarType
2156# TODO: remove these once we support Type's in the JIT IR and we can once again
2157# use the unified toType operator
2158cast_pytorch_to_onnx = {
2159    "Byte": _C_onnx.TensorProtoDataType.UINT8,
2160    "Char": _C_onnx.TensorProtoDataType.INT8,
2161    "Double": _C_onnx.TensorProtoDataType.DOUBLE,
2162    "Float": _C_onnx.TensorProtoDataType.FLOAT,
2163    "Half": _C_onnx.TensorProtoDataType.FLOAT16,
2164    "Int": _C_onnx.TensorProtoDataType.INT32,
2165    "Long": _C_onnx.TensorProtoDataType.INT64,
2166    "Short": _C_onnx.TensorProtoDataType.INT16,
2167    "Bool": _C_onnx.TensorProtoDataType.BOOL,
2168    "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64,
2169    "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128,
2170    "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16,
2171    "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED,
2172}
2173
2174# Deprecated. Internally use _type_utils.ScalarType
2175scalar_name_to_pytorch = {
2176    "uint8_t": "Byte",
2177    "int8_t": "Char",
2178    "double": "Double",
2179    "float": "Float",
2180    "half": "Half",
2181    "int": "Int",
2182    "int64_t": "Long",
2183    "int16_t": "Short",
2184    "bool": "Bool",
2185    "complex64": "ComplexFloat",
2186    "complex128": "ComplexDouble",
2187    "qint8": "QInt8",
2188    "quint8": "QUInt8",
2189    "qint32": "QInt32",
2190    "bfloat16": "BFloat16",
2191}
2192
2193
2194# Deprecated. Internally use _type_utils.ScalarType
2195# This indicates each scalar type's corresponding
2196# torch type. Related source:
2197# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
2198scalar_type_to_pytorch_type = [
2199    torch.uint8,  # 0
2200    torch.int8,  # 1
2201    torch.short,  # 2
2202    torch.int,  # 3
2203    torch.int64,  # 4
2204    torch.half,  # 5
2205    torch.float,  # 6
2206    torch.double,  # 7
2207    torch.complex32,  # 8
2208    torch.complex64,  # 9
2209    torch.complex128,  # 10
2210    torch.bool,  # 11
2211    torch.qint8,  # 12
2212    torch.quint8,  # 13
2213    torch.qint32,  # 14
2214    torch.bfloat16,  # 15
2215]
2216
2217# Deprecated. Internally use _type_utils.ScalarType
2218# source of truth is
2219# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp
2220pytorch_name_to_type = {
2221    "Byte": torch.uint8,
2222    "Char": torch.int8,
2223    "Double": torch.double,
2224    "Float": torch.float,
2225    "Half": torch.half,
2226    "Int": torch.int,
2227    "Long": torch.int64,
2228    "Short": torch.short,
2229    "Bool": torch.bool,
2230    "ComplexFloat": torch.complex64,
2231    "ComplexDouble": torch.complex128,
2232    "QInt8": torch.qint8,
2233    "QUInt8": torch.quint8,
2234    "QInt32": torch.qint32,
2235    "BFloat16": torch.bfloat16,
2236}
2237
2238
2239# Deprecated. Internally use _type_utils.ScalarType
2240scalar_type_to_onnx = [
2241    cast_pytorch_to_onnx["Byte"],  # 0
2242    cast_pytorch_to_onnx["Char"],  # 1
2243    cast_pytorch_to_onnx["Short"],  # 2
2244    cast_pytorch_to_onnx["Int"],  # 3
2245    cast_pytorch_to_onnx["Long"],  # 4
2246    cast_pytorch_to_onnx["Half"],  # 5
2247    cast_pytorch_to_onnx["Float"],  # 6
2248    cast_pytorch_to_onnx["Double"],  # 7
2249    cast_pytorch_to_onnx["Undefined"],  # 8
2250    cast_pytorch_to_onnx["ComplexFloat"],  # 9
2251    cast_pytorch_to_onnx["ComplexDouble"],  # 10
2252    cast_pytorch_to_onnx["Bool"],  # 11
2253    cast_pytorch_to_onnx["Char"],  # 12
2254    cast_pytorch_to_onnx["Byte"],  # 13
2255    cast_pytorch_to_onnx["Int"],  # 14
2256    cast_pytorch_to_onnx["BFloat16"],  # 15
2257]
2258
2259# Global set to store the list of quantized operators in the network.
2260# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
2261_quantized_ops: set[int] = set()
2262