xref: /aosp_15_r20/external/pytorch/test/test_autocast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import unittest
4
5import torch
6from torch.testing._internal.autocast_test_lists import (
7    AutocastCPUTestLists,
8    TestAutocast,
9)
10from torch.testing._internal.common_utils import (
11    IS_WINDOWS,
12    run_tests,
13    skipIfTorchDynamo,
14    TestCase,
15)
16from torch.utils._python_dispatch import TorchDispatchMode
17
18
19class TestAutocastCPU(TestAutocast):
20    def setUp(self):
21        super().setUp()
22        self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
23
24    def tearDown(self):
25        del self.autocast_lists
26        super().tearDown()
27
28    @skipIfTorchDynamo()
29    def test_autocast_torch_expect_builtin_promote(self):
30        for (
31            op,
32            args1,
33            args2,
34            out_type,
35        ) in self.autocast_lists.torch_expect_builtin_promote:
36            self._run_autocast_outofplace(
37                op, args1, torch.float32, device="cpu", out_type=out_type
38            )
39            self._run_autocast_outofplace(
40                op,
41                args2,
42                torch.float32,
43                device="cpu",
44                out_type=out_type,
45                amp_dtype=torch.float16,
46            )
47
48    @skipIfTorchDynamo()
49    def test_autocast_methods_expect_builtin_promote(self):
50        for (
51            op,
52            args1,
53            args2,
54            out_type,
55        ) in self.autocast_lists.methods_expect_builtin_promote:
56            self._run_autocast_outofplace(
57                op, args1, torch.float32, device="cpu", module=None, out_type=out_type
58            )
59            self._run_autocast_outofplace(
60                op,
61                args2,
62                torch.float32,
63                device="cpu",
64                module=None,
65                out_type=out_type,
66                amp_dtype=torch.float16,
67            )
68
69    @skipIfTorchDynamo()
70    def test_autocast_torch_16(self):
71        for op_with_args in self.autocast_lists.torch_16:
72            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
73            self._run_autocast_outofplace(
74                op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
75            )
76            self._run_autocast_outofplace(
77                op,
78                args,
79                torch.float16,
80                device="cpu",
81                add_kwargs=maybe_kwargs,
82                amp_dtype=torch.float16,
83            )
84
85    @skipIfTorchDynamo()
86    def test_autocast_nn_16(self):
87        for op_with_args in self.autocast_lists.nn_16:
88            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
89            self._run_autocast_outofplace(
90                op,
91                args,
92                torch.bfloat16,
93                device="cpu",
94                module=torch._C._nn,
95                add_kwargs=maybe_kwargs,
96            )
97            self._run_autocast_outofplace(
98                op,
99                args,
100                torch.float16,
101                device="cpu",
102                module=torch._C._nn,
103                add_kwargs=maybe_kwargs,
104                amp_dtype=torch.float16,
105            )
106
107    @skipIfTorchDynamo()
108    def test_autocast_torch_fp32(self):
109        for op_with_args in self.autocast_lists.torch_fp32:
110            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
111            self._run_autocast_outofplace(
112                op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
113            )
114            self._run_autocast_outofplace(
115                op,
116                args,
117                torch.float32,
118                device="cpu",
119                add_kwargs=maybe_kwargs,
120                amp_dtype=torch.float16,
121            )
122
123    @skipIfTorchDynamo()
124    def test_autocast_nn_fp32(self):
125        for op_with_args in self.autocast_lists.nn_fp32:
126            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
127            self._run_autocast_outofplace(
128                op,
129                args,
130                torch.float32,
131                device="cpu",
132                module=torch._C._nn,
133                add_kwargs=maybe_kwargs,
134            )
135            self._run_autocast_outofplace(
136                op,
137                args,
138                torch.float32,
139                device="cpu",
140                module=torch._C._nn,
141                add_kwargs=maybe_kwargs,
142                amp_dtype=torch.float16,
143            )
144
145    @skipIfTorchDynamo()
146    def test_autocast_torch_need_autocast_promote(self):
147        for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
148            self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
149            self._run_autocast_outofplace(
150                op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
151            )
152
153    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
154    def test_autocast_rnn(self):
155        if (
156            torch.backends.mkldnn.is_available()
157            and torch.ops.mkldnn._is_mkldnn_bf16_supported()
158        ):
159            x = torch.randn(1, 2, 1)
160            hx = torch.randn(2, 2, 1)
161            cx = torch.randn(2, 2, 1)
162
163            m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
164
165            # Raise ValueError when autocast is not enabled
166            with self.assertRaisesRegex(ValueError, "input must have the type"):
167                m(x, (hx, cx))
168
169            # Should be able to run the below case with autocast
170            with torch.amp.autocast(device_type="cpu"):
171                m(x, (hx, cx))
172
173    def test_autocast_disabled_with_fp32_dtype(self):
174        with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
175            _ = torch.ones(10)
176
177    def test_generic_autocast(self):
178        for op_with_args in self.autocast_lists.torch_16:
179            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
180            with torch.amp.autocast(device_type="cpu"):
181                generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
182            with torch.amp.autocast(device_type="cpu"):
183                cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
184            self.assertEqual(generic_autocast_output, cpu_autocast_output)
185
186    def test_cpu_autocast_deprecated_warning(self):
187        with self.assertWarnsRegex(
188            FutureWarning,
189            r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.",
190        ):
191            with torch.cpu.amp.autocast():
192                _ = torch.ones(10)
193
194
195class CustomLinear(torch.autograd.Function):
196    @staticmethod
197    def forward(ctx, x, w_t):
198        ctx.save_for_backward(x, w_t)
199        return torch.nn.functional.linear(x, w_t)
200
201    @staticmethod
202    def backward(ctx, grad_output):
203        x, w_t = ctx.saved_tensors
204        with torch.autocast(device_type="cuda"):
205            dL_dX = torch.matmul(grad_output, w_t)
206            dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
207        return dL_dX, dL_dW
208
209
210class WeightDTypeCastCounterMode(TorchDispatchMode):
211    def __init__(self, weight):
212        super().__init__()
213        self.dtype_cast_counter = 0
214        self.weight = weight
215
216    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
217        if (
218            func is torch.ops.aten._to_copy.default
219            and args[0] is self.weight
220            and kwargs["dtype"] is torch.float16
221        ):
222            self.dtype_cast_counter += 1
223        return func(*args, **kwargs)
224
225    def __enter__(self):
226        self.old_clear_cache = torch.clear_autocast_cache
227        torch.clear_autocast_cache = lambda: None
228        return super().__enter__()
229
230    def __exit__(self, exc_type, exc_val, exc_tb):
231        torch.clear_autocast_cache = self.old_clear_cache
232        return super().__exit__(exc_type, exc_val, exc_tb)
233
234
235@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
236class TestAutocastGPU(TestCase):
237    def test_cast_cache_is_global(self):
238        """
239        Verifies that the autocast cache is global. This is done by
240        mocking out cache clearing at the end of the forward pass,
241        running forward+backward with an explicit call to autocast in the
242        backward, and verifying that the weight only get cast to float16 once.
243        """
244
245        data = torch.randn(2, 3).cuda()
246        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
247
248        with WeightDTypeCastCounterMode(weight) as mode:
249            with torch.autocast(device_type="cuda"):
250                output = CustomLinear.apply(data, weight)
251                s = output.sum()
252            s.backward()
253
254        self.assertEqual(mode.dtype_cast_counter, 1)
255
256    def test_cache_disabled(self):
257        data = torch.randn(2, 3).cuda()
258        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
259
260        try:
261            torch._C._set_cached_tensors_enabled(True)
262            torch._C._add_cached_tensor(weight)
263
264            with WeightDTypeCastCounterMode(weight) as mode:
265                with torch.autocast(device_type="cuda"):
266                    output = CustomLinear.apply(data, weight)
267                    s = output.sum()
268                s.backward()
269
270            # we should not have cached the conversion of the weight
271            self.assertEqual(mode.dtype_cast_counter, 2)
272
273        finally:
274            torch._C._set_cached_tensors_enabled(False)
275
276    # index_put under AMP follows a cast policy called "promote",
277    # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230
278    # That means:
279    #   (1) double precision is ignored,
280    #   (2) if any argument is float, then all arguments are promoted to float,
281    #   (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype.
282    # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution,
283    # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward
284    # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied.
285    # For more info see https://github.com/pytorch/pytorch/issues/132715.
286    def test_autocast_prioritize(self):
287        device = "cuda"
288        dtype = torch.bfloat16
289
290        with torch.autocast(device_type=device, enabled=True, dtype=dtype):
291            t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True)
292            index = torch.randint(
293                low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device
294            )
295            val = torch.randn(1, dtype=dtype, device=device)
296
297            res = torch.index_put(t, [index], val)
298
299            loss = res.mean()
300            loss.backward()
301
302
303@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
304class TestAutocastMPS(TestCase):
305    def test_cast_cache_is_global(self):
306        class CustomLinear(torch.autograd.Function):
307            @staticmethod
308            def forward(ctx, x, w_t):
309                ctx.save_for_backward(x, w_t)
310                return torch.nn.functional.linear(x, w_t)
311
312            @staticmethod
313            def backward(ctx, grad_output):
314                x, w_t = ctx.saved_tensors
315                with torch.autocast(device_type="mps"):
316                    dL_dX = torch.matmul(grad_output, w_t)
317                    dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
318                return dL_dX, dL_dW
319
320        data = torch.randn(2, 3).to("mps")
321        weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
322        weight_dtype_cast_counter = 0
323
324        class WeightDTypeCastCounterMode(TorchDispatchMode):
325            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
326                if (
327                    func is torch.ops.aten._to_copy.default
328                    and args[0] is weight
329                    and kwargs["dtype"] is torch.float16
330                ):
331                    nonlocal weight_dtype_cast_counter
332                    weight_dtype_cast_counter += 1
333                return func(*args, **kwargs)
334
335            def __enter__(self):
336                # self.old_clear_cache = torch.clear_autocast_cache
337                # torch.clear_autocast_cache = lambda: None
338                return super().__enter__()
339
340            def __exit__(self, exc_type, exc_val, exc_tb):
341                # torch.clear_autocast_cache = self.old_clear_cache
342                return super().__exit__(exc_type, exc_val, exc_tb)
343
344        with WeightDTypeCastCounterMode():
345            with torch.autocast(device_type="mps"):
346                output = CustomLinear.apply(data, weight)
347                s = output.sum()
348            s.backward()
349        self.assertEqual(weight_dtype_cast_counter, 2)
350
351
352class TestTorchAutocast(TestCase):
353    def test_autocast_fast_dtype(self):
354        gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda")
355        cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu")
356        self.assertEqual(gpu_fast_dtype, torch.half)
357        self.assertEqual(cpu_fast_dtype, torch.bfloat16)
358
359    def test_invalid_device(self):
360        dev = "not a real device"
361        msg = f"Invalid device string: '{dev}'"
362        with self.assertRaisesRegex(RuntimeError, msg):
363            with torch.autocast(device_type=dev):
364                _ = torch.tensor(1)
365        with self.assertRaisesRegex(RuntimeError, msg):
366            assert torch.amp.is_autocast_available(device_type=dev)
367
368    def test_non_string_device(self):
369        """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
370        dev = torch.device("cpu")
371        msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
372        with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
373            torch.autocast(device_type=dev)
374
375
376if __name__ == "__main__":
377    run_tests()
378