1# mypy: ignore-errors 2 3import collections 4 5import torch 6from torch.testing._internal.common_utils import TEST_WITH_ROCM 7from torch.testing._internal.common_utils import TestCase 8 9 10class AutocastTestLists: 11 def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype): 12 input = (torch.randn((n, n), device=dev, dtype=torch.float32),) 13 14 hx = ((torch.randn((n, n), device=dev, dtype=torch.float32), 15 torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else 16 torch.randn((n, n), device=dev, dtype=torch.float32),) 17 18 weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_ih 19 torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_hh 20 torch.randn((num_chunks * n), device=dev, dtype=torch.float32), # bias_ih 21 torch.randn((num_chunks * n), device=dev, dtype=torch.float32)) # bias_hh 22 23 # returns args as a tuple 24 return input + hx + weights 25 26 # Supplies ops and arguments for test_autocast_* in test/test_cuda.py 27 def __init__(self, dev): 28 super().__init__() 29 n = 8 30 # Utility arguments, created as one-element tuples 31 pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) 32 pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) 33 pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) 34 mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) 35 mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) 36 mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) 37 38 dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) 39 conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), 40 torch.randn(dimset, dtype=torch.float32, device=dev)) 41 for dimset in dimsets] 42 bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) 43 element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) 44 pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) 45 pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) 46 mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 47 mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 48 mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 49 mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 50 51 # The lists below organize ops that autocast needs to test. 52 # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py. 53 # Each op is associated with a tuple of valid arguments. 54 # In addition, cudnn conv ops are not supported on ROCm and hence will 55 # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list. 56 57 # Some ops implement built-in type promotion. These don't need autocasting, 58 # but autocasting relies on their promotion, so we include tests to double-check. 59 self.torch_expect_builtin_promote = [ 60 ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool), 61 ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool), 62 ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool), 63 ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool), 64 ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool), 65 ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool), 66 ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32), 67 ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32), 68 ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32), 69 ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32), 70 ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32), 71 ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32), 72 ] 73 self.methods_expect_builtin_promote = [ 74 ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 75 ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 76 ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 77 ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 78 ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 79 ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool), 80 ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32), 81 ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32), 82 ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32), 83 ] 84 85 # The remaining lists organize ops that autocast treats explicitly. 86 self.torch_fp16 = [ 87 # deprecated _convolution 88 ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, 89 (0, 0), 1, False, True, True)), 90 # the current _convolution 91 ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, 92 (0, 0), 1, False, True, True, True)), 93 ("conv1d", conv_args_fp32[0]), 94 ("conv2d", conv_args_fp32[1]), 95 ("conv3d", conv_args_fp32[2]), 96 ("conv_tbc", conv_args_fp32[0] + bias_fp32), 97 ("conv_transpose1d", conv_args_fp32[0]), 98 ("conv_transpose2d", conv_args_fp32[1]), 99 ("conv_transpose3d", conv_args_fp32[2]), 100 ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)), 101 ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM), 102 ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1), 103 (1, 1), 1, False, True, True), TEST_WITH_ROCM), 104 ("prelu", pointwise0_fp32 + element0_fp32), 105 ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), 106 ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32), 107 ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32), 108 ("matmul", mat0_fp32 + mat1_fp32), 109 ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32), 110 ("mm", mat0_fp32 + mat1_fp32), 111 ("mv", mat0_fp32 + pointwise0_fp32), 112 ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32), 113 ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), 114 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 115 ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), 116 torch.randn((n, n, n), device=dev, dtype=torch.float32), 117 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 118 ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), 119 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 120 # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell. 121 # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32), 122 # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32), 123 ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)), 124 ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)), 125 ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)), 126 ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)), 127 ] 128 self.torch_fp32 = [ 129 ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)), 130 ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)), 131 ("cosh", pointwise0_fp16), 132 ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)), 133 ("exp", pointwise0_fp16), 134 ("expm1", pointwise0_fp16), 135 ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)), 136 ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)), 137 ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)), 138 ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)), 139 ("reciprocal", pointwise0_fp16), 140 ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)), 141 ("sinh", pointwise0_fp16), 142 ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)), 143 ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16), 144 ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)), 145 # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API. 146 ("softmax", pointwise0_fp16 + (0,)), 147 ("log_softmax", pointwise0_fp16 + (0,)), 148 ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)), 149 ("group_norm", mat0_fp16 + (1,)), 150 ("norm", pointwise0_fp16), 151 ("norm", pointwise0_fp16, {"dim": 0}), 152 # these need magma 153 # ("norm", mat0_fp16, {"p": "nuc"}), 154 # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}), 155 ("norm", pointwise0_fp16, {"p": 1}), 156 ("norm", pointwise0_fp16, {"p": 1, "dim": 0}), 157 ("cosine_similarity", mat0_fp16 + mat1_fp16), 158 ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))), 159 ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16), 160 torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16), 161 torch.tensor([1], device=dev, dtype=torch.int))), 162 ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)), 163 ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)), 164 ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)), 165 ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16), 166 ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)), 167 ("cumprod", pointwise0_fp16 + (0,)), 168 ("cumsum", pointwise0_fp16 + (0,)), 169 ("dist", pointwise0_fp16 + pointwise1_fp16), 170 ("pdist", mat0_fp16), 171 ("cdist", mat0_fp16 + mat1_fp16), 172 ("prod", pointwise0_fp16), 173 ("prod", pointwise0_fp16 + (0,)), 174 ("renorm", mat0_fp16 + (2, 0, 1.0)), 175 ("sum", pointwise0_fp16), 176 ("sum", mat0_fp16 + (1,)), 177 ("logsumexp", mat0_fp16 + (1,)), 178 ] 179 self.torch_need_autocast_promote = [ 180 ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)), 181 ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16), 182 ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)), 183 ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev), 184 torch.randn((1, 2), dtype=torch.float32, device=dev), 185 torch.randn((1, 2, 2), dtype=torch.float16, device=dev), 186 torch.randn((1,), dtype=torch.float32, device=dev))), 187 ("cross", (torch.randn(3, dtype=torch.float32, device=dev), 188 torch.randn(3, dtype=torch.float16, device=dev))), 189 ("dot", pointwise0_fp16 + pointwise1_fp32), 190 ("vdot", pointwise0_fp16 + pointwise1_fp32), 191 ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev), 192 torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev), 193 0, 0, False)), 194 ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),), 195 torch.randn(1, device=dev, dtype=torch.float16))), 196 ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),), 197 torch.randn(1, device=dev, dtype=torch.float32))), 198 ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev), 199 torch.randn((2, 2, 2), dtype=torch.float16, device=dev))), 200 ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev), 201 0, 202 torch.randint(0, 2, (2, 2, 2), device=dev), 203 torch.randn((2, 2, 2), dtype=torch.float16, device=dev))), 204 ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev), 205 0, 206 torch.randint(0, 2, (2, 2, 2), device=dev), 207 torch.randn((2, 2, 2), dtype=torch.float32, device=dev))), 208 ] 209 self.nn_fp16 = [ 210 ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32), 211 ] 212 self.nn_fp32 = [ 213 ("softplus", pointwise0_fp16), 214 ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float), 215 torch.zeros((n,), device=dev, dtype=torch.long))), 216 ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half), 217 torch.zeros((n, n, n), device=dev, dtype=torch.long))), 218 ("l1_loss", mat0_fp16 + mat1_fp16), 219 ("smooth_l1_loss", mat0_fp16 + mat1_fp16), 220 ("mse_loss", mat0_fp16 + mat1_fp16), 221 ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), 222 ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), 223 ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)), 224 ] 225 self.linalg_fp16 = [ 226 ("linalg_vecdot", mat0_fp32 + mat0_fp32), 227 ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)), 228 ] 229 self.methods_fp16 = [ 230 ("__matmul__", mat0_fp32 + mat1_fp32) 231 ] 232 self.methods_fp32 = [ 233 ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)), 234 ] 235 self.banned = [ 236 ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32), 237 torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn), 238 ] 239 240 241class AutocastCPUTestLists: 242 # Supplies ops and arguments for test_autocast_* in test/test_cpu.py 243 def __init__(self, dev): 244 super().__init__() 245 n = 8 246 # Utility arguments, created as one-element tuples 247 pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) 248 pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) 249 pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) 250 mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) 251 mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) 252 mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) 253 254 pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) 255 pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) 256 257 dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) 258 259 dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) 260 for dimset in dummy_dimsets] 261 262 dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) 263 conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), 264 torch.randn(dimset, dtype=torch.bfloat16, device=dev)) 265 for dimset in dimsets] 266 conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), 267 torch.randn(dimset, dtype=torch.float32, device=dev)) 268 for dimset in dimsets] 269 270 bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) 271 element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) 272 pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) 273 pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) 274 mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 275 mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 276 mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 277 mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) 278 279 dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) 280 for dimset in dummy_dimsets] 281 # The lists below organize ops that autocast needs to test. 282 # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py. 283 # Each op is associated with a tuple of valid arguments. 284 285 # Some ops implement built-in type promotion. These don't need autocasting, 286 # but autocasting relies on their promotion, so we include tests to double-check. 287 self.torch_expect_builtin_promote = [ 288 ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 289 ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 290 ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 291 ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 292 ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 293 ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 294 ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 295 ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 296 ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 297 ] 298 299 self.methods_expect_builtin_promote = [ 300 ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 301 ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 302 ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 303 ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 304 ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 305 ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), 306 ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 307 ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 308 ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), 309 ] 310 # The remaining lists organize ops that autocast treats explicitly. 311 self.torch_16 = [ 312 ("conv1d", conv_args_fp32[0]), 313 ("conv2d", conv_args_fp32[1]), 314 ("conv3d", conv_args_fp32[2]), 315 ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), 316 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 317 ("mm", mat0_fp32 + mat1_fp32), 318 ("matmul", mat0_fp32 + mat1_fp32), 319 ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), 320 torch.randn((n, n, n), device=dev, dtype=torch.float32), 321 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 322 ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), 323 ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), 324 torch.randn((n, n, n), device=dev, dtype=torch.float32))), 325 ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), 326 torch.randn((5, 3, 5), device=dev, dtype=torch.float32), 327 torch.randn(5, device=dev, dtype=torch.float32), 328 0)), 329 ("conv_transpose1d", conv_args_fp32[0]), 330 ("conv_transpose2d", conv_args_fp32[1]), 331 ("conv_transpose3d", conv_args_fp32[2]), 332 ("prelu", pointwise0_fp32 + element0_fp32), 333 ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32), 334 torch.randn((n, n, n), device=dev, dtype=torch.float32), 335 torch.randn((n, n, n), device=dev, dtype=torch.float32), 336 n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32), 337 torch.randn((3 * n), device=dev, dtype=torch.float32), 338 torch.randn((n, n), device=dev, dtype=torch.float32), 339 torch.randn((n), device=dev, dtype=torch.float32))), 340 ] 341 self.torch_fp32 = [ 342 ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))), 343 ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16), 344 torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16), 345 torch.tensor([1], device=dev, dtype=torch.int))), 346 ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)), 347 ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)), 348 ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), 349 ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), 350 ] 351 self.nn_16 = [ 352 ("linear", mat0_fp32 + mat1_fp32, {}), 353 ] 354 self.nn_fp32 = [ 355 ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}), 356 ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) + 357 (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), 358 ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}), 359 ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16), 360 torch.zeros((n,), device=dev, dtype=torch.long))), 361 ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16), 362 torch.zeros((n, n, n), device=dev, dtype=torch.long))), 363 ("l1_loss", mat0_bf16 + mat1_bf16), 364 ("smooth_l1_loss", mat0_bf16 + mat1_bf16), 365 ("mse_loss", mat0_bf16 + mat1_bf16), 366 ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), 367 ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), 368 ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)), 369 ("huber_loss", mat0_bf16 + mat1_bf16), 370 ] 371 self.torch_need_autocast_promote = [ 372 ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), 373 ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), 374 ] 375 376 377class TestAutocast(TestCase): 378 def args_maybe_kwargs(self, op_with_args): 379 if len(op_with_args) == 2: 380 return op_with_args[0], op_with_args[1], {} 381 else: 382 return op_with_args[0], op_with_args[1], op_with_args[2] 383 384 def _run_autocast_outofplace( 385 self, 386 op, 387 args, 388 run_as_type, 389 device, 390 out_type=None, 391 module=torch, 392 add_kwargs=None, 393 amp_dtype=torch.bfloat16, 394 ): 395 # helper to cast args 396 def cast(val, to_type): 397 if isinstance(val, torch.Tensor): 398 return val.to(to_type) if val.is_floating_point() else val 399 elif isinstance(val, collections.abc.Iterable): 400 return type(val)(cast(v, to_type) for v in val) 401 else: 402 return val 403 404 if add_kwargs is None: 405 add_kwargs = {} 406 407 self.assertFalse(torch.is_autocast_enabled(device_type=device)) 408 with torch.amp.autocast(device_type=device, dtype=amp_dtype): 409 self.assertTrue(torch.is_autocast_enabled(device_type=device)) 410 411 out_type = out_type if out_type is not None else run_as_type 412 output = output_method = None 413 414 # Try module.* variant, if requested: 415 if module is not None and hasattr(module, op): 416 output = getattr(module, op)(*args, **add_kwargs) 417 if isinstance(output, torch.Tensor): 418 self.assertTrue( 419 out_type == output.dtype, 420 f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", 421 ) 422 # Try Tensor.* variant: 423 if hasattr(torch.Tensor, op): 424 output_method = getattr(args[0], op)(*args[1:], **add_kwargs) 425 if isinstance(output_method, torch.Tensor): 426 self.assertTrue( 427 out_type == output_method.dtype, 428 f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", 429 ) 430 431 self.assertTrue( 432 (output is not None) or (output_method is not None), 433 f"{op} not found as an attribute on either Tensor or the requested module {module}", 434 ) 435 436 # Accounts for ops that return Tensors, iterables, and other non-Tensors. 437 # For example, lstm_cell returns a tuple and equal returns bool. 438 def compare(first, second): 439 if isinstance(first, torch.Tensor): 440 return torch.equal(first, second) 441 elif isinstance(first, collections.abc.Iterable): 442 return all(compare(f, s) for f, s in zip(first, second)) 443 else: 444 return first == second 445 446 # If both torch.* and Tensor.* variants were found, check outputs are identical 447 if (output is not None) and (output_method is not None): 448 self.assertTrue(type(output) == type(output_method)) 449 comparison = compare(output, output_method) 450 self.assertTrue( 451 comparison, f"torch.{op} result did not match Tensor.{op} result" 452 ) 453 454 # Compare numerics to Python-side "autocasting" that (we expect) does the same thing 455 # as the C++-side autocasting, and should be bitwise accurate. 456 output_to_compare = output if output is not None else output_method 457 with torch.amp.autocast(device_type=device, enabled=False): 458 self.assertFalse( 459 torch.is_autocast_enabled(device_type=device) 460 ) 461 462 if module is not None and hasattr(module, op): 463 control = getattr(module, op)( 464 *cast(args, run_as_type), **add_kwargs 465 ) 466 else: 467 control = getattr(args[0].to(run_as_type), op)( 468 *cast(args[1:], run_as_type), **add_kwargs 469 ) 470 self.assertTrue(type(output_to_compare) == type(control)) 471 comparison = compare(output_to_compare, control) 472 self.assertTrue(comparison, f"torch.{op} result did not match control") 473 self.assertTrue(torch.is_autocast_enabled(device_type=device)) 474 self.assertFalse(torch.is_autocast_enabled(device_type=device)) 475