1# mypy: ignore-errors 2 3""" Define analogs of numpy dtypes supported by pytorch. 4Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. 5""" 6import builtins 7 8import torch 9 10from . import _dtypes_impl 11 12 13# ### Scalar types ### 14 15 16class generic: 17 name = "generic" 18 19 def __new__(cls, value): 20 # NumPy scalars are modelled as 0-D arrays 21 # so a call to np.float32(4) produces a 0-D array. 22 23 from ._ndarray import asarray, ndarray 24 25 if isinstance(value, str) and value in ["inf", "nan"]: 26 value = {"inf": torch.inf, "nan": torch.nan}[value] 27 28 if isinstance(value, ndarray): 29 return value.astype(cls) 30 else: 31 return asarray(value, dtype=cls) 32 33 34################## 35# abstract types # 36################## 37 38 39class number(generic): 40 name = "number" 41 42 43class integer(number): 44 name = "integer" 45 46 47class inexact(number): 48 name = "inexact" 49 50 51class signedinteger(integer): 52 name = "signedinteger" 53 54 55class unsignedinteger(integer): 56 name = "unsignedinteger" 57 58 59class floating(inexact): 60 name = "floating" 61 62 63class complexfloating(inexact): 64 name = "complexfloating" 65 66 67_abstract_dtypes = [ 68 "generic", 69 "number", 70 "integer", 71 "signedinteger", 72 "unsignedinteger", 73 "inexact", 74 "floating", 75 "complexfloating", 76] 77 78# ##### concrete types 79 80# signed integers 81 82 83class int8(signedinteger): 84 name = "int8" 85 typecode = "b" 86 torch_dtype = torch.int8 87 88 89class int16(signedinteger): 90 name = "int16" 91 typecode = "h" 92 torch_dtype = torch.int16 93 94 95class int32(signedinteger): 96 name = "int32" 97 typecode = "i" 98 torch_dtype = torch.int32 99 100 101class int64(signedinteger): 102 name = "int64" 103 typecode = "l" 104 torch_dtype = torch.int64 105 106 107# unsigned integers 108 109 110class uint8(unsignedinteger): 111 name = "uint8" 112 typecode = "B" 113 torch_dtype = torch.uint8 114 115 116class uint16(unsignedinteger): 117 name = "uint16" 118 typecode = "H" 119 torch_dtype = torch.uint16 120 121 122class uint32(signedinteger): 123 name = "uint32" 124 typecode = "I" 125 torch_dtype = torch.uint32 126 127 128class uint64(signedinteger): 129 name = "uint64" 130 typecode = "L" 131 torch_dtype = torch.uint64 132 133 134# floating point 135 136 137class float16(floating): 138 name = "float16" 139 typecode = "e" 140 torch_dtype = torch.float16 141 142 143class float32(floating): 144 name = "float32" 145 typecode = "f" 146 torch_dtype = torch.float32 147 148 149class float64(floating): 150 name = "float64" 151 typecode = "d" 152 torch_dtype = torch.float64 153 154 155class complex64(complexfloating): 156 name = "complex64" 157 typecode = "F" 158 torch_dtype = torch.complex64 159 160 161class complex128(complexfloating): 162 name = "complex128" 163 typecode = "D" 164 torch_dtype = torch.complex128 165 166 167class bool_(generic): 168 name = "bool_" 169 typecode = "?" 170 torch_dtype = torch.bool 171 172 173# name aliases 174_name_aliases = { 175 "intp": int64, 176 "int_": int64, 177 "intc": int32, 178 "byte": int8, 179 "short": int16, 180 "longlong": int64, # XXX: is this correct? 181 "ulonglong": uint64, 182 "ubyte": uint8, 183 "half": float16, 184 "single": float32, 185 "double": float64, 186 "float_": float64, 187 "csingle": complex64, 188 "singlecomplex": complex64, 189 "cdouble": complex128, 190 "cfloat": complex128, 191 "complex_": complex128, 192} 193# We register float_ = float32 and so on 194for name, obj in _name_aliases.items(): 195 vars()[name] = obj 196 197 198# Replicate this NumPy-defined way of grouping scalar types, 199# cf tests/core/test_scalar_methods.py 200sctypes = { 201 "int": [int8, int16, int32, int64], 202 "uint": [uint8, uint16, uint32, uint64], 203 "float": [float16, float32, float64], 204 "complex": [complex64, complex128], 205 "others": [bool_], 206} 207 208 209# Support mappings/functions 210 211_names = {st.name: st for cat in sctypes for st in sctypes[cat]} 212_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]} 213_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]} 214 215 216_aliases = { 217 "u1": uint8, 218 "i1": int8, 219 "i2": int16, 220 "i4": int32, 221 "i8": int64, 222 "b": int8, # XXX: srsly? 223 "f2": float16, 224 "f4": float32, 225 "f8": float64, 226 "c8": complex64, 227 "c16": complex128, 228 # numpy-specific trailing underscore 229 "bool_": bool_, 230} 231 232 233_python_types = { 234 int: int64, 235 float: float64, 236 complex: complex128, 237 builtins.bool: bool_, 238 # also allow stringified names of python types 239 int.__name__: int64, 240 float.__name__: float64, 241 complex.__name__: complex128, 242 builtins.bool.__name__: bool_, 243} 244 245 246def sctype_from_string(s): 247 """Normalize a string value: a type 'name' or a typecode or a width alias.""" 248 if s in _names: 249 return _names[s] 250 if s in _name_aliases.keys(): 251 return _name_aliases[s] 252 if s in _typecodes: 253 return _typecodes[s] 254 if s in _aliases: 255 return _aliases[s] 256 if s in _python_types: 257 return _python_types[s] 258 raise TypeError(f"data type {s!r} not understood") 259 260 261def sctype_from_torch_dtype(torch_dtype): 262 return _torch_dtypes[torch_dtype] 263 264 265# ### DTypes. ### 266 267 268def dtype(arg): 269 if arg is None: 270 arg = _dtypes_impl.default_dtypes().float_dtype 271 return DType(arg) 272 273 274class DType: 275 def __init__(self, arg): 276 # a pytorch object? 277 if isinstance(arg, torch.dtype): 278 sctype = _torch_dtypes[arg] 279 elif isinstance(arg, torch.Tensor): 280 sctype = _torch_dtypes[arg.dtype] 281 # a scalar type? 282 elif issubclass_(arg, generic): 283 sctype = arg 284 # a dtype already? 285 elif isinstance(arg, DType): 286 sctype = arg._scalar_type 287 # a has a right attribute? 288 elif hasattr(arg, "dtype"): 289 sctype = arg.dtype._scalar_type 290 else: 291 sctype = sctype_from_string(arg) 292 self._scalar_type = sctype 293 294 @property 295 def name(self): 296 return self._scalar_type.name 297 298 @property 299 def type(self): 300 return self._scalar_type 301 302 @property 303 def kind(self): 304 # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html 305 return _torch_dtypes[self.torch_dtype].name[0] 306 307 @property 308 def typecode(self): 309 return self._scalar_type.typecode 310 311 def __eq__(self, other): 312 if isinstance(other, DType): 313 return self._scalar_type == other._scalar_type 314 try: 315 other_instance = DType(other) 316 except TypeError: 317 return False 318 return self._scalar_type == other_instance._scalar_type 319 320 @property 321 def torch_dtype(self): 322 return self._scalar_type.torch_dtype 323 324 def __hash__(self): 325 return hash(self._scalar_type.name) 326 327 def __repr__(self): 328 return f'dtype("{self.name}")' 329 330 __str__ = __repr__ 331 332 @property 333 def itemsize(self): 334 elem = self.type(1) 335 return elem.tensor.element_size() 336 337 def __getstate__(self): 338 return self._scalar_type 339 340 def __setstate__(self, value): 341 self._scalar_type = value 342 343 344typecodes = { 345 "All": "efdFDBbhil?", 346 "AllFloat": "efdFD", 347 "AllInteger": "Bbhil", 348 "Integer": "bhil", 349 "UnsignedInteger": "B", 350 "Float": "efd", 351 "Complex": "FD", 352} 353 354 355# ### Defaults and dtype discovery 356 357 358def set_default_dtype(fp_dtype="numpy", int_dtype="numpy"): 359 """Set the (global) defaults for fp, complex, and int dtypes. 360 361 The complex dtype is inferred from the float (fp) dtype. It has 362 a width at least twice the width of the float dtype, 363 i.e., it's complex128 for float64 and complex64 for float32. 364 365 Parameters 366 ---------- 367 fp_dtype 368 Allowed values are "numpy", "pytorch" or dtype_like things which 369 can be converted into a DType instance. 370 Default is "numpy" (i.e. float64). 371 int_dtype 372 Allowed values are "numpy", "pytorch" or dtype_like things which 373 can be converted into a DType instance. 374 Default is "numpy" (i.e. int64). 375 376 Returns 377 ------- 378 The old default dtype state: a namedtuple with attributes ``float_dtype``, 379 ``complex_dtypes`` and ``int_dtype``. These attributes store *pytorch* 380 dtypes. 381 382 Notes 383 ------------ 384 This functions has a side effect: it sets the global state with the provided dtypes. 385 386 The complex dtype has bit width of at least twice the width of the float 387 dtype, i.e. it's complex128 for float64 and complex64 for float32. 388 389 """ 390 if fp_dtype not in ["numpy", "pytorch"]: 391 fp_dtype = dtype(fp_dtype).torch_dtype 392 if int_dtype not in ["numpy", "pytorch"]: 393 int_dtype = dtype(int_dtype).torch_dtype 394 395 if fp_dtype == "numpy": 396 float_dtype = torch.float64 397 elif fp_dtype == "pytorch": 398 float_dtype = torch.float32 399 else: 400 float_dtype = fp_dtype 401 402 complex_dtype = { 403 torch.float64: torch.complex128, 404 torch.float32: torch.complex64, 405 torch.float16: torch.complex64, 406 }[float_dtype] 407 408 if int_dtype in ["numpy", "pytorch"]: 409 int_dtype = torch.int64 410 else: 411 int_dtype = int_dtype 412 413 new_defaults = _dtypes_impl.DefaultDTypes( 414 float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype 415 ) 416 417 # set the new global state and return the old state 418 old_defaults = _dtypes_impl.default_dtypes 419 _dtypes_impl._default_dtypes = new_defaults 420 return old_defaults 421 422 423def issubclass_(arg, klass): 424 try: 425 return issubclass(arg, klass) 426 except TypeError: 427 return False 428 429 430def issubdtype(arg1, arg2): 431 # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 432 433 # We also accept strings even if NumPy doesn't as dtypes are serialized as their 434 # string representation in dynamo's graph 435 def str_to_abstract(t): 436 if isinstance(t, str) and t in _abstract_dtypes: 437 return globals()[t] 438 return t 439 440 arg1 = str_to_abstract(arg1) 441 arg2 = str_to_abstract(arg2) 442 443 if not issubclass_(arg1, generic): 444 arg1 = dtype(arg1).type 445 if not issubclass_(arg2, generic): 446 arg2 = dtype(arg2).type 447 return issubclass(arg1, arg2) 448 449 450__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"] 451__all__ += list(_names.keys()) # noqa: PLE0605 452__all__ += list(_name_aliases.keys()) # noqa: PLE0605 453__all__ += _abstract_dtypes # noqa: PLE0605 454