xref: /aosp_15_r20/external/pytorch/test/test_autograd_fallback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport warnings
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch.library import _scoped_library, Library
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
11*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
12*da0073e9SAndroid Build Coastguard Worker    parametrize,
13*da0073e9SAndroid Build Coastguard Worker    run_tests,
14*da0073e9SAndroid Build Coastguard Worker    TestCase,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
19*da0073e9SAndroid Build Coastguard Workerdef autograd_fallback_mode(mode):
20*da0073e9SAndroid Build Coastguard Worker    prev = torch._C._get_autograd_fallback_mode()
21*da0073e9SAndroid Build Coastguard Worker    try:
22*da0073e9SAndroid Build Coastguard Worker        torch._C._set_autograd_fallback_mode(mode)
23*da0073e9SAndroid Build Coastguard Worker        yield
24*da0073e9SAndroid Build Coastguard Worker    finally:
25*da0073e9SAndroid Build Coastguard Worker        torch._C._set_autograd_fallback_mode(prev)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass TestAutogradFallback(TestCase):
29*da0073e9SAndroid Build Coastguard Worker    test_ns = "_test_autograd_fallback"
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
32*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch.ops, self.test_ns):
33*da0073e9SAndroid Build Coastguard Worker            delattr(torch.ops, self.test_ns)
34*da0073e9SAndroid Build Coastguard Worker        if hasattr(self, "lib"):
35*da0073e9SAndroid Build Coastguard Worker            del self.lib.m
36*da0073e9SAndroid Build Coastguard Worker            del self.lib
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    def get_op(self, name):
39*da0073e9SAndroid Build Coastguard Worker        return getattr(getattr(torch.ops, self.test_ns), name).default
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def get_lib(self):
42*da0073e9SAndroid Build Coastguard Worker        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
43*da0073e9SAndroid Build Coastguard Worker        self.lib = lib
44*da0073e9SAndroid Build Coastguard Worker        return lib
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
47*da0073e9SAndroid Build Coastguard Worker    def test_no_grad(self, mode):
48*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
49*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
50*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
51*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
52*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings():
55*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("error")
56*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
57*da0073e9SAndroid Build Coastguard Worker                    a = torch.randn([], requires_grad=True)
58*da0073e9SAndroid Build Coastguard Worker                    b = torch.randn([], requires_grad=True)
59*da0073e9SAndroid Build Coastguard Worker                    out = op(a, b, 1)
60*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(out.requires_grad)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings():
63*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("error")
64*da0073e9SAndroid Build Coastguard Worker                a = torch.randn([])
65*da0073e9SAndroid Build Coastguard Worker                b = torch.randn([])
66*da0073e9SAndroid Build Coastguard Worker                out = op(a, b, 1)
67*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(out.requires_grad)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
70*da0073e9SAndroid Build Coastguard Worker    def test_no_autograd_kernel(self, mode):
71*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
72*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
73*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
74*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a, b, c):
77*da0073e9SAndroid Build Coastguard Worker                result = a.detach().numpy() + b.detach().numpy() + c
78*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(result)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker            # Some inputs requiring grad
83*da0073e9SAndroid Build Coastguard Worker            a = torch.randn([], requires_grad=False)
84*da0073e9SAndroid Build Coastguard Worker            b = torch.randn([], requires_grad=True)
85*da0073e9SAndroid Build Coastguard Worker            out = op(a, b, 1).sum()
86*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=True):
87*da0073e9SAndroid Build Coastguard Worker                out.backward()
88*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(b.grad)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def _check_ctx(self, mode, *, mode_nothing_raises=False):
91*da0073e9SAndroid Build Coastguard Worker        if mode == "warn":
92*da0073e9SAndroid Build Coastguard Worker            return self.assertWarnsRegex(
93*da0073e9SAndroid Build Coastguard Worker                UserWarning, "an autograd kernel was not registered"
94*da0073e9SAndroid Build Coastguard Worker            )
95*da0073e9SAndroid Build Coastguard Worker        assert mode == "nothing"
96*da0073e9SAndroid Build Coastguard Worker        if mode_nothing_raises:
97*da0073e9SAndroid Build Coastguard Worker            return self.assertRaisesRegex(RuntimeError, "does not require grad")
98*da0073e9SAndroid Build Coastguard Worker        return contextlib.nullcontext()
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
101*da0073e9SAndroid Build Coastguard Worker    def test_no_autograd_kernel_inplace(self, mode):
102*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
103*da0073e9SAndroid Build Coastguard Worker            # input modified in-place gets returned as output
104*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
105*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
106*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y):
109*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
110*da0073e9SAndroid Build Coastguard Worker                    x.sin_()
111*da0073e9SAndroid Build Coastguard Worker                    y.cos_()
112*da0073e9SAndroid Build Coastguard Worker                return x, y
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
117*da0073e9SAndroid Build Coastguard Worker            w = x.clone()
118*da0073e9SAndroid Build Coastguard Worker            v = x.clone()
119*da0073e9SAndroid Build Coastguard Worker            y0 = w[0]
120*da0073e9SAndroid Build Coastguard Worker            y1 = v[1]
121*da0073e9SAndroid Build Coastguard Worker            z0, z1 = op(y0, y1)
122*da0073e9SAndroid Build Coastguard Worker            for tensor in [w, v, z0, z1, y0, y1]:
123*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
124*da0073e9SAndroid Build Coastguard Worker                    tensor.sum().backward(retain_graph=True)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker            # no outputs: we don't do anything. Maybe we should in the future.
127*da0073e9SAndroid Build Coastguard Worker            # This is not a common failure mode.
128*da0073e9SAndroid Build Coastguard Worker            lib.define("bar(Tensor(a!) self) -> ()")
129*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("bar")
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker            def bar_impl(x):
132*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
133*da0073e9SAndroid Build Coastguard Worker                    x.sin_()
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            lib.impl("bar", bar_impl, "CPU")
136*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings():
137*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("error")
138*da0073e9SAndroid Build Coastguard Worker                x = torch.randn([], requires_grad=True)
139*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
140*da0073e9SAndroid Build Coastguard Worker                z = op(y)
141*da0073e9SAndroid Build Coastguard Worker                y.backward()
142*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, torch.ones_like(x))
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
145*da0073e9SAndroid Build Coastguard Worker    def test_cpu_return_self(self, mode):
146*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
147*da0073e9SAndroid Build Coastguard Worker            # To be clear, none of these situations are OK and will lead
148*da0073e9SAndroid Build Coastguard Worker            # to other problems down the line. We're testing them because
149*da0073e9SAndroid Build Coastguard Worker            # it is fairly common to actually do these things.
150*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
151*da0073e9SAndroid Build Coastguard Worker                lib.define("foo(Tensor self) -> Tensor")
152*da0073e9SAndroid Build Coastguard Worker                lib.impl("foo", lambda x: x, "CPU")
153*da0073e9SAndroid Build Coastguard Worker                op = self.get_op("foo")
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
156*da0073e9SAndroid Build Coastguard Worker                y = op(x).sum()
157*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
158*da0073e9SAndroid Build Coastguard Worker                    y.backward()
159*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, torch.ones_like(x))
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker                lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
162*da0073e9SAndroid Build Coastguard Worker                lib.impl("bar", lambda x: x, "CPU")
163*da0073e9SAndroid Build Coastguard Worker                op = self.get_op("bar")
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
166*da0073e9SAndroid Build Coastguard Worker                y = op(x).sum()
167*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
168*da0073e9SAndroid Build Coastguard Worker                    y.backward()
169*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, torch.ones_like(x))
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
172*da0073e9SAndroid Build Coastguard Worker    def test_composite_registered_to_cpu(self, mode):
173*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
174*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
175*da0073e9SAndroid Build Coastguard Worker                lib.define("foo(Tensor self) -> Tensor")
176*da0073e9SAndroid Build Coastguard Worker                lib.impl("foo", lambda x: x.sin().sum(), "CPU")
177*da0073e9SAndroid Build Coastguard Worker                op = self.get_op("foo")
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
180*da0073e9SAndroid Build Coastguard Worker                y = op(x)
181*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
182*da0073e9SAndroid Build Coastguard Worker                    y.backward()
183*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, x.cos())
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
186*da0073e9SAndroid Build Coastguard Worker    def test_autograd_function_registered_to_cpu(self, mode):
187*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
188*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
189*da0073e9SAndroid Build Coastguard Worker                lib.define("foo(Tensor self) -> Tensor")
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker                class NumpySin(torch.autograd.Function):
192*da0073e9SAndroid Build Coastguard Worker                    @staticmethod
193*da0073e9SAndroid Build Coastguard Worker                    def forward(ctx, x):
194*da0073e9SAndroid Build Coastguard Worker                        ctx.save_for_backward(x)
195*da0073e9SAndroid Build Coastguard Worker                        return torch.tensor(np.sin(x.cpu().numpy()))
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker                    @staticmethod
198*da0073e9SAndroid Build Coastguard Worker                    def backward(ctx, gx):
199*da0073e9SAndroid Build Coastguard Worker                        (x,) = ctx.saved_tensors
200*da0073e9SAndroid Build Coastguard Worker                        return gx * x.cos()
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker                lib.impl("foo", NumpySin.apply, "CPU")
203*da0073e9SAndroid Build Coastguard Worker                op = self.get_op("foo")
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
206*da0073e9SAndroid Build Coastguard Worker                y = op(x).sum()
207*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
208*da0073e9SAndroid Build Coastguard Worker                    y.backward()
209*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, x.cos())
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
212*da0073e9SAndroid Build Coastguard Worker    def test_inplace_autograd_function_registered_to_cpu(self, mode):
213*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
214*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
215*da0073e9SAndroid Build Coastguard Worker                lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker                class NumpySin_(torch.autograd.Function):
218*da0073e9SAndroid Build Coastguard Worker                    @staticmethod
219*da0073e9SAndroid Build Coastguard Worker                    def forward(ctx, x):
220*da0073e9SAndroid Build Coastguard Worker                        ctx.save_for_backward(x.clone())
221*da0073e9SAndroid Build Coastguard Worker                        x_np = x.detach().numpy()
222*da0073e9SAndroid Build Coastguard Worker                        np.sin(x_np, out=x_np)
223*da0073e9SAndroid Build Coastguard Worker                        ctx.mark_dirty(x)
224*da0073e9SAndroid Build Coastguard Worker                        return x
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker                    @staticmethod
227*da0073e9SAndroid Build Coastguard Worker                    def backward(ctx, gx):
228*da0073e9SAndroid Build Coastguard Worker                        (x,) = ctx.saved_tensors
229*da0073e9SAndroid Build Coastguard Worker                        return gx * x.cos()
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker                lib.impl("foo", NumpySin_.apply, "CPU")
232*da0073e9SAndroid Build Coastguard Worker                op = self.get_op("foo")
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
235*da0073e9SAndroid Build Coastguard Worker                z = x.clone()
236*da0073e9SAndroid Build Coastguard Worker                w = z[0]
237*da0073e9SAndroid Build Coastguard Worker                y = op(w)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker                expected = torch.zeros_like(x)
240*da0073e9SAndroid Build Coastguard Worker                expected[0] = x[0].cos()
241*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
242*da0073e9SAndroid Build Coastguard Worker                    (gx,) = torch.autograd.grad(
243*da0073e9SAndroid Build Coastguard Worker                        y, x, torch.ones_like(y), retain_graph=True
244*da0073e9SAndroid Build Coastguard Worker                    )
245*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(gx, expected)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker                expected = torch.ones_like(x)
248*da0073e9SAndroid Build Coastguard Worker                expected[0] = x[0].cos()
249*da0073e9SAndroid Build Coastguard Worker                with self._check_ctx(mode):
250*da0073e9SAndroid Build Coastguard Worker                    (gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
251*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(gx, expected)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
254*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
255*da0073e9SAndroid Build Coastguard Worker        # We don't do anything special (that is, we don't rebase history).
256*da0073e9SAndroid Build Coastguard Worker        # See NOTE [autograd fallback and in-place operations] for why
257*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
258*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
259*da0073e9SAndroid Build Coastguard Worker                # Correct usage of (a!)
260*da0073e9SAndroid Build Coastguard Worker                lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker                def foo_impl(x, y):
263*da0073e9SAndroid Build Coastguard Worker                    x_d = x.detach()
264*da0073e9SAndroid Build Coastguard Worker                    y = y.detach()
265*da0073e9SAndroid Build Coastguard Worker                    x_d.add_(y)
266*da0073e9SAndroid Build Coastguard Worker                    return x
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker                lib.impl("foo", foo_impl, "CPU")
269*da0073e9SAndroid Build Coastguard Worker                foo = self.get_op("foo")
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker                # Incorrect usage of (a!): user doesn't return tensor as-is
272*da0073e9SAndroid Build Coastguard Worker                lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker                def bar_impl(x, y):
275*da0073e9SAndroid Build Coastguard Worker                    x_d = x.detach()
276*da0073e9SAndroid Build Coastguard Worker                    y = y.detach()
277*da0073e9SAndroid Build Coastguard Worker                    x_d.add_(y)
278*da0073e9SAndroid Build Coastguard Worker                    return x_d.clone()
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker                lib.impl("bar", bar_impl, "CPU")
281*da0073e9SAndroid Build Coastguard Worker                bar = self.get_op("bar")
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker                # User mutated input tensor but didn't return it.
284*da0073e9SAndroid Build Coastguard Worker                lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker                def baz_impl(x, y):
287*da0073e9SAndroid Build Coastguard Worker                    x_d = x.detach()
288*da0073e9SAndroid Build Coastguard Worker                    y = y.detach()
289*da0073e9SAndroid Build Coastguard Worker                    x_d.add_(y)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker                lib.impl("baz", baz_impl, "CPU")
292*da0073e9SAndroid Build Coastguard Worker                baz = self.get_op("baz")
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker                # Test in-place on non-view
295*da0073e9SAndroid Build Coastguard Worker                for op in (foo, bar, baz):
296*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(3)
297*da0073e9SAndroid Build Coastguard Worker                    y = torch.randn(3, requires_grad=True)
298*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
299*da0073e9SAndroid Build Coastguard Worker                        z = x.clone()
300*da0073e9SAndroid Build Coastguard Worker                        op(z, y)
301*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker                # Test in-place on view
304*da0073e9SAndroid Build Coastguard Worker                for op in (foo, bar, baz):
305*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(3)
306*da0073e9SAndroid Build Coastguard Worker                    y = torch.randn(3, requires_grad=True)
307*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
308*da0073e9SAndroid Build Coastguard Worker                        z = x[:]
309*da0073e9SAndroid Build Coastguard Worker                        op(z, y)
310*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
313*da0073e9SAndroid Build Coastguard Worker    def test_post_autograd_returns_leaf(self, mode):
314*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
315*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
316*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a) -> (Tensor, Tensor)")
317*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker            lib.impl(
320*da0073e9SAndroid Build Coastguard Worker                "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
321*da0073e9SAndroid Build Coastguard Worker            )
322*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
323*da0073e9SAndroid Build Coastguard Worker            y, z = op(x)
324*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode):
325*da0073e9SAndroid Build Coastguard Worker                z.sum().backward()
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
328*da0073e9SAndroid Build Coastguard Worker    def test_undefined_inputs_outputs(self, mode):
329*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
330*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
331*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
332*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a, b):
335*da0073e9SAndroid Build Coastguard Worker                return None, b.clone()
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
340*da0073e9SAndroid Build Coastguard Worker            # NB: PyTorch dispatcher treats "None" as undefined Tensor.
341*da0073e9SAndroid Build Coastguard Worker            y, z = op(None, x)
342*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode):
343*da0073e9SAndroid Build Coastguard Worker                z.sum().backward()
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
346*da0073e9SAndroid Build Coastguard Worker    def test_undefined_grads(self, mode):
347*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
348*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
349*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
350*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a, b):
353*da0073e9SAndroid Build Coastguard Worker                return a.sin(), b.cos()
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
358*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3)
359*da0073e9SAndroid Build Coastguard Worker            w, z = op(x, y)
360*da0073e9SAndroid Build Coastguard Worker            w = torch._C._functions.UndefinedGrad()(w)
361*da0073e9SAndroid Build Coastguard Worker            z = torch._C._functions.UndefinedGrad()(z)
362*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode):
363*da0073e9SAndroid Build Coastguard Worker                (z + w).sum().backward()
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
366*da0073e9SAndroid Build Coastguard Worker    def test_base_does_not_require_grad(self, mode):
367*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
368*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
369*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
370*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a):
373*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
374*da0073e9SAndroid Build Coastguard Worker                    return a.zero_()
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
377*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
378*da0073e9SAndroid Build Coastguard Worker            y = x[:]
379*da0073e9SAndroid Build Coastguard Worker            y.requires_grad_()
380*da0073e9SAndroid Build Coastguard Worker            w = y[:]
381*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(w._base is x)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker            # Hook should be registered on w, but not w._base
384*da0073e9SAndroid Build Coastguard Worker            op(w)
385*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode):
386*da0073e9SAndroid Build Coastguard Worker                w.sum().backward()
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
389*da0073e9SAndroid Build Coastguard Worker    def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
390*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
391*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
392*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
393*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a, b):
396*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
397*da0073e9SAndroid Build Coastguard Worker                    x = a.clone()
398*da0073e9SAndroid Build Coastguard Worker                    z = b.clone()
399*da0073e9SAndroid Build Coastguard Worker                y = a * b
400*da0073e9SAndroid Build Coastguard Worker                return x, y, z
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
403*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(3, requires_grad=True)
404*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(3, requires_grad=True)
405*da0073e9SAndroid Build Coastguard Worker            x, y, z = op(a, b)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=True):
408*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
409*da0073e9SAndroid Build Coastguard Worker                    x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
410*da0073e9SAndroid Build Coastguard Worker                )
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=False):
413*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
414*da0073e9SAndroid Build Coastguard Worker                    y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
415*da0073e9SAndroid Build Coastguard Worker                )
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=True):
418*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
419*da0073e9SAndroid Build Coastguard Worker                    z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
420*da0073e9SAndroid Build Coastguard Worker                )
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", ("nothing", "warn"))
423*da0073e9SAndroid Build Coastguard Worker    def test_supports_tensor_lists(self, mode):
424*da0073e9SAndroid Build Coastguard Worker        with autograd_fallback_mode(mode):
425*da0073e9SAndroid Build Coastguard Worker            lib = self.get_lib()
426*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor[] a) -> Tensor[]")
427*da0073e9SAndroid Build Coastguard Worker            op = self.get_op("foo")
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker            def foo_impl(a):
430*da0073e9SAndroid Build Coastguard Worker                x, y, z = a
431*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
432*da0073e9SAndroid Build Coastguard Worker                    return x + y + z, x * y * z
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
435*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
436*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(1, requires_grad=True)
437*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(2, 1, requires_grad=True)
438*da0073e9SAndroid Build Coastguard Worker            a, b = op([x, y, z])
439*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=True):
440*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
441*da0073e9SAndroid Build Coastguard Worker                    a,
442*da0073e9SAndroid Build Coastguard Worker                    (x, y, z),
443*da0073e9SAndroid Build Coastguard Worker                    torch.ones_like(a),
444*da0073e9SAndroid Build Coastguard Worker                    allow_unused=True,
445*da0073e9SAndroid Build Coastguard Worker                    retain_graph=True,
446*da0073e9SAndroid Build Coastguard Worker                )
447*da0073e9SAndroid Build Coastguard Worker            with self._check_ctx(mode, mode_nothing_raises=True):
448*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
449*da0073e9SAndroid Build Coastguard Worker                    b,
450*da0073e9SAndroid Build Coastguard Worker                    (x, y, z),
451*da0073e9SAndroid Build Coastguard Worker                    torch.ones_like(b),
452*da0073e9SAndroid Build Coastguard Worker                    allow_unused=True,
453*da0073e9SAndroid Build Coastguard Worker                    retain_graph=True,
454*da0073e9SAndroid Build Coastguard Worker                )
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutogradFallback)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
460*da0073e9SAndroid Build Coastguard Worker    run_tests()
461