xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/autocast_test_lists.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4
5import torch
6from torch.testing._internal.common_utils import TEST_WITH_ROCM
7from torch.testing._internal.common_utils import TestCase
8
9
10class AutocastTestLists:
11    def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
12        input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
13
14        hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
15               torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
16              torch.randn((n, n), device=dev, dtype=torch.float32),)
17
18        weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_ih
19                   torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_hh
20                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32),  # bias_ih
21                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32))  # bias_hh
22
23        # returns args as a tuple
24        return input + hx + weights
25
26    # Supplies ops and arguments for test_autocast_* in test/test_cuda.py
27    def __init__(self, dev):
28        super().__init__()
29        n = 8
30        # Utility arguments, created as one-element tuples
31        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
32        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
33        pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
34        mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
35        mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
36        mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
37
38        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
39        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
40                           torch.randn(dimset, dtype=torch.float32, device=dev))
41                          for dimset in dimsets]
42        bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
43        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
44        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
45        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
46        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
47        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
48        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
49        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
50
51        # The lists below organize ops that autocast needs to test.
52        # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
53        # Each op is associated with a tuple of valid arguments.
54        # In addition, cudnn conv ops are not supported on ROCm and hence will
55        # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
56
57        # Some ops implement built-in type promotion.  These don't need autocasting,
58        # but autocasting relies on their promotion, so we include tests to double-check.
59        self.torch_expect_builtin_promote = [
60            ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
61            ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
62            ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
63            ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
64            ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
65            ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
66            ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
67            ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
68            ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
69            ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
70            ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
71            ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
72        ]
73        self.methods_expect_builtin_promote = [
74            ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
75            ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
76            ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
77            ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
78            ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
79            ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
80            ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
81            ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
82            ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
83        ]
84
85        # The remaining lists organize ops that autocast treats explicitly.
86        self.torch_fp16 = [
87            # deprecated _convolution
88            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
89                                                              (0, 0), 1, False, True, True)),
90            # the current  _convolution
91            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
92                                                              (0, 0), 1, False, True, True, True)),
93            ("conv1d", conv_args_fp32[0]),
94            ("conv2d", conv_args_fp32[1]),
95            ("conv3d", conv_args_fp32[2]),
96            ("conv_tbc", conv_args_fp32[0] + bias_fp32),
97            ("conv_transpose1d", conv_args_fp32[0]),
98            ("conv_transpose2d", conv_args_fp32[1]),
99            ("conv_transpose3d", conv_args_fp32[2]),
100            ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
101            ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
102            ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
103                                                                 (1, 1), 1, False, True, True), TEST_WITH_ROCM),
104            ("prelu", pointwise0_fp32 + element0_fp32),
105            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
106            ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
107            ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
108            ("matmul", mat0_fp32 + mat1_fp32),
109            ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
110            ("mm", mat0_fp32 + mat1_fp32),
111            ("mv", mat0_fp32 + pointwise0_fp32),
112            ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
113            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
114                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
115            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
116                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
117                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
118            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
119                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
120            # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
121            # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
122            # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
123            ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
124            ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
125            ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
126            ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
127        ]
128        self.torch_fp32 = [
129            ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
130            ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
131            ("cosh", pointwise0_fp16),
132            ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
133            ("exp", pointwise0_fp16),
134            ("expm1", pointwise0_fp16),
135            ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
136            ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
137            ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
138            ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
139            ("reciprocal", pointwise0_fp16),
140            ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
141            ("sinh", pointwise0_fp16),
142            ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
143            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
144            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
145            # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
146            ("softmax", pointwise0_fp16 + (0,)),
147            ("log_softmax", pointwise0_fp16 + (0,)),
148            ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
149            ("group_norm", mat0_fp16 + (1,)),
150            ("norm", pointwise0_fp16),
151            ("norm", pointwise0_fp16, {"dim": 0}),
152            # these need magma
153            # ("norm", mat0_fp16, {"p": "nuc"}),
154            # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
155            ("norm", pointwise0_fp16, {"p": 1}),
156            ("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
157            ("cosine_similarity", mat0_fp16 + mat1_fp16),
158            ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
159            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
160                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
161                                       torch.tensor([1], device=dev, dtype=torch.int))),
162            ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
163            ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
164            ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
165            ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
166            ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
167            ("cumprod", pointwise0_fp16 + (0,)),
168            ("cumsum", pointwise0_fp16 + (0,)),
169            ("dist", pointwise0_fp16 + pointwise1_fp16),
170            ("pdist", mat0_fp16),
171            ("cdist", mat0_fp16 + mat1_fp16),
172            ("prod", pointwise0_fp16),
173            ("prod", pointwise0_fp16 + (0,)),
174            ("renorm", mat0_fp16 + (2, 0, 1.0)),
175            ("sum", pointwise0_fp16),
176            ("sum", mat0_fp16 + (1,)),
177            ("logsumexp", mat0_fp16 + (1,)),
178        ]
179        self.torch_need_autocast_promote = [
180            ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
181            ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
182            ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
183            ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
184                          torch.randn((1, 2), dtype=torch.float32, device=dev),
185                          torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
186                          torch.randn((1,), dtype=torch.float32, device=dev))),
187            ("cross", (torch.randn(3, dtype=torch.float32, device=dev),
188                       torch.randn(3, dtype=torch.float16, device=dev))),
189            ("dot", pointwise0_fp16 + pointwise1_fp32),
190            ("vdot", pointwise0_fp16 + pointwise1_fp32),
191            ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
192                              torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
193                              0, 0, False)),
194            ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
195                                             torch.randn(1, device=dev, dtype=torch.float16))),
196            ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
197                                             torch.randn(1, device=dev, dtype=torch.float32))),
198            ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
199                           torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
200            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
201                             0,
202                             torch.randint(0, 2, (2, 2, 2), device=dev),
203                             torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
204            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
205                             0,
206                             torch.randint(0, 2, (2, 2, 2), device=dev),
207                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
208        ]
209        self.nn_fp16 = [
210            ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
211        ]
212        self.nn_fp32 = [
213            ("softplus", pointwise0_fp16),
214            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
215                          torch.zeros((n,), device=dev, dtype=torch.long))),
216            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
217                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
218            ("l1_loss", mat0_fp16 + mat1_fp16),
219            ("smooth_l1_loss", mat0_fp16 + mat1_fp16),
220            ("mse_loss", mat0_fp16 + mat1_fp16),
221            ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
222            ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
223            ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
224        ]
225        self.linalg_fp16 = [
226            ("linalg_vecdot", mat0_fp32 + mat0_fp32),
227            ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
228        ]
229        self.methods_fp16 = [
230            ("__matmul__", mat0_fp32 + mat1_fp32)
231        ]
232        self.methods_fp32 = [
233            ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
234        ]
235        self.banned = [
236            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
237                                      torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
238        ]
239
240
241class AutocastCPUTestLists:
242    # Supplies ops and arguments for test_autocast_* in test/test_cpu.py
243    def __init__(self, dev):
244        super().__init__()
245        n = 8
246        # Utility arguments, created as one-element tuples
247        pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
248        pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
249        pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
250        mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
251        mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
252        mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
253
254        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
255        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
256
257        dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
258
259        dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
260                      for dimset in dummy_dimsets]
261
262        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
263        conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),
264                           torch.randn(dimset, dtype=torch.bfloat16, device=dev))
265                          for dimset in dimsets]
266        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
267                           torch.randn(dimset, dtype=torch.float32, device=dev))
268                          for dimset in dimsets]
269
270        bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
271        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
272        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
273        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
274        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
275        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
276        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
277        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
278
279        dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),)
280                      for dimset in dummy_dimsets]
281        # The lists below organize ops that autocast needs to test.
282        # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
283        # Each op is associated with a tuple of valid arguments.
284
285        # Some ops implement built-in type promotion.  These don't need autocasting,
286        # but autocasting relies on their promotion, so we include tests to double-check.
287        self.torch_expect_builtin_promote = [
288            ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
289            ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
290            ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
291            ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
292            ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
293            ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
294            ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
295            ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
296            ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
297        ]
298
299        self.methods_expect_builtin_promote = [
300            ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
301            ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
302            ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
303            ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
304            ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
305            ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
306            ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
307            ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
308            ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
309        ]
310        # The remaining lists organize ops that autocast treats explicitly.
311        self.torch_16 = [
312            ("conv1d", conv_args_fp32[0]),
313            ("conv2d", conv_args_fp32[1]),
314            ("conv3d", conv_args_fp32[2]),
315            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
316                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
317            ("mm", mat0_fp32 + mat1_fp32),
318            ("matmul", mat0_fp32 + mat1_fp32),
319            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
320                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
321                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
322            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
323            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
324                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
325            ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
326                          torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
327                          torch.randn(5, device=dev, dtype=torch.float32),
328                          0)),
329            ("conv_transpose1d", conv_args_fp32[0]),
330            ("conv_transpose2d", conv_args_fp32[1]),
331            ("conv_transpose3d", conv_args_fp32[2]),
332            ("prelu", pointwise0_fp32 + element0_fp32),
333            ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
334                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
335                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
336                                              n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
337                                              torch.randn((3 * n), device=dev, dtype=torch.float32),
338                                              torch.randn((n, n), device=dev, dtype=torch.float32),
339                                              torch.randn((n), device=dev, dtype=torch.float32))),
340        ]
341        self.torch_fp32 = [
342            ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
343            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
344                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
345                                       torch.tensor([1], device=dev, dtype=torch.int))),
346            ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
347            ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
348            ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
349            ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
350        ]
351        self.nn_16 = [
352            ("linear", mat0_fp32 + mat1_fp32, {}),
353        ]
354        self.nn_fp32 = [
355            ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
356            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
357                                     (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
358            ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
359            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
360                          torch.zeros((n,), device=dev, dtype=torch.long))),
361            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
362                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
363            ("l1_loss", mat0_bf16 + mat1_bf16),
364            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
365            ("mse_loss", mat0_bf16 + mat1_bf16),
366            ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
367            ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
368            ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
369            ("huber_loss", mat0_bf16 + mat1_bf16),
370        ]
371        self.torch_need_autocast_promote = [
372            ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
373            ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
374        ]
375
376
377class TestAutocast(TestCase):
378    def args_maybe_kwargs(self, op_with_args):
379        if len(op_with_args) == 2:
380            return op_with_args[0], op_with_args[1], {}
381        else:
382            return op_with_args[0], op_with_args[1], op_with_args[2]
383
384    def _run_autocast_outofplace(
385        self,
386        op,
387        args,
388        run_as_type,
389        device,
390        out_type=None,
391        module=torch,
392        add_kwargs=None,
393        amp_dtype=torch.bfloat16,
394    ):
395        # helper to cast args
396        def cast(val, to_type):
397            if isinstance(val, torch.Tensor):
398                return val.to(to_type) if val.is_floating_point() else val
399            elif isinstance(val, collections.abc.Iterable):
400                return type(val)(cast(v, to_type) for v in val)
401            else:
402                return val
403
404        if add_kwargs is None:
405            add_kwargs = {}
406
407        self.assertFalse(torch.is_autocast_enabled(device_type=device))
408        with torch.amp.autocast(device_type=device, dtype=amp_dtype):
409            self.assertTrue(torch.is_autocast_enabled(device_type=device))
410
411            out_type = out_type if out_type is not None else run_as_type
412            output = output_method = None
413
414            # Try module.* variant, if requested:
415            if module is not None and hasattr(module, op):
416                output = getattr(module, op)(*args, **add_kwargs)
417                if isinstance(output, torch.Tensor):
418                    self.assertTrue(
419                        out_type == output.dtype,
420                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
421                    )
422            # Try Tensor.* variant:
423            if hasattr(torch.Tensor, op):
424                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
425                if isinstance(output_method, torch.Tensor):
426                    self.assertTrue(
427                        out_type == output_method.dtype,
428                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
429                    )
430
431            self.assertTrue(
432                (output is not None) or (output_method is not None),
433                f"{op} not found as an attribute on either Tensor or the requested module {module}",
434            )
435
436            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
437            # For example, lstm_cell returns a tuple and equal returns bool.
438            def compare(first, second):
439                if isinstance(first, torch.Tensor):
440                    return torch.equal(first, second)
441                elif isinstance(first, collections.abc.Iterable):
442                    return all(compare(f, s) for f, s in zip(first, second))
443                else:
444                    return first == second
445
446            # If both torch.* and Tensor.* variants were found, check outputs are identical
447            if (output is not None) and (output_method is not None):
448                self.assertTrue(type(output) == type(output_method))
449                comparison = compare(output, output_method)
450                self.assertTrue(
451                    comparison, f"torch.{op} result did not match Tensor.{op} result"
452                )
453
454            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
455            # as the C++-side autocasting, and should be bitwise accurate.
456            output_to_compare = output if output is not None else output_method
457            with torch.amp.autocast(device_type=device, enabled=False):
458                self.assertFalse(
459                    torch.is_autocast_enabled(device_type=device)
460                )
461
462                if module is not None and hasattr(module, op):
463                    control = getattr(module, op)(
464                        *cast(args, run_as_type), **add_kwargs
465                    )
466                else:
467                    control = getattr(args[0].to(run_as_type), op)(
468                        *cast(args[1:], run_as_type), **add_kwargs
469                    )
470                self.assertTrue(type(output_to_compare) == type(control))
471                comparison = compare(output_to_compare, control)
472                self.assertTrue(comparison, f"torch.{op} result did not match control")
473            self.assertTrue(torch.is_autocast_enabled(device_type=device))
474        self.assertFalse(torch.is_autocast_enabled(device_type=device))
475