xref: /aosp_15_r20/external/pytorch/torch/onnx/_type_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Utilities for converting and operating on ONNX, JIT and torch types."""
3
4from __future__ import annotations
5
6import enum
7import typing
8from typing import Literal
9
10import torch
11from torch._C import _onnx as _C_onnx
12from torch.onnx import errors
13
14
15if typing.TYPE_CHECKING:
16    # Hack to help mypy to recognize torch._C.Value
17    from torch import _C  # noqa: F401
18
19ScalarName = Literal[
20    "Byte",
21    "Char",
22    "Double",
23    "Float",
24    "Half",
25    "Int",
26    "Long",
27    "Short",
28    "Bool",
29    "ComplexHalf",
30    "ComplexFloat",
31    "ComplexDouble",
32    "QInt8",
33    "QUInt8",
34    "QInt32",
35    "BFloat16",
36    "Float8E5M2",
37    "Float8E4M3FN",
38    "Float8E5M2FNUZ",
39    "Float8E4M3FNUZ",
40    "Undefined",
41]
42
43TorchName = Literal[
44    "bool",
45    "uint8_t",
46    "int8_t",
47    "double",
48    "float",
49    "half",
50    "int",
51    "int64_t",
52    "int16_t",
53    "complex32",
54    "complex64",
55    "complex128",
56    "qint8",
57    "quint8",
58    "qint32",
59    "bfloat16",
60    "float8_e5m2",
61    "float8_e4m3fn",
62    "float8_e5m2fnuz",
63    "float8_e4m3fnuz",
64]
65
66
67class JitScalarType(enum.IntEnum):
68    """Scalar types defined in torch.
69
70    Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types.
71
72    Examples:
73        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
74        >>> # xdoctest: +IGNORE_WANT("win32 has different output")
75        >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type()
76        TensorProtoDataType.FLOAT
77
78        >>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type()
79        TensorProtoDataType.FLOAT
80
81        >>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type()
82        TensorProtoDataType.FLOAT
83
84    """
85
86    # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
87    UINT8 = 0
88    INT8 = enum.auto()  # 1
89    INT16 = enum.auto()  # 2
90    INT = enum.auto()  # 3
91    INT64 = enum.auto()  # 4
92    HALF = enum.auto()  # 5
93    FLOAT = enum.auto()  # 6
94    DOUBLE = enum.auto()  # 7
95    COMPLEX32 = enum.auto()  # 8
96    COMPLEX64 = enum.auto()  # 9
97    COMPLEX128 = enum.auto()  # 10
98    BOOL = enum.auto()  # 11
99    QINT8 = enum.auto()  # 12
100    QUINT8 = enum.auto()  # 13
101    QINT32 = enum.auto()  # 14
102    BFLOAT16 = enum.auto()  # 15
103    FLOAT8E5M2 = enum.auto()  # 16
104    FLOAT8E4M3FN = enum.auto()  # 17
105    FLOAT8E5M2FNUZ = enum.auto()  # 18
106    FLOAT8E4M3FNUZ = enum.auto()  # 19
107    UNDEFINED = enum.auto()  # 20
108
109    @classmethod
110    def _from_name(cls, name: ScalarName | TorchName | str | None) -> JitScalarType:
111        """Convert a JIT scalar type or torch type name to ScalarType.
112
113        Note: DO NOT USE this API when `name` comes from a `torch._C.Value.type()` calls.
114            A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can
115            be raised in several scenarios where shape info is not present.
116            Instead use `from_value` API which is safer.
117
118        Args:
119            name: JIT scalar type name (Byte) or torch type name (uint8_t).
120
121        Returns:
122            JitScalarType
123
124        Raises:
125           OnnxExporterError: if name is not a valid scalar type name or if it is None.
126        """
127        if name is None:
128            raise errors.OnnxExporterError("Scalar type name cannot be None")
129        if valid_scalar_name(name):
130            return _SCALAR_NAME_TO_TYPE[name]  # type: ignore[index]
131        if valid_torch_name(name):
132            return _TORCH_NAME_TO_SCALAR_TYPE[name]  # type: ignore[index]
133
134        raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'")
135
136    @classmethod
137    def from_dtype(cls, dtype: torch.dtype | None) -> JitScalarType:
138        """Convert a torch dtype to JitScalarType.
139
140        Note: DO NOT USE this API when `dtype` comes from a `torch._C.Value.type()` calls.
141            A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can
142            be raised in several scenarios where shape info is not present.
143            Instead use `from_value` API which is safer.
144
145        Args:
146            dtype: A torch.dtype to create a JitScalarType from
147
148        Returns:
149            JitScalarType
150
151        Raises:
152            OnnxExporterError: if dtype is not a valid torch.dtype or if it is None.
153        """
154        if dtype not in _DTYPE_TO_SCALAR_TYPE:
155            raise errors.OnnxExporterError(f"Unknown dtype: {dtype}")
156        return _DTYPE_TO_SCALAR_TYPE[dtype]
157
158    @classmethod
159    def from_onnx_type(
160        cls, onnx_type: int | _C_onnx.TensorProtoDataType | None
161    ) -> JitScalarType:
162        """Convert a ONNX data type to JitScalarType.
163
164        Args:
165            onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from
166
167        Returns:
168            JitScalarType
169
170        Raises:
171            OnnxExporterError: if dtype is not a valid torch.dtype or if it is None.
172        """
173        if onnx_type not in _ONNX_TO_SCALAR_TYPE:
174            raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}")
175        return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)]
176
177    @classmethod
178    def from_value(
179        cls, value: None | torch._C.Value | torch.Tensor, default=None
180    ) -> JitScalarType:
181        """Create a JitScalarType from an value's scalar type.
182
183        Args:
184            value: An object to fetch scalar type from.
185            default: The JitScalarType to return if a valid scalar cannot be fetched from value
186
187        Returns:
188            JitScalarType.
189
190        Raises:
191            OnnxExporterError: if value does not have a valid scalar type and default is None.
192            SymbolicValueError: when value.type()'s info are empty and default is None
193        """
194
195        if not isinstance(value, (torch._C.Value, torch.Tensor)) or (
196            isinstance(value, torch._C.Value) and value.node().mustBeNone()
197        ):
198            # default value of type JitScalarType is returned when value is not valid
199            if default is None:
200                raise errors.OnnxExporterError(
201                    "value must be either torch._C.Value or torch.Tensor objects."
202                )
203            elif not isinstance(default, JitScalarType):
204                raise errors.OnnxExporterError(
205                    "default value must be a JitScalarType object."
206                )
207            return default
208
209        # Each value type has their own way of storing scalar type
210        if isinstance(value, torch.Tensor):
211            return cls.from_dtype(value.dtype)
212        if isinstance(value.type(), torch.ListType):
213            try:
214                return cls.from_dtype(value.type().getElementType().dtype())
215            except RuntimeError:
216                return cls._from_name(str(value.type().getElementType()))
217        if isinstance(value.type(), torch._C.OptionalType):
218            if value.type().getElementType().dtype() is None:
219                if isinstance(default, JitScalarType):
220                    return default
221                raise errors.OnnxExporterError(
222                    "default value must be a JitScalarType object."
223                )
224            return cls.from_dtype(value.type().getElementType().dtype())
225
226        scalar_type = None
227        if value.node().kind() != "prim::Constant" or not isinstance(
228            value.type(), torch._C.NoneType
229        ):
230            # value must be a non-list torch._C.Value scalar
231            scalar_type = value.type().scalarType()
232
233        if scalar_type is not None:
234            return cls._from_name(scalar_type)
235
236        # When everything fails... try to default
237        if default is not None:
238            return default
239        raise errors.SymbolicValueError(
240            f"Cannot determine scalar type for this '{type(value.type())}' instance and "
241            "a default value was not provided.",
242            value,
243        )
244
245    def scalar_name(self) -> ScalarName:
246        """Convert a JitScalarType to a JIT scalar type name."""
247        return _SCALAR_TYPE_TO_NAME[self]
248
249    def torch_name(self) -> TorchName:
250        """Convert a JitScalarType to a torch type name."""
251        return _SCALAR_TYPE_TO_TORCH_NAME[self]
252
253    def dtype(self) -> torch.dtype:
254        """Convert a JitScalarType to a torch dtype."""
255        return _SCALAR_TYPE_TO_DTYPE[self]
256
257    def onnx_type(self) -> _C_onnx.TensorProtoDataType:
258        """Convert a JitScalarType to an ONNX data type."""
259        if self not in _SCALAR_TYPE_TO_ONNX:
260            raise errors.OnnxExporterError(
261                f"Scalar type {self} cannot be converted to ONNX"
262            )
263        return _SCALAR_TYPE_TO_ONNX[self]
264
265    def onnx_compatible(self) -> bool:
266        """Return whether this JitScalarType is compatible with ONNX."""
267        return (
268            self in _SCALAR_TYPE_TO_ONNX
269            and self != JitScalarType.UNDEFINED
270            and self != JitScalarType.COMPLEX32
271        )
272
273
274def valid_scalar_name(scalar_name: ScalarName | str) -> bool:
275    """Return whether the given scalar name is a valid JIT scalar type name."""
276    return scalar_name in _SCALAR_NAME_TO_TYPE
277
278
279def valid_torch_name(torch_name: TorchName | str) -> bool:
280    """Return whether the given torch name is a valid torch type name."""
281    return torch_name in _TORCH_NAME_TO_SCALAR_TYPE
282
283
284# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
285_SCALAR_TYPE_TO_NAME: dict[JitScalarType, ScalarName] = {
286    JitScalarType.BOOL: "Bool",
287    JitScalarType.UINT8: "Byte",
288    JitScalarType.INT8: "Char",
289    JitScalarType.INT16: "Short",
290    JitScalarType.INT: "Int",
291    JitScalarType.INT64: "Long",
292    JitScalarType.HALF: "Half",
293    JitScalarType.FLOAT: "Float",
294    JitScalarType.DOUBLE: "Double",
295    JitScalarType.COMPLEX32: "ComplexHalf",
296    JitScalarType.COMPLEX64: "ComplexFloat",
297    JitScalarType.COMPLEX128: "ComplexDouble",
298    JitScalarType.QINT8: "QInt8",
299    JitScalarType.QUINT8: "QUInt8",
300    JitScalarType.QINT32: "QInt32",
301    JitScalarType.BFLOAT16: "BFloat16",
302    JitScalarType.FLOAT8E5M2: "Float8E5M2",
303    JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN",
304    JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
305    JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
306    JitScalarType.UNDEFINED: "Undefined",
307}
308
309_SCALAR_NAME_TO_TYPE: dict[ScalarName, JitScalarType] = {
310    v: k for k, v in _SCALAR_TYPE_TO_NAME.items()
311}
312
313_SCALAR_TYPE_TO_TORCH_NAME: dict[JitScalarType, TorchName] = {
314    JitScalarType.BOOL: "bool",
315    JitScalarType.UINT8: "uint8_t",
316    JitScalarType.INT8: "int8_t",
317    JitScalarType.INT16: "int16_t",
318    JitScalarType.INT: "int",
319    JitScalarType.INT64: "int64_t",
320    JitScalarType.HALF: "half",
321    JitScalarType.FLOAT: "float",
322    JitScalarType.DOUBLE: "double",
323    JitScalarType.COMPLEX32: "complex32",
324    JitScalarType.COMPLEX64: "complex64",
325    JitScalarType.COMPLEX128: "complex128",
326    JitScalarType.QINT8: "qint8",
327    JitScalarType.QUINT8: "quint8",
328    JitScalarType.QINT32: "qint32",
329    JitScalarType.BFLOAT16: "bfloat16",
330    JitScalarType.FLOAT8E5M2: "float8_e5m2",
331    JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn",
332    JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz",
333    JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz",
334}
335
336_TORCH_NAME_TO_SCALAR_TYPE: dict[TorchName, JitScalarType] = {
337    v: k for k, v in _SCALAR_TYPE_TO_TORCH_NAME.items()
338}
339
340_SCALAR_TYPE_TO_ONNX = {
341    JitScalarType.BOOL: _C_onnx.TensorProtoDataType.BOOL,
342    JitScalarType.UINT8: _C_onnx.TensorProtoDataType.UINT8,
343    JitScalarType.INT8: _C_onnx.TensorProtoDataType.INT8,
344    JitScalarType.INT16: _C_onnx.TensorProtoDataType.INT16,
345    JitScalarType.INT: _C_onnx.TensorProtoDataType.INT32,
346    JitScalarType.INT64: _C_onnx.TensorProtoDataType.INT64,
347    JitScalarType.HALF: _C_onnx.TensorProtoDataType.FLOAT16,
348    JitScalarType.FLOAT: _C_onnx.TensorProtoDataType.FLOAT,
349    JitScalarType.DOUBLE: _C_onnx.TensorProtoDataType.DOUBLE,
350    JitScalarType.COMPLEX64: _C_onnx.TensorProtoDataType.COMPLEX64,
351    JitScalarType.COMPLEX128: _C_onnx.TensorProtoDataType.COMPLEX128,
352    JitScalarType.BFLOAT16: _C_onnx.TensorProtoDataType.BFLOAT16,
353    JitScalarType.UNDEFINED: _C_onnx.TensorProtoDataType.UNDEFINED,
354    JitScalarType.COMPLEX32: _C_onnx.TensorProtoDataType.UNDEFINED,
355    JitScalarType.QINT8: _C_onnx.TensorProtoDataType.INT8,
356    JitScalarType.QUINT8: _C_onnx.TensorProtoDataType.UINT8,
357    JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32,
358    JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2,
359    JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN,
360    JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ,
361    JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
362}
363
364_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()}
365
366# source of truth is
367# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp
368_SCALAR_TYPE_TO_DTYPE = {
369    JitScalarType.BOOL: torch.bool,
370    JitScalarType.UINT8: torch.uint8,
371    JitScalarType.INT8: torch.int8,
372    JitScalarType.INT16: torch.short,
373    JitScalarType.INT: torch.int,
374    JitScalarType.INT64: torch.int64,
375    JitScalarType.HALF: torch.half,
376    JitScalarType.FLOAT: torch.float,
377    JitScalarType.DOUBLE: torch.double,
378    JitScalarType.COMPLEX32: torch.complex32,
379    JitScalarType.COMPLEX64: torch.complex64,
380    JitScalarType.COMPLEX128: torch.complex128,
381    JitScalarType.QINT8: torch.qint8,
382    JitScalarType.QUINT8: torch.quint8,
383    JitScalarType.QINT32: torch.qint32,
384    JitScalarType.BFLOAT16: torch.bfloat16,
385    JitScalarType.FLOAT8E5M2: torch.float8_e5m2,
386    JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn,
387    JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
388    JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
389}
390
391_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()}
392