xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_dtype.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from typing import List
4
5import torch
6
7
8# Functions and classes for describing the dtypes a function supports
9# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
10
11
12# Verifies each given dtype is a torch.dtype
13def _validate_dtypes(*dtypes):
14    for dtype in dtypes:
15        assert isinstance(dtype, torch.dtype)
16    return dtypes
17
18
19# class for tuples corresponding to a PyTorch dispatch macro
20class _dispatch_dtypes(tuple):
21    def __add__(self, other):
22        assert isinstance(other, tuple)
23        return _dispatch_dtypes(tuple.__add__(self, other))
24
25
26_empty_types = _dispatch_dtypes(())
27
28
29def empty_types():
30    return _empty_types
31
32
33_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
34
35
36def floating_types():
37    return _floating_types
38
39
40_floating_types_and_half = _floating_types + (torch.half,)
41
42
43def floating_types_and_half():
44    return _floating_types_and_half
45
46
47def floating_types_and(*dtypes):
48    return _floating_types + _validate_dtypes(*dtypes)
49
50
51_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
52
53
54def floating_and_complex_types():
55    return _floating_and_complex_types
56
57
58def floating_and_complex_types_and(*dtypes):
59    return _floating_and_complex_types + _validate_dtypes(*dtypes)
60
61
62_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
63
64
65def double_types():
66    return _double_types
67
68
69# NB: Does not contain uint16/uint32/uint64 for BC reasons
70_integral_types = _dispatch_dtypes(
71    (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
72)
73
74
75def integral_types():
76    return _integral_types
77
78
79def integral_types_and(*dtypes):
80    return _integral_types + _validate_dtypes(*dtypes)
81
82
83_all_types = _floating_types + _integral_types
84
85
86def all_types():
87    return _all_types
88
89
90def all_types_and(*dtypes):
91    return _all_types + _validate_dtypes(*dtypes)
92
93
94_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
95
96
97def complex_types():
98    return _complex_types
99
100
101def complex_types_and(*dtypes):
102    return _complex_types + _validate_dtypes(*dtypes)
103
104
105_all_types_and_complex = _all_types + _complex_types
106
107
108def all_types_and_complex():
109    return _all_types_and_complex
110
111
112def all_types_and_complex_and(*dtypes):
113    return _all_types_and_complex + _validate_dtypes(*dtypes)
114
115
116_all_types_and_half = _all_types + (torch.half,)
117
118
119def all_types_and_half():
120    return _all_types_and_half
121
122
123def custom_types(*dtypes):
124    """Create a list of arbitrary dtypes"""
125    return _empty_types + _validate_dtypes(*dtypes)
126
127
128# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
129
130
131# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
132def get_all_dtypes(
133    include_half=True,
134    include_bfloat16=True,
135    include_bool=True,
136    include_complex=True,
137    include_complex32=False,
138    include_qint=False,
139) -> List[torch.dtype]:
140    dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
141        include_half=include_half, include_bfloat16=include_bfloat16
142    )
143    if include_bool:
144        dtypes.append(torch.bool)
145    if include_complex:
146        dtypes += get_all_complex_dtypes(include_complex32)
147    if include_qint:
148        dtypes += get_all_qint_dtypes()
149    return dtypes
150
151
152def get_all_math_dtypes(device) -> List[torch.dtype]:
153    return (
154        get_all_int_dtypes()
155        + get_all_fp_dtypes(
156            include_half=device.startswith("cuda"), include_bfloat16=False
157        )
158        + get_all_complex_dtypes()
159    )
160
161
162def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
163    return (
164        [torch.complex32, torch.complex64, torch.complex128]
165        if include_complex32
166        else [torch.complex64, torch.complex128]
167    )
168
169
170def get_all_int_dtypes() -> List[torch.dtype]:
171    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
172
173
174def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
175    dtypes = [torch.float32, torch.float64]
176    if include_half:
177        dtypes.append(torch.float16)
178    if include_bfloat16:
179        dtypes.append(torch.bfloat16)
180    return dtypes
181
182
183def get_all_qint_dtypes() -> List[torch.dtype]:
184    return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
185
186
187float_to_corresponding_complex_type_map = {
188    torch.float16: torch.complex32,
189    torch.float32: torch.complex64,
190    torch.float64: torch.complex128,
191}
192