xref: /aosp_15_r20/external/pytorch/test/test_jit_autocast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda.amp import autocast
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Workerfrom test_jit import JitTestCase
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
12*da0073e9SAndroid Build Coastguard Workerfrom jit.test_models import MnistNet
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerTEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
17*da0073e9SAndroid Build Coastguard Workerclass TestAutocast(JitTestCase):
18*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
19*da0073e9SAndroid Build Coastguard Worker        # common input tensors
20*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
21*da0073e9SAndroid Build Coastguard Worker            self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
22*da0073e9SAndroid Build Coastguard Worker            self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
23*da0073e9SAndroid Build Coastguard Worker            self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
24*da0073e9SAndroid Build Coastguard Worker            self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
25*da0073e9SAndroid Build Coastguard Worker            self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
26*da0073e9SAndroid Build Coastguard Worker            self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
27*da0073e9SAndroid Build Coastguard Worker            self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
28*da0073e9SAndroid Build Coastguard Worker            self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
29*da0073e9SAndroid Build Coastguard Worker        self.old_value = torch._C._jit_set_autocast_mode(True)
30*da0073e9SAndroid Build Coastguard Worker        super().setUp()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
33*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_autocast_mode(self.old_value)
34*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
37*da0073e9SAndroid Build Coastguard Worker    def test_jit_generic_autocast(self):
38*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
39*da0073e9SAndroid Build Coastguard Worker        def fn_cuda_autocast(a, b):
40*da0073e9SAndroid Build Coastguard Worker            with autocast():
41*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(a, b)
42*da0073e9SAndroid Build Coastguard Worker                y = torch.sum(x)
43*da0073e9SAndroid Build Coastguard Worker                return x, y
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
46*da0073e9SAndroid Build Coastguard Worker        def fn_generic_autocast(a, b):
47*da0073e9SAndroid Build Coastguard Worker            with torch.amp.autocast(device_type='cuda'):
48*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(a, b)
49*da0073e9SAndroid Build Coastguard Worker                y = torch.sum(x)
50*da0073e9SAndroid Build Coastguard Worker                return x, y
51*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
54*da0073e9SAndroid Build Coastguard Worker    def test_minimal(self):
55*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
56*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
57*da0073e9SAndroid Build Coastguard Worker            with autocast():
58*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(a, b)
59*da0073e9SAndroid Build Coastguard Worker                y = torch.sum(x)
60*da0073e9SAndroid Build Coastguard Worker                return x, y
61*da0073e9SAndroid Build Coastguard Worker        x, y = fn(self.a_fp32, self.b_fp32)
62*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.float16)
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, torch.float32)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
66*da0073e9SAndroid Build Coastguard Worker    def test_linear_bf16(self):
67*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
68*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
69*da0073e9SAndroid Build Coastguard Worker            with autocast(dtype=torch.bfloat16):
70*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(a, b)
71*da0073e9SAndroid Build Coastguard Worker                y = torch.sum(x)
72*da0073e9SAndroid Build Coastguard Worker                return x, y
73*da0073e9SAndroid Build Coastguard Worker        x, y = fn(self.a_fp32, self.b_fp32)
74*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.bfloat16)
75*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, torch.float32)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
78*da0073e9SAndroid Build Coastguard Worker    def test_minimal_cpu(self):
79*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
80*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
81*da0073e9SAndroid Build Coastguard Worker            with autocast():
82*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
83*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
87*da0073e9SAndroid Build Coastguard Worker    def test_minimal_off(self):
88*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
89*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
90*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False):
91*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
92*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
93*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
96*da0073e9SAndroid Build Coastguard Worker    def test_runtime_autocast_state(self):
97*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
98*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, use_amp: bool):
99*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=use_amp):
100*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
101*da0073e9SAndroid Build Coastguard Worker        # runtime values for autocast enable argument are not supported
102*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
103*da0073e9SAndroid Build Coastguard Worker            fn(self.a_fp32, self.b_fp32, True)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
106*da0073e9SAndroid Build Coastguard Worker    def test_runtime_autocast_state_expr(self):
107*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
108*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
109*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True if a[0][0] > 0.5 else False):
110*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
111*da0073e9SAndroid Build Coastguard Worker        # runtime values for autocast enable argument are not supported
112*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
113*da0073e9SAndroid Build Coastguard Worker            fn(self.a_fp32, self.b_fp32)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
116*da0073e9SAndroid Build Coastguard Worker    def test_explicit_casts(self):
117*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
118*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
119*da0073e9SAndroid Build Coastguard Worker            with autocast():
120*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a.double(), b.double()).float()
121*da0073e9SAndroid Build Coastguard Worker                f = torch.mm(c, d).double()
122*da0073e9SAndroid Build Coastguard Worker            g = torch.mm(c.double(), f)
123*da0073e9SAndroid Build Coastguard Worker            return e, f, g
124*da0073e9SAndroid Build Coastguard Worker        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float32)
126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float64)
127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.dtype, torch.float64)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    # multiple uses of the same input value
130*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
131*da0073e9SAndroid Build Coastguard Worker    def test_duplicate_inputs(self):
132*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
133*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
134*da0073e9SAndroid Build Coastguard Worker            with autocast():
135*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a, a)
136*da0073e9SAndroid Build Coastguard Worker                f = torch.mm(e, e)
137*da0073e9SAndroid Build Coastguard Worker            return e, f
138*da0073e9SAndroid Build Coastguard Worker        e, f = fn(self.a_fp32, self.b_fp32)
139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float16)
140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float16)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
143*da0073e9SAndroid Build Coastguard Worker    def test_fp32_policy(self):
144*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
145*da0073e9SAndroid Build Coastguard Worker        def fn(a):
146*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
147*da0073e9SAndroid Build Coastguard Worker                return torch.log(a)
148*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp16)
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
152*da0073e9SAndroid Build Coastguard Worker    def test_fp32_policy_with_fp64(self):
153*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
154*da0073e9SAndroid Build Coastguard Worker        def fn(a):
155*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
156*da0073e9SAndroid Build Coastguard Worker                return torch.log(a)
157*da0073e9SAndroid Build Coastguard Worker        # fp32 policy should not narrow fp64 to fp32!
158*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32.double())
159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float64)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
162*da0073e9SAndroid Build Coastguard Worker    def test_promote_policy(self):
163*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
164*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
165*da0073e9SAndroid Build Coastguard Worker            with autocast():
166*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a, b)
167*da0073e9SAndroid Build Coastguard Worker                f = torch.addcmul(e, c, d, value=0.1)
168*da0073e9SAndroid Build Coastguard Worker            return e, f
169*da0073e9SAndroid Build Coastguard Worker        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float16)
171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float32)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
174*da0073e9SAndroid Build Coastguard Worker    def test_promote_policy_fp64(self):
175*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
176*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
177*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
178*da0073e9SAndroid Build Coastguard Worker                return torch.addcmul(a, a, b, value=0.1)
179*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32.double(), self.b_fp32.double())
180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float64)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
183*da0073e9SAndroid Build Coastguard Worker    def test_fp32_set_opt_dtype_policy(self):
184*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
185*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d, dtype: Optional[int]):
186*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
187*da0073e9SAndroid Build Coastguard Worker                x = torch.softmax(a, 0)
188*da0073e9SAndroid Build Coastguard Worker                y = torch.softmax(b, 0, None)
189*da0073e9SAndroid Build Coastguard Worker                z = torch.softmax(c, 0, torch.float64)
190*da0073e9SAndroid Build Coastguard Worker                w = torch.softmax(d, 0, dtype)
191*da0073e9SAndroid Build Coastguard Worker            return x, y, z, w
192*da0073e9SAndroid Build Coastguard Worker        x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.float32)
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, torch.float32)
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.dtype, torch.float64)
196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w.dtype, torch.float16)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
199*da0073e9SAndroid Build Coastguard Worker    def test_fp32_set_opt_dtype_policy_fp64(self):
200*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
201*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d, dtype: Optional[int]):
202*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
203*da0073e9SAndroid Build Coastguard Worker                x = torch.softmax(a, 0)
204*da0073e9SAndroid Build Coastguard Worker                y = torch.softmax(b, 0, None)
205*da0073e9SAndroid Build Coastguard Worker                z = torch.softmax(c, 0, torch.float64)
206*da0073e9SAndroid Build Coastguard Worker                w = torch.softmax(d, 0, dtype)
207*da0073e9SAndroid Build Coastguard Worker            return x, y, z, w
208*da0073e9SAndroid Build Coastguard Worker        x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.float64)
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, torch.float64)
211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.dtype, torch.float64)
212*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w.dtype, torch.float64)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "broken due to lack of type propagation")
215*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
216*da0073e9SAndroid Build Coastguard Worker    def test_control_flow(self):
217*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
218*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
219*da0073e9SAndroid Build Coastguard Worker            with autocast():
220*da0073e9SAndroid Build Coastguard Worker                if a[0][0] > 0.5:
221*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(a, b)
222*da0073e9SAndroid Build Coastguard Worker                    x = 1
223*da0073e9SAndroid Build Coastguard Worker                else:
224*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(c, d)
225*da0073e9SAndroid Build Coastguard Worker                    x = 2
226*da0073e9SAndroid Build Coastguard Worker                f = torch.mm(d, e) * x
227*da0073e9SAndroid Build Coastguard Worker            return e, f
228*da0073e9SAndroid Build Coastguard Worker        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float16)
230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float16)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    # this works find in regular Python, but it creates a delicate
233*da0073e9SAndroid Build Coastguard Worker    # situation in TorchScript where the types are not consistent across
234*da0073e9SAndroid Build Coastguard Worker    # the then/else branches
235*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
236*da0073e9SAndroid Build Coastguard Worker    def test_divergent_types(self):
237*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
238*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
239*da0073e9SAndroid Build Coastguard Worker            with autocast():
240*da0073e9SAndroid Build Coastguard Worker                if a[0][0] > 0.5:
241*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(a, b)
242*da0073e9SAndroid Build Coastguard Worker                    f = torch.mm(a, b).float()
243*da0073e9SAndroid Build Coastguard Worker                else:
244*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(c, d).float()
245*da0073e9SAndroid Build Coastguard Worker                    f = torch.mm(a, b)
246*da0073e9SAndroid Build Coastguard Worker            return torch.mm(e.float(), f.float())
247*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    # another, more complex case of divergent types
251*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
252*da0073e9SAndroid Build Coastguard Worker    def test_divergent_autocast(self):
253*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
254*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
255*da0073e9SAndroid Build Coastguard Worker            autocast_on = autocast(enabled=True)
256*da0073e9SAndroid Build Coastguard Worker            autocast_off = autocast(enabled=False)
257*da0073e9SAndroid Build Coastguard Worker            if a[0][0] > 0.5:
258*da0073e9SAndroid Build Coastguard Worker                with autocast_on:
259*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(a, b)
260*da0073e9SAndroid Build Coastguard Worker            else:
261*da0073e9SAndroid Build Coastguard Worker                with autocast_off:
262*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(c, d)
263*da0073e9SAndroid Build Coastguard Worker            return torch.mm(e, e)
264*da0073e9SAndroid Build Coastguard Worker        fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
267*da0073e9SAndroid Build Coastguard Worker    def test_conditional_autocast(self):
268*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
269*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
270*da0073e9SAndroid Build Coastguard Worker            autocast_on = autocast(enabled=True)
271*da0073e9SAndroid Build Coastguard Worker            autocast_off = autocast(enabled=False)
272*da0073e9SAndroid Build Coastguard Worker            with autocast_on if a[0][0] > 0.5 else autocast_off:
273*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
274*da0073e9SAndroid Build Coastguard Worker        # conditional autocast expressions are not supported
275*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
276*da0073e9SAndroid Build Coastguard Worker            fn(self.a_fp32, self.b_fp32)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
279*da0073e9SAndroid Build Coastguard Worker    def test_nested_autocast(self):
280*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
281*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
282*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False):
283*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a, b)
284*da0073e9SAndroid Build Coastguard Worker                with autocast(enabled=True):
285*da0073e9SAndroid Build Coastguard Worker                    f = torch.mm(e, c)
286*da0073e9SAndroid Build Coastguard Worker                    with autocast(enabled=False):
287*da0073e9SAndroid Build Coastguard Worker                        g = torch.mm(e, d)
288*da0073e9SAndroid Build Coastguard Worker            return e, f, g
289*da0073e9SAndroid Build Coastguard Worker        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float32)
291*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float16)
292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.dtype, torch.float32)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
295*da0073e9SAndroid Build Coastguard Worker    def test_implicitly_nested_autocast(self):
296*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
297*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
298*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False), autocast(enabled=True):
299*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
300*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
304*da0073e9SAndroid Build Coastguard Worker    def test_reused_autocast(self):
305*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
306*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
307*da0073e9SAndroid Build Coastguard Worker            autocast_instance = autocast(enabled=True)
308*da0073e9SAndroid Build Coastguard Worker            with autocast_instance:
309*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a, b)
310*da0073e9SAndroid Build Coastguard Worker                with autocast_instance:
311*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(c, d)
312*da0073e9SAndroid Build Coastguard Worker                    f = torch.mm(d, e)
313*da0073e9SAndroid Build Coastguard Worker            g = torch.mm(e, f)
314*da0073e9SAndroid Build Coastguard Worker            return e, f, g
315*da0073e9SAndroid Build Coastguard Worker        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float16)
317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float16)
318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.dtype, torch.float16)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    # TODO: fix and enable this test?
321*da0073e9SAndroid Build Coastguard Worker    #   (we could technically fix this, but is it really worth it?)
322*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "unsuported autocast syntax")
323*da0073e9SAndroid Build Coastguard Worker    def test_reused_autocast_expr(self):
324*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
325*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
326*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True) as autocast_instance:
327*da0073e9SAndroid Build Coastguard Worker                e = torch.mm(a, b)
328*da0073e9SAndroid Build Coastguard Worker                with autocast_instance:
329*da0073e9SAndroid Build Coastguard Worker                    e = torch.mm(c, d)
330*da0073e9SAndroid Build Coastguard Worker                    f = torch.mm(d, e)
331*da0073e9SAndroid Build Coastguard Worker            g = torch.mm(e, f)
332*da0073e9SAndroid Build Coastguard Worker            return e, f, g
333*da0073e9SAndroid Build Coastguard Worker        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.dtype, torch.float16)
335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.dtype, torch.float16)
336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.dtype, torch.float16)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
339*da0073e9SAndroid Build Coastguard Worker    def test_callees(self):
340*da0073e9SAndroid Build Coastguard Worker        def helper(a, b):
341*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
344*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
345*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
346*da0073e9SAndroid Build Coastguard Worker                tmp = helper(a, b)
347*da0073e9SAndroid Build Coastguard Worker                tmp = helper(tmp, tmp)
348*da0073e9SAndroid Build Coastguard Worker                tmp = helper(tmp, tmp)
349*da0073e9SAndroid Build Coastguard Worker                tmp = helper(tmp, tmp)
350*da0073e9SAndroid Build Coastguard Worker                return helper(tmp, b)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
356*da0073e9SAndroid Build Coastguard Worker    def test_callees_with_autocast_on(self):
357*da0073e9SAndroid Build Coastguard Worker        def helper(a, b):
358*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
359*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
362*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
363*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False):
364*da0073e9SAndroid Build Coastguard Worker                return helper(a, b)
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
370*da0073e9SAndroid Build Coastguard Worker    def test_callees_with_autocast_off(self):
371*da0073e9SAndroid Build Coastguard Worker        def helper(a, b):
372*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False):
373*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
376*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
377*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
378*da0073e9SAndroid Build Coastguard Worker                return helper(a, b)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    # scripting inside eager autocast
384*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
385*da0073e9SAndroid Build Coastguard Worker    def test_eager_and_script(self):
386*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
387*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
388*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
389*da0073e9SAndroid Build Coastguard Worker        for i in range(8):
390*da0073e9SAndroid Build Coastguard Worker            use_autocast = (i % 2 == 0)
391*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch.float16 if use_autocast else torch.float32
392*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=use_autocast):
393*da0073e9SAndroid Build Coastguard Worker                result = fn(self.a_fp32, self.b_fp32)
394*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    # traced inside scripting
397*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
398*da0073e9SAndroid Build Coastguard Worker    def test_script_and_tracing(self):
399*da0073e9SAndroid Build Coastguard Worker        def helper(a, b):
400*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
405*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
406*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
407*da0073e9SAndroid Build Coastguard Worker                return traced(a, b)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker    # traced with autocast inside scripting
413*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
414*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
415*da0073e9SAndroid Build Coastguard Worker    def test_script_and_tracing_with_autocast(self):
416*da0073e9SAndroid Build Coastguard Worker        def helper(a, b):
417*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=False):
418*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b) * 2.0
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
423*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
424*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
425*da0073e9SAndroid Build Coastguard Worker                return traced(a, b)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
428*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float32)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker    # scripted called from traced
431*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
432*da0073e9SAndroid Build Coastguard Worker    def test_tracing_and_script(self):
433*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
434*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
435*da0073e9SAndroid Build Coastguard Worker            with autocast():
436*da0073e9SAndroid Build Coastguard Worker                return torch.mm(a, b)
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        def traced(a, b):
439*da0073e9SAndroid Build Coastguard Worker            return fn(a, b)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
442*da0073e9SAndroid Build Coastguard Worker        result = traced(self.a_fp32, self.b_fp32)
443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    # scripted called from traced with autocast
446*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
447*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
448*da0073e9SAndroid Build Coastguard Worker    def test_tracing_with_autocast_and_script(self):
449*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
450*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
451*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        def traced(a, b):
454*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
455*da0073e9SAndroid Build Coastguard Worker                return fn(a, b)
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
458*da0073e9SAndroid Build Coastguard Worker        result = traced(self.a_fp32, self.b_fp32)
459*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
462*da0073e9SAndroid Build Coastguard Worker    def test_script_module(self):
463*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
464*da0073e9SAndroid Build Coastguard Worker            def __init__(self, N, M):
465*da0073e9SAndroid Build Coastguard Worker                super().__init__()
466*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
467*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(N, M).float()
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
470*da0073e9SAndroid Build Coastguard Worker                with autocast(enabled=True):
471*da0073e9SAndroid Build Coastguard Worker                    output = self.weight.mv(input)
472*da0073e9SAndroid Build Coastguard Worker                    output = self.linear(output)
473*da0073e9SAndroid Build Coastguard Worker                    return output
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
476*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, dtype=torch.float32, device='cuda')
477*da0073e9SAndroid Build Coastguard Worker        result = scripted_module(input)
478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "autocast decorators not supported")
481*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
482*da0073e9SAndroid Build Coastguard Worker    def test_autocast_decorator(self):
483*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
484*da0073e9SAndroid Build Coastguard Worker        @autocast(enabled=True)
485*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
486*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
487*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    # this is equivalent to running scripted functions inside autocast)
491*da0073e9SAndroid Build Coastguard Worker    # (see also test_eager_and_script)
492*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
493*da0073e9SAndroid Build Coastguard Worker    def test_autocast_decorator_outside_jit(self):
494*da0073e9SAndroid Build Coastguard Worker        @autocast(enabled=True)
495*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
496*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
497*da0073e9SAndroid Build Coastguard Worker            return torch.mm(a, b)
498*da0073e9SAndroid Build Coastguard Worker        result = fn(self.a_fp32, self.b_fp32)
499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, torch.float16)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
502*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self):
503*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
504*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
505*da0073e9SAndroid Build Coastguard Worker            with autocast(enabled=True):
506*da0073e9SAndroid Build Coastguard Worker                x = torch.addmm(a, b, c)
507*da0073e9SAndroid Build Coastguard Worker                y = torch.addmm(a, b, c, out=a)
508*da0073e9SAndroid Build Coastguard Worker                z = a.addmm_(b, c)
509*da0073e9SAndroid Build Coastguard Worker                return x, y, z
510*da0073e9SAndroid Build Coastguard Worker        x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.float16)
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, torch.float32)
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.dtype, torch.float32)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def _test_autocast(self, func, cast_op, *args):
516*da0073e9SAndroid Build Coastguard Worker        jit_func = torch.jit.script(func)
517*da0073e9SAndroid Build Coastguard Worker        o = func(*args)
518*da0073e9SAndroid Build Coastguard Worker        jit_o = jit_func(*args)
519*da0073e9SAndroid Build Coastguard Worker        if cast_op is not None:
520*da0073e9SAndroid Build Coastguard Worker            FileCheck().check(cast_op).run(jit_func.graph_for(*args))
521*da0073e9SAndroid Build Coastguard Worker        for o0, o1 in zip(o, jit_o):
522*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o0.dtype, o1.dtype)
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
525*da0073e9SAndroid Build Coastguard Worker    def test_autocast_api(self):
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        def t_autocast_cpu(x, y):
528*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cpu", dtype=torch.bfloat16):
529*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        def t_autocast_cuda(x, y):
532*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cuda", dtype=torch.half):
533*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        def t_cuda_amp_autocast(x, y):
536*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.amp.autocast():
537*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        def t_cpu_amp_autocast(x, y):
540*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast():
541*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
544*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
545*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
546*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
547*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
548*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "we need to provide dtype argument at this moment")
551*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
552*da0073e9SAndroid Build Coastguard Worker    def test_autocast_api_not_supported(self):
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        def t_autocast_cpu(x, y):
555*da0073e9SAndroid Build Coastguard Worker            # no dtype provided is not currently supported
556*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cpu"):
557*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        def t_autocast_cuda(x, y):
560*da0073e9SAndroid Build Coastguard Worker            # no dtype provided is not currently supported
561*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cuda"):
562*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
565*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
566*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
567*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
570*da0073e9SAndroid Build Coastguard Worker    def test_autocast_mixed_dtypes(self):
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        def t(cpu0, cpu1, cuda0, cuda1):
573*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cpu", torch.bfloat16):
574*da0073e9SAndroid Build Coastguard Worker                with torch.autocast("cuda", torch.float16):
575*da0073e9SAndroid Build Coastguard Worker                    cpu_o = torch.mm(cpu0, cpu1)
576*da0073e9SAndroid Build Coastguard Worker                    cuda_o = torch.mm(cuda0, cuda1)
577*da0073e9SAndroid Build Coastguard Worker                    return cpu_o, cuda_o
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        jit_t = torch.jit.script(t)
580*da0073e9SAndroid Build Coastguard Worker        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
581*da0073e9SAndroid Build Coastguard Worker        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
582*da0073e9SAndroid Build Coastguard Worker        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
583*da0073e9SAndroid Build Coastguard Worker        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
584*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
587*da0073e9SAndroid Build Coastguard Worker    def test_jit_executor_under_autocast(self):
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        def t(cpu0, cpu1, cuda0, cuda1):
590*da0073e9SAndroid Build Coastguard Worker            cpu_o = torch.mm(cpu0, cpu1)
591*da0073e9SAndroid Build Coastguard Worker            cuda_o = torch.mm(cuda0, cuda1)
592*da0073e9SAndroid Build Coastguard Worker            return cpu_o, cuda_o
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker        jit_t = torch.jit.script(t)
595*da0073e9SAndroid Build Coastguard Worker        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
596*da0073e9SAndroid Build Coastguard Worker        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
597*da0073e9SAndroid Build Coastguard Worker        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
598*da0073e9SAndroid Build Coastguard Worker        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cpu", torch.bfloat16):
601*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cuda", torch.float16):
602*da0073e9SAndroid Build Coastguard Worker                self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cpu", torch.bfloat16):
605*da0073e9SAndroid Build Coastguard Worker            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda", torch.float16):
608*da0073e9SAndroid Build Coastguard Worker            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        # no cast op should be observed when executing outside autocast context
611*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
614*da0073e9SAndroid Build Coastguard Worker    def test_autocast_autodiff(self):
615*da0073e9SAndroid Build Coastguard Worker        def t(t0, t1):
616*da0073e9SAndroid Build Coastguard Worker            o = torch.mm(t0, t1)
617*da0073e9SAndroid Build Coastguard Worker            return o.relu()
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        jit_t = torch.jit.script(t)
620*da0073e9SAndroid Build Coastguard Worker        t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
621*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        # run optimization
624*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
625*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cuda", torch.float16):
626*da0073e9SAndroid Build Coastguard Worker                jit_o = jit_t(t0, t1)
627*da0073e9SAndroid Build Coastguard Worker            jit_o.sum().backward()
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        t0.grad = None
630*da0073e9SAndroid Build Coastguard Worker        t1.grad = None
631*da0073e9SAndroid Build Coastguard Worker        ref_t0 = t0.detach().requires_grad_()
632*da0073e9SAndroid Build Coastguard Worker        ref_t1 = t1.detach().requires_grad_()
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda", torch.float16):
635*da0073e9SAndroid Build Coastguard Worker            o = t(ref_t0, ref_t1)
636*da0073e9SAndroid Build Coastguard Worker            jit_o = jit_t(t0, t1)
637*da0073e9SAndroid Build Coastguard Worker        jit_o.sum().backward()
638*da0073e9SAndroid Build Coastguard Worker        o.sum().backward()
639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o, jit_o)
640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t0.grad, ref_t0.grad)
641*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.grad, ref_t1.grad)
642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o.dtype, jit_o.dtype)
643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
647*da0073e9SAndroid Build Coastguard Worker    def test_jit_call_method_under_autocast(self):
648*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
649*da0073e9SAndroid Build Coastguard Worker        class Iface(torch.nn.Module):
650*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y) -> torch.Tensor:
651*da0073e9SAndroid Build Coastguard Worker                pass
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        class Impl(Iface):
654*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
655*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, y)
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        class Thing1(torch.nn.Module):
658*da0073e9SAndroid Build Coastguard Worker            impl: Iface
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
661*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.amp.autocast():
662*da0073e9SAndroid Build Coastguard Worker                    a = torch.mm(x, y)
663*da0073e9SAndroid Build Coastguard Worker                    b = self.impl.forward(a, x)
664*da0073e9SAndroid Build Coastguard Worker                    return b
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        scripted_impl = torch.jit.script(Impl())
667*da0073e9SAndroid Build Coastguard Worker        thing1 = Thing1()
668*da0073e9SAndroid Build Coastguard Worker        thing1.impl = scripted_impl
669*da0073e9SAndroid Build Coastguard Worker        scripted_thing1 = torch.jit.script(thing1)
670*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2])
671*da0073e9SAndroid Build Coastguard Worker        y = torch.rand([2, 2])
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker        # make sure this doesn't throw an error
674*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.amp.autocast():
675*da0073e9SAndroid Build Coastguard Worker            ans = scripted_thing1.forward(x, y)
676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        # sanity check: this isn't supported currently when global autocasting
679*da0073e9SAndroid Build Coastguard Worker        # isn't enabled
680*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
683*da0073e9SAndroid Build Coastguard Worker    def test_jit_freeze_autocast_basic(self):
684*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
685*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
686*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.amp.autocast():
687*da0073e9SAndroid Build Coastguard Worker                    return torch.mm(x, y)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 4), dtype=torch.float).cuda()
690*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 5), dtype=torch.float).cuda()
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        mod = TestModule().eval()
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        # sanity check
695*da0073e9SAndroid Build Coastguard Worker        self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
698*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker        # make sure that the runtime pass doesn't duplicate autocast nodes
701*da0073e9SAndroid Build Coastguard Worker        frozen_mod(x, y)
702*da0073e9SAndroid Build Coastguard Worker        optimized_graph = frozen_mod.graph_for(x, y)
703*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
706*da0073e9SAndroid Build Coastguard Worker    def test_jit_freeze_autocast_constants(self):
707*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
708*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
709*da0073e9SAndroid Build Coastguard Worker                super().__init__()
710*da0073e9SAndroid Build Coastguard Worker                self.x = torch.rand((3, 4), dtype=torch.float).cuda()
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            def forward(self, y):
713*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.amp.autocast():
714*da0073e9SAndroid Build Coastguard Worker                    return torch.mm(self.x, y)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 5), dtype=torch.float).cuda()
717*da0073e9SAndroid Build Coastguard Worker        mod = TestModule().eval()
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
720*da0073e9SAndroid Build Coastguard Worker        # freezing should pre-cast the constant self.x to remove one autocast call
721*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        # the runtime autocasting pass will re-insert the second autocast call,
724*da0073e9SAndroid Build Coastguard Worker        # but constant propagation will merge it with the constant that it's casting.
725*da0073e9SAndroid Build Coastguard Worker        frozen_mod(y)
726*da0073e9SAndroid Build Coastguard Worker        optimized_graph = frozen_mod.graph_for(y)
727*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_CUDA, "CPU-only test")
730*da0073e9SAndroid Build Coastguard Worker    def test_jit_autocast_softmax_cpu(self):
731*da0073e9SAndroid Build Coastguard Worker        def fn(x):
732*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast():
733*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.softmax(x, dim=0)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
736*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2), dtype=torch.bfloat16)
737*da0073e9SAndroid Build Coastguard Worker        fn_s(x)
738*da0073e9SAndroid Build Coastguard Worker        y = fn_s(x)
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.dtype == torch.bfloat16)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
743*da0073e9SAndroid Build Coastguard Worker    def test_jit_autocast_softmax_gpu(self):
744*da0073e9SAndroid Build Coastguard Worker        def fn(x):
745*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.amp.autocast():
746*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.softmax(x, dim=0)
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
749*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2), dtype=torch.half).cuda()
750*da0073e9SAndroid Build Coastguard Worker        fn_s(x)
751*da0073e9SAndroid Build Coastguard Worker        y = fn_s(x)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.dtype == torch.float)
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker    def test_ignore_amp(self):
756*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
757*da0073e9SAndroid Build Coastguard Worker        def foo(x):
758*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x, x)
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([10, 10], dtype=torch.float)
761*da0073e9SAndroid Build Coastguard Worker        foo._set_ignore_amp(True)
762*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast():
763*da0073e9SAndroid Build Coastguard Worker            foo(inp)
764*da0073e9SAndroid Build Coastguard Worker            foo(inp)
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
767*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("_autocast_to_reduced").run(g)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Workerclass convbn(torch.nn.Module):
770*da0073e9SAndroid Build Coastguard Worker    def __init__(self, bias_enabled=True):
771*da0073e9SAndroid Build Coastguard Worker        super().__init__()
772*da0073e9SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
773*da0073e9SAndroid Build Coastguard Worker        self.bn = torch.nn.BatchNorm2d(64)
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
776*da0073e9SAndroid Build Coastguard Worker        return self.bn(self.conv(x))
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
779*da0073e9SAndroid Build Coastguard Workerclass TestJitTraceAutocast(JitTestCase):
780*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
781*da0073e9SAndroid Build Coastguard Worker        super().setUp()
782*da0073e9SAndroid Build Coastguard Worker        self.previous_default_dtype = torch.get_default_dtype()
783*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(torch.float32)
784*da0073e9SAndroid Build Coastguard Worker        self.models = [MnistNet(),
785*da0073e9SAndroid Build Coastguard Worker                       convbn(bias_enabled=True),
786*da0073e9SAndroid Build Coastguard Worker                       convbn(bias_enabled=False)]
787*da0073e9SAndroid Build Coastguard Worker        self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
788*da0073e9SAndroid Build Coastguard Worker                       torch.randn(32, 3, 224, 224, device='cpu'),
789*da0073e9SAndroid Build Coastguard Worker                       torch.randn(32, 3, 224, 224, device='cpu')]
790*da0073e9SAndroid Build Coastguard Worker        self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
793*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
794*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(self.previous_default_dtype)
795*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker    def test_generate_autocast_jit_trace_model(self):
798*da0073e9SAndroid Build Coastguard Worker        def test_generate_autocast_jit_trace_model(model, x):
799*da0073e9SAndroid Build Coastguard Worker            model.eval()
800*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
801*da0073e9SAndroid Build Coastguard Worker                traced_model = torch.jit.trace(model, x)
802*da0073e9SAndroid Build Coastguard Worker            traced_model = torch.jit.freeze(traced_model)
803*da0073e9SAndroid Build Coastguard Worker        for i in range(self.models.__len__()):
804*da0073e9SAndroid Build Coastguard Worker            test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    def test_nchw_autocast_jit_trace_model(self):
807*da0073e9SAndroid Build Coastguard Worker        def test_nchw_autocast_jit_trace_model(model, x):
808*da0073e9SAndroid Build Coastguard Worker            model.eval()
809*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
810*da0073e9SAndroid Build Coastguard Worker                traced_model = torch.jit.trace(model, x)
811*da0073e9SAndroid Build Coastguard Worker            traced_model = torch.jit.freeze(traced_model)
812*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
813*da0073e9SAndroid Build Coastguard Worker                y = traced_model(x.clone())
814*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(), torch.no_grad():
815*da0073e9SAndroid Build Coastguard Worker                y2 = model(x.clone())
816*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
817*da0073e9SAndroid Build Coastguard Worker        for i in range(self.models.__len__()):
818*da0073e9SAndroid Build Coastguard Worker            test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    def test_nhwc_autocast_jit_trace_model(self):
821*da0073e9SAndroid Build Coastguard Worker        def test_nhwc_autocast_jit_trace_model(model, x):
822*da0073e9SAndroid Build Coastguard Worker            model = model.to(memory_format=torch.channels_last)
823*da0073e9SAndroid Build Coastguard Worker            model.eval()
824*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
825*da0073e9SAndroid Build Coastguard Worker                traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
826*da0073e9SAndroid Build Coastguard Worker            traced_model = torch.jit.freeze(traced_model)
827*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
828*da0073e9SAndroid Build Coastguard Worker                y = traced_model(x.clone().to(memory_format=torch.channels_last))
829*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(), torch.no_grad():
830*da0073e9SAndroid Build Coastguard Worker                y2 = model(x.clone().to(memory_format=torch.channels_last))
831*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
832*da0073e9SAndroid Build Coastguard Worker        for i in range(self.models.__len__()):
833*da0073e9SAndroid Build Coastguard Worker            if self.inputs[i].size().__len__() == 5:
834*da0073e9SAndroid Build Coastguard Worker                # NHWC 3D case not support yet
835*da0073e9SAndroid Build Coastguard Worker                continue
836*da0073e9SAndroid Build Coastguard Worker            test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker    def test_cat_promote(self):
839*da0073e9SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
840*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
841*da0073e9SAndroid Build Coastguard Worker                return torch.cat([a, b], 0)
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        with torch.jit.fuser("none"):
844*da0073e9SAndroid Build Coastguard Worker            # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
845*da0073e9SAndroid Build Coastguard Worker            # To avoid the fusion group from TE, we will disable the fuser here.
846*da0073e9SAndroid Build Coastguard Worker            for jit_freeze_or_not in [False, True]:
847*da0073e9SAndroid Build Coastguard Worker                test_model = TestModel().eval()
848*da0073e9SAndroid Build Coastguard Worker                with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
849*da0073e9SAndroid Build Coastguard Worker                    a = torch.rand(24, 128, 128)
850*da0073e9SAndroid Build Coastguard Worker                    b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
851*da0073e9SAndroid Build Coastguard Worker                    c = test_model(a, b)
852*da0073e9SAndroid Build Coastguard Worker                    traced = torch.jit.trace(test_model, (a, b))
853*da0073e9SAndroid Build Coastguard Worker                if jit_freeze_or_not:
854*da0073e9SAndroid Build Coastguard Worker                    traced = torch.jit.freeze(traced)
855*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
856*da0073e9SAndroid Build Coastguard Worker                    c2 = traced(a, b)
857*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(c.dtype, torch.float32)
858*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(c2.dtype, torch.float32)
859*da0073e9SAndroid Build Coastguard Worker                traced_graph = traced.graph_for(a, b)
860*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    def test_script_autocast_cpu(self):
863*da0073e9SAndroid Build Coastguard Worker        def fn(x):
864*da0073e9SAndroid Build Coastguard Worker            if torch.is_autocast_cpu_enabled():
865*da0073e9SAndroid Build Coastguard Worker                return x.relu()
866*da0073e9SAndroid Build Coastguard Worker            else:
867*da0073e9SAndroid Build Coastguard Worker                return x.sin()
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4)) - 0.5
872*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast():
873*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn_s(x), fn(x))
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast(enabled=True):
876*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn_s(x), fn(x))
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "No cuda")
881*da0073e9SAndroid Build Coastguard Worker    def test_script_autocast_cuda(self):
882*da0073e9SAndroid Build Coastguard Worker        def fn(x):
883*da0073e9SAndroid Build Coastguard Worker            if torch.is_autocast_enabled():
884*da0073e9SAndroid Build Coastguard Worker                return x.relu()
885*da0073e9SAndroid Build Coastguard Worker            else:
886*da0073e9SAndroid Build Coastguard Worker                return x.sin()
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4)) - 0.5
891*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast():
892*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn_s(x), fn(x))
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.amp.autocast(enabled=True):
895*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn_s(x), fn(x))
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    def test_scripted_aliasing(self):
901*da0073e9SAndroid Build Coastguard Worker        # torch.is_autocast_enabled should not be able to move inside of the autocast context.
902*da0073e9SAndroid Build Coastguard Worker        def fn(x):
903*da0073e9SAndroid Build Coastguard Worker            if torch.is_autocast_enabled():
904*da0073e9SAndroid Build Coastguard Worker                y = True
905*da0073e9SAndroid Build Coastguard Worker            else:
906*da0073e9SAndroid Build Coastguard Worker                y = False
907*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.amp.autocast(enabled=True):
908*da0073e9SAndroid Build Coastguard Worker                z = x.relu()
909*da0073e9SAndroid Build Coastguard Worker            return y, z
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
912*da0073e9SAndroid Build Coastguard Worker        graph = fn_s.graph
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker        aliasdb = graph.alias_db()
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker        is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
917*da0073e9SAndroid Build Coastguard Worker        enter_nodes = graph.findAllNodes("prim::Enter")
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(is_enabled_nodes), 1)
920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(enter_nodes), 1)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker    def test_script_autocast_enable_and_check(self):
926*da0073e9SAndroid Build Coastguard Worker        def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
927*da0073e9SAndroid Build Coastguard Worker            b1 = torch.is_autocast_cpu_enabled()
928*da0073e9SAndroid Build Coastguard Worker            v1 = torch.mm(x, y)
929*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast(enabled=True):
930*da0073e9SAndroid Build Coastguard Worker                b2 = torch.is_autocast_cpu_enabled()
931*da0073e9SAndroid Build Coastguard Worker                v2 = torch.mm(x, y)
932*da0073e9SAndroid Build Coastguard Worker                with torch.cpu.amp.autocast(enabled=False):
933*da0073e9SAndroid Build Coastguard Worker                    b3 = torch.is_autocast_cpu_enabled()
934*da0073e9SAndroid Build Coastguard Worker                    v3 = torch.mm(x, y)
935*da0073e9SAndroid Build Coastguard Worker            return (v1, b1, v2, b2, v3, b3)
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
938*da0073e9SAndroid Build Coastguard Worker        def check_fn_results(arr):
939*da0073e9SAndroid Build Coastguard Worker            [v1, b1, v2, b2, v3, b3] = arr
940*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((v1.dtype == torch.float) != b1)
941*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((v2.dtype == torch.float) != b2)
942*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((v3.dtype == torch.float) != b3)
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2), dtype=torch.float)
945*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((2, 2), dtype=torch.float)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast(enabled=False):
950*da0073e9SAndroid Build Coastguard Worker            check_fn_results(fn(x, y))
951*da0073e9SAndroid Build Coastguard Worker            check_fn_results(fn_s(x, y))
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        with torch.cpu.amp.autocast(enabled=True):
954*da0073e9SAndroid Build Coastguard Worker            check_fn_results(fn(x, y))
955*da0073e9SAndroid Build Coastguard Worker            check_fn_results(fn_s(x, y))
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
959*da0073e9SAndroid Build Coastguard Worker    run_tests()
960