xref: /aosp_15_r20/external/pytorch/test/test_jit_autocast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.cuda.amp import autocast
5from typing import Optional, Tuple
6
7import unittest
8from test_jit import JitTestCase
9from torch.testing._internal.common_cuda import TEST_CUDA
10from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
11from torch.testing import FileCheck
12from jit.test_models import MnistNet
13
14TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
15
16@skipIfTorchDynamo("Not a TorchDynamo suitable test")
17class TestAutocast(JitTestCase):
18    def setUp(self):
19        # common input tensors
20        if TEST_CUDA:
21            self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
22            self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
23            self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
24            self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
25            self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
26            self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
27            self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
28            self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
29        self.old_value = torch._C._jit_set_autocast_mode(True)
30        super().setUp()
31
32    def tearDown(self):
33        torch._C._jit_set_autocast_mode(self.old_value)
34        super().tearDown()
35
36    @unittest.skipIf(not TEST_CUDA, "No cuda")
37    def test_jit_generic_autocast(self):
38        @torch.jit.script
39        def fn_cuda_autocast(a, b):
40            with autocast():
41                x = torch.mm(a, b)
42                y = torch.sum(x)
43                return x, y
44
45        @torch.jit.script
46        def fn_generic_autocast(a, b):
47            with torch.amp.autocast(device_type='cuda'):
48                x = torch.mm(a, b)
49                y = torch.sum(x)
50                return x, y
51        self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))
52
53    @unittest.skipIf(not TEST_CUDA, "No cuda")
54    def test_minimal(self):
55        @torch.jit.script
56        def fn(a, b):
57            with autocast():
58                x = torch.mm(a, b)
59                y = torch.sum(x)
60                return x, y
61        x, y = fn(self.a_fp32, self.b_fp32)
62        self.assertEqual(x.dtype, torch.float16)
63        self.assertEqual(y.dtype, torch.float32)
64
65    @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
66    def test_linear_bf16(self):
67        @torch.jit.script
68        def fn(a, b):
69            with autocast(dtype=torch.bfloat16):
70                x = torch.mm(a, b)
71                y = torch.sum(x)
72                return x, y
73        x, y = fn(self.a_fp32, self.b_fp32)
74        self.assertEqual(x.dtype, torch.bfloat16)
75        self.assertEqual(y.dtype, torch.float32)
76
77    @unittest.skipIf(not TEST_CUDA, "No cuda")
78    def test_minimal_cpu(self):
79        @torch.jit.script
80        def fn(a, b):
81            with autocast():
82                return torch.mm(a, b)
83        result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
84        self.assertEqual(result.dtype, torch.float32)
85
86    @unittest.skipIf(not TEST_CUDA, "No cuda")
87    def test_minimal_off(self):
88        @torch.jit.script
89        def fn(a, b):
90            with autocast(enabled=False):
91                return torch.mm(a, b)
92        result = fn(self.a_fp32, self.b_fp32)
93        self.assertEqual(result.dtype, torch.float32)
94
95    @unittest.skipIf(not TEST_CUDA, "No cuda")
96    def test_runtime_autocast_state(self):
97        @torch.jit.script
98        def fn(a, b, use_amp: bool):
99            with autocast(enabled=use_amp):
100                return torch.mm(a, b)
101        # runtime values for autocast enable argument are not supported
102        with self.assertRaises(RuntimeError):
103            fn(self.a_fp32, self.b_fp32, True)
104
105    @unittest.skipIf(not TEST_CUDA, "No cuda")
106    def test_runtime_autocast_state_expr(self):
107        @torch.jit.script
108        def fn(a, b):
109            with autocast(enabled=True if a[0][0] > 0.5 else False):
110                return torch.mm(a, b)
111        # runtime values for autocast enable argument are not supported
112        with self.assertRaises(RuntimeError):
113            fn(self.a_fp32, self.b_fp32)
114
115    @unittest.skipIf(not TEST_CUDA, "No cuda")
116    def test_explicit_casts(self):
117        @torch.jit.script
118        def fn(a, b, c, d):
119            with autocast():
120                e = torch.mm(a.double(), b.double()).float()
121                f = torch.mm(c, d).double()
122            g = torch.mm(c.double(), f)
123            return e, f, g
124        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
125        self.assertEqual(e.dtype, torch.float32)
126        self.assertEqual(f.dtype, torch.float64)
127        self.assertEqual(g.dtype, torch.float64)
128
129    # multiple uses of the same input value
130    @unittest.skipIf(not TEST_CUDA, "No cuda")
131    def test_duplicate_inputs(self):
132        @torch.jit.script
133        def fn(a, b):
134            with autocast():
135                e = torch.mm(a, a)
136                f = torch.mm(e, e)
137            return e, f
138        e, f = fn(self.a_fp32, self.b_fp32)
139        self.assertEqual(e.dtype, torch.float16)
140        self.assertEqual(f.dtype, torch.float16)
141
142    @unittest.skipIf(not TEST_CUDA, "No cuda")
143    def test_fp32_policy(self):
144        @torch.jit.script
145        def fn(a):
146            with autocast(enabled=True):
147                return torch.log(a)
148        result = fn(self.a_fp16)
149        self.assertEqual(result.dtype, torch.float32)
150
151    @unittest.skipIf(not TEST_CUDA, "No cuda")
152    def test_fp32_policy_with_fp64(self):
153        @torch.jit.script
154        def fn(a):
155            with autocast(enabled=True):
156                return torch.log(a)
157        # fp32 policy should not narrow fp64 to fp32!
158        result = fn(self.a_fp32.double())
159        self.assertEqual(result.dtype, torch.float64)
160
161    @unittest.skipIf(not TEST_CUDA, "No cuda")
162    def test_promote_policy(self):
163        @torch.jit.script
164        def fn(a, b, c, d):
165            with autocast():
166                e = torch.mm(a, b)
167                f = torch.addcmul(e, c, d, value=0.1)
168            return e, f
169        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
170        self.assertEqual(e.dtype, torch.float16)
171        self.assertEqual(f.dtype, torch.float32)
172
173    @unittest.skipIf(not TEST_CUDA, "No cuda")
174    def test_promote_policy_fp64(self):
175        @torch.jit.script
176        def fn(a, b):
177            with autocast(enabled=True):
178                return torch.addcmul(a, a, b, value=0.1)
179        result = fn(self.a_fp32.double(), self.b_fp32.double())
180        self.assertEqual(result.dtype, torch.float64)
181
182    @unittest.skipIf(not TEST_CUDA, "No cuda")
183    def test_fp32_set_opt_dtype_policy(self):
184        @torch.jit.script
185        def fn(a, b, c, d, dtype: Optional[int]):
186            with autocast(enabled=True):
187                x = torch.softmax(a, 0)
188                y = torch.softmax(b, 0, None)
189                z = torch.softmax(c, 0, torch.float64)
190                w = torch.softmax(d, 0, dtype)
191            return x, y, z, w
192        x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
193        self.assertEqual(x.dtype, torch.float32)
194        self.assertEqual(y.dtype, torch.float32)
195        self.assertEqual(z.dtype, torch.float64)
196        self.assertEqual(w.dtype, torch.float16)
197
198    @unittest.skipIf(not TEST_CUDA, "No cuda")
199    def test_fp32_set_opt_dtype_policy_fp64(self):
200        @torch.jit.script
201        def fn(a, b, c, d, dtype: Optional[int]):
202            with autocast(enabled=True):
203                x = torch.softmax(a, 0)
204                y = torch.softmax(b, 0, None)
205                z = torch.softmax(c, 0, torch.float64)
206                w = torch.softmax(d, 0, dtype)
207            return x, y, z, w
208        x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
209        self.assertEqual(x.dtype, torch.float64)
210        self.assertEqual(y.dtype, torch.float64)
211        self.assertEqual(z.dtype, torch.float64)
212        self.assertEqual(w.dtype, torch.float64)
213
214    @unittest.skipIf(True, "broken due to lack of type propagation")
215    @unittest.skipIf(not TEST_CUDA, "No cuda")
216    def test_control_flow(self):
217        @torch.jit.script
218        def fn(a, b, c, d):
219            with autocast():
220                if a[0][0] > 0.5:
221                    e = torch.mm(a, b)
222                    x = 1
223                else:
224                    e = torch.mm(c, d)
225                    x = 2
226                f = torch.mm(d, e) * x
227            return e, f
228        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
229        self.assertEqual(e.dtype, torch.float16)
230        self.assertEqual(f.dtype, torch.float16)
231
232    # this works find in regular Python, but it creates a delicate
233    # situation in TorchScript where the types are not consistent across
234    # the then/else branches
235    @unittest.skipIf(not TEST_CUDA, "No cuda")
236    def test_divergent_types(self):
237        @torch.jit.script
238        def fn(a, b, c, d):
239            with autocast():
240                if a[0][0] > 0.5:
241                    e = torch.mm(a, b)
242                    f = torch.mm(a, b).float()
243                else:
244                    e = torch.mm(c, d).float()
245                    f = torch.mm(a, b)
246            return torch.mm(e.float(), f.float())
247        result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
248        self.assertEqual(result.dtype, torch.float32)
249
250    # another, more complex case of divergent types
251    @unittest.skipIf(not TEST_CUDA, "No cuda")
252    def test_divergent_autocast(self):
253        @torch.jit.script
254        def fn(a, b, c, d):
255            autocast_on = autocast(enabled=True)
256            autocast_off = autocast(enabled=False)
257            if a[0][0] > 0.5:
258                with autocast_on:
259                    e = torch.mm(a, b)
260            else:
261                with autocast_off:
262                    e = torch.mm(c, d)
263            return torch.mm(e, e)
264        fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
265
266    @unittest.skipIf(not TEST_CUDA, "No cuda")
267    def test_conditional_autocast(self):
268        @torch.jit.script
269        def fn(a, b):
270            autocast_on = autocast(enabled=True)
271            autocast_off = autocast(enabled=False)
272            with autocast_on if a[0][0] > 0.5 else autocast_off:
273                return torch.mm(a, b)
274        # conditional autocast expressions are not supported
275        with self.assertRaises(RuntimeError):
276            fn(self.a_fp32, self.b_fp32)
277
278    @unittest.skipIf(not TEST_CUDA, "No cuda")
279    def test_nested_autocast(self):
280        @torch.jit.script
281        def fn(a, b, c, d):
282            with autocast(enabled=False):
283                e = torch.mm(a, b)
284                with autocast(enabled=True):
285                    f = torch.mm(e, c)
286                    with autocast(enabled=False):
287                        g = torch.mm(e, d)
288            return e, f, g
289        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
290        self.assertEqual(e.dtype, torch.float32)
291        self.assertEqual(f.dtype, torch.float16)
292        self.assertEqual(g.dtype, torch.float32)
293
294    @unittest.skipIf(not TEST_CUDA, "No cuda")
295    def test_implicitly_nested_autocast(self):
296        @torch.jit.script
297        def fn(a, b):
298            with autocast(enabled=False), autocast(enabled=True):
299                return torch.mm(a, b)
300        result = fn(self.a_fp32, self.b_fp32)
301        self.assertEqual(result.dtype, torch.float16)
302
303    @unittest.skipIf(not TEST_CUDA, "No cuda")
304    def test_reused_autocast(self):
305        @torch.jit.script
306        def fn(a, b, c, d):
307            autocast_instance = autocast(enabled=True)
308            with autocast_instance:
309                e = torch.mm(a, b)
310                with autocast_instance:
311                    e = torch.mm(c, d)
312                    f = torch.mm(d, e)
313            g = torch.mm(e, f)
314            return e, f, g
315        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
316        self.assertEqual(e.dtype, torch.float16)
317        self.assertEqual(f.dtype, torch.float16)
318        self.assertEqual(g.dtype, torch.float16)
319
320    # TODO: fix and enable this test?
321    #   (we could technically fix this, but is it really worth it?)
322    @unittest.skipIf(True, "unsuported autocast syntax")
323    def test_reused_autocast_expr(self):
324        @torch.jit.script
325        def fn(a, b, c, d):
326            with autocast(enabled=True) as autocast_instance:
327                e = torch.mm(a, b)
328                with autocast_instance:
329                    e = torch.mm(c, d)
330                    f = torch.mm(d, e)
331            g = torch.mm(e, f)
332            return e, f, g
333        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
334        self.assertEqual(e.dtype, torch.float16)
335        self.assertEqual(f.dtype, torch.float16)
336        self.assertEqual(g.dtype, torch.float16)
337
338    @unittest.skipIf(not TEST_CUDA, "No cuda")
339    def test_callees(self):
340        def helper(a, b):
341            return torch.mm(a, b)
342
343        @torch.jit.script
344        def fn(a, b):
345            with autocast(enabled=True):
346                tmp = helper(a, b)
347                tmp = helper(tmp, tmp)
348                tmp = helper(tmp, tmp)
349                tmp = helper(tmp, tmp)
350                return helper(tmp, b)
351
352        result = fn(self.a_fp32, self.b_fp32)
353        self.assertEqual(result.dtype, torch.float16)
354
355    @unittest.skipIf(not TEST_CUDA, "No cuda")
356    def test_callees_with_autocast_on(self):
357        def helper(a, b):
358            with autocast(enabled=True):
359                return torch.mm(a, b)
360
361        @torch.jit.script
362        def fn(a, b):
363            with autocast(enabled=False):
364                return helper(a, b)
365
366        result = fn(self.a_fp32, self.b_fp32)
367        self.assertEqual(result.dtype, torch.float16)
368
369    @unittest.skipIf(not TEST_CUDA, "No cuda")
370    def test_callees_with_autocast_off(self):
371        def helper(a, b):
372            with autocast(enabled=False):
373                return torch.mm(a, b)
374
375        @torch.jit.script
376        def fn(a, b):
377            with autocast(enabled=True):
378                return helper(a, b)
379
380        result = fn(self.a_fp32, self.b_fp32)
381        self.assertEqual(result.dtype, torch.float32)
382
383    # scripting inside eager autocast
384    @unittest.skipIf(not TEST_CUDA, "No cuda")
385    def test_eager_and_script(self):
386        @torch.jit.script
387        def fn(a, b):
388            return torch.mm(a, b)
389        for i in range(8):
390            use_autocast = (i % 2 == 0)
391            expected_dtype = torch.float16 if use_autocast else torch.float32
392            with autocast(enabled=use_autocast):
393                result = fn(self.a_fp32, self.b_fp32)
394            self.assertEqual(result.dtype, expected_dtype)
395
396    # traced inside scripting
397    @unittest.skipIf(not TEST_CUDA, "No cuda")
398    def test_script_and_tracing(self):
399        def helper(a, b):
400            return torch.mm(a, b)
401
402        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
403
404        @torch.jit.script
405        def fn(a, b):
406            with autocast(enabled=True):
407                return traced(a, b)
408
409        result = fn(self.a_fp32, self.b_fp32)
410        self.assertEqual(result.dtype, torch.float16)
411
412    # traced with autocast inside scripting
413    @unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
414    @unittest.skipIf(not TEST_CUDA, "No cuda")
415    def test_script_and_tracing_with_autocast(self):
416        def helper(a, b):
417            with autocast(enabled=False):
418                return torch.mm(a, b) * 2.0
419
420        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
421
422        @torch.jit.script
423        def fn(a, b):
424            with autocast(enabled=True):
425                return traced(a, b)
426
427        result = fn(self.a_fp32, self.b_fp32)
428        self.assertEqual(result.dtype, torch.float32)
429
430    # scripted called from traced
431    @unittest.skipIf(not TEST_CUDA, "No cuda")
432    def test_tracing_and_script(self):
433        @torch.jit.script
434        def fn(a, b):
435            with autocast():
436                return torch.mm(a, b)
437
438        def traced(a, b):
439            return fn(a, b)
440
441        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
442        result = traced(self.a_fp32, self.b_fp32)
443        self.assertEqual(result.dtype, torch.float16)
444
445    # scripted called from traced with autocast
446    @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
447    @unittest.skipIf(not TEST_CUDA, "No cuda")
448    def test_tracing_with_autocast_and_script(self):
449        @torch.jit.script
450        def fn(a, b):
451            return torch.mm(a, b)
452
453        def traced(a, b):
454            with autocast(enabled=True):
455                return fn(a, b)
456
457        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
458        result = traced(self.a_fp32, self.b_fp32)
459        self.assertEqual(result.dtype, torch.float16)
460
461    @unittest.skipIf(not TEST_CUDA, "No cuda")
462    def test_script_module(self):
463        class TestModule(torch.nn.Module):
464            def __init__(self, N, M):
465                super().__init__()
466                self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
467                self.linear = torch.nn.Linear(N, M).float()
468
469            def forward(self, input):
470                with autocast(enabled=True):
471                    output = self.weight.mv(input)
472                    output = self.linear(output)
473                    return output
474
475        scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
476        input = torch.rand(3, dtype=torch.float32, device='cuda')
477        result = scripted_module(input)
478        self.assertEqual(result.dtype, torch.float16)
479
480    @unittest.skipIf(True, "autocast decorators not supported")
481    @unittest.skipIf(not TEST_CUDA, "No cuda")
482    def test_autocast_decorator(self):
483        @torch.jit.script
484        @autocast(enabled=True)
485        def fn(a, b):
486            return torch.mm(a, b)
487        result = fn(self.a_fp32, self.b_fp32)
488        self.assertEqual(result.dtype, torch.float16)
489
490    # this is equivalent to running scripted functions inside autocast)
491    # (see also test_eager_and_script)
492    @unittest.skipIf(not TEST_CUDA, "No cuda")
493    def test_autocast_decorator_outside_jit(self):
494        @autocast(enabled=True)
495        @torch.jit.script
496        def fn(a, b):
497            return torch.mm(a, b)
498        result = fn(self.a_fp32, self.b_fp32)
499        self.assertEqual(result.dtype, torch.float16)
500
501    @unittest.skipIf(not TEST_CUDA, "No cuda")
502    def test_inplace(self):
503        @torch.jit.script
504        def fn(a, b, c):
505            with autocast(enabled=True):
506                x = torch.addmm(a, b, c)
507                y = torch.addmm(a, b, c, out=a)
508                z = a.addmm_(b, c)
509                return x, y, z
510        x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
511        self.assertEqual(x.dtype, torch.float16)
512        self.assertEqual(y.dtype, torch.float32)
513        self.assertEqual(z.dtype, torch.float32)
514
515    def _test_autocast(self, func, cast_op, *args):
516        jit_func = torch.jit.script(func)
517        o = func(*args)
518        jit_o = jit_func(*args)
519        if cast_op is not None:
520            FileCheck().check(cast_op).run(jit_func.graph_for(*args))
521        for o0, o1 in zip(o, jit_o):
522            self.assertEqual(o0.dtype, o1.dtype)
523
524    @unittest.skipIf(not TEST_CUDA, "No cuda")
525    def test_autocast_api(self):
526
527        def t_autocast_cpu(x, y):
528            with torch.autocast("cpu", dtype=torch.bfloat16):
529                return torch.mm(x, y)
530
531        def t_autocast_cuda(x, y):
532            with torch.autocast("cuda", dtype=torch.half):
533                return torch.mm(x, y)
534
535        def t_cuda_amp_autocast(x, y):
536            with torch.cuda.amp.autocast():
537                return torch.mm(x, y)
538
539        def t_cpu_amp_autocast(x, y):
540            with torch.cpu.amp.autocast():
541                return torch.mm(x, y)
542
543        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
544        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
545        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
546        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
547        self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
548        self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
549
550    @unittest.skipIf(True, "we need to provide dtype argument at this moment")
551    @unittest.skipIf(not TEST_CUDA, "No cuda")
552    def test_autocast_api_not_supported(self):
553
554        def t_autocast_cpu(x, y):
555            # no dtype provided is not currently supported
556            with torch.autocast("cpu"):
557                return torch.mm(x, y)
558
559        def t_autocast_cuda(x, y):
560            # no dtype provided is not currently supported
561            with torch.autocast("cuda"):
562                return torch.mm(x, y)
563
564        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
565        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
566        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
567        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
568
569    @unittest.skipIf(not TEST_CUDA, "No cuda")
570    def test_autocast_mixed_dtypes(self):
571
572        def t(cpu0, cpu1, cuda0, cuda1):
573            with torch.autocast("cpu", torch.bfloat16):
574                with torch.autocast("cuda", torch.float16):
575                    cpu_o = torch.mm(cpu0, cpu1)
576                    cuda_o = torch.mm(cuda0, cuda1)
577                    return cpu_o, cuda_o
578
579        jit_t = torch.jit.script(t)
580        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
581        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
582        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
583        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
584        self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
585
586    @unittest.skipIf(not TEST_CUDA, "No cuda")
587    def test_jit_executor_under_autocast(self):
588
589        def t(cpu0, cpu1, cuda0, cuda1):
590            cpu_o = torch.mm(cpu0, cpu1)
591            cuda_o = torch.mm(cuda0, cuda1)
592            return cpu_o, cuda_o
593
594        jit_t = torch.jit.script(t)
595        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
596        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
597        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
598        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
599
600        with torch.autocast("cpu", torch.bfloat16):
601            with torch.autocast("cuda", torch.float16):
602                self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
603
604        with torch.autocast("cpu", torch.bfloat16):
605            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
606
607        with torch.autocast("cuda", torch.float16):
608            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
609
610        # no cast op should be observed when executing outside autocast context
611        self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
612
613    @unittest.skipIf(not TEST_CUDA, "No cuda")
614    def test_autocast_autodiff(self):
615        def t(t0, t1):
616            o = torch.mm(t0, t1)
617            return o.relu()
618
619        jit_t = torch.jit.script(t)
620        t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
621        t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
622
623        # run optimization
624        for i in range(5):
625            with torch.autocast("cuda", torch.float16):
626                jit_o = jit_t(t0, t1)
627            jit_o.sum().backward()
628
629        t0.grad = None
630        t1.grad = None
631        ref_t0 = t0.detach().requires_grad_()
632        ref_t1 = t1.detach().requires_grad_()
633
634        with torch.autocast("cuda", torch.float16):
635            o = t(ref_t0, ref_t1)
636            jit_o = jit_t(t0, t1)
637        jit_o.sum().backward()
638        o.sum().backward()
639        self.assertEqual(o, jit_o)
640        self.assertEqual(t0.grad, ref_t0.grad)
641        self.assertEqual(t1.grad, ref_t1.grad)
642        self.assertEqual(o.dtype, jit_o.dtype)
643        self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
644        self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
645
646    @unittest.skipIf(not TEST_CUDA, "No cuda")
647    def test_jit_call_method_under_autocast(self):
648        @torch.jit.interface
649        class Iface(torch.nn.Module):
650            def forward(self, x, y) -> torch.Tensor:
651                pass
652
653        class Impl(Iface):
654            def forward(self, x, y):
655                return torch.mm(x, y)
656
657        class Thing1(torch.nn.Module):
658            impl: Iface
659
660            def forward(self, x, y):
661                with torch.cuda.amp.autocast():
662                    a = torch.mm(x, y)
663                    b = self.impl.forward(a, x)
664                    return b
665
666        scripted_impl = torch.jit.script(Impl())
667        thing1 = Thing1()
668        thing1.impl = scripted_impl
669        scripted_thing1 = torch.jit.script(thing1)
670        x = torch.rand([2, 2])
671        y = torch.rand([2, 2])
672
673        # make sure this doesn't throw an error
674        with torch.cuda.amp.autocast():
675            ans = scripted_thing1.forward(x, y)
676        self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
677
678        # sanity check: this isn't supported currently when global autocasting
679        # isn't enabled
680        self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
681
682    @unittest.skipIf(not TEST_CUDA, "No cuda")
683    def test_jit_freeze_autocast_basic(self):
684        class TestModule(torch.nn.Module):
685            def forward(self, x, y):
686                with torch.cuda.amp.autocast():
687                    return torch.mm(x, y)
688
689        x = torch.rand((3, 4), dtype=torch.float).cuda()
690        y = torch.rand((4, 5), dtype=torch.float).cuda()
691
692        mod = TestModule().eval()
693
694        # sanity check
695        self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
696
697        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
698        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
699
700        # make sure that the runtime pass doesn't duplicate autocast nodes
701        frozen_mod(x, y)
702        optimized_graph = frozen_mod.graph_for(x, y)
703        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
704
705    @unittest.skipIf(not TEST_CUDA, "No cuda")
706    def test_jit_freeze_autocast_constants(self):
707        class TestModule(torch.nn.Module):
708            def __init__(self) -> None:
709                super().__init__()
710                self.x = torch.rand((3, 4), dtype=torch.float).cuda()
711
712            def forward(self, y):
713                with torch.cuda.amp.autocast():
714                    return torch.mm(self.x, y)
715
716        y = torch.rand((4, 5), dtype=torch.float).cuda()
717        mod = TestModule().eval()
718
719        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
720        # freezing should pre-cast the constant self.x to remove one autocast call
721        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
722
723        # the runtime autocasting pass will re-insert the second autocast call,
724        # but constant propagation will merge it with the constant that it's casting.
725        frozen_mod(y)
726        optimized_graph = frozen_mod.graph_for(y)
727        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
728
729    @unittest.skipIf(TEST_CUDA, "CPU-only test")
730    def test_jit_autocast_softmax_cpu(self):
731        def fn(x):
732            with torch.cpu.amp.autocast():
733                return torch.nn.functional.softmax(x, dim=0)
734
735        fn_s = torch.jit.script(fn)
736        x = torch.rand((2, 2), dtype=torch.bfloat16)
737        fn_s(x)
738        y = fn_s(x)
739
740        self.assertTrue(y.dtype == torch.bfloat16)
741
742    @unittest.skipIf(not TEST_CUDA, "No cuda")
743    def test_jit_autocast_softmax_gpu(self):
744        def fn(x):
745            with torch.cuda.amp.autocast():
746                return torch.nn.functional.softmax(x, dim=0)
747
748        fn_s = torch.jit.script(fn)
749        x = torch.rand((2, 2), dtype=torch.half).cuda()
750        fn_s(x)
751        y = fn_s(x)
752
753        self.assertTrue(y.dtype == torch.float)
754
755    def test_ignore_amp(self):
756        @torch.jit.script
757        def foo(x):
758            return torch.mm(x, x)
759
760        inp = torch.rand([10, 10], dtype=torch.float)
761        foo._set_ignore_amp(True)
762        with torch.cpu.amp.autocast():
763            foo(inp)
764            foo(inp)
765
766        g = torch.jit.last_executed_optimized_graph()
767        FileCheck().check_not("_autocast_to_reduced").run(g)
768
769class convbn(torch.nn.Module):
770    def __init__(self, bias_enabled=True):
771        super().__init__()
772        self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
773        self.bn = torch.nn.BatchNorm2d(64)
774
775    def forward(self, x):
776        return self.bn(self.conv(x))
777
778@skipIfTorchDynamo("Not a TorchDynamo suitable test")
779class TestJitTraceAutocast(JitTestCase):
780    def setUp(self):
781        super().setUp()
782        self.previous_default_dtype = torch.get_default_dtype()
783        torch.set_default_dtype(torch.float32)
784        self.models = [MnistNet(),
785                       convbn(bias_enabled=True),
786                       convbn(bias_enabled=False)]
787        self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
788                       torch.randn(32, 3, 224, 224, device='cpu'),
789                       torch.randn(32, 3, 224, 224, device='cpu')]
790        self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
791
792    def tearDown(self):
793        torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
794        torch.set_default_dtype(self.previous_default_dtype)
795        super().tearDown()
796
797    def test_generate_autocast_jit_trace_model(self):
798        def test_generate_autocast_jit_trace_model(model, x):
799            model.eval()
800            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
801                traced_model = torch.jit.trace(model, x)
802            traced_model = torch.jit.freeze(traced_model)
803        for i in range(self.models.__len__()):
804            test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
805
806    def test_nchw_autocast_jit_trace_model(self):
807        def test_nchw_autocast_jit_trace_model(model, x):
808            model.eval()
809            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
810                traced_model = torch.jit.trace(model, x)
811            traced_model = torch.jit.freeze(traced_model)
812            with torch.no_grad():
813                y = traced_model(x.clone())
814            with torch.cpu.amp.autocast(), torch.no_grad():
815                y2 = model(x.clone())
816            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
817        for i in range(self.models.__len__()):
818            test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
819
820    def test_nhwc_autocast_jit_trace_model(self):
821        def test_nhwc_autocast_jit_trace_model(model, x):
822            model = model.to(memory_format=torch.channels_last)
823            model.eval()
824            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
825                traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
826            traced_model = torch.jit.freeze(traced_model)
827            with torch.no_grad():
828                y = traced_model(x.clone().to(memory_format=torch.channels_last))
829            with torch.cpu.amp.autocast(), torch.no_grad():
830                y2 = model(x.clone().to(memory_format=torch.channels_last))
831            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
832        for i in range(self.models.__len__()):
833            if self.inputs[i].size().__len__() == 5:
834                # NHWC 3D case not support yet
835                continue
836            test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
837
838    def test_cat_promote(self):
839        class TestModel(torch.nn.Module):
840            def forward(self, a, b):
841                return torch.cat([a, b], 0)
842
843        with torch.jit.fuser("none"):
844            # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
845            # To avoid the fusion group from TE, we will disable the fuser here.
846            for jit_freeze_or_not in [False, True]:
847                test_model = TestModel().eval()
848                with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
849                    a = torch.rand(24, 128, 128)
850                    b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
851                    c = test_model(a, b)
852                    traced = torch.jit.trace(test_model, (a, b))
853                if jit_freeze_or_not:
854                    traced = torch.jit.freeze(traced)
855                for _ in range(3):
856                    c2 = traced(a, b)
857                self.assertTrue(c.dtype, torch.float32)
858                self.assertTrue(c2.dtype, torch.float32)
859                traced_graph = traced.graph_for(a, b)
860                self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
861
862    def test_script_autocast_cpu(self):
863        def fn(x):
864            if torch.is_autocast_cpu_enabled():
865                return x.relu()
866            else:
867                return x.sin()
868
869        fn_s = torch.jit.script(fn)
870
871        x = torch.rand((4, 4)) - 0.5
872        with torch.cpu.amp.autocast():
873            self.assertEqual(fn_s(x), fn(x))
874
875        with torch.cpu.amp.autocast(enabled=True):
876            self.assertEqual(fn_s(x), fn(x))
877
878        self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
879
880    @unittest.skipIf(not TEST_CUDA, "No cuda")
881    def test_script_autocast_cuda(self):
882        def fn(x):
883            if torch.is_autocast_enabled():
884                return x.relu()
885            else:
886                return x.sin()
887
888        fn_s = torch.jit.script(fn)
889
890        x = torch.rand((4, 4)) - 0.5
891        with torch.cpu.amp.autocast():
892            self.assertEqual(fn_s(x), fn(x))
893
894        with torch.cuda.amp.autocast(enabled=True):
895            self.assertEqual(fn_s(x), fn(x))
896
897        self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
898
899
900    def test_scripted_aliasing(self):
901        # torch.is_autocast_enabled should not be able to move inside of the autocast context.
902        def fn(x):
903            if torch.is_autocast_enabled():
904                y = True
905            else:
906                y = False
907            with torch.cuda.amp.autocast(enabled=True):
908                z = x.relu()
909            return y, z
910
911        fn_s = torch.jit.script(fn)
912        graph = fn_s.graph
913
914        aliasdb = graph.alias_db()
915
916        is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
917        enter_nodes = graph.findAllNodes("prim::Enter")
918
919        self.assertEqual(len(is_enabled_nodes), 1)
920        self.assertEqual(len(enter_nodes), 1)
921
922        self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
923
924
925    def test_script_autocast_enable_and_check(self):
926        def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
927            b1 = torch.is_autocast_cpu_enabled()
928            v1 = torch.mm(x, y)
929            with torch.cpu.amp.autocast(enabled=True):
930                b2 = torch.is_autocast_cpu_enabled()
931                v2 = torch.mm(x, y)
932                with torch.cpu.amp.autocast(enabled=False):
933                    b3 = torch.is_autocast_cpu_enabled()
934                    v3 = torch.mm(x, y)
935            return (v1, b1, v2, b2, v3, b3)
936
937        # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
938        def check_fn_results(arr):
939            [v1, b1, v2, b2, v3, b3] = arr
940            self.assertTrue((v1.dtype == torch.float) != b1)
941            self.assertTrue((v2.dtype == torch.float) != b2)
942            self.assertTrue((v3.dtype == torch.float) != b3)
943
944        x = torch.rand((2, 2), dtype=torch.float)
945        y = torch.rand((2, 2), dtype=torch.float)
946
947        fn_s = torch.jit.script(fn)
948
949        with torch.cpu.amp.autocast(enabled=False):
950            check_fn_results(fn(x, y))
951            check_fn_results(fn_s(x, y))
952
953        with torch.cpu.amp.autocast(enabled=True):
954            check_fn_results(fn(x, y))
955            check_fn_results(fn_s(x, y))
956
957
958if __name__ == "__main__":
959    run_tests()
960