xref: /aosp_15_r20/external/pytorch/test/test_autograd_fallback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: autograd"]
2
3import contextlib
4import warnings
5
6import numpy as np
7
8import torch
9from torch.library import _scoped_library, Library
10from torch.testing._internal.common_utils import (
11    instantiate_parametrized_tests,
12    parametrize,
13    run_tests,
14    TestCase,
15)
16
17
18@contextlib.contextmanager
19def autograd_fallback_mode(mode):
20    prev = torch._C._get_autograd_fallback_mode()
21    try:
22        torch._C._set_autograd_fallback_mode(mode)
23        yield
24    finally:
25        torch._C._set_autograd_fallback_mode(prev)
26
27
28class TestAutogradFallback(TestCase):
29    test_ns = "_test_autograd_fallback"
30
31    def tearDown(self):
32        if hasattr(torch.ops, self.test_ns):
33            delattr(torch.ops, self.test_ns)
34        if hasattr(self, "lib"):
35            del self.lib.m
36            del self.lib
37
38    def get_op(self, name):
39        return getattr(getattr(torch.ops, self.test_ns), name).default
40
41    def get_lib(self):
42        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
43        self.lib = lib
44        return lib
45
46    @parametrize("mode", ("nothing", "warn"))
47    def test_no_grad(self, mode):
48        with autograd_fallback_mode(mode):
49            lib = self.get_lib()
50            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
51            lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
52            op = self.get_op("foo")
53
54            with warnings.catch_warnings():
55                warnings.simplefilter("error")
56                with torch.no_grad():
57                    a = torch.randn([], requires_grad=True)
58                    b = torch.randn([], requires_grad=True)
59                    out = op(a, b, 1)
60                self.assertFalse(out.requires_grad)
61
62            with warnings.catch_warnings():
63                warnings.simplefilter("error")
64                a = torch.randn([])
65                b = torch.randn([])
66                out = op(a, b, 1)
67                self.assertFalse(out.requires_grad)
68
69    @parametrize("mode", ("nothing", "warn"))
70    def test_no_autograd_kernel(self, mode):
71        with autograd_fallback_mode(mode):
72            lib = self.get_lib()
73            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
74            op = self.get_op("foo")
75
76            def foo_impl(a, b, c):
77                result = a.detach().numpy() + b.detach().numpy() + c
78                return torch.tensor(result)
79
80            lib.impl("foo", foo_impl, "CPU")
81
82            # Some inputs requiring grad
83            a = torch.randn([], requires_grad=False)
84            b = torch.randn([], requires_grad=True)
85            out = op(a, b, 1).sum()
86            with self._check_ctx(mode, mode_nothing_raises=True):
87                out.backward()
88            self.assertIsNone(b.grad)
89
90    def _check_ctx(self, mode, *, mode_nothing_raises=False):
91        if mode == "warn":
92            return self.assertWarnsRegex(
93                UserWarning, "an autograd kernel was not registered"
94            )
95        assert mode == "nothing"
96        if mode_nothing_raises:
97            return self.assertRaisesRegex(RuntimeError, "does not require grad")
98        return contextlib.nullcontext()
99
100    @parametrize("mode", ("nothing", "warn"))
101    def test_no_autograd_kernel_inplace(self, mode):
102        with autograd_fallback_mode(mode):
103            # input modified in-place gets returned as output
104            lib = self.get_lib()
105            lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
106            op = self.get_op("foo")
107
108            def foo_impl(x, y):
109                with torch.no_grad():
110                    x.sin_()
111                    y.cos_()
112                return x, y
113
114            lib.impl("foo", foo_impl, "CPU")
115
116            x = torch.randn(3, requires_grad=True)
117            w = x.clone()
118            v = x.clone()
119            y0 = w[0]
120            y1 = v[1]
121            z0, z1 = op(y0, y1)
122            for tensor in [w, v, z0, z1, y0, y1]:
123                with self._check_ctx(mode):
124                    tensor.sum().backward(retain_graph=True)
125
126            # no outputs: we don't do anything. Maybe we should in the future.
127            # This is not a common failure mode.
128            lib.define("bar(Tensor(a!) self) -> ()")
129            op = self.get_op("bar")
130
131            def bar_impl(x):
132                with torch.no_grad():
133                    x.sin_()
134
135            lib.impl("bar", bar_impl, "CPU")
136            with warnings.catch_warnings():
137                warnings.simplefilter("error")
138                x = torch.randn([], requires_grad=True)
139                y = x.clone()
140                z = op(y)
141                y.backward()
142                self.assertEqual(x.grad, torch.ones_like(x))
143
144    @parametrize("mode", ("nothing", "warn"))
145    def test_cpu_return_self(self, mode):
146        with autograd_fallback_mode(mode):
147            # To be clear, none of these situations are OK and will lead
148            # to other problems down the line. We're testing them because
149            # it is fairly common to actually do these things.
150            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
151                lib.define("foo(Tensor self) -> Tensor")
152                lib.impl("foo", lambda x: x, "CPU")
153                op = self.get_op("foo")
154
155                x = torch.randn(3, requires_grad=True)
156                y = op(x).sum()
157                with self._check_ctx(mode):
158                    y.backward()
159                    self.assertEqual(x.grad, torch.ones_like(x))
160
161                lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
162                lib.impl("bar", lambda x: x, "CPU")
163                op = self.get_op("bar")
164
165                x = torch.randn(3, requires_grad=True)
166                y = op(x).sum()
167                with self._check_ctx(mode):
168                    y.backward()
169                    self.assertEqual(x.grad, torch.ones_like(x))
170
171    @parametrize("mode", ("nothing", "warn"))
172    def test_composite_registered_to_cpu(self, mode):
173        with autograd_fallback_mode(mode):
174            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
175                lib.define("foo(Tensor self) -> Tensor")
176                lib.impl("foo", lambda x: x.sin().sum(), "CPU")
177                op = self.get_op("foo")
178
179                x = torch.randn(3, requires_grad=True)
180                y = op(x)
181                with self._check_ctx(mode):
182                    y.backward()
183                    self.assertEqual(x.grad, x.cos())
184
185    @parametrize("mode", ("nothing", "warn"))
186    def test_autograd_function_registered_to_cpu(self, mode):
187        with autograd_fallback_mode(mode):
188            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
189                lib.define("foo(Tensor self) -> Tensor")
190
191                class NumpySin(torch.autograd.Function):
192                    @staticmethod
193                    def forward(ctx, x):
194                        ctx.save_for_backward(x)
195                        return torch.tensor(np.sin(x.cpu().numpy()))
196
197                    @staticmethod
198                    def backward(ctx, gx):
199                        (x,) = ctx.saved_tensors
200                        return gx * x.cos()
201
202                lib.impl("foo", NumpySin.apply, "CPU")
203                op = self.get_op("foo")
204
205                x = torch.randn(3, requires_grad=True)
206                y = op(x).sum()
207                with self._check_ctx(mode):
208                    y.backward()
209                    self.assertEqual(x.grad, x.cos())
210
211    @parametrize("mode", ("nothing", "warn"))
212    def test_inplace_autograd_function_registered_to_cpu(self, mode):
213        with autograd_fallback_mode(mode):
214            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
215                lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
216
217                class NumpySin_(torch.autograd.Function):
218                    @staticmethod
219                    def forward(ctx, x):
220                        ctx.save_for_backward(x.clone())
221                        x_np = x.detach().numpy()
222                        np.sin(x_np, out=x_np)
223                        ctx.mark_dirty(x)
224                        return x
225
226                    @staticmethod
227                    def backward(ctx, gx):
228                        (x,) = ctx.saved_tensors
229                        return gx * x.cos()
230
231                lib.impl("foo", NumpySin_.apply, "CPU")
232                op = self.get_op("foo")
233
234                x = torch.randn(3, requires_grad=True)
235                z = x.clone()
236                w = z[0]
237                y = op(w)
238
239                expected = torch.zeros_like(x)
240                expected[0] = x[0].cos()
241                with self._check_ctx(mode):
242                    (gx,) = torch.autograd.grad(
243                        y, x, torch.ones_like(y), retain_graph=True
244                    )
245                    self.assertEqual(gx, expected)
246
247                expected = torch.ones_like(x)
248                expected[0] = x[0].cos()
249                with self._check_ctx(mode):
250                    (gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
251                    self.assertEqual(gx, expected)
252
253    @parametrize("mode", ("nothing", "warn"))
254    def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
255        # We don't do anything special (that is, we don't rebase history).
256        # See NOTE [autograd fallback and in-place operations] for why
257        with autograd_fallback_mode(mode):
258            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
259                # Correct usage of (a!)
260                lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
261
262                def foo_impl(x, y):
263                    x_d = x.detach()
264                    y = y.detach()
265                    x_d.add_(y)
266                    return x
267
268                lib.impl("foo", foo_impl, "CPU")
269                foo = self.get_op("foo")
270
271                # Incorrect usage of (a!): user doesn't return tensor as-is
272                lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
273
274                def bar_impl(x, y):
275                    x_d = x.detach()
276                    y = y.detach()
277                    x_d.add_(y)
278                    return x_d.clone()
279
280                lib.impl("bar", bar_impl, "CPU")
281                bar = self.get_op("bar")
282
283                # User mutated input tensor but didn't return it.
284                lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
285
286                def baz_impl(x, y):
287                    x_d = x.detach()
288                    y = y.detach()
289                    x_d.add_(y)
290
291                lib.impl("baz", baz_impl, "CPU")
292                baz = self.get_op("baz")
293
294                # Test in-place on non-view
295                for op in (foo, bar, baz):
296                    x = torch.randn(3)
297                    y = torch.randn(3, requires_grad=True)
298                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
299                        z = x.clone()
300                        op(z, y)
301                        torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
302
303                # Test in-place on view
304                for op in (foo, bar, baz):
305                    x = torch.randn(3)
306                    y = torch.randn(3, requires_grad=True)
307                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
308                        z = x[:]
309                        op(z, y)
310                        torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
311
312    @parametrize("mode", ("nothing", "warn"))
313    def test_post_autograd_returns_leaf(self, mode):
314        with autograd_fallback_mode(mode):
315            lib = self.get_lib()
316            lib.define("foo(Tensor a) -> (Tensor, Tensor)")
317            op = self.get_op("foo")
318
319            lib.impl(
320                "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
321            )
322            x = torch.randn(3, requires_grad=True)
323            y, z = op(x)
324            with self._check_ctx(mode):
325                z.sum().backward()
326
327    @parametrize("mode", ("nothing", "warn"))
328    def test_undefined_inputs_outputs(self, mode):
329        with autograd_fallback_mode(mode):
330            lib = self.get_lib()
331            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
332            op = self.get_op("foo")
333
334            def foo_impl(a, b):
335                return None, b.clone()
336
337            lib.impl("foo", foo_impl, "CPU")
338
339            x = torch.randn(3, requires_grad=True)
340            # NB: PyTorch dispatcher treats "None" as undefined Tensor.
341            y, z = op(None, x)
342            with self._check_ctx(mode):
343                z.sum().backward()
344
345    @parametrize("mode", ("nothing", "warn"))
346    def test_undefined_grads(self, mode):
347        with autograd_fallback_mode(mode):
348            lib = self.get_lib()
349            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
350            op = self.get_op("foo")
351
352            def foo_impl(a, b):
353                return a.sin(), b.cos()
354
355            lib.impl("foo", foo_impl, "CPU")
356
357            x = torch.randn(3, requires_grad=True)
358            y = torch.randn(3)
359            w, z = op(x, y)
360            w = torch._C._functions.UndefinedGrad()(w)
361            z = torch._C._functions.UndefinedGrad()(z)
362            with self._check_ctx(mode):
363                (z + w).sum().backward()
364
365    @parametrize("mode", ("nothing", "warn"))
366    def test_base_does_not_require_grad(self, mode):
367        with autograd_fallback_mode(mode):
368            lib = self.get_lib()
369            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
370            op = self.get_op("foo")
371
372            def foo_impl(a):
373                with torch.no_grad():
374                    return a.zero_()
375
376            lib.impl("foo", foo_impl, "CPU")
377            x = torch.randn(3)
378            y = x[:]
379            y.requires_grad_()
380            w = y[:]
381            self.assertTrue(w._base is x)
382
383            # Hook should be registered on w, but not w._base
384            op(w)
385            with self._check_ctx(mode):
386                w.sum().backward()
387
388    @parametrize("mode", ("nothing", "warn"))
389    def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
390        with autograd_fallback_mode(mode):
391            lib = self.get_lib()
392            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
393            op = self.get_op("foo")
394
395            def foo_impl(a, b):
396                with torch.no_grad():
397                    x = a.clone()
398                    z = b.clone()
399                y = a * b
400                return x, y, z
401
402            lib.impl("foo", foo_impl, "CPU")
403            a = torch.randn(3, requires_grad=True)
404            b = torch.randn(3, requires_grad=True)
405            x, y, z = op(a, b)
406
407            with self._check_ctx(mode, mode_nothing_raises=True):
408                torch.autograd.grad(
409                    x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
410                )
411
412            with self._check_ctx(mode, mode_nothing_raises=False):
413                torch.autograd.grad(
414                    y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
415                )
416
417            with self._check_ctx(mode, mode_nothing_raises=True):
418                torch.autograd.grad(
419                    z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
420                )
421
422    @parametrize("mode", ("nothing", "warn"))
423    def test_supports_tensor_lists(self, mode):
424        with autograd_fallback_mode(mode):
425            lib = self.get_lib()
426            lib.define("foo(Tensor[] a) -> Tensor[]")
427            op = self.get_op("foo")
428
429            def foo_impl(a):
430                x, y, z = a
431                with torch.no_grad():
432                    return x + y + z, x * y * z
433
434            lib.impl("foo", foo_impl, "CPU")
435            x = torch.randn(3, requires_grad=True)
436            y = torch.randn(1, requires_grad=True)
437            z = torch.randn(2, 1, requires_grad=True)
438            a, b = op([x, y, z])
439            with self._check_ctx(mode, mode_nothing_raises=True):
440                torch.autograd.grad(
441                    a,
442                    (x, y, z),
443                    torch.ones_like(a),
444                    allow_unused=True,
445                    retain_graph=True,
446                )
447            with self._check_ctx(mode, mode_nothing_raises=True):
448                torch.autograd.grad(
449                    b,
450                    (x, y, z),
451                    torch.ones_like(b),
452                    allow_unused=True,
453                    retain_graph=True,
454                )
455
456
457instantiate_parametrized_tests(TestAutogradFallback)
458
459if __name__ == "__main__":
460    run_tests()
461