xref: /aosp_15_r20/external/pytorch/torch/_numpy/_dtypes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3""" Define analogs of numpy dtypes supported by pytorch.
4Define the scalar types and supported dtypes and numpy <--> torch dtype mappings.
5"""
6import builtins
7
8import torch
9
10from . import _dtypes_impl
11
12
13# ### Scalar types ###
14
15
16class generic:
17    name = "generic"
18
19    def __new__(cls, value):
20        # NumPy scalars are modelled as 0-D arrays
21        # so a call to np.float32(4) produces a 0-D array.
22
23        from ._ndarray import asarray, ndarray
24
25        if isinstance(value, str) and value in ["inf", "nan"]:
26            value = {"inf": torch.inf, "nan": torch.nan}[value]
27
28        if isinstance(value, ndarray):
29            return value.astype(cls)
30        else:
31            return asarray(value, dtype=cls)
32
33
34##################
35# abstract types #
36##################
37
38
39class number(generic):
40    name = "number"
41
42
43class integer(number):
44    name = "integer"
45
46
47class inexact(number):
48    name = "inexact"
49
50
51class signedinteger(integer):
52    name = "signedinteger"
53
54
55class unsignedinteger(integer):
56    name = "unsignedinteger"
57
58
59class floating(inexact):
60    name = "floating"
61
62
63class complexfloating(inexact):
64    name = "complexfloating"
65
66
67_abstract_dtypes = [
68    "generic",
69    "number",
70    "integer",
71    "signedinteger",
72    "unsignedinteger",
73    "inexact",
74    "floating",
75    "complexfloating",
76]
77
78# ##### concrete types
79
80# signed integers
81
82
83class int8(signedinteger):
84    name = "int8"
85    typecode = "b"
86    torch_dtype = torch.int8
87
88
89class int16(signedinteger):
90    name = "int16"
91    typecode = "h"
92    torch_dtype = torch.int16
93
94
95class int32(signedinteger):
96    name = "int32"
97    typecode = "i"
98    torch_dtype = torch.int32
99
100
101class int64(signedinteger):
102    name = "int64"
103    typecode = "l"
104    torch_dtype = torch.int64
105
106
107# unsigned integers
108
109
110class uint8(unsignedinteger):
111    name = "uint8"
112    typecode = "B"
113    torch_dtype = torch.uint8
114
115
116class uint16(unsignedinteger):
117    name = "uint16"
118    typecode = "H"
119    torch_dtype = torch.uint16
120
121
122class uint32(signedinteger):
123    name = "uint32"
124    typecode = "I"
125    torch_dtype = torch.uint32
126
127
128class uint64(signedinteger):
129    name = "uint64"
130    typecode = "L"
131    torch_dtype = torch.uint64
132
133
134# floating point
135
136
137class float16(floating):
138    name = "float16"
139    typecode = "e"
140    torch_dtype = torch.float16
141
142
143class float32(floating):
144    name = "float32"
145    typecode = "f"
146    torch_dtype = torch.float32
147
148
149class float64(floating):
150    name = "float64"
151    typecode = "d"
152    torch_dtype = torch.float64
153
154
155class complex64(complexfloating):
156    name = "complex64"
157    typecode = "F"
158    torch_dtype = torch.complex64
159
160
161class complex128(complexfloating):
162    name = "complex128"
163    typecode = "D"
164    torch_dtype = torch.complex128
165
166
167class bool_(generic):
168    name = "bool_"
169    typecode = "?"
170    torch_dtype = torch.bool
171
172
173# name aliases
174_name_aliases = {
175    "intp": int64,
176    "int_": int64,
177    "intc": int32,
178    "byte": int8,
179    "short": int16,
180    "longlong": int64,  # XXX: is this correct?
181    "ulonglong": uint64,
182    "ubyte": uint8,
183    "half": float16,
184    "single": float32,
185    "double": float64,
186    "float_": float64,
187    "csingle": complex64,
188    "singlecomplex": complex64,
189    "cdouble": complex128,
190    "cfloat": complex128,
191    "complex_": complex128,
192}
193# We register float_ = float32 and so on
194for name, obj in _name_aliases.items():
195    vars()[name] = obj
196
197
198# Replicate this NumPy-defined way of grouping scalar types,
199# cf tests/core/test_scalar_methods.py
200sctypes = {
201    "int": [int8, int16, int32, int64],
202    "uint": [uint8, uint16, uint32, uint64],
203    "float": [float16, float32, float64],
204    "complex": [complex64, complex128],
205    "others": [bool_],
206}
207
208
209# Support mappings/functions
210
211_names = {st.name: st for cat in sctypes for st in sctypes[cat]}
212_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]}
213_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]}
214
215
216_aliases = {
217    "u1": uint8,
218    "i1": int8,
219    "i2": int16,
220    "i4": int32,
221    "i8": int64,
222    "b": int8,  # XXX: srsly?
223    "f2": float16,
224    "f4": float32,
225    "f8": float64,
226    "c8": complex64,
227    "c16": complex128,
228    # numpy-specific trailing underscore
229    "bool_": bool_,
230}
231
232
233_python_types = {
234    int: int64,
235    float: float64,
236    complex: complex128,
237    builtins.bool: bool_,
238    # also allow stringified names of python types
239    int.__name__: int64,
240    float.__name__: float64,
241    complex.__name__: complex128,
242    builtins.bool.__name__: bool_,
243}
244
245
246def sctype_from_string(s):
247    """Normalize a string value: a type 'name' or a typecode or a width alias."""
248    if s in _names:
249        return _names[s]
250    if s in _name_aliases.keys():
251        return _name_aliases[s]
252    if s in _typecodes:
253        return _typecodes[s]
254    if s in _aliases:
255        return _aliases[s]
256    if s in _python_types:
257        return _python_types[s]
258    raise TypeError(f"data type {s!r} not understood")
259
260
261def sctype_from_torch_dtype(torch_dtype):
262    return _torch_dtypes[torch_dtype]
263
264
265# ### DTypes. ###
266
267
268def dtype(arg):
269    if arg is None:
270        arg = _dtypes_impl.default_dtypes().float_dtype
271    return DType(arg)
272
273
274class DType:
275    def __init__(self, arg):
276        # a pytorch object?
277        if isinstance(arg, torch.dtype):
278            sctype = _torch_dtypes[arg]
279        elif isinstance(arg, torch.Tensor):
280            sctype = _torch_dtypes[arg.dtype]
281        # a scalar type?
282        elif issubclass_(arg, generic):
283            sctype = arg
284        # a dtype already?
285        elif isinstance(arg, DType):
286            sctype = arg._scalar_type
287        # a has a right attribute?
288        elif hasattr(arg, "dtype"):
289            sctype = arg.dtype._scalar_type
290        else:
291            sctype = sctype_from_string(arg)
292        self._scalar_type = sctype
293
294    @property
295    def name(self):
296        return self._scalar_type.name
297
298    @property
299    def type(self):
300        return self._scalar_type
301
302    @property
303    def kind(self):
304        # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
305        return _torch_dtypes[self.torch_dtype].name[0]
306
307    @property
308    def typecode(self):
309        return self._scalar_type.typecode
310
311    def __eq__(self, other):
312        if isinstance(other, DType):
313            return self._scalar_type == other._scalar_type
314        try:
315            other_instance = DType(other)
316        except TypeError:
317            return False
318        return self._scalar_type == other_instance._scalar_type
319
320    @property
321    def torch_dtype(self):
322        return self._scalar_type.torch_dtype
323
324    def __hash__(self):
325        return hash(self._scalar_type.name)
326
327    def __repr__(self):
328        return f'dtype("{self.name}")'
329
330    __str__ = __repr__
331
332    @property
333    def itemsize(self):
334        elem = self.type(1)
335        return elem.tensor.element_size()
336
337    def __getstate__(self):
338        return self._scalar_type
339
340    def __setstate__(self, value):
341        self._scalar_type = value
342
343
344typecodes = {
345    "All": "efdFDBbhil?",
346    "AllFloat": "efdFD",
347    "AllInteger": "Bbhil",
348    "Integer": "bhil",
349    "UnsignedInteger": "B",
350    "Float": "efd",
351    "Complex": "FD",
352}
353
354
355# ### Defaults and dtype discovery
356
357
358def set_default_dtype(fp_dtype="numpy", int_dtype="numpy"):
359    """Set the (global) defaults for fp, complex, and int dtypes.
360
361    The complex dtype is inferred from the float (fp) dtype. It has
362    a width at least twice the width of the float dtype,
363    i.e., it's complex128 for float64 and complex64 for float32.
364
365    Parameters
366    ----------
367    fp_dtype
368        Allowed values are "numpy", "pytorch" or dtype_like things which
369        can be converted into a DType instance.
370        Default is "numpy" (i.e. float64).
371    int_dtype
372        Allowed values are "numpy", "pytorch" or dtype_like things which
373        can be converted into a DType instance.
374        Default is "numpy" (i.e. int64).
375
376    Returns
377    -------
378    The old default dtype state: a namedtuple with attributes ``float_dtype``,
379    ``complex_dtypes`` and ``int_dtype``. These attributes store *pytorch*
380    dtypes.
381
382    Notes
383    ------------
384    This functions has a side effect: it sets the global state with the provided dtypes.
385
386    The complex dtype has bit width of at least twice the width of the float
387    dtype, i.e. it's complex128 for float64 and complex64 for float32.
388
389    """
390    if fp_dtype not in ["numpy", "pytorch"]:
391        fp_dtype = dtype(fp_dtype).torch_dtype
392    if int_dtype not in ["numpy", "pytorch"]:
393        int_dtype = dtype(int_dtype).torch_dtype
394
395    if fp_dtype == "numpy":
396        float_dtype = torch.float64
397    elif fp_dtype == "pytorch":
398        float_dtype = torch.float32
399    else:
400        float_dtype = fp_dtype
401
402    complex_dtype = {
403        torch.float64: torch.complex128,
404        torch.float32: torch.complex64,
405        torch.float16: torch.complex64,
406    }[float_dtype]
407
408    if int_dtype in ["numpy", "pytorch"]:
409        int_dtype = torch.int64
410    else:
411        int_dtype = int_dtype
412
413    new_defaults = _dtypes_impl.DefaultDTypes(
414        float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype
415    )
416
417    # set the new global state and return the old state
418    old_defaults = _dtypes_impl.default_dtypes
419    _dtypes_impl._default_dtypes = new_defaults
420    return old_defaults
421
422
423def issubclass_(arg, klass):
424    try:
425        return issubclass(arg, klass)
426    except TypeError:
427        return False
428
429
430def issubdtype(arg1, arg2):
431    # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
432
433    # We also accept strings even if NumPy doesn't as dtypes are serialized as their
434    # string representation in dynamo's graph
435    def str_to_abstract(t):
436        if isinstance(t, str) and t in _abstract_dtypes:
437            return globals()[t]
438        return t
439
440    arg1 = str_to_abstract(arg1)
441    arg2 = str_to_abstract(arg2)
442
443    if not issubclass_(arg1, generic):
444        arg1 = dtype(arg1).type
445    if not issubclass_(arg2, generic):
446        arg2 = dtype(arg2).type
447    return issubclass(arg1, arg2)
448
449
450__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"]
451__all__ += list(_names.keys())  # noqa: PLE0605
452__all__ += list(_name_aliases.keys())  # noqa: PLE0605
453__all__ += _abstract_dtypes  # noqa: PLE0605
454