1# mypy: ignore-errors 2 3from typing import List 4 5import torch 6 7 8# Functions and classes for describing the dtypes a function supports 9# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros 10 11 12# Verifies each given dtype is a torch.dtype 13def _validate_dtypes(*dtypes): 14 for dtype in dtypes: 15 assert isinstance(dtype, torch.dtype) 16 return dtypes 17 18 19# class for tuples corresponding to a PyTorch dispatch macro 20class _dispatch_dtypes(tuple): 21 def __add__(self, other): 22 assert isinstance(other, tuple) 23 return _dispatch_dtypes(tuple.__add__(self, other)) 24 25 26_empty_types = _dispatch_dtypes(()) 27 28 29def empty_types(): 30 return _empty_types 31 32 33_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) 34 35 36def floating_types(): 37 return _floating_types 38 39 40_floating_types_and_half = _floating_types + (torch.half,) 41 42 43def floating_types_and_half(): 44 return _floating_types_and_half 45 46 47def floating_types_and(*dtypes): 48 return _floating_types + _validate_dtypes(*dtypes) 49 50 51_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) 52 53 54def floating_and_complex_types(): 55 return _floating_and_complex_types 56 57 58def floating_and_complex_types_and(*dtypes): 59 return _floating_and_complex_types + _validate_dtypes(*dtypes) 60 61 62_double_types = _dispatch_dtypes((torch.float64, torch.complex128)) 63 64 65def double_types(): 66 return _double_types 67 68 69# NB: Does not contain uint16/uint32/uint64 for BC reasons 70_integral_types = _dispatch_dtypes( 71 (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 72) 73 74 75def integral_types(): 76 return _integral_types 77 78 79def integral_types_and(*dtypes): 80 return _integral_types + _validate_dtypes(*dtypes) 81 82 83_all_types = _floating_types + _integral_types 84 85 86def all_types(): 87 return _all_types 88 89 90def all_types_and(*dtypes): 91 return _all_types + _validate_dtypes(*dtypes) 92 93 94_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) 95 96 97def complex_types(): 98 return _complex_types 99 100 101def complex_types_and(*dtypes): 102 return _complex_types + _validate_dtypes(*dtypes) 103 104 105_all_types_and_complex = _all_types + _complex_types 106 107 108def all_types_and_complex(): 109 return _all_types_and_complex 110 111 112def all_types_and_complex_and(*dtypes): 113 return _all_types_and_complex + _validate_dtypes(*dtypes) 114 115 116_all_types_and_half = _all_types + (torch.half,) 117 118 119def all_types_and_half(): 120 return _all_types_and_half 121 122 123def custom_types(*dtypes): 124 """Create a list of arbitrary dtypes""" 125 return _empty_types + _validate_dtypes(*dtypes) 126 127 128# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro 129 130 131# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. 132def get_all_dtypes( 133 include_half=True, 134 include_bfloat16=True, 135 include_bool=True, 136 include_complex=True, 137 include_complex32=False, 138 include_qint=False, 139) -> List[torch.dtype]: 140 dtypes = get_all_int_dtypes() + get_all_fp_dtypes( 141 include_half=include_half, include_bfloat16=include_bfloat16 142 ) 143 if include_bool: 144 dtypes.append(torch.bool) 145 if include_complex: 146 dtypes += get_all_complex_dtypes(include_complex32) 147 if include_qint: 148 dtypes += get_all_qint_dtypes() 149 return dtypes 150 151 152def get_all_math_dtypes(device) -> List[torch.dtype]: 153 return ( 154 get_all_int_dtypes() 155 + get_all_fp_dtypes( 156 include_half=device.startswith("cuda"), include_bfloat16=False 157 ) 158 + get_all_complex_dtypes() 159 ) 160 161 162def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: 163 return ( 164 [torch.complex32, torch.complex64, torch.complex128] 165 if include_complex32 166 else [torch.complex64, torch.complex128] 167 ) 168 169 170def get_all_int_dtypes() -> List[torch.dtype]: 171 return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] 172 173 174def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]: 175 dtypes = [torch.float32, torch.float64] 176 if include_half: 177 dtypes.append(torch.float16) 178 if include_bfloat16: 179 dtypes.append(torch.bfloat16) 180 return dtypes 181 182 183def get_all_qint_dtypes() -> List[torch.dtype]: 184 return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] 185 186 187float_to_corresponding_complex_type_map = { 188 torch.float16: torch.complex32, 189 torch.float32: torch.complex64, 190 torch.float64: torch.complex128, 191} 192