xref: /aosp_15_r20/external/pytorch/test/test_autocast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import (
7*da0073e9SAndroid Build Coastguard Worker    AutocastCPUTestLists,
8*da0073e9SAndroid Build Coastguard Worker    TestAutocast,
9*da0073e9SAndroid Build Coastguard Worker)
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
11*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
12*da0073e9SAndroid Build Coastguard Worker    run_tests,
13*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
14*da0073e9SAndroid Build Coastguard Worker    TestCase,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass TestAutocastCPU(TestAutocast):
20*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
21*da0073e9SAndroid Build Coastguard Worker        super().setUp()
22*da0073e9SAndroid Build Coastguard Worker        self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
25*da0073e9SAndroid Build Coastguard Worker        del self.autocast_lists
26*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
29*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_expect_builtin_promote(self):
30*da0073e9SAndroid Build Coastguard Worker        for (
31*da0073e9SAndroid Build Coastguard Worker            op,
32*da0073e9SAndroid Build Coastguard Worker            args1,
33*da0073e9SAndroid Build Coastguard Worker            args2,
34*da0073e9SAndroid Build Coastguard Worker            out_type,
35*da0073e9SAndroid Build Coastguard Worker        ) in self.autocast_lists.torch_expect_builtin_promote:
36*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
37*da0073e9SAndroid Build Coastguard Worker                op, args1, torch.float32, device="cpu", out_type=out_type
38*da0073e9SAndroid Build Coastguard Worker            )
39*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
40*da0073e9SAndroid Build Coastguard Worker                op,
41*da0073e9SAndroid Build Coastguard Worker                args2,
42*da0073e9SAndroid Build Coastguard Worker                torch.float32,
43*da0073e9SAndroid Build Coastguard Worker                device="cpu",
44*da0073e9SAndroid Build Coastguard Worker                out_type=out_type,
45*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
46*da0073e9SAndroid Build Coastguard Worker            )
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
49*da0073e9SAndroid Build Coastguard Worker    def test_autocast_methods_expect_builtin_promote(self):
50*da0073e9SAndroid Build Coastguard Worker        for (
51*da0073e9SAndroid Build Coastguard Worker            op,
52*da0073e9SAndroid Build Coastguard Worker            args1,
53*da0073e9SAndroid Build Coastguard Worker            args2,
54*da0073e9SAndroid Build Coastguard Worker            out_type,
55*da0073e9SAndroid Build Coastguard Worker        ) in self.autocast_lists.methods_expect_builtin_promote:
56*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
57*da0073e9SAndroid Build Coastguard Worker                op, args1, torch.float32, device="cpu", module=None, out_type=out_type
58*da0073e9SAndroid Build Coastguard Worker            )
59*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
60*da0073e9SAndroid Build Coastguard Worker                op,
61*da0073e9SAndroid Build Coastguard Worker                args2,
62*da0073e9SAndroid Build Coastguard Worker                torch.float32,
63*da0073e9SAndroid Build Coastguard Worker                device="cpu",
64*da0073e9SAndroid Build Coastguard Worker                module=None,
65*da0073e9SAndroid Build Coastguard Worker                out_type=out_type,
66*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
67*da0073e9SAndroid Build Coastguard Worker            )
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
70*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_16(self):
71*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_16:
72*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
73*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
74*da0073e9SAndroid Build Coastguard Worker                op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
75*da0073e9SAndroid Build Coastguard Worker            )
76*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
77*da0073e9SAndroid Build Coastguard Worker                op,
78*da0073e9SAndroid Build Coastguard Worker                args,
79*da0073e9SAndroid Build Coastguard Worker                torch.float16,
80*da0073e9SAndroid Build Coastguard Worker                device="cpu",
81*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
82*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
83*da0073e9SAndroid Build Coastguard Worker            )
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
86*da0073e9SAndroid Build Coastguard Worker    def test_autocast_nn_16(self):
87*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.nn_16:
88*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
89*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
90*da0073e9SAndroid Build Coastguard Worker                op,
91*da0073e9SAndroid Build Coastguard Worker                args,
92*da0073e9SAndroid Build Coastguard Worker                torch.bfloat16,
93*da0073e9SAndroid Build Coastguard Worker                device="cpu",
94*da0073e9SAndroid Build Coastguard Worker                module=torch._C._nn,
95*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
96*da0073e9SAndroid Build Coastguard Worker            )
97*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
98*da0073e9SAndroid Build Coastguard Worker                op,
99*da0073e9SAndroid Build Coastguard Worker                args,
100*da0073e9SAndroid Build Coastguard Worker                torch.float16,
101*da0073e9SAndroid Build Coastguard Worker                device="cpu",
102*da0073e9SAndroid Build Coastguard Worker                module=torch._C._nn,
103*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
104*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
105*da0073e9SAndroid Build Coastguard Worker            )
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
108*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_fp32(self):
109*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_fp32:
110*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
111*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
112*da0073e9SAndroid Build Coastguard Worker                op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
113*da0073e9SAndroid Build Coastguard Worker            )
114*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
115*da0073e9SAndroid Build Coastguard Worker                op,
116*da0073e9SAndroid Build Coastguard Worker                args,
117*da0073e9SAndroid Build Coastguard Worker                torch.float32,
118*da0073e9SAndroid Build Coastguard Worker                device="cpu",
119*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
120*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
121*da0073e9SAndroid Build Coastguard Worker            )
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
124*da0073e9SAndroid Build Coastguard Worker    def test_autocast_nn_fp32(self):
125*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.nn_fp32:
126*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
127*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
128*da0073e9SAndroid Build Coastguard Worker                op,
129*da0073e9SAndroid Build Coastguard Worker                args,
130*da0073e9SAndroid Build Coastguard Worker                torch.float32,
131*da0073e9SAndroid Build Coastguard Worker                device="cpu",
132*da0073e9SAndroid Build Coastguard Worker                module=torch._C._nn,
133*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
134*da0073e9SAndroid Build Coastguard Worker            )
135*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
136*da0073e9SAndroid Build Coastguard Worker                op,
137*da0073e9SAndroid Build Coastguard Worker                args,
138*da0073e9SAndroid Build Coastguard Worker                torch.float32,
139*da0073e9SAndroid Build Coastguard Worker                device="cpu",
140*da0073e9SAndroid Build Coastguard Worker                module=torch._C._nn,
141*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
142*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
143*da0073e9SAndroid Build Coastguard Worker            )
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
146*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_need_autocast_promote(self):
147*da0073e9SAndroid Build Coastguard Worker        for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
148*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
149*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
150*da0073e9SAndroid Build Coastguard Worker                op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
151*da0073e9SAndroid Build Coastguard Worker            )
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
154*da0073e9SAndroid Build Coastguard Worker    def test_autocast_rnn(self):
155*da0073e9SAndroid Build Coastguard Worker        if (
156*da0073e9SAndroid Build Coastguard Worker            torch.backends.mkldnn.is_available()
157*da0073e9SAndroid Build Coastguard Worker            and torch.ops.mkldnn._is_mkldnn_bf16_supported()
158*da0073e9SAndroid Build Coastguard Worker        ):
159*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 2, 1)
160*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(2, 2, 1)
161*da0073e9SAndroid Build Coastguard Worker            cx = torch.randn(2, 2, 1)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker            # Raise ValueError when autocast is not enabled
166*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "input must have the type"):
167*da0073e9SAndroid Build Coastguard Worker                m(x, (hx, cx))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker            # Should be able to run the below case with autocast
170*da0073e9SAndroid Build Coastguard Worker            with torch.amp.autocast(device_type="cpu"):
171*da0073e9SAndroid Build Coastguard Worker                m(x, (hx, cx))
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    def test_autocast_disabled_with_fp32_dtype(self):
174*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
175*da0073e9SAndroid Build Coastguard Worker            _ = torch.ones(10)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def test_generic_autocast(self):
178*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_16:
179*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
180*da0073e9SAndroid Build Coastguard Worker            with torch.amp.autocast(device_type="cpu"):
181*da0073e9SAndroid Build Coastguard Worker                generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
182*da0073e9SAndroid Build Coastguard Worker            with torch.amp.autocast(device_type="cpu"):
183*da0073e9SAndroid Build Coastguard Worker                cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
184*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(generic_autocast_output, cpu_autocast_output)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def test_cpu_autocast_deprecated_warning(self):
187*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
188*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
189*da0073e9SAndroid Build Coastguard Worker            r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.",
190*da0073e9SAndroid Build Coastguard Worker        ):
191*da0073e9SAndroid Build Coastguard Worker            with torch.cpu.amp.autocast():
192*da0073e9SAndroid Build Coastguard Worker                _ = torch.ones(10)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerclass CustomLinear(torch.autograd.Function):
196*da0073e9SAndroid Build Coastguard Worker    @staticmethod
197*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, x, w_t):
198*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(x, w_t)
199*da0073e9SAndroid Build Coastguard Worker        return torch.nn.functional.linear(x, w_t)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    @staticmethod
202*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, grad_output):
203*da0073e9SAndroid Build Coastguard Worker        x, w_t = ctx.saved_tensors
204*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="cuda"):
205*da0073e9SAndroid Build Coastguard Worker            dL_dX = torch.matmul(grad_output, w_t)
206*da0073e9SAndroid Build Coastguard Worker            dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
207*da0073e9SAndroid Build Coastguard Worker        return dL_dX, dL_dW
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Workerclass WeightDTypeCastCounterMode(TorchDispatchMode):
211*da0073e9SAndroid Build Coastguard Worker    def __init__(self, weight):
212*da0073e9SAndroid Build Coastguard Worker        super().__init__()
213*da0073e9SAndroid Build Coastguard Worker        self.dtype_cast_counter = 0
214*da0073e9SAndroid Build Coastguard Worker        self.weight = weight
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
217*da0073e9SAndroid Build Coastguard Worker        if (
218*da0073e9SAndroid Build Coastguard Worker            func is torch.ops.aten._to_copy.default
219*da0073e9SAndroid Build Coastguard Worker            and args[0] is self.weight
220*da0073e9SAndroid Build Coastguard Worker            and kwargs["dtype"] is torch.float16
221*da0073e9SAndroid Build Coastguard Worker        ):
222*da0073e9SAndroid Build Coastguard Worker            self.dtype_cast_counter += 1
223*da0073e9SAndroid Build Coastguard Worker        return func(*args, **kwargs)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
226*da0073e9SAndroid Build Coastguard Worker        self.old_clear_cache = torch.clear_autocast_cache
227*da0073e9SAndroid Build Coastguard Worker        torch.clear_autocast_cache = lambda: None
228*da0073e9SAndroid Build Coastguard Worker        return super().__enter__()
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
231*da0073e9SAndroid Build Coastguard Worker        torch.clear_autocast_cache = self.old_clear_cache
232*da0073e9SAndroid Build Coastguard Worker        return super().__exit__(exc_type, exc_val, exc_tb)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
236*da0073e9SAndroid Build Coastguard Workerclass TestAutocastGPU(TestCase):
237*da0073e9SAndroid Build Coastguard Worker    def test_cast_cache_is_global(self):
238*da0073e9SAndroid Build Coastguard Worker        """
239*da0073e9SAndroid Build Coastguard Worker        Verifies that the autocast cache is global. This is done by
240*da0073e9SAndroid Build Coastguard Worker        mocking out cache clearing at the end of the forward pass,
241*da0073e9SAndroid Build Coastguard Worker        running forward+backward with an explicit call to autocast in the
242*da0073e9SAndroid Build Coastguard Worker        backward, and verifying that the weight only get cast to float16 once.
243*da0073e9SAndroid Build Coastguard Worker        """
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2, 3).cuda()
246*da0073e9SAndroid Build Coastguard Worker        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        with WeightDTypeCastCounterMode(weight) as mode:
249*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type="cuda"):
250*da0073e9SAndroid Build Coastguard Worker                output = CustomLinear.apply(data, weight)
251*da0073e9SAndroid Build Coastguard Worker                s = output.sum()
252*da0073e9SAndroid Build Coastguard Worker            s.backward()
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mode.dtype_cast_counter, 1)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def test_cache_disabled(self):
257*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2, 3).cuda()
258*da0073e9SAndroid Build Coastguard Worker        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        try:
261*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cached_tensors_enabled(True)
262*da0073e9SAndroid Build Coastguard Worker            torch._C._add_cached_tensor(weight)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker            with WeightDTypeCastCounterMode(weight) as mode:
265*da0073e9SAndroid Build Coastguard Worker                with torch.autocast(device_type="cuda"):
266*da0073e9SAndroid Build Coastguard Worker                    output = CustomLinear.apply(data, weight)
267*da0073e9SAndroid Build Coastguard Worker                    s = output.sum()
268*da0073e9SAndroid Build Coastguard Worker                s.backward()
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker            # we should not have cached the conversion of the weight
271*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mode.dtype_cast_counter, 2)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        finally:
274*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cached_tensors_enabled(False)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    # index_put under AMP follows a cast policy called "promote",
277*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230
278*da0073e9SAndroid Build Coastguard Worker    # That means:
279*da0073e9SAndroid Build Coastguard Worker    #   (1) double precision is ignored,
280*da0073e9SAndroid Build Coastguard Worker    #   (2) if any argument is float, then all arguments are promoted to float,
281*da0073e9SAndroid Build Coastguard Worker    #   (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype.
282*da0073e9SAndroid Build Coastguard Worker    # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution,
283*da0073e9SAndroid Build Coastguard Worker    # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward
284*da0073e9SAndroid Build Coastguard Worker    # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied.
285*da0073e9SAndroid Build Coastguard Worker    # For more info see https://github.com/pytorch/pytorch/issues/132715.
286*da0073e9SAndroid Build Coastguard Worker    def test_autocast_prioritize(self):
287*da0073e9SAndroid Build Coastguard Worker        device = "cuda"
288*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type=device, enabled=True, dtype=dtype):
291*da0073e9SAndroid Build Coastguard Worker            t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True)
292*da0073e9SAndroid Build Coastguard Worker            index = torch.randint(
293*da0073e9SAndroid Build Coastguard Worker                low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device
294*da0073e9SAndroid Build Coastguard Worker            )
295*da0073e9SAndroid Build Coastguard Worker            val = torch.randn(1, dtype=dtype, device=device)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker            res = torch.index_put(t, [index], val)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            loss = res.mean()
300*da0073e9SAndroid Build Coastguard Worker            loss.backward()
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
304*da0073e9SAndroid Build Coastguard Workerclass TestAutocastMPS(TestCase):
305*da0073e9SAndroid Build Coastguard Worker    def test_cast_cache_is_global(self):
306*da0073e9SAndroid Build Coastguard Worker        class CustomLinear(torch.autograd.Function):
307*da0073e9SAndroid Build Coastguard Worker            @staticmethod
308*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, w_t):
309*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x, w_t)
310*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.linear(x, w_t)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker            @staticmethod
313*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
314*da0073e9SAndroid Build Coastguard Worker                x, w_t = ctx.saved_tensors
315*da0073e9SAndroid Build Coastguard Worker                with torch.autocast(device_type="mps"):
316*da0073e9SAndroid Build Coastguard Worker                    dL_dX = torch.matmul(grad_output, w_t)
317*da0073e9SAndroid Build Coastguard Worker                    dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
318*da0073e9SAndroid Build Coastguard Worker                return dL_dX, dL_dW
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2, 3).to("mps")
321*da0073e9SAndroid Build Coastguard Worker        weight = torch.nn.Parameter(torch.randn(4, 3).to("mps"))
322*da0073e9SAndroid Build Coastguard Worker        weight_dtype_cast_counter = 0
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        class WeightDTypeCastCounterMode(TorchDispatchMode):
325*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
326*da0073e9SAndroid Build Coastguard Worker                if (
327*da0073e9SAndroid Build Coastguard Worker                    func is torch.ops.aten._to_copy.default
328*da0073e9SAndroid Build Coastguard Worker                    and args[0] is weight
329*da0073e9SAndroid Build Coastguard Worker                    and kwargs["dtype"] is torch.float16
330*da0073e9SAndroid Build Coastguard Worker                ):
331*da0073e9SAndroid Build Coastguard Worker                    nonlocal weight_dtype_cast_counter
332*da0073e9SAndroid Build Coastguard Worker                    weight_dtype_cast_counter += 1
333*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
336*da0073e9SAndroid Build Coastguard Worker                # self.old_clear_cache = torch.clear_autocast_cache
337*da0073e9SAndroid Build Coastguard Worker                # torch.clear_autocast_cache = lambda: None
338*da0073e9SAndroid Build Coastguard Worker                return super().__enter__()
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_val, exc_tb):
341*da0073e9SAndroid Build Coastguard Worker                # torch.clear_autocast_cache = self.old_clear_cache
342*da0073e9SAndroid Build Coastguard Worker                return super().__exit__(exc_type, exc_val, exc_tb)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        with WeightDTypeCastCounterMode():
345*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type="mps"):
346*da0073e9SAndroid Build Coastguard Worker                output = CustomLinear.apply(data, weight)
347*da0073e9SAndroid Build Coastguard Worker                s = output.sum()
348*da0073e9SAndroid Build Coastguard Worker            s.backward()
349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weight_dtype_cast_counter, 2)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Workerclass TestTorchAutocast(TestCase):
353*da0073e9SAndroid Build Coastguard Worker    def test_autocast_fast_dtype(self):
354*da0073e9SAndroid Build Coastguard Worker        gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda")
355*da0073e9SAndroid Build Coastguard Worker        cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu")
356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gpu_fast_dtype, torch.half)
357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_fast_dtype, torch.bfloat16)
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def test_invalid_device(self):
360*da0073e9SAndroid Build Coastguard Worker        dev = "not a real device"
361*da0073e9SAndroid Build Coastguard Worker        msg = f"Invalid device string: '{dev}'"
362*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
363*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type=dev):
364*da0073e9SAndroid Build Coastguard Worker                _ = torch.tensor(1)
365*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
366*da0073e9SAndroid Build Coastguard Worker            assert torch.amp.is_autocast_available(device_type=dev)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def test_non_string_device(self):
369*da0073e9SAndroid Build Coastguard Worker        """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
370*da0073e9SAndroid Build Coastguard Worker        dev = torch.device("cpu")
371*da0073e9SAndroid Build Coastguard Worker        msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
372*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
373*da0073e9SAndroid Build Coastguard Worker            torch.autocast(device_type=dev)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
377*da0073e9SAndroid Build Coastguard Worker    run_tests()
378