1# mypy: allow-untyped-defs 2"""Utilities for converting and operating on ONNX, JIT and torch types.""" 3 4from __future__ import annotations 5 6import enum 7import typing 8from typing import Literal 9 10import torch 11from torch._C import _onnx as _C_onnx 12from torch.onnx import errors 13 14 15if typing.TYPE_CHECKING: 16 # Hack to help mypy to recognize torch._C.Value 17 from torch import _C # noqa: F401 18 19ScalarName = Literal[ 20 "Byte", 21 "Char", 22 "Double", 23 "Float", 24 "Half", 25 "Int", 26 "Long", 27 "Short", 28 "Bool", 29 "ComplexHalf", 30 "ComplexFloat", 31 "ComplexDouble", 32 "QInt8", 33 "QUInt8", 34 "QInt32", 35 "BFloat16", 36 "Float8E5M2", 37 "Float8E4M3FN", 38 "Float8E5M2FNUZ", 39 "Float8E4M3FNUZ", 40 "Undefined", 41] 42 43TorchName = Literal[ 44 "bool", 45 "uint8_t", 46 "int8_t", 47 "double", 48 "float", 49 "half", 50 "int", 51 "int64_t", 52 "int16_t", 53 "complex32", 54 "complex64", 55 "complex128", 56 "qint8", 57 "quint8", 58 "qint32", 59 "bfloat16", 60 "float8_e5m2", 61 "float8_e4m3fn", 62 "float8_e5m2fnuz", 63 "float8_e4m3fnuz", 64] 65 66 67class JitScalarType(enum.IntEnum): 68 """Scalar types defined in torch. 69 70 Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types. 71 72 Examples: 73 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 74 >>> # xdoctest: +IGNORE_WANT("win32 has different output") 75 >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() 76 TensorProtoDataType.FLOAT 77 78 >>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() 79 TensorProtoDataType.FLOAT 80 81 >>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() 82 TensorProtoDataType.FLOAT 83 84 """ 85 86 # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h 87 UINT8 = 0 88 INT8 = enum.auto() # 1 89 INT16 = enum.auto() # 2 90 INT = enum.auto() # 3 91 INT64 = enum.auto() # 4 92 HALF = enum.auto() # 5 93 FLOAT = enum.auto() # 6 94 DOUBLE = enum.auto() # 7 95 COMPLEX32 = enum.auto() # 8 96 COMPLEX64 = enum.auto() # 9 97 COMPLEX128 = enum.auto() # 10 98 BOOL = enum.auto() # 11 99 QINT8 = enum.auto() # 12 100 QUINT8 = enum.auto() # 13 101 QINT32 = enum.auto() # 14 102 BFLOAT16 = enum.auto() # 15 103 FLOAT8E5M2 = enum.auto() # 16 104 FLOAT8E4M3FN = enum.auto() # 17 105 FLOAT8E5M2FNUZ = enum.auto() # 18 106 FLOAT8E4M3FNUZ = enum.auto() # 19 107 UNDEFINED = enum.auto() # 20 108 109 @classmethod 110 def _from_name(cls, name: ScalarName | TorchName | str | None) -> JitScalarType: 111 """Convert a JIT scalar type or torch type name to ScalarType. 112 113 Note: DO NOT USE this API when `name` comes from a `torch._C.Value.type()` calls. 114 A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can 115 be raised in several scenarios where shape info is not present. 116 Instead use `from_value` API which is safer. 117 118 Args: 119 name: JIT scalar type name (Byte) or torch type name (uint8_t). 120 121 Returns: 122 JitScalarType 123 124 Raises: 125 OnnxExporterError: if name is not a valid scalar type name or if it is None. 126 """ 127 if name is None: 128 raise errors.OnnxExporterError("Scalar type name cannot be None") 129 if valid_scalar_name(name): 130 return _SCALAR_NAME_TO_TYPE[name] # type: ignore[index] 131 if valid_torch_name(name): 132 return _TORCH_NAME_TO_SCALAR_TYPE[name] # type: ignore[index] 133 134 raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'") 135 136 @classmethod 137 def from_dtype(cls, dtype: torch.dtype | None) -> JitScalarType: 138 """Convert a torch dtype to JitScalarType. 139 140 Note: DO NOT USE this API when `dtype` comes from a `torch._C.Value.type()` calls. 141 A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can 142 be raised in several scenarios where shape info is not present. 143 Instead use `from_value` API which is safer. 144 145 Args: 146 dtype: A torch.dtype to create a JitScalarType from 147 148 Returns: 149 JitScalarType 150 151 Raises: 152 OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. 153 """ 154 if dtype not in _DTYPE_TO_SCALAR_TYPE: 155 raise errors.OnnxExporterError(f"Unknown dtype: {dtype}") 156 return _DTYPE_TO_SCALAR_TYPE[dtype] 157 158 @classmethod 159 def from_onnx_type( 160 cls, onnx_type: int | _C_onnx.TensorProtoDataType | None 161 ) -> JitScalarType: 162 """Convert a ONNX data type to JitScalarType. 163 164 Args: 165 onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from 166 167 Returns: 168 JitScalarType 169 170 Raises: 171 OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. 172 """ 173 if onnx_type not in _ONNX_TO_SCALAR_TYPE: 174 raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}") 175 return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)] 176 177 @classmethod 178 def from_value( 179 cls, value: None | torch._C.Value | torch.Tensor, default=None 180 ) -> JitScalarType: 181 """Create a JitScalarType from an value's scalar type. 182 183 Args: 184 value: An object to fetch scalar type from. 185 default: The JitScalarType to return if a valid scalar cannot be fetched from value 186 187 Returns: 188 JitScalarType. 189 190 Raises: 191 OnnxExporterError: if value does not have a valid scalar type and default is None. 192 SymbolicValueError: when value.type()'s info are empty and default is None 193 """ 194 195 if not isinstance(value, (torch._C.Value, torch.Tensor)) or ( 196 isinstance(value, torch._C.Value) and value.node().mustBeNone() 197 ): 198 # default value of type JitScalarType is returned when value is not valid 199 if default is None: 200 raise errors.OnnxExporterError( 201 "value must be either torch._C.Value or torch.Tensor objects." 202 ) 203 elif not isinstance(default, JitScalarType): 204 raise errors.OnnxExporterError( 205 "default value must be a JitScalarType object." 206 ) 207 return default 208 209 # Each value type has their own way of storing scalar type 210 if isinstance(value, torch.Tensor): 211 return cls.from_dtype(value.dtype) 212 if isinstance(value.type(), torch.ListType): 213 try: 214 return cls.from_dtype(value.type().getElementType().dtype()) 215 except RuntimeError: 216 return cls._from_name(str(value.type().getElementType())) 217 if isinstance(value.type(), torch._C.OptionalType): 218 if value.type().getElementType().dtype() is None: 219 if isinstance(default, JitScalarType): 220 return default 221 raise errors.OnnxExporterError( 222 "default value must be a JitScalarType object." 223 ) 224 return cls.from_dtype(value.type().getElementType().dtype()) 225 226 scalar_type = None 227 if value.node().kind() != "prim::Constant" or not isinstance( 228 value.type(), torch._C.NoneType 229 ): 230 # value must be a non-list torch._C.Value scalar 231 scalar_type = value.type().scalarType() 232 233 if scalar_type is not None: 234 return cls._from_name(scalar_type) 235 236 # When everything fails... try to default 237 if default is not None: 238 return default 239 raise errors.SymbolicValueError( 240 f"Cannot determine scalar type for this '{type(value.type())}' instance and " 241 "a default value was not provided.", 242 value, 243 ) 244 245 def scalar_name(self) -> ScalarName: 246 """Convert a JitScalarType to a JIT scalar type name.""" 247 return _SCALAR_TYPE_TO_NAME[self] 248 249 def torch_name(self) -> TorchName: 250 """Convert a JitScalarType to a torch type name.""" 251 return _SCALAR_TYPE_TO_TORCH_NAME[self] 252 253 def dtype(self) -> torch.dtype: 254 """Convert a JitScalarType to a torch dtype.""" 255 return _SCALAR_TYPE_TO_DTYPE[self] 256 257 def onnx_type(self) -> _C_onnx.TensorProtoDataType: 258 """Convert a JitScalarType to an ONNX data type.""" 259 if self not in _SCALAR_TYPE_TO_ONNX: 260 raise errors.OnnxExporterError( 261 f"Scalar type {self} cannot be converted to ONNX" 262 ) 263 return _SCALAR_TYPE_TO_ONNX[self] 264 265 def onnx_compatible(self) -> bool: 266 """Return whether this JitScalarType is compatible with ONNX.""" 267 return ( 268 self in _SCALAR_TYPE_TO_ONNX 269 and self != JitScalarType.UNDEFINED 270 and self != JitScalarType.COMPLEX32 271 ) 272 273 274def valid_scalar_name(scalar_name: ScalarName | str) -> bool: 275 """Return whether the given scalar name is a valid JIT scalar type name.""" 276 return scalar_name in _SCALAR_NAME_TO_TYPE 277 278 279def valid_torch_name(torch_name: TorchName | str) -> bool: 280 """Return whether the given torch name is a valid torch type name.""" 281 return torch_name in _TORCH_NAME_TO_SCALAR_TYPE 282 283 284# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h 285_SCALAR_TYPE_TO_NAME: dict[JitScalarType, ScalarName] = { 286 JitScalarType.BOOL: "Bool", 287 JitScalarType.UINT8: "Byte", 288 JitScalarType.INT8: "Char", 289 JitScalarType.INT16: "Short", 290 JitScalarType.INT: "Int", 291 JitScalarType.INT64: "Long", 292 JitScalarType.HALF: "Half", 293 JitScalarType.FLOAT: "Float", 294 JitScalarType.DOUBLE: "Double", 295 JitScalarType.COMPLEX32: "ComplexHalf", 296 JitScalarType.COMPLEX64: "ComplexFloat", 297 JitScalarType.COMPLEX128: "ComplexDouble", 298 JitScalarType.QINT8: "QInt8", 299 JitScalarType.QUINT8: "QUInt8", 300 JitScalarType.QINT32: "QInt32", 301 JitScalarType.BFLOAT16: "BFloat16", 302 JitScalarType.FLOAT8E5M2: "Float8E5M2", 303 JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN", 304 JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ", 305 JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ", 306 JitScalarType.UNDEFINED: "Undefined", 307} 308 309_SCALAR_NAME_TO_TYPE: dict[ScalarName, JitScalarType] = { 310 v: k for k, v in _SCALAR_TYPE_TO_NAME.items() 311} 312 313_SCALAR_TYPE_TO_TORCH_NAME: dict[JitScalarType, TorchName] = { 314 JitScalarType.BOOL: "bool", 315 JitScalarType.UINT8: "uint8_t", 316 JitScalarType.INT8: "int8_t", 317 JitScalarType.INT16: "int16_t", 318 JitScalarType.INT: "int", 319 JitScalarType.INT64: "int64_t", 320 JitScalarType.HALF: "half", 321 JitScalarType.FLOAT: "float", 322 JitScalarType.DOUBLE: "double", 323 JitScalarType.COMPLEX32: "complex32", 324 JitScalarType.COMPLEX64: "complex64", 325 JitScalarType.COMPLEX128: "complex128", 326 JitScalarType.QINT8: "qint8", 327 JitScalarType.QUINT8: "quint8", 328 JitScalarType.QINT32: "qint32", 329 JitScalarType.BFLOAT16: "bfloat16", 330 JitScalarType.FLOAT8E5M2: "float8_e5m2", 331 JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn", 332 JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz", 333 JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz", 334} 335 336_TORCH_NAME_TO_SCALAR_TYPE: dict[TorchName, JitScalarType] = { 337 v: k for k, v in _SCALAR_TYPE_TO_TORCH_NAME.items() 338} 339 340_SCALAR_TYPE_TO_ONNX = { 341 JitScalarType.BOOL: _C_onnx.TensorProtoDataType.BOOL, 342 JitScalarType.UINT8: _C_onnx.TensorProtoDataType.UINT8, 343 JitScalarType.INT8: _C_onnx.TensorProtoDataType.INT8, 344 JitScalarType.INT16: _C_onnx.TensorProtoDataType.INT16, 345 JitScalarType.INT: _C_onnx.TensorProtoDataType.INT32, 346 JitScalarType.INT64: _C_onnx.TensorProtoDataType.INT64, 347 JitScalarType.HALF: _C_onnx.TensorProtoDataType.FLOAT16, 348 JitScalarType.FLOAT: _C_onnx.TensorProtoDataType.FLOAT, 349 JitScalarType.DOUBLE: _C_onnx.TensorProtoDataType.DOUBLE, 350 JitScalarType.COMPLEX64: _C_onnx.TensorProtoDataType.COMPLEX64, 351 JitScalarType.COMPLEX128: _C_onnx.TensorProtoDataType.COMPLEX128, 352 JitScalarType.BFLOAT16: _C_onnx.TensorProtoDataType.BFLOAT16, 353 JitScalarType.UNDEFINED: _C_onnx.TensorProtoDataType.UNDEFINED, 354 JitScalarType.COMPLEX32: _C_onnx.TensorProtoDataType.UNDEFINED, 355 JitScalarType.QINT8: _C_onnx.TensorProtoDataType.INT8, 356 JitScalarType.QUINT8: _C_onnx.TensorProtoDataType.UINT8, 357 JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32, 358 JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2, 359 JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN, 360 JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, 361 JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, 362} 363 364_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()} 365 366# source of truth is 367# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp 368_SCALAR_TYPE_TO_DTYPE = { 369 JitScalarType.BOOL: torch.bool, 370 JitScalarType.UINT8: torch.uint8, 371 JitScalarType.INT8: torch.int8, 372 JitScalarType.INT16: torch.short, 373 JitScalarType.INT: torch.int, 374 JitScalarType.INT64: torch.int64, 375 JitScalarType.HALF: torch.half, 376 JitScalarType.FLOAT: torch.float, 377 JitScalarType.DOUBLE: torch.double, 378 JitScalarType.COMPLEX32: torch.complex32, 379 JitScalarType.COMPLEX64: torch.complex64, 380 JitScalarType.COMPLEX128: torch.complex128, 381 JitScalarType.QINT8: torch.qint8, 382 JitScalarType.QUINT8: torch.quint8, 383 JitScalarType.QINT32: torch.qint32, 384 JitScalarType.BFLOAT16: torch.bfloat16, 385 JitScalarType.FLOAT8E5M2: torch.float8_e5m2, 386 JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn, 387 JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, 388 JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, 389} 390 391_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()} 392