xref: /aosp_15_r20/external/pytorch/test/test_custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: custom-operators"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport subprocess
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport typing
10*da0073e9SAndroid Build Coastguard Workerimport unittest
11*da0073e9SAndroid Build Coastguard Workerfrom typing import *  # noqa: F403
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport numpy as np
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerimport torch._custom_ops as custom_ops
16*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.optests as optests
17*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
18*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension
19*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx
20*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
21*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.impl import CustomOp, infer_schema
22*da0073e9SAndroid Build Coastguard Workerfrom torch._library.infer_schema import tuple_to_list
23*da0073e9SAndroid Build Coastguard Workerfrom torch._utils_internal import get_file_path_2
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import custom_op_db
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
27*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
28*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
29*da0073e9SAndroid Build Coastguard Worker    ops,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
32*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
33*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
34*da0073e9SAndroid Build Coastguard Worker    parametrize,
35*da0073e9SAndroid Build Coastguard Worker    run_tests,
36*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
37*da0073e9SAndroid Build Coastguard Worker    subtest,
38*da0073e9SAndroid Build Coastguard Worker    TestCase,
39*da0073e9SAndroid Build Coastguard Worker)
40*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.custom_op_db import numpy_nonzero
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker# Shadowed by `torch.testing._internal.common_utils.custom_op`
44*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.impl import custom_op  # usort: skip
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef requires_compile(fun):
48*da0073e9SAndroid Build Coastguard Worker    fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun)
49*da0073e9SAndroid Build Coastguard Worker    return fun
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerclass CustomOpTestCaseBase(TestCase):
53*da0073e9SAndroid Build Coastguard Worker    test_ns = "_test_custom_op"
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
56*da0073e9SAndroid Build Coastguard Worker        super().setUp()
57*da0073e9SAndroid Build Coastguard Worker        self.libraries = []
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
60*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
61*da0073e9SAndroid Build Coastguard Worker        import torch._custom_op
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        keys = list(torch._custom_op.impl.global_registry.keys())
64*da0073e9SAndroid Build Coastguard Worker        for key in keys:
65*da0073e9SAndroid Build Coastguard Worker            if not key.startswith(f"{self.test_ns}::"):
66*da0073e9SAndroid Build Coastguard Worker                continue
67*da0073e9SAndroid Build Coastguard Worker            torch._custom_op.impl.global_registry[key]._destroy()
68*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch.ops, self.test_ns):
69*da0073e9SAndroid Build Coastguard Worker            delattr(torch.ops, self.test_ns)
70*da0073e9SAndroid Build Coastguard Worker        for lib in self.libraries:
71*da0073e9SAndroid Build Coastguard Worker            lib._destroy()
72*da0073e9SAndroid Build Coastguard Worker        del self.libraries
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def ns(self):
75*da0073e9SAndroid Build Coastguard Worker        return getattr(torch.ops, self.test_ns)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def lib(self):
78*da0073e9SAndroid Build Coastguard Worker        result = torch.library.Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
79*da0073e9SAndroid Build Coastguard Worker        self.libraries.append(result)
80*da0073e9SAndroid Build Coastguard Worker        return result
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def get_op(self, qualname):
83*da0073e9SAndroid Build Coastguard Worker        return torch._custom_op.impl.get_op(qualname)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker@requires_compile
87*da0073e9SAndroid Build Coastguard Workerclass TestCustomOpTesting(CustomOpTestCaseBase):
88*da0073e9SAndroid Build Coastguard Worker    @parametrize("check_gradients", (False, "auto"))
89*da0073e9SAndroid Build Coastguard Worker    @parametrize("dynamic", (True, False))
90*da0073e9SAndroid Build Coastguard Worker    def test_aot_autograd_check_degenerate_cases(
91*da0073e9SAndroid Build Coastguard Worker        self, device, dynamic, check_gradients
92*da0073e9SAndroid Build Coastguard Worker    ):
93*da0073e9SAndroid Build Coastguard Worker        def simple(x):
94*da0073e9SAndroid Build Coastguard Worker            return x.clone()
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        # Should not raise
97*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device=device)
98*da0073e9SAndroid Build Coastguard Worker        optests.aot_autograd_check(
99*da0073e9SAndroid Build Coastguard Worker            simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
100*da0073e9SAndroid Build Coastguard Worker        )
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        def outputs_dont_require_grad(x):
103*da0073e9SAndroid Build Coastguard Worker            return x.detach()
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        # Should not raise
106*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device=device, requires_grad=True)
107*da0073e9SAndroid Build Coastguard Worker        optests.aot_autograd_check(
108*da0073e9SAndroid Build Coastguard Worker            simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
109*da0073e9SAndroid Build Coastguard Worker        )
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        def no_outputs(x):
112*da0073e9SAndroid Build Coastguard Worker            return x.detach()
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        # Should not raise
115*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device=device, requires_grad=True)
116*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device=device, requires_grad=False)
117*da0073e9SAndroid Build Coastguard Worker        optests.aot_autograd_check(
118*da0073e9SAndroid Build Coastguard Worker            no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
119*da0073e9SAndroid Build Coastguard Worker        )
120*da0073e9SAndroid Build Coastguard Worker        optests.aot_autograd_check(
121*da0073e9SAndroid Build Coastguard Worker            no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
122*da0073e9SAndroid Build Coastguard Worker        )
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_incorrect_schema_mutation(self, device):
125*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
126*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
127*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
130*da0073e9SAndroid Build Coastguard Worker            @staticmethod
131*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
132*da0073e9SAndroid Build Coastguard Worker                guard = torch._C._AutoDispatchBelowAutograd()
133*da0073e9SAndroid Build Coastguard Worker                try:
134*da0073e9SAndroid Build Coastguard Worker                    return op(x)
135*da0073e9SAndroid Build Coastguard Worker                finally:
136*da0073e9SAndroid Build Coastguard Worker                    del guard
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker            @staticmethod
139*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
140*da0073e9SAndroid Build Coastguard Worker                return gx
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
143*da0073e9SAndroid Build Coastguard Worker            x.sin_()
144*da0073e9SAndroid Build Coastguard Worker            return x.clone()
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
147*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
148*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CUDA")
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
151*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
152*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
153*da0073e9SAndroid Build Coastguard Worker        ):
154*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    def test_incorrect_schema_view(self, device):
157*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
158*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
159*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
162*da0073e9SAndroid Build Coastguard Worker            @staticmethod
163*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
164*da0073e9SAndroid Build Coastguard Worker                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
165*da0073e9SAndroid Build Coastguard Worker                with torch._C._AutoDispatchBelowAutograd():
166*da0073e9SAndroid Build Coastguard Worker                    with torch._C._ExcludeDispatchKeyGuard(
167*da0073e9SAndroid Build Coastguard Worker                        torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
168*da0073e9SAndroid Build Coastguard Worker                    ):
169*da0073e9SAndroid Build Coastguard Worker                        return op(x)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker            @staticmethod
172*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
173*da0073e9SAndroid Build Coastguard Worker                return gx
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
176*da0073e9SAndroid Build Coastguard Worker            return x.view_as(x)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x):
179*da0073e9SAndroid Build Coastguard Worker            return x.view_as(x)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
182*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
183*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_meta, "Meta")
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(3.14159 / 3, requires_grad=True)
186*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
187*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError,
188*da0073e9SAndroid Build Coastguard Worker            "Argument x is not defined to alias output but was aliasing",
189*da0073e9SAndroid Build Coastguard Worker        ):
190*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def test_missing_abstract_impl(self, device):
193*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
194*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
195*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
198*da0073e9SAndroid Build Coastguard Worker            @staticmethod
199*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
200*da0073e9SAndroid Build Coastguard Worker                with torch._C._AutoDispatchBelowAutograd():
201*da0073e9SAndroid Build Coastguard Worker                    return op(x)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker            @staticmethod
204*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
205*da0073e9SAndroid Build Coastguard Worker                return 2 * gx
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
208*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
211*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
212*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CUDA")
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1.0], requires_grad=True)
215*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
216*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError,
217*da0073e9SAndroid Build Coastguard Worker            "_test_custom_op.foo.default",
218*da0073e9SAndroid Build Coastguard Worker        ):
219*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
222*da0073e9SAndroid Build Coastguard Worker    def test_incorrect_abstract_impl(self, device):
223*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
224*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
225*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
228*da0073e9SAndroid Build Coastguard Worker            @staticmethod
229*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
230*da0073e9SAndroid Build Coastguard Worker                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
231*da0073e9SAndroid Build Coastguard Worker                guard = torch._C._AutoDispatchBelowAutograd()
232*da0073e9SAndroid Build Coastguard Worker                guard2 = torch._C.ExcludeDispatchKeyGuard(
233*da0073e9SAndroid Build Coastguard Worker                    torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
234*da0073e9SAndroid Build Coastguard Worker                )
235*da0073e9SAndroid Build Coastguard Worker                try:
236*da0073e9SAndroid Build Coastguard Worker                    return op(x)
237*da0073e9SAndroid Build Coastguard Worker                finally:
238*da0073e9SAndroid Build Coastguard Worker                    del guard
239*da0073e9SAndroid Build Coastguard Worker                    del guard2
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker            @staticmethod
242*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
243*da0073e9SAndroid Build Coastguard Worker                return gx
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
246*da0073e9SAndroid Build Coastguard Worker            return x**2
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x):
249*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(1) ** 2
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
252*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
253*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CUDA")
254*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_meta, "Meta")
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1.0], requires_grad=True)
257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
258*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def test_missing_functionalization(self, device):
261*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
262*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
263*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
266*da0073e9SAndroid Build Coastguard Worker            @staticmethod
267*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
268*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(x)
269*da0073e9SAndroid Build Coastguard Worker                with torch._C._AutoDispatchBelowAutograd():
270*da0073e9SAndroid Build Coastguard Worker                    return op(x)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker            @staticmethod
273*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
274*da0073e9SAndroid Build Coastguard Worker                return gx
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
277*da0073e9SAndroid Build Coastguard Worker            return x.sin_()
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x):
280*da0073e9SAndroid Build Coastguard Worker            return x
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
283*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
284*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CUDA")
285*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_meta, "Meta")
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1.0])
288*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
289*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
290*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError,
291*da0073e9SAndroid Build Coastguard Worker            "We only support functionalizing operators whose outputs do not have alias annotations",
292*da0073e9SAndroid Build Coastguard Worker        ):
293*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (y,), {})
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    def test_autograd_registered_at_backend(self, device):
296*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
297*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
298*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
301*da0073e9SAndroid Build Coastguard Worker            @staticmethod
302*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
303*da0073e9SAndroid Build Coastguard Worker                return x.clone()
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker            @staticmethod
306*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
307*da0073e9SAndroid Build Coastguard Worker                return gx * 0.5
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "CPU")
310*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "CUDA")
311*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", lambda x: x.clone(), "Meta")
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
316*da0073e9SAndroid Build Coastguard Worker            torch.testing._internal.optests.OpCheckError,
317*da0073e9SAndroid Build Coastguard Worker            "does not have an autograd kernel",
318*da0073e9SAndroid Build Coastguard Worker        ):
319*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        # I'm not sure why this is necessary
322*da0073e9SAndroid Build Coastguard Worker        del lib
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def test_global_state_mutation(self, device):
325*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
326*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
327*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
330*da0073e9SAndroid Build Coastguard Worker            invoked = 0
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker            @staticmethod
333*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
334*da0073e9SAndroid Build Coastguard Worker                Foo.invoked += 1
335*da0073e9SAndroid Build Coastguard Worker                return x.clone() * Foo.invoked
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker            @staticmethod
338*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
339*da0073e9SAndroid Build Coastguard Worker                return gx
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(3.14159 / 3, requires_grad=True)
344*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
345*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
346*da0073e9SAndroid Build Coastguard Worker        ):
347*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,), {})
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
350*da0073e9SAndroid Build Coastguard Worker    def test_opcheck_opinfo(self, device, dtype, op):
351*da0073e9SAndroid Build Coastguard Worker        for sample_input in op.sample_inputs(
352*da0073e9SAndroid Build Coastguard Worker            device, dtype, requires_grad=op.supports_autograd
353*da0073e9SAndroid Build Coastguard Worker        ):
354*da0073e9SAndroid Build Coastguard Worker            args = [sample_input.input] + list(sample_input.args)
355*da0073e9SAndroid Build Coastguard Worker            kwargs = sample_input.kwargs
356*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op.op, args, kwargs)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    def test_opcheck_fails_basic(self, device):
359*da0073e9SAndroid Build Coastguard Worker        @custom_op(f"{self.test_ns}::foo")
360*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor: ...
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        @foo.impl(["cpu", "cuda"])
363*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
364*da0073e9SAndroid Build Coastguard Worker            return x.sum()
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device=device, requires_grad=True)
367*da0073e9SAndroid Build Coastguard Worker        # Triggers the CustomOp autograd NYI error
368*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
369*da0073e9SAndroid Build Coastguard Worker            optests.OpCheckError, "Autograd has not been implemented for operator"
370*da0073e9SAndroid Build Coastguard Worker        ):
371*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_autograd_registration_check_autograd_kernel(self, device):
374*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
375*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
376*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
379*da0073e9SAndroid Build Coastguard Worker            @staticmethod
380*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
381*da0073e9SAndroid Build Coastguard Worker                with torch._C._AutoDispatchBelowAutograd():
382*da0073e9SAndroid Build Coastguard Worker                    return op(x)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker            @staticmethod
385*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
386*da0073e9SAndroid Build Coastguard Worker                return gx
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
389*da0073e9SAndroid Build Coastguard Worker            return x.sin()
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "Autograd")
392*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CPU")
393*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CUDA")
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True, device=device)
396*da0073e9SAndroid Build Coastguard Worker        # Should not raise
397*da0073e9SAndroid Build Coastguard Worker        optests.autograd_registration_check(op, (x,), {})
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker    def test_autograd_registration_check_compositeimplicitautograd(self, device):
400*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
401*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
402*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
405*da0073e9SAndroid Build Coastguard Worker            return x.sin().cos()
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True, device=device)
410*da0073e9SAndroid Build Coastguard Worker        # Should not raise
411*da0073e9SAndroid Build Coastguard Worker        optests.autograd_registration_check(op, (x,), {})
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    def test_autograd_registration_check_incorrect_composite(self, device):
414*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
415*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
416*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
419*da0073e9SAndroid Build Coastguard Worker            return x.sin().cos()
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True, device=device)
424*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
425*da0073e9SAndroid Build Coastguard Worker            optests.autograd_registration_check(op, (x,), {})
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    def test_autograd_registration_check_incorrect(self, device):
428*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
429*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
430*da0073e9SAndroid Build Coastguard Worker        op = self.ns().foo.default
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
433*da0073e9SAndroid Build Coastguard Worker            @staticmethod
434*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
435*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            @staticmethod
438*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
439*da0073e9SAndroid Build Coastguard Worker                return gx
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "CPU")
442*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", Foo.apply, "CUDA")
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True, device=device)
445*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
446*da0073e9SAndroid Build Coastguard Worker            optests.autograd_registration_check(op, (x,), {})
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker    def test_assert_raises_regex(self, device):
449*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.optests.aot_autograd import assert_raises_regex
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        with assert_raises_regex(RuntimeError, "c"):
452*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("abcd")
453*da0073e9SAndroid Build Coastguard Worker        with assert_raises_regex(RuntimeError, "c.*"):
454*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("abcd")
455*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "instead got"):
456*da0073e9SAndroid Build Coastguard Worker            with assert_raises_regex(RuntimeError, "c.*"):
457*da0073e9SAndroid Build Coastguard Worker                raise ValueError("abcd")
458*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "Expected exception"):
459*da0073e9SAndroid Build Coastguard Worker            with assert_raises_regex(RuntimeError, "c.*"):
460*da0073e9SAndroid Build Coastguard Worker                pass
461*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "to match regex"):
462*da0073e9SAndroid Build Coastguard Worker            with assert_raises_regex(RuntimeError, "f"):
463*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("abcd")
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Workerclass TestCustomOp(CustomOpTestCaseBase):
467*da0073e9SAndroid Build Coastguard Worker    test_ns = "_test_custom_op"
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker    @requires_compile
470*da0073e9SAndroid Build Coastguard Worker    def test_functionalize_error(self):
471*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:
472*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            def foo(x):
475*da0073e9SAndroid Build Coastguard Worker                return x.sin_()
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo, "CompositeExplicitAutograd")
478*da0073e9SAndroid Build Coastguard Worker            foo_op = self.get_op(f"{self.test_ns}::foo")
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            lib.define("bar(Tensor(a) x) -> Tensor(a)")
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker            def bar(x):
483*da0073e9SAndroid Build Coastguard Worker                return x.view(-1)
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker            lib.impl("bar", bar, "CompositeExplicitAutograd")
486*da0073e9SAndroid Build Coastguard Worker            bar_op = self.get_op(f"{self.test_ns}::bar")
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker            msg = r".*We only support functionalizing operators whose outputs do not have alias annotations"
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="aot_eager", fullgraph=True)
493*da0073e9SAndroid Build Coastguard Worker            def f(x):
494*da0073e9SAndroid Build Coastguard Worker                return foo_op(x)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="aot_eager", fullgraph=True)
497*da0073e9SAndroid Build Coastguard Worker            def g(x):
498*da0073e9SAndroid Build Coastguard Worker                return bar_op(x)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
501*da0073e9SAndroid Build Coastguard Worker                f(x)
502*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
503*da0073e9SAndroid Build Coastguard Worker                g(x)
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    def test_invalid_schemas(self):
506*da0073e9SAndroid Build Coastguard Worker        # function schmea validation goes through torchgen, so this is just a
507*da0073e9SAndroid Build Coastguard Worker        # basic test.
508*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
509*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def test_invalid_qualname(self):
512*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "overload"):
513*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def test_name_must_match(self):
516*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "to have name"):
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
519*da0073e9SAndroid Build Coastguard Worker            def baz(x: Tensor) -> Tensor:
520*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_schemas(self):
523*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "only supports functional"):
524*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(
525*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
526*da0073e9SAndroid Build Coastguard Worker            )(foo)
527*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "only supports functional"):
528*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(
529*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
530*da0073e9SAndroid Build Coastguard Worker            )(foo)
531*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "only supports functional"):
532*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
533*da0073e9SAndroid Build Coastguard Worker                foo
534*da0073e9SAndroid Build Coastguard Worker            )
535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "self"):
536*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
537*da0073e9SAndroid Build Coastguard Worker                foo
538*da0073e9SAndroid Build Coastguard Worker            )
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    # Tests for the older custom_op API
541*da0073e9SAndroid Build Coastguard Worker    def test_schema_matches_signature(self):
542*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "signature to match"):
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker            @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
545*da0073e9SAndroid Build Coastguard Worker            def blah(x):
546*da0073e9SAndroid Build Coastguard Worker                pass
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "signature to match"):
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker            @custom_op(
551*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
552*da0073e9SAndroid Build Coastguard Worker            )
553*da0073e9SAndroid Build Coastguard Worker            def blah2(x, y):
554*da0073e9SAndroid Build Coastguard Worker                pass
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "signature to match"):
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker            @custom_op(
559*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah3",
560*da0073e9SAndroid Build Coastguard Worker                "(Tensor x, *, Tensor w, Tensor z) -> Tensor",
561*da0073e9SAndroid Build Coastguard Worker            )
562*da0073e9SAndroid Build Coastguard Worker            def blah3(x, *, y, z):
563*da0073e9SAndroid Build Coastguard Worker                pass
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "signature to match"):
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker            @custom_op(
568*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah4",
569*da0073e9SAndroid Build Coastguard Worker                "(Tensor x, *, Tensor z, Tensor y) -> Tensor",
570*da0073e9SAndroid Build Coastguard Worker            )
571*da0073e9SAndroid Build Coastguard Worker            def blah4(x, *, y, z):
572*da0073e9SAndroid Build Coastguard Worker                pass
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "not supported"):
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker            @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
577*da0073e9SAndroid Build Coastguard Worker            def blah5(*args):
578*da0073e9SAndroid Build Coastguard Worker                pass
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "not supported"):
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker            @custom_op(
583*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
584*da0073e9SAndroid Build Coastguard Worker            )
585*da0073e9SAndroid Build Coastguard Worker            def blah6(**kwargs):
586*da0073e9SAndroid Build Coastguard Worker                pass
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "default arguments"):
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker            @custom_op(
591*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
592*da0073e9SAndroid Build Coastguard Worker            )
593*da0073e9SAndroid Build Coastguard Worker            def blah7(x=1, *, y):
594*da0073e9SAndroid Build Coastguard Worker                pass
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "default arguments"):
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker            @custom_op(
599*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
600*da0073e9SAndroid Build Coastguard Worker            )
601*da0073e9SAndroid Build Coastguard Worker            def blah8(x, *, y=1):
602*da0073e9SAndroid Build Coastguard Worker                pass
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        # kwonly-arg works
605*da0073e9SAndroid Build Coastguard Worker        @custom_op(
606*da0073e9SAndroid Build Coastguard Worker            f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
607*da0073e9SAndroid Build Coastguard Worker        )
608*da0073e9SAndroid Build Coastguard Worker        def blah9(x, *, y):
609*da0073e9SAndroid Build Coastguard Worker            pass
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    def test_infer_schema_no_return(self):
612*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
613*da0073e9SAndroid Build Coastguard Worker            ValueError, "No return type annotation was provided. Please add one."
614*da0073e9SAndroid Build Coastguard Worker        ):
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("mylib::foo", mutates_args={})
617*da0073e9SAndroid Build Coastguard Worker            def foo(x: torch.Tensor, y: int):
618*da0073e9SAndroid Build Coastguard Worker                return x * y
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    def test_infer_schema_supported(self):
621*da0073e9SAndroid Build Coastguard Worker        def a(x: Tensor) -> Tensor:
622*da0073e9SAndroid Build Coastguard Worker            return torch.empty([])
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
625*da0073e9SAndroid Build Coastguard Worker            infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
626*da0073e9SAndroid Build Coastguard Worker        )
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
629*da0073e9SAndroid Build Coastguard Worker            return torch.empty([])
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
632*da0073e9SAndroid Build Coastguard Worker            infer_schema(kwonly1, mutates_args=()),
633*da0073e9SAndroid Build Coastguard Worker            """(Tensor x, *, SymInt y, float z) -> Tensor""",
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        def kwonly2(*, y: Tensor) -> Tensor:
637*da0073e9SAndroid Build Coastguard Worker            return torch.empty([])
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
640*da0073e9SAndroid Build Coastguard Worker            infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
641*da0073e9SAndroid Build Coastguard Worker        )
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        def b(
644*da0073e9SAndroid Build Coastguard Worker            x: Tensor,
645*da0073e9SAndroid Build Coastguard Worker            y: int,
646*da0073e9SAndroid Build Coastguard Worker            z: bool,
647*da0073e9SAndroid Build Coastguard Worker            a: float,
648*da0073e9SAndroid Build Coastguard Worker            b: torch.dtype,
649*da0073e9SAndroid Build Coastguard Worker            c: torch.device,
650*da0073e9SAndroid Build Coastguard Worker            d: torch.types.Number,
651*da0073e9SAndroid Build Coastguard Worker        ) -> Tuple[Tensor, int, float, bool]:
652*da0073e9SAndroid Build Coastguard Worker            return torch.empty([]), 1, 0.1, True
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
655*da0073e9SAndroid Build Coastguard Worker            infer_schema(b, mutates_args=()),
656*da0073e9SAndroid Build Coastguard Worker            """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
657*da0073e9SAndroid Build Coastguard Worker        )
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        def c(
660*da0073e9SAndroid Build Coastguard Worker            x: Tensor,
661*da0073e9SAndroid Build Coastguard Worker            y: Sequence[Tensor],
662*da0073e9SAndroid Build Coastguard Worker            z: Optional[Tensor],
663*da0073e9SAndroid Build Coastguard Worker            w: Sequence[Optional[Tensor]],
664*da0073e9SAndroid Build Coastguard Worker        ) -> List[Tensor]:
665*da0073e9SAndroid Build Coastguard Worker            return [torch.empty([])]
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
668*da0073e9SAndroid Build Coastguard Worker            infer_schema(c, mutates_args=()),
669*da0073e9SAndroid Build Coastguard Worker            """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
670*da0073e9SAndroid Build Coastguard Worker        )
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        def d(x: Tensor) -> Tuple[List[Tensor], Tensor]:
673*da0073e9SAndroid Build Coastguard Worker            return [torch.empty([])], torch.empty([])
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
676*da0073e9SAndroid Build Coastguard Worker            infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
677*da0073e9SAndroid Build Coastguard Worker        )
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        def e() -> Tensor:
680*da0073e9SAndroid Build Coastguard Worker            return torch.empty([])
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> None:
685*da0073e9SAndroid Build Coastguard Worker            pass
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
688*da0073e9SAndroid Build Coastguard Worker            infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
689*da0073e9SAndroid Build Coastguard Worker        )
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        def g(
692*da0073e9SAndroid Build Coastguard Worker            x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
693*da0073e9SAndroid Build Coastguard Worker        ) -> None:
694*da0073e9SAndroid Build Coastguard Worker            pass
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
697*da0073e9SAndroid Build Coastguard Worker            infer_schema(g, mutates_args=()),
698*da0073e9SAndroid Build Coastguard Worker            """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
699*da0073e9SAndroid Build Coastguard Worker        )
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
702*da0073e9SAndroid Build Coastguard Worker            infer_schema(g, mutates_args={"x", "w", "z"}),
703*da0073e9SAndroid Build Coastguard Worker            """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
704*da0073e9SAndroid Build Coastguard Worker        )
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
707*da0073e9SAndroid Build Coastguard Worker            infer_schema(g, mutates_args="unknown"),
708*da0073e9SAndroid Build Coastguard Worker            """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
709*da0073e9SAndroid Build Coastguard Worker        )
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker        def h(
712*da0073e9SAndroid Build Coastguard Worker            x: Tensor,
713*da0073e9SAndroid Build Coastguard Worker            a: Optional[int] = None,
714*da0073e9SAndroid Build Coastguard Worker            b: float = 3.14,
715*da0073e9SAndroid Build Coastguard Worker            c: bool = True,
716*da0073e9SAndroid Build Coastguard Worker            d: int = 3,
717*da0073e9SAndroid Build Coastguard Worker            e: str = "foo",
718*da0073e9SAndroid Build Coastguard Worker            f: torch.dtype = torch.float,
719*da0073e9SAndroid Build Coastguard Worker            g: torch.dtype = torch.float32,
720*da0073e9SAndroid Build Coastguard Worker            h: torch.dtype = torch.int,
721*da0073e9SAndroid Build Coastguard Worker            i: torch.device = torch.device("cpu:0"),
722*da0073e9SAndroid Build Coastguard Worker            j: torch.device = "cpu",
723*da0073e9SAndroid Build Coastguard Worker        ) -> None:
724*da0073e9SAndroid Build Coastguard Worker            pass
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
727*da0073e9SAndroid Build Coastguard Worker            infer_schema(h, mutates_args=()),
728*da0073e9SAndroid Build Coastguard Worker            (
729*da0073e9SAndroid Build Coastguard Worker                """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
730*da0073e9SAndroid Build Coastguard Worker                """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
731*da0073e9SAndroid Build Coastguard Worker            ),
732*da0073e9SAndroid Build Coastguard Worker        )
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x: torch.Tensor) -> torch.Tensor:
735*da0073e9SAndroid Build Coastguard Worker            return x.sin()
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
738*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker    def test_infer_schema_unsupported(self):
741*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "varargs"):
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
744*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker            infer_schema(foo, mutates_args=())
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "varkwargs"):
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker            def foo(**kwargs):
751*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker            infer_schema(foo, mutates_args=())
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "must have a type annotation"):
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            def foo(x):
758*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker            infer_schema(foo, mutates_args=())
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "unsupported"):
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor) -> Tuple[Tensor, ...]:
765*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker            infer_schema(foo, mutates_args=())
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "can be mutated"):
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: int) -> Tensor:
772*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker            infer_schema(foo, mutates_args={"y"})
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker    def _generate_examples(self, typ):
777*da0073e9SAndroid Build Coastguard Worker        if typ is int:
778*da0073e9SAndroid Build Coastguard Worker            return [17]
779*da0073e9SAndroid Build Coastguard Worker        if typ is float:
780*da0073e9SAndroid Build Coastguard Worker            return [3.14]
781*da0073e9SAndroid Build Coastguard Worker        if typ is bool:
782*da0073e9SAndroid Build Coastguard Worker            return [True]
783*da0073e9SAndroid Build Coastguard Worker        if typ is str:
784*da0073e9SAndroid Build Coastguard Worker            return ["foo"]
785*da0073e9SAndroid Build Coastguard Worker        if typ is torch.dtype:
786*da0073e9SAndroid Build Coastguard Worker            return [torch.float32]
787*da0073e9SAndroid Build Coastguard Worker        if typ is torch.device:
788*da0073e9SAndroid Build Coastguard Worker            return [torch.device("cpu")]
789*da0073e9SAndroid Build Coastguard Worker        if typ == torch.types.Number:
790*da0073e9SAndroid Build Coastguard Worker            return [2.718]
791*da0073e9SAndroid Build Coastguard Worker        if typ is torch.Tensor:
792*da0073e9SAndroid Build Coastguard Worker            return [torch.tensor(3)]
793*da0073e9SAndroid Build Coastguard Worker        if typ == Optional[torch.types.Number]:
794*da0073e9SAndroid Build Coastguard Worker            return [None, 2.718]
795*da0073e9SAndroid Build Coastguard Worker        origin = typing.get_origin(typ)
796*da0073e9SAndroid Build Coastguard Worker        if origin is Union:
797*da0073e9SAndroid Build Coastguard Worker            args = typing.get_args(typ)
798*da0073e9SAndroid Build Coastguard Worker            assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
799*da0073e9SAndroid Build Coastguard Worker            elt = args[0] if args[1] is type(None) else args[1]
800*da0073e9SAndroid Build Coastguard Worker            return self._generate_examples(elt) + [None]
801*da0073e9SAndroid Build Coastguard Worker        if origin is list:
802*da0073e9SAndroid Build Coastguard Worker            args = typing.get_args(typ)
803*da0073e9SAndroid Build Coastguard Worker            assert len(args) == 1
804*da0073e9SAndroid Build Coastguard Worker            elt = args[0]
805*da0073e9SAndroid Build Coastguard Worker            return [
806*da0073e9SAndroid Build Coastguard Worker                self._generate_examples(elt),
807*da0073e9SAndroid Build Coastguard Worker                self._generate_examples(elt),
808*da0073e9SAndroid Build Coastguard Worker                self._generate_examples(elt),
809*da0073e9SAndroid Build Coastguard Worker            ]
810*da0073e9SAndroid Build Coastguard Worker        if origin is collections.abc.Sequence:
811*da0073e9SAndroid Build Coastguard Worker            args = typing.get_args(typ)
812*da0073e9SAndroid Build Coastguard Worker            assert len(args) == 1
813*da0073e9SAndroid Build Coastguard Worker            examples = self._generate_examples(args[0])
814*da0073e9SAndroid Build Coastguard Worker            return list(itertools.product(examples, examples)) + []
815*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
816*da0073e9SAndroid Build Coastguard Worker            f"testrunner cannot generate instanstance of type {typ}"
817*da0073e9SAndroid Build Coastguard Worker        )
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    def test_supported_return_types_single_return(self):
820*da0073e9SAndroid Build Coastguard Worker        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
821*da0073e9SAndroid Build Coastguard Worker            for example in self._generate_examples(typ):
822*da0073e9SAndroid Build Coastguard Worker                try:
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker                    @custom_ops.custom_op(f"{self.test_ns}::foo")
825*da0073e9SAndroid Build Coastguard Worker                    def foo(x: Tensor) -> typ:
826*da0073e9SAndroid Build Coastguard Worker                        raise NotImplementedError
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker                    @custom_ops.impl(f"{self.test_ns}::foo")
829*da0073e9SAndroid Build Coastguard Worker                    def foo_impl(x: Tensor) -> typ:
830*da0073e9SAndroid Build Coastguard Worker                        return example
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker                    op = self.get_op(f"{self.test_ns}::foo")
833*da0073e9SAndroid Build Coastguard Worker                    result = op(torch.randn([]))
834*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result, example, msg=f"{typ} {example}")
835*da0073e9SAndroid Build Coastguard Worker                finally:
836*da0073e9SAndroid Build Coastguard Worker                    custom_ops._destroy(f"{self.test_ns}::foo")
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker    def test_supported_return_types_multi_return(self):
839*da0073e9SAndroid Build Coastguard Worker        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
840*da0073e9SAndroid Build Coastguard Worker            for example in self._generate_examples(typ):
841*da0073e9SAndroid Build Coastguard Worker                try:
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker                    @custom_ops.custom_op(f"{self.test_ns}::foo")
844*da0073e9SAndroid Build Coastguard Worker                    def foo(x: Tensor) -> Tuple[typ, typ]:
845*da0073e9SAndroid Build Coastguard Worker                        raise NotImplementedError
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker                    @custom_ops.impl(f"{self.test_ns}::foo")
848*da0073e9SAndroid Build Coastguard Worker                    def foo_impl(x: Tensor) -> Tuple[typ, typ]:
849*da0073e9SAndroid Build Coastguard Worker                        return (example, example)
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker                    op = self.get_op(f"{self.test_ns}::foo")
852*da0073e9SAndroid Build Coastguard Worker                    result = op(torch.randn([]))
853*da0073e9SAndroid Build Coastguard Worker                    expected = (example, example)
854*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result, expected, msg=f"{typ} {example}")
855*da0073e9SAndroid Build Coastguard Worker                finally:
856*da0073e9SAndroid Build Coastguard Worker                    custom_ops._destroy(f"{self.test_ns}::foo")
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker    def test_supported_param_types(self):
859*da0073e9SAndroid Build Coastguard Worker        for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES:
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: typ) -> Tensor:
863*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker            yeet = None
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker            @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
868*da0073e9SAndroid Build Coastguard Worker            def foo_cpu(x, y):
869*da0073e9SAndroid Build Coastguard Worker                nonlocal yeet
870*da0073e9SAndroid Build Coastguard Worker                yeet = y
871*da0073e9SAndroid Build Coastguard Worker                return x.clone()
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker            try:
874*da0073e9SAndroid Build Coastguard Worker                for example in self._generate_examples(typ):
875*da0073e9SAndroid Build Coastguard Worker                    op = self.get_op(f"{self.test_ns}::foo")
876*da0073e9SAndroid Build Coastguard Worker                    op(torch.randn([]), example)
877*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(yeet, example, msg=f"{typ} {example}")
878*da0073e9SAndroid Build Coastguard Worker                    yeet = None
879*da0073e9SAndroid Build Coastguard Worker            finally:
880*da0073e9SAndroid Build Coastguard Worker                custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    def test_sequences(self):
883*da0073e9SAndroid Build Coastguard Worker        # Sequence[int] gets automagically turned into int[] in the schema.
884*da0073e9SAndroid Build Coastguard Worker        # This test checks that we actually do support arbitrary sequence types.
885*da0073e9SAndroid Build Coastguard Worker        class MySequence(collections.abc.Sequence):
886*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
887*da0073e9SAndroid Build Coastguard Worker                self._container = [1, 2, 3]
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, idx):
890*da0073e9SAndroid Build Coastguard Worker                return self._container[idx]
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
893*da0073e9SAndroid Build Coastguard Worker                return len(self._container)
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{self.test_ns}::foo")
896*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
897*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker        called = 0
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
902*da0073e9SAndroid Build Coastguard Worker        def foo_cpu(x, sizes):
903*da0073e9SAndroid Build Coastguard Worker            nonlocal called
904*da0073e9SAndroid Build Coastguard Worker            called += 1
905*da0073e9SAndroid Build Coastguard Worker            # Dispatcher will normalize the sequence type into a List
906*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sizes, [1, 2, 3])
907*da0073e9SAndroid Build Coastguard Worker            return x.clone()
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
910*da0073e9SAndroid Build Coastguard Worker        seq = MySequence()
911*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
912*da0073e9SAndroid Build Coastguard Worker        op(x, seq)
913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_param_types(self):
916*da0073e9SAndroid Build Coastguard Worker        # Not comprehensive (it doesn't need to be), just a check that our mechanism works
917*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "unsupported type"):
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
920*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
921*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker            del foo
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "unsupported type"):
926*da0073e9SAndroid Build Coastguard Worker            # int[N] in Dispatcher is a bit wild, so we don't try to support it.
927*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
928*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
929*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker            del foo
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"):
934*da0073e9SAndroid Build Coastguard Worker            # test that we propose a correct and supported type.
935*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
936*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
937*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker            del foo
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError) as cm:
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
944*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: Tuple[int, float]) -> Tensor:
945*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker            del foo
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("example", str(cm.exception), "")
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "unsupported type"):
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
954*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, y: Callable) -> Tensor:
955*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker            del foo
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker    def test_supported_schemas(self):
960*da0073e9SAndroid Build Coastguard Worker        # All of these should already be tested by PyTorch codegen
961*da0073e9SAndroid Build Coastguard Worker        # (we share the same mechanism), but here's a sanity check.
962*da0073e9SAndroid Build Coastguard Worker        schemas = [
963*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> Tensor",
964*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> Tensor y",
965*da0073e9SAndroid Build Coastguard Worker            "(Tensor[] x) -> Tensor y",
966*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> (Tensor, Tensor)",
967*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> (Tensor y, Tensor z)",
968*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> (Tensor y, Tensor z)",
969*da0073e9SAndroid Build Coastguard Worker        ]
970*da0073e9SAndroid Build Coastguard Worker        other_schemas = [
971*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
972*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, Tensor w) -> (Tensor, Tensor)",
973*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, Tensor w) -> Tensor",
974*da0073e9SAndroid Build Coastguard Worker            "(Tensor? x, Tensor w) -> Tensor",
975*da0073e9SAndroid Build Coastguard Worker            "(Tensor? x, Tensor[] w) -> Tensor",
976*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, int[] w) -> Tensor",
977*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, SymInt[] w) -> Tensor",
978*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, Scalar w) -> Tensor",
979*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, float w) -> Tensor",
980*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, float? w) -> Tensor",
981*da0073e9SAndroid Build Coastguard Worker            "(Tensor x, bool[] w) -> Tensor",
982*da0073e9SAndroid Build Coastguard Worker        ]
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker        for schema in schemas:
985*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
986*da0073e9SAndroid Build Coastguard Worker            custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
987*da0073e9SAndroid Build Coastguard Worker        for schema in other_schemas:
988*da0073e9SAndroid Build Coastguard Worker            custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
989*da0073e9SAndroid Build Coastguard Worker            custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    def test_reserved_ns(self):
992*da0073e9SAndroid Build Coastguard Worker        from torch._custom_op.impl import RESERVED_NS
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker        for ns in RESERVED_NS:
995*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
996*da0073e9SAndroid Build Coastguard Worker                custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker                @custom_ops.custom_op(f"{ns}::foo2")
1001*da0073e9SAndroid Build Coastguard Worker                def foo2(x: torch.Tensor) -> torch.Tensor:
1002*da0073e9SAndroid Build Coastguard Worker                    raise NotImplementedError
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker    def test_private_ctor(self):
1005*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
1006*da0073e9SAndroid Build Coastguard Worker            CustomOp(None, None, None, None, None)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker    def test_lifetime(self):
1009*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1010*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1011*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker        custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker        # We can't define an op multiple times,
1016*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "multiple times"):
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1019*da0073e9SAndroid Build Coastguard Worker            def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1020*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        # Unless we delete the original op.
1023*da0073e9SAndroid Build Coastguard Worker        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        # Smoke test
1026*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1027*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1028*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    def test_autograd_notimplemented(self):
1033*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1034*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1035*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1038*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1039*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1040*da0073e9SAndroid Build Coastguard Worker            op(x)
1041*da0073e9SAndroid Build Coastguard Worker        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1042*da0073e9SAndroid Build Coastguard Worker        del foo
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1045*da0073e9SAndroid Build Coastguard Worker        def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
1046*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1049*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
1050*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1051*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1052*da0073e9SAndroid Build Coastguard Worker            op([y, x])
1053*da0073e9SAndroid Build Coastguard Worker        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1054*da0073e9SAndroid Build Coastguard Worker        del foo
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1057*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1058*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1061*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
1062*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1063*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1064*da0073e9SAndroid Build Coastguard Worker            op(y, x)
1065*da0073e9SAndroid Build Coastguard Worker        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1066*da0073e9SAndroid Build Coastguard Worker
1067*da0073e9SAndroid Build Coastguard Worker    def test_autograd_notimplemented_gradmode(self):
1068*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1069*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1070*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1073*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x, y):
1074*da0073e9SAndroid Build Coastguard Worker            return x * y
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1077*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
1078*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1079*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1080*da0073e9SAndroid Build Coastguard Worker            # Shouldn't raise, because we are in no_grad
1081*da0073e9SAndroid Build Coastguard Worker            op(y, x)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker    def test_impl_cpu(self):
1084*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1085*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1086*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1089*da0073e9SAndroid Build Coastguard Worker        def foo_cpu(x):
1090*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1093*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1094*da0073e9SAndroid Build Coastguard Worker        result = op(x)
1095*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, foo_cpu(x))
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker    def test_impl_invalid_devices(self):
1098*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1099*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1100*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1103*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker        from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker        for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
1108*da0073e9SAndroid Build Coastguard Worker            # Smoke test: should not raise error
1109*da0073e9SAndroid Build Coastguard Worker            custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
1110*da0073e9SAndroid Build Coastguard Worker                foo_impl
1111*da0073e9SAndroid Build Coastguard Worker            )
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker        # Not supported by this API: we can either support them in the future
1114*da0073e9SAndroid Build Coastguard Worker        # or provide some other CustomOp.def_* function. This depends on how
1115*da0073e9SAndroid Build Coastguard Worker        # common the use cases are.
1116*da0073e9SAndroid Build Coastguard Worker        for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
1117*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "we only support device_type"):
1118*da0073e9SAndroid Build Coastguard Worker                custom_ops.impl(
1119*da0073e9SAndroid Build Coastguard Worker                    f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
1120*da0073e9SAndroid Build Coastguard Worker                )(foo_impl)
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker    def test_backward_partially_registered(self):
1123*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1124*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1125*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1128*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1129*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1132*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1133*da0073e9SAndroid Build Coastguard Worker            return grad * saved.cos()
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1136*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1137*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1138*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "unable to find a 'save_for_backward'"
1139*da0073e9SAndroid Build Coastguard Worker        ):
1140*da0073e9SAndroid Build Coastguard Worker            y = op(x)
1141*da0073e9SAndroid Build Coastguard Worker            y.backward()
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker    def test_save_for_backward_inputs_are_namedtuple(self):
1144*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1145*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1146*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1149*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1150*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        hit = 0
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1155*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1156*da0073e9SAndroid Build Coastguard Worker            nonlocal hit
1157*da0073e9SAndroid Build Coastguard Worker            hit += 1
1158*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(inputs, tuple))
1159*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(inputs._asdict().keys()), ["x"])
1160*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1163*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1164*da0073e9SAndroid Build Coastguard Worker            return {"x": grad * saved.cos()}
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1167*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1168*da0073e9SAndroid Build Coastguard Worker        y = op(x)
1169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hit, 1)
1170*da0073e9SAndroid Build Coastguard Worker        y.backward()
1171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hit, 1)
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker    def test_backward_returns_dict(self):
1174*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1175*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1176*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1179*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1180*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1183*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1184*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1187*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1188*da0073e9SAndroid Build Coastguard Worker            return grad * saved.cos()
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1191*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1192*da0073e9SAndroid Build Coastguard Worker        y = op(x)
1193*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1194*da0073e9SAndroid Build Coastguard Worker            y.backward()
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker    def test_backward_dict_invalid_keys(self):
1197*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1198*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1199*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1202*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1203*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1206*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1207*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1210*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1211*da0073e9SAndroid Build Coastguard Worker            return {"x": grad * saved.cos(), "y": None}
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1214*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1215*da0073e9SAndroid Build Coastguard Worker        y = op(x)
1216*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1217*da0073e9SAndroid Build Coastguard Worker            y.backward()
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker    def test_backward_dict_grad_for_nontensor(self):
1220*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1221*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1222*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1225*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x, dim):
1226*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1229*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1230*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1233*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1234*da0073e9SAndroid Build Coastguard Worker            return {"x": grad * saved.cos(), "dim": None}
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1237*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1238*da0073e9SAndroid Build Coastguard Worker        y = op(x, 32)
1239*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1240*da0073e9SAndroid Build Coastguard Worker            y.backward()
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    def test_backward_dict_requires_keys_for_input_tensors(self):
1243*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1244*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1245*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1248*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x, y):
1249*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1252*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1253*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1256*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1257*da0073e9SAndroid Build Coastguard Worker            return {"x": grad * saved.cos()}
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1260*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1261*da0073e9SAndroid Build Coastguard Worker        y = op(x, x)
1262*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1263*da0073e9SAndroid Build Coastguard Worker            y.backward()
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker    def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1266*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1267*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1268*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1271*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x, y):
1272*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1275*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1276*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1279*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1280*da0073e9SAndroid Build Coastguard Worker            return {"x": grad * saved.cos()}
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1283*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1284*da0073e9SAndroid Build Coastguard Worker        y = op(x, None)
1285*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1286*da0073e9SAndroid Build Coastguard Worker            y.backward()
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker    def test_backward_grads_are_tensor_or_none(self):
1289*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1290*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1291*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1294*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1295*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1298*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1299*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1302*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1303*da0073e9SAndroid Build Coastguard Worker            return {"x": (grad * saved.cos(),)}
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1306*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1307*da0073e9SAndroid Build Coastguard Worker        y = op(x)
1308*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1309*da0073e9SAndroid Build Coastguard Worker            y.backward()
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker    def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1312*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1313*da0073e9SAndroid Build Coastguard Worker        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1314*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1317*da0073e9SAndroid Build Coastguard Worker        def foo_impl(xs):
1318*da0073e9SAndroid Build Coastguard Worker            return xs[0].sin()
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1321*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1322*da0073e9SAndroid Build Coastguard Worker            return inputs.xs[0]
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1325*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1326*da0073e9SAndroid Build Coastguard Worker            return {"xs": [grad * saved.cos(), None]}
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1329*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1330*da0073e9SAndroid Build Coastguard Worker        y = op(xs)
1331*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1332*da0073e9SAndroid Build Coastguard Worker            y.backward()
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1335*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1336*da0073e9SAndroid Build Coastguard Worker        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1337*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1340*da0073e9SAndroid Build Coastguard Worker        def foo_impl(xs):
1341*da0073e9SAndroid Build Coastguard Worker            return xs[0].sin()
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1344*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1345*da0073e9SAndroid Build Coastguard Worker            return inputs.xs[0]
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1348*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1349*da0073e9SAndroid Build Coastguard Worker            return {"xs": [grad * saved.cos(), None, (None,)]}
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1352*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1353*da0073e9SAndroid Build Coastguard Worker        y = op(xs)
1354*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1355*da0073e9SAndroid Build Coastguard Worker            y.backward()
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker    def test_backward_tensorlist_input_requires_list_grads(self):
1358*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1359*da0073e9SAndroid Build Coastguard Worker        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1360*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1363*da0073e9SAndroid Build Coastguard Worker        def foo_impl(xs):
1364*da0073e9SAndroid Build Coastguard Worker            return xs[0].sin()
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1367*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1368*da0073e9SAndroid Build Coastguard Worker            return inputs.xs[0]
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1371*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad):
1372*da0073e9SAndroid Build Coastguard Worker            return {"xs": None}
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1375*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1376*da0073e9SAndroid Build Coastguard Worker        y = op(xs)
1377*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1378*da0073e9SAndroid Build Coastguard Worker            y.backward()
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker    def test_backward_output_differentiability_type(self):
1381*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1382*da0073e9SAndroid Build Coastguard Worker        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1383*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker            @custom_ops.impl_backward(
1388*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1389*da0073e9SAndroid Build Coastguard Worker            )
1390*da0073e9SAndroid Build Coastguard Worker            def foo_backward(ctx, saved, grad):
1391*da0073e9SAndroid Build Coastguard Worker                return {"xs": None}
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker    def test_backward_output_differentiability_numel(self):
1394*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1395*da0073e9SAndroid Build Coastguard Worker        def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1396*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker            @custom_ops.impl_backward(
1401*da0073e9SAndroid Build Coastguard Worker                f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1402*da0073e9SAndroid Build Coastguard Worker            )
1403*da0073e9SAndroid Build Coastguard Worker            def foo_backward(ctx, saved, grad):
1404*da0073e9SAndroid Build Coastguard Worker                return {"xs": None}
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    def test_backward_output_differentiability_tensorlist(self):
1407*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{self.test_ns}::foo")
1408*da0073e9SAndroid Build Coastguard Worker        def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1409*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{self.test_ns}::foo")
1412*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1413*da0073e9SAndroid Build Coastguard Worker            return [x.clone(), x.clone()], x.clone()
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1416*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1417*da0073e9SAndroid Build Coastguard Worker            return []
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(
1420*da0073e9SAndroid Build Coastguard Worker            f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1421*da0073e9SAndroid Build Coastguard Worker        )
1422*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad_lst, grad):
1423*da0073e9SAndroid Build Coastguard Worker            return {"x": grad}
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1426*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1427*da0073e9SAndroid Build Coastguard Worker        [a, b], c = op(x)
1428*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.requires_grad)
1429*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(b.requires_grad)
1430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(c.requires_grad)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker    def test_backward_output_differentiability_non_tensor(self):
1433*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{self.test_ns}::foo")
1434*da0073e9SAndroid Build Coastguard Worker        def foo(x: Tensor) -> Tuple[Tensor, int]:
1435*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{self.test_ns}::foo")
1438*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1439*da0073e9SAndroid Build Coastguard Worker            return x.clone(), 3
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1442*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1443*da0073e9SAndroid Build Coastguard Worker            return []
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(
1446*da0073e9SAndroid Build Coastguard Worker            f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1447*da0073e9SAndroid Build Coastguard Worker        )
1448*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad0, grad1):
1449*da0073e9SAndroid Build Coastguard Worker            return {"x": grad0}
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1452*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1453*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1454*da0073e9SAndroid Build Coastguard Worker            op(x)
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1457*da0073e9SAndroid Build Coastguard Worker    def test_impl_separate(self):
1458*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1459*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1460*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1463*da0073e9SAndroid Build Coastguard Worker        def foo_cpu(x):
1464*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1467*da0073e9SAndroid Build Coastguard Worker        def foo_cuda(x):
1468*da0073e9SAndroid Build Coastguard Worker            return x.cos()
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1471*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1472*da0073e9SAndroid Build Coastguard Worker        result = op(x)
1473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, foo_cpu(x))
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        x_cuda = x.cuda()
1476*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1477*da0073e9SAndroid Build Coastguard Worker        result = op(x_cuda)
1478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, foo_cuda(x_cuda))
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1481*da0073e9SAndroid Build Coastguard Worker    def test_impl_multiple(self):
1482*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1483*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1484*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1487*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1488*da0073e9SAndroid Build Coastguard Worker            return x.cos()
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1491*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1492*da0073e9SAndroid Build Coastguard Worker        result = op(x)
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, foo_impl(x))
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker        x_cuda = x.cuda()
1496*da0073e9SAndroid Build Coastguard Worker        result = op(x_cuda)
1497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, foo_impl(x_cuda))
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker    def test_impl_abstract_overload(self):
1500*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1501*da0073e9SAndroid Build Coastguard Worker        lib.define("sin.blah(Tensor x) -> Tensor")
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker        torch.library.impl_abstract(
1504*da0073e9SAndroid Build Coastguard Worker            f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1505*da0073e9SAndroid Build Coastguard Worker        )
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker        op = self.ns().sin.blah
1508*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="meta")
1509*da0073e9SAndroid Build Coastguard Worker        op(x)
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker    def test_impl_meta(self):
1512*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1513*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1514*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1517*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x, dim):
1518*da0073e9SAndroid Build Coastguard Worker            output_shape = list(x.shape)
1519*da0073e9SAndroid Build Coastguard Worker            del output_shape[dim]
1520*da0073e9SAndroid Build Coastguard Worker            return x.new_empty(output_shape)
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device="meta")
1523*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1524*da0073e9SAndroid Build Coastguard Worker        result = op(x, 1)
1525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, foo_meta(x, 1).shape)
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker    def test_duplicate_impl(self):
1528*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1529*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1530*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1533*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x, dim):
1534*da0073e9SAndroid Build Coastguard Worker            output_shape = list(x.shape)
1535*da0073e9SAndroid Build Coastguard Worker            del output_shape[dim]
1536*da0073e9SAndroid Build Coastguard Worker            return x.new_empty(output_shape)
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1541*da0073e9SAndroid Build Coastguard Worker            def foo_meta2(x, dim):
1542*da0073e9SAndroid Build Coastguard Worker                output_shape = list(x.shape)
1543*da0073e9SAndroid Build Coastguard Worker                del output_shape[dim]
1544*da0073e9SAndroid Build Coastguard Worker                return x.new_empty(output_shape)
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker    def test_new_data_dependent_symint(self):
1547*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1548*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1549*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1552*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x):
1553*da0073e9SAndroid Build Coastguard Worker            ctx = torch.library.get_ctx()
1554*da0073e9SAndroid Build Coastguard Worker            r = ctx.new_dynamic_size(min=1)
1555*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1556*da0073e9SAndroid Build Coastguard Worker                ctx.new_dynamic_size(min=-1)
1557*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "SymInt"):
1558*da0073e9SAndroid Build Coastguard Worker                ctx.new_dynamic_size(max=x.numel())
1559*da0073e9SAndroid Build Coastguard Worker            # NB: You must return dynamic sizes!
1560*da0073e9SAndroid Build Coastguard Worker            return x.new_empty(r)
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device="cpu")
1563*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1564*da0073e9SAndroid Build Coastguard Worker        make_fx(op, tracing_mode="symbolic")(x)
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker    def test_meta_for_data_dependent_shape_operation(self):
1567*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, device="meta")
1568*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1569*da0073e9SAndroid Build Coastguard Worker            numpy_nonzero(x)
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker    def test_basic_make_fx(self):
1572*da0073e9SAndroid Build Coastguard Worker        # More serious tests are in our CustomOp opinfo db,
1573*da0073e9SAndroid Build Coastguard Worker        # this one is just a sanity check.
1574*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1575*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1576*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1579*da0073e9SAndroid Build Coastguard Worker        def foo_meta(x):
1580*da0073e9SAndroid Build Coastguard Worker            return x.sum()
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1583*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1584*da0073e9SAndroid Build Coastguard Worker        gm = make_fx(op, tracing_mode="symbolic")(x)
1585*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker    def test_not_implemented_error(self):
1588*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1589*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
1590*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1593*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::foo")
1594*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1595*da0073e9SAndroid Build Coastguard Worker            op(x)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="meta")
1598*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"):
1599*da0073e9SAndroid Build Coastguard Worker            op(x)
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1602*da0073e9SAndroid Build Coastguard Worker        def bar(sizes: Sequence[int]) -> torch.Tensor:
1603*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(f"{self.test_ns}::bar")
1606*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1607*da0073e9SAndroid Build Coastguard Worker            op((1, 2, 3))
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_basic(self):
1610*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
1611*da0073e9SAndroid Build Coastguard Worker        gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
1612*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("nonzero" in gm.code)
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_fake_tracing(self):
1615*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
1616*da0073e9SAndroid Build Coastguard Worker        # We've updated to attempt to use unbacked symints even for fake
1617*da0073e9SAndroid Build Coastguard Worker        # tracing
1618*da0073e9SAndroid Build Coastguard Worker        make_fx(numpy_nonzero, tracing_mode="fake")(x)
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker    def test_symints(self):
1621*da0073e9SAndroid Build Coastguard Worker        def f(x):
1622*da0073e9SAndroid Build Coastguard Worker            return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4)
1625*da0073e9SAndroid Build Coastguard Worker        gm = make_fx(f, tracing_mode="symbolic")(x)
1626*da0073e9SAndroid Build Coastguard Worker        result = gm(x)
1627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, f(x))
1628*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1629*da0073e9SAndroid Build Coastguard Worker            gm.code.strip(),
1630*da0073e9SAndroid Build Coastguard Worker            """\
1631*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1):
1632*da0073e9SAndroid Build Coastguard Worker    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1633*da0073e9SAndroid Build Coastguard Worker    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1634*da0073e9SAndroid Build Coastguard Worker    sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1635*da0073e9SAndroid Build Coastguard Worker    numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1636*da0073e9SAndroid Build Coastguard Worker    return numpy_view_copy""",  # noqa: B950
1637*da0073e9SAndroid Build Coastguard Worker        )
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1640*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_compile(self):
1641*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo.testing
1642*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import counters
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker        counters.clear()
1645*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1646*da0073e9SAndroid Build Coastguard Worker
1647*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt)
1648*da0073e9SAndroid Build Coastguard Worker        def f(x):
1649*da0073e9SAndroid Build Coastguard Worker            return numpy_nonzero(x.clone()).clone()
1650*da0073e9SAndroid Build Coastguard Worker
1651*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(10))
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 1)
1654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(iter(counters["graph_break"].values())), 1)
1655*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1656*da0073e9SAndroid Build Coastguard Worker            next(iter(counters["graph_break"].keys())).replace(";", "\n"),
1657*da0073e9SAndroid Build Coastguard Worker            """\
1658*da0073e9SAndroid Build Coastguard Workerdynamic shape operator: _torch_testing.numpy_nonzero.default
1659*da0073e9SAndroid Build Coastguard Worker to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
1660*da0073e9SAndroid Build Coastguard Worker        )
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker    # pre-existing problem: torch.compile(dynamic=True) will, by default,
1663*da0073e9SAndroid Build Coastguard Worker    # graph break on data-dependent operations. Eventually we'll make it so
1664*da0073e9SAndroid Build Coastguard Worker    # that it never graph breaks on data-dependent operations.
1665*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1666*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_nms_dynamic_compile(self):
1667*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo.testing
1668*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import counters
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker        counters.clear()
1671*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, dynamic=True)
1674*da0073e9SAndroid Build Coastguard Worker        def f(x, s, i):
1675*da0073e9SAndroid Build Coastguard Worker            return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(20, 4), torch.randn(20), 0.1)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    def test_impl_on_existing_op(self):
1682*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1683*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1684*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker        @torch._custom_ops.impl(qualname)
1687*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1688*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1691*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1692*da0073e9SAndroid Build Coastguard Worker        result = op(x)
1693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x.sin())
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1696*da0073e9SAndroid Build Coastguard Worker        "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1697*da0073e9SAndroid Build Coastguard Worker    )
1698*da0073e9SAndroid Build Coastguard Worker    def test_impl_on_existing_op_with_cpu_registration(self, key):
1699*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1700*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1701*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1702*da0073e9SAndroid Build Coastguard Worker
1703*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1704*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, key)
1707*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1710*da0073e9SAndroid Build Coastguard Worker            custom_ops.impl(qualname, func=foo_impl)
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker    def test_abstract_impl_on_existing_op(self):
1713*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1714*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1715*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl_abstract(qualname, lib=self.lib())
1718*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1719*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1722*da0073e9SAndroid Build Coastguard Worker        with torch._subclasses.FakeTensorMode():
1723*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
1724*da0073e9SAndroid Build Coastguard Worker            result = op(x)
1725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, x.shape)
1726*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.stride(), x.stride())
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker    def test_abstract_impl_on_existing_op_with_meta(self):
1729*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1730*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1731*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1734*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1735*da0073e9SAndroid Build Coastguard Worker
1736*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "Meta")
1737*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1740*da0073e9SAndroid Build Coastguard Worker            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1743*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1744*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1745*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1748*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1751*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1752*da0073e9SAndroid Build Coastguard Worker
1753*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1754*da0073e9SAndroid Build Coastguard Worker            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1755*da0073e9SAndroid Build Coastguard Worker
1756*da0073e9SAndroid Build Coastguard Worker    def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1757*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1758*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1759*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1762*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1765*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1768*da0073e9SAndroid Build Coastguard Worker        with torch._subclasses.FakeTensorMode():
1769*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10)
1770*da0073e9SAndroid Build Coastguard Worker            result = op(x)
1771*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ())
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker    def _test_backward_impl_raises(self, qualname, err_regex):
1774*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_regex):
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker            @custom_ops.impl_save_for_backward(qualname)
1777*da0073e9SAndroid Build Coastguard Worker            def foo2(x):
1778*da0073e9SAndroid Build Coastguard Worker                return
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_regex):
1781*da0073e9SAndroid Build Coastguard Worker
1782*da0073e9SAndroid Build Coastguard Worker            @custom_ops.impl_backward(qualname)
1783*da0073e9SAndroid Build Coastguard Worker            def foo3(x):
1784*da0073e9SAndroid Build Coastguard Worker                return
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1787*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1788*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor(a) x) -> Tensor(a)")
1789*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1790*da0073e9SAndroid Build Coastguard Worker        self._test_backward_impl_raises(qualname, "operator that returns views")
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1793*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1794*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor(a!) x) -> Tensor")
1795*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1796*da0073e9SAndroid Build Coastguard Worker        self._test_backward_impl_raises(qualname, "non-functional")
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1799*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1800*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> ()")
1801*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1802*da0073e9SAndroid Build Coastguard Worker        self._test_backward_impl_raises(qualname, "no returns")
1803*da0073e9SAndroid Build Coastguard Worker
1804*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1805*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1806*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1807*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1808*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1809*da0073e9SAndroid Build Coastguard Worker        self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1810*da0073e9SAndroid Build Coastguard Worker
1811*da0073e9SAndroid Build Coastguard Worker    @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1812*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op_with_key(self, key):
1813*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1814*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1815*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1816*da0073e9SAndroid Build Coastguard Worker        lib.impl("foo", lambda x: x.sin().cos(), key)
1817*da0073e9SAndroid Build Coastguard Worker        self._test_backward_impl_raises(qualname, key)
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    def test_is_functional_schema(self):
1820*da0073e9SAndroid Build Coastguard Worker        tests = {
1821*da0073e9SAndroid Build Coastguard Worker            "foo(Tensor x) -> Tensor": True,
1822*da0073e9SAndroid Build Coastguard Worker            "foo(Tensor(a) x) -> Tensor": True,
1823*da0073e9SAndroid Build Coastguard Worker            "foo(Tensor(a!) x) -> Tensor": False,
1824*da0073e9SAndroid Build Coastguard Worker            "foo(Tensor(a) x) -> Tensor(a)": False,
1825*da0073e9SAndroid Build Coastguard Worker            "foo(Tensor x) -> ()": False,
1826*da0073e9SAndroid Build Coastguard Worker        }
1827*da0073e9SAndroid Build Coastguard Worker        for schema_str, expected in tests.items():
1828*da0073e9SAndroid Build Coastguard Worker            res = torch._library.utils.is_functional_schema(schema_str)
1829*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker            from torchgen.model import FunctionSchema
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker            schema = FunctionSchema.parse(schema_str)
1834*da0073e9SAndroid Build Coastguard Worker            res = torch._library.utils.is_functional_schema(schema)
1835*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker            schema = torch._C.parse_schema(schema_str)
1838*da0073e9SAndroid Build Coastguard Worker            res = torch._library.utils.is_functional_schema(schema)
1839*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker    def test_incorrect_schema_types(self):
1842*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1843*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1844*da0073e9SAndroid Build Coastguard Worker                lib.define("foo12(Tensor a) -> asdfasdf")
1845*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1846*da0073e9SAndroid Build Coastguard Worker                lib.define("foo12(asdf a) -> Tensor")
1847*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
1848*da0073e9SAndroid Build Coastguard Worker                lib.define("foo12(int64_t a) -> Tensor")
1849*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Use `float`"):
1850*da0073e9SAndroid Build Coastguard Worker                lib.define("foo12(double a) -> Tensor")
1851*da0073e9SAndroid Build Coastguard Worker
1852*da0073e9SAndroid Build Coastguard Worker    def test_is_tensorlist_like_type(self):
1853*da0073e9SAndroid Build Coastguard Worker        tensorlists = [
1854*da0073e9SAndroid Build Coastguard Worker            # Tensor[]
1855*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.where.default._schema.returns[0].type,
1856*da0073e9SAndroid Build Coastguard Worker            # Tensor?[]
1857*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.index.Tensor._schema.arguments[1].type,
1858*da0073e9SAndroid Build Coastguard Worker            # Tensor[]?
1859*da0073e9SAndroid Build Coastguard Worker            torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type,
1860*da0073e9SAndroid Build Coastguard Worker            # Tensor?[]?
1861*da0073e9SAndroid Build Coastguard Worker            torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type,
1862*da0073e9SAndroid Build Coastguard Worker        ]
1863*da0073e9SAndroid Build Coastguard Worker        non_tensorlists = [
1864*da0073e9SAndroid Build Coastguard Worker            # Tensor
1865*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sin.default._schema.arguments[0].type,
1866*da0073e9SAndroid Build Coastguard Worker            # IntList
1867*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sum.dim_IntList._schema.arguments[1].type,
1868*da0073e9SAndroid Build Coastguard Worker        ]
1869*da0073e9SAndroid Build Coastguard Worker        for a in tensorlists:
1870*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._library.utils.is_tensorlist_like_type(a))
1871*da0073e9SAndroid Build Coastguard Worker        for a in non_tensorlists:
1872*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._library.utils.is_tensorlist_like_type(a))
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker    def test_backward_impl_on_existing_op(self):
1875*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1876*da0073e9SAndroid Build Coastguard Worker        lib.define("foo(Tensor x) -> Tensor")
1877*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::foo"
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl(qualname)
1880*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x):
1881*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1882*da0073e9SAndroid Build Coastguard Worker                return x.sin()
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_save_for_backward(qualname)
1885*da0073e9SAndroid Build Coastguard Worker        def foo_save_for_backward(inputs, output):
1886*da0073e9SAndroid Build Coastguard Worker            return inputs.x
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker        @custom_ops.impl_backward(qualname)
1889*da0073e9SAndroid Build Coastguard Worker        def foo_backward(ctx, saved, grad_out):
1890*da0073e9SAndroid Build Coastguard Worker            return {"x": grad_out * saved.cos()}
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
1893*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
1894*da0073e9SAndroid Build Coastguard Worker        y = op(x)
1895*da0073e9SAndroid Build Coastguard Worker        (gx,) = torch.autograd.grad(y, x)
1896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx, x.cos())
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1899*da0073e9SAndroid Build Coastguard Worker        "tags",
1900*da0073e9SAndroid Build Coastguard Worker        [
1901*da0073e9SAndroid Build Coastguard Worker            subtest(torch.Tag.pointwise, "single"),
1902*da0073e9SAndroid Build Coastguard Worker            subtest((torch.Tag.pointwise,), "tuple"),
1903*da0073e9SAndroid Build Coastguard Worker            subtest([torch.Tag.pointwise], "list"),
1904*da0073e9SAndroid Build Coastguard Worker        ],
1905*da0073e9SAndroid Build Coastguard Worker    )
1906*da0073e9SAndroid Build Coastguard Worker    def test_define_with_tags(self, tags):
1907*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1908*da0073e9SAndroid Build Coastguard Worker        tags = (torch.Tag.pointwise,)
1909*da0073e9SAndroid Build Coastguard Worker        torch.library.define(
1910*da0073e9SAndroid Build Coastguard Worker            f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1911*da0073e9SAndroid Build Coastguard Worker        )
1912*da0073e9SAndroid Build Coastguard Worker        actual = self.ns().foo.default.tags
1913*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(actual, list))
1914*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, list(tags))
1915*da0073e9SAndroid Build Coastguard Worker
1916*da0073e9SAndroid Build Coastguard Worker    def test_builtin_aten_ops_are_pt2_compliant(self):
1917*da0073e9SAndroid Build Coastguard Worker        for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1918*da0073e9SAndroid Build Coastguard Worker            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard Worker    def test_builtin_torchscript_ops(self):
1921*da0073e9SAndroid Build Coastguard Worker        for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1922*da0073e9SAndroid Build Coastguard Worker            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1923*da0073e9SAndroid Build Coastguard Worker
1924*da0073e9SAndroid Build Coastguard Worker    def test_autogen_aten_ops_are_pt2_compliant(self):
1925*da0073e9SAndroid Build Coastguard Worker        for op in [torch.ops.aten.fill.Tensor_out]:
1926*da0073e9SAndroid Build Coastguard Worker            self.assertIn(torch.Tag.generated, op.tags)
1927*da0073e9SAndroid Build Coastguard Worker            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    def test_resolve_packet(self):
1930*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1931*da0073e9SAndroid Build Coastguard Worker        result = torch._C._jit_resolve_packet("aten::sum", x)
1932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, "default")
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker        result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1935*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, "dim_IntList")
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1938*da0073e9SAndroid Build Coastguard Worker            result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker    def test_define_bad_schema(self):
1941*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1942*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1943*da0073e9SAndroid Build Coastguard Worker            torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1944*da0073e9SAndroid Build Coastguard Worker
1945*da0073e9SAndroid Build Coastguard Worker    def test_define_and_impl(self):
1946*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1947*da0073e9SAndroid Build Coastguard Worker        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1950*da0073e9SAndroid Build Coastguard Worker        def f(x):
1951*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(np.sin(x.numpy()))
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1954*da0073e9SAndroid Build Coastguard Worker        y = self.ns().foo(x)
1955*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    def test_define_validation(self):
1958*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "namespace"):
1959*da0073e9SAndroid Build Coastguard Worker            torch.library.define("foo", "(Tensor x) -> Tensor")
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker    def test_legacy_define(self):
1962*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker        @torch.library.define(lib, "foo(Tensor x) -> Tensor")
1965*da0073e9SAndroid Build Coastguard Worker        def f(x):
1966*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(np.sin(x.numpy()))
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1969*da0073e9SAndroid Build Coastguard Worker        y = self.ns().foo(x)
1970*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker    def test_impl_function(self):
1973*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1974*da0073e9SAndroid Build Coastguard Worker        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker        def f(x):
1977*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(np.sin(x.numpy()))
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1980*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1981*da0073e9SAndroid Build Coastguard Worker        y = self.ns().foo(x)
1982*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker    def test_legacy_impl(self):
1985*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
1986*da0073e9SAndroid Build Coastguard Worker        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl(lib, "foo", "CPU")
1989*da0073e9SAndroid Build Coastguard Worker        def f(x):
1990*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(np.sin(x.numpy()))
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1993*da0073e9SAndroid Build Coastguard Worker        y = self.ns().foo(x)
1994*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker    def test_defined_in_python(self):
1997*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1998*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
2001*da0073e9SAndroid Build Coastguard Worker        torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2002*da0073e9SAndroid Build Coastguard Worker        ns = self.ns()
2003*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ns.foo.default._defined_in_python)
2004*da0073e9SAndroid Build Coastguard Worker
2005*da0073e9SAndroid Build Coastguard Worker        torch.library.define(
2006*da0073e9SAndroid Build Coastguard Worker            "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
2007*da0073e9SAndroid Build Coastguard Worker        )
2008*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ns.bar.overload._defined_in_python)
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker    def _test_impl_device(self, name, types, device):
2011*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
2012*da0073e9SAndroid Build Coastguard Worker        torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
2013*da0073e9SAndroid Build Coastguard Worker
2014*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl(f"{self.test_ns}::{name}", types)
2015*da0073e9SAndroid Build Coastguard Worker        def f(x):
2016*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
2017*da0073e9SAndroid Build Coastguard Worker            y = torch.from_numpy(np.sin(x_np))
2018*da0073e9SAndroid Build Coastguard Worker            return y.to(device=x.device)
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device=device)
2021*da0073e9SAndroid Build Coastguard Worker        y = getattr(self.ns(), name)(x)
2022*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker    def test_impl_device_cpu(self):
2025*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo1", "default", "cpu")
2026*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo2", ["cpu"], "cpu")
2027*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
2030*da0073e9SAndroid Build Coastguard Worker    def test_impl_device_cuda(self):
2031*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo4", "default", "cuda")
2032*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo5", ["cuda"], "cuda")
2033*da0073e9SAndroid Build Coastguard Worker        self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    def test_impl_device_function(self):
2036*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
2037*da0073e9SAndroid Build Coastguard Worker        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2038*da0073e9SAndroid Build Coastguard Worker
2039*da0073e9SAndroid Build Coastguard Worker        def f(x):
2040*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
2041*da0073e9SAndroid Build Coastguard Worker            y = torch.from_numpy(np.sin(x_np))
2042*da0073e9SAndroid Build Coastguard Worker            return y.to(device=x.device)
2043*da0073e9SAndroid Build Coastguard Worker
2044*da0073e9SAndroid Build Coastguard Worker        torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
2045*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2046*da0073e9SAndroid Build Coastguard Worker        y = self.ns().foo(x)
2047*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(y, x.sin())
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker    def test_impl_device_invalid(self):
2050*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
2051*da0073e9SAndroid Build Coastguard Worker            torch.library.impl("blah::blah", "somethingsomething")
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker    def test_autograd_function_backed_op(self):
2054*da0073e9SAndroid Build Coastguard Worker        cpp_source = """
2055*da0073e9SAndroid Build Coastguard Workerstruct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2056*da0073e9SAndroid Build Coastguard Worker  static constexpr bool is_traceable = true;
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker  static torch::Tensor forward(
2059*da0073e9SAndroid Build Coastguard Worker      torch::autograd::AutogradContext* ctx,
2060*da0073e9SAndroid Build Coastguard Worker      const torch::Tensor& x) {
2061*da0073e9SAndroid Build Coastguard Worker    return x;
2062*da0073e9SAndroid Build Coastguard Worker  }
2063*da0073e9SAndroid Build Coastguard Worker
2064*da0073e9SAndroid Build Coastguard Worker  static torch::autograd::variable_list backward(
2065*da0073e9SAndroid Build Coastguard Worker      torch::autograd::AutogradContext *ctx,
2066*da0073e9SAndroid Build Coastguard Worker      torch::autograd::variable_list grad_output) {
2067*da0073e9SAndroid Build Coastguard Worker    return grad_output;
2068*da0073e9SAndroid Build Coastguard Worker  }
2069*da0073e9SAndroid Build Coastguard Worker};
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Workertorch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
2072*da0073e9SAndroid Build Coastguard Worker  return CustomOpAutogradFunction::apply(x);
2073*da0073e9SAndroid Build Coastguard Worker}
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(mylib, m) {
2076*da0073e9SAndroid Build Coastguard Worker    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2077*da0073e9SAndroid Build Coastguard Worker}
2078*da0073e9SAndroid Build Coastguard Worker        """
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker        module = torch.utils.cpp_extension.load_inline(
2081*da0073e9SAndroid Build Coastguard Worker            name="mylib",
2082*da0073e9SAndroid Build Coastguard Worker            cpp_sources=cpp_source,
2083*da0073e9SAndroid Build Coastguard Worker            functions="custom_op_backed_by_autograd_fn",
2084*da0073e9SAndroid Build Coastguard Worker            verbose=True,
2085*da0073e9SAndroid Build Coastguard Worker        )
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2, requires_grad=True)
2088*da0073e9SAndroid Build Coastguard Worker        temp = x.clone().detach()
2089*da0073e9SAndroid Build Coastguard Worker        out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
2090*da0073e9SAndroid Build Coastguard Worker        loss = out.sum()
2091*da0073e9SAndroid Build Coastguard Worker        loss.backward()
2092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, temp)
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker
2095*da0073e9SAndroid Build Coastguard Workerdef op_with_incorrect_schema(testcase, name):
2096*da0073e9SAndroid Build Coastguard Worker    lib = testcase.lib()
2097*da0073e9SAndroid Build Coastguard Worker    lib.define(f"{name}(Tensor x) -> Tensor")
2098*da0073e9SAndroid Build Coastguard Worker    qualname = f"{testcase.test_ns}::{name}"
2099*da0073e9SAndroid Build Coastguard Worker    lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
2100*da0073e9SAndroid Build Coastguard Worker    return testcase.get_op(qualname)
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Workerclass MiniOpTest(CustomOpTestCaseBase):
2104*da0073e9SAndroid Build Coastguard Worker    test_ns = "mini_op_test"
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker    def _init_op_delayed_backward_error(self):
2107*da0073e9SAndroid Build Coastguard Worker        name = "delayed_error"
2108*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::{name}"
2109*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
2110*da0073e9SAndroid Build Coastguard Worker        lib.define(f"{name}(Tensor x) -> Tensor")
2111*da0073e9SAndroid Build Coastguard Worker        lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
2112*da0073e9SAndroid Build Coastguard Worker        op = self.get_op(qualname)
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker        class Op(torch.autograd.Function):
2115*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2116*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
2117*da0073e9SAndroid Build Coastguard Worker                with torch._C._AutoDispatchBelowAutograd():
2118*da0073e9SAndroid Build Coastguard Worker                    return op(x)
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2121*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2122*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        def autograd_impl(x):
2125*da0073e9SAndroid Build Coastguard Worker            return Op.apply(x)
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker        lib.impl(name, autograd_impl, "Autograd")
2128*da0073e9SAndroid Build Coastguard Worker        return op
2129*da0073e9SAndroid Build Coastguard Worker
2130*da0073e9SAndroid Build Coastguard Worker    def _init_op_with_no_abstract_impl(self):
2131*da0073e9SAndroid Build Coastguard Worker        name = "no_abstract"
2132*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.test_ns}::{name}"
2133*da0073e9SAndroid Build Coastguard Worker        lib = self.lib()
2134*da0073e9SAndroid Build Coastguard Worker        lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
2135*da0073e9SAndroid Build Coastguard Worker        lib.impl(name, lambda x: x.clone(), "CPU")
2136*da0073e9SAndroid Build Coastguard Worker        return torch._library.utils.lookup_op(qualname)
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
2139*da0073e9SAndroid Build Coastguard Worker        super().setUp()
2140*da0073e9SAndroid Build Coastguard Worker        self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
2141*da0073e9SAndroid Build Coastguard Worker        self._op_delayed_backward_error = self._init_op_delayed_backward_error()
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker    @optests.dontGenerateOpCheckTests("Testing this API")
2144*da0073e9SAndroid Build Coastguard Worker    def test_dont_generate(self):
2145*da0073e9SAndroid Build Coastguard Worker        op = op_with_incorrect_schema(self, "incorrect_schema")
2146*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2147*da0073e9SAndroid Build Coastguard Worker        op(x)
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker    def test_mm(self):
2150*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True)
2151*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 5)
2152*da0073e9SAndroid Build Coastguard Worker        result = torch.ops.aten.mm.default(x, y)
2153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x @ y)
2154*da0073e9SAndroid Build Coastguard Worker
2155*da0073e9SAndroid Build Coastguard Worker    def test_mm_meta(self):
2156*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device="meta")
2157*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 5, device="meta")
2158*da0073e9SAndroid Build Coastguard Worker        result = torch.ops.aten.mm.default(x, y)
2159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (x @ y).shape)
2160*da0073e9SAndroid Build Coastguard Worker
2161*da0073e9SAndroid Build Coastguard Worker    def test_mm_fake(self):
2162*da0073e9SAndroid Build Coastguard Worker        with torch._subclasses.fake_tensor.FakeTensorMode():
2163*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 3, requires_grad=True, device="cpu")
2164*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, 5, device="cpu")
2165*da0073e9SAndroid Build Coastguard Worker            result = torch.ops.aten.mm.default(x, y)
2166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, (x @ y).shape)
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker    def test_mm_errors(self):
2169*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True)
2170*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 5)
2171*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
2172*da0073e9SAndroid Build Coastguard Worker            result = torch.ops.aten.mm.default(x, y)
2173*da0073e9SAndroid Build Coastguard Worker
2174*da0073e9SAndroid Build Coastguard Worker    def test_nonzero(self):
2175*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1, 2, 0, 0])
2176*da0073e9SAndroid Build Coastguard Worker        y = torch.ops.aten.nonzero.default(x)
2177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, torch.tensor([[1], [2]]))
2178*da0073e9SAndroid Build Coastguard Worker
2179*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self):
2180*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2181*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone()
2182*da0073e9SAndroid Build Coastguard Worker        y = torch.ops.aten.sin_(x)
2183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_clone.sin())
2184*da0073e9SAndroid Build Coastguard Worker
2185*da0073e9SAndroid Build Coastguard Worker    def test_incorrect_schema(self):
2186*da0073e9SAndroid Build Coastguard Worker        op = op_with_incorrect_schema(self, "incorrect_schema")
2187*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2188*da0073e9SAndroid Build Coastguard Worker        op(x)
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker    def test_no_abstract(self):
2191*da0073e9SAndroid Build Coastguard Worker        op = self._op_with_no_abstract_impl
2192*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2193*da0073e9SAndroid Build Coastguard Worker        op(x)
2194*da0073e9SAndroid Build Coastguard Worker
2195*da0073e9SAndroid Build Coastguard Worker    def test_delayed_error(self):
2196*da0073e9SAndroid Build Coastguard Worker        op = self._op_delayed_backward_error
2197*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
2198*da0073e9SAndroid Build Coastguard Worker        y = op(x)
2199*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(NotImplementedError):
2200*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker    def test_delayed_error_no_requires_grad(self):
2203*da0073e9SAndroid Build Coastguard Worker        op = self._op_delayed_backward_error
2204*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
2205*da0073e9SAndroid Build Coastguard Worker        y = op(x)
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Workerclass TestCustomOpAPI(TestCase):
2209*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2210*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
2211*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2212*da0073e9SAndroid Build Coastguard Worker        def add(x: Tensor, y: float) -> Tensor:
2213*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy(force=True)
2214*da0073e9SAndroid Build Coastguard Worker            out_np = x_np + y
2215*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2218*da0073e9SAndroid Build Coastguard Worker        y = 3.14
2219*da0073e9SAndroid Build Coastguard Worker        z = add(x, y)
2220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x + y)
2221*da0073e9SAndroid Build Coastguard Worker
2222*da0073e9SAndroid Build Coastguard Worker        cpu_called = False
2223*da0073e9SAndroid Build Coastguard Worker
2224*da0073e9SAndroid Build Coastguard Worker        @add.register_kernel("cpu")
2225*da0073e9SAndroid Build Coastguard Worker        def _(x, y):
2226*da0073e9SAndroid Build Coastguard Worker            nonlocal cpu_called
2227*da0073e9SAndroid Build Coastguard Worker            cpu_called = True
2228*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
2229*da0073e9SAndroid Build Coastguard Worker            out_np = x_np + y
2230*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np)
2231*da0073e9SAndroid Build Coastguard Worker
2232*da0073e9SAndroid Build Coastguard Worker        z = add(x, y)
2233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x + y)
2234*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cpu_called)
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2237*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_skips_autograd(self):
2238*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2239*da0073e9SAndroid Build Coastguard Worker        def add(x: Tensor, y: float) -> Tensor:
2240*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy(force=True)
2241*da0073e9SAndroid Build Coastguard Worker            out_np = x_np + y
2242*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
2243*da0073e9SAndroid Build Coastguard Worker
2244*da0073e9SAndroid Build Coastguard Worker        called = 0
2245*da0073e9SAndroid Build Coastguard Worker
2246*da0073e9SAndroid Build Coastguard Worker        def setup_context(ctx, inputs, output):
2247*da0073e9SAndroid Build Coastguard Worker            nonlocal called
2248*da0073e9SAndroid Build Coastguard Worker            called += 1
2249*da0073e9SAndroid Build Coastguard Worker
2250*da0073e9SAndroid Build Coastguard Worker        def backward(ctx, grad):
2251*da0073e9SAndroid Build Coastguard Worker            raise AssertionError("should not be reached")
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker        add.register_autograd(backward, setup_context=setup_context)
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
2256*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2257*da0073e9SAndroid Build Coastguard Worker            y = add(x, 2.0)
2258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 0)
2259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x + 2.0)
2260*da0073e9SAndroid Build Coastguard Worker
2261*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(False)
2262*da0073e9SAndroid Build Coastguard Worker        y = add(x, 2.0)
2263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 0)
2264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x + 2.0)
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
2267*da0073e9SAndroid Build Coastguard Worker        y = add(x, 2.0)
2268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
2269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x + 2.0)
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2272*da0073e9SAndroid Build Coastguard Worker    def test_manual_schema(self):
2273*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
2274*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::add",
2275*da0073e9SAndroid Build Coastguard Worker            mutates_args=(),
2276*da0073e9SAndroid Build Coastguard Worker            schema="(Tensor x, float y) -> Tensor",
2277*da0073e9SAndroid Build Coastguard Worker        )
2278*da0073e9SAndroid Build Coastguard Worker        def add(x, y):
2279*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy(force=True)
2280*da0073e9SAndroid Build Coastguard Worker            out_np = x_np + y
2281*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
2282*da0073e9SAndroid Build Coastguard Worker
2283*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2284*da0073e9SAndroid Build Coastguard Worker        y = 3.14
2285*da0073e9SAndroid Build Coastguard Worker        z = add(x, y)
2286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x + y)
2287*da0073e9SAndroid Build Coastguard Worker
2288*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
2289*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::sin_",
2290*da0073e9SAndroid Build Coastguard Worker            mutates_args=["x"],
2291*da0073e9SAndroid Build Coastguard Worker            schema="(Tensor(a!) x) -> ()",
2292*da0073e9SAndroid Build Coastguard Worker        )
2293*da0073e9SAndroid Build Coastguard Worker        def sin_(x):
2294*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
2295*da0073e9SAndroid Build Coastguard Worker            np.sin(x_np, out=x_np)
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2298*da0073e9SAndroid Build Coastguard Worker        expected = x.sin()
2299*da0073e9SAndroid Build Coastguard Worker        sin_(x)
2300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, expected)
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2303*da0073e9SAndroid Build Coastguard Worker    def test_kwarg_only_tensors(self):
2304*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2305*da0073e9SAndroid Build Coastguard Worker
2306*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2307*da0073e9SAndroid Build Coastguard Worker            def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor:
2308*da0073e9SAndroid Build Coastguard Worker                pass
2309*da0073e9SAndroid Build Coastguard Worker
2310*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2311*da0073e9SAndroid Build Coastguard Worker
2312*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2313*da0073e9SAndroid Build Coastguard Worker            def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor:
2314*da0073e9SAndroid Build Coastguard Worker                pass
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2319*da0073e9SAndroid Build Coastguard Worker            def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor:
2320*da0073e9SAndroid Build Coastguard Worker                pass
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2323*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor x, *, Tensor y) -> Tensor")
2324*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2325*da0073e9SAndroid Build Coastguard Worker                torch.library.register_autograd(
2326*da0073e9SAndroid Build Coastguard Worker                    "_torch_testing::foo",
2327*da0073e9SAndroid Build Coastguard Worker                    lambda grad: grad,
2328*da0073e9SAndroid Build Coastguard Worker                    setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
2329*da0073e9SAndroid Build Coastguard Worker                )
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2332*da0073e9SAndroid Build Coastguard Worker                torch.library.register_vmap(
2333*da0073e9SAndroid Build Coastguard Worker                    "_torch_testing::foo",
2334*da0073e9SAndroid Build Coastguard Worker                    lambda info, in_dims, x, *, y: (x, 0),
2335*da0073e9SAndroid Build Coastguard Worker                )
2336*da0073e9SAndroid Build Coastguard Worker
2337*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2338*da0073e9SAndroid Build Coastguard Worker    def test_register_autograd_kwargonly_low_level(self):
2339*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2340*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor x, *, float y) -> Tensor")
2341*da0073e9SAndroid Build Coastguard Worker            called = False
2342*da0073e9SAndroid Build Coastguard Worker
2343*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, *, y):
2344*da0073e9SAndroid Build Coastguard Worker                return x * y
2345*da0073e9SAndroid Build Coastguard Worker
2346*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
2347*da0073e9SAndroid Build Coastguard Worker
2348*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2349*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2350*da0073e9SAndroid Build Coastguard Worker                called = True
2351*da0073e9SAndroid Build Coastguard Worker                return grad * ctx.y
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, keyword_only_inputs, output):
2354*da0073e9SAndroid Build Coastguard Worker                assert tuple(keyword_only_inputs.keys()) == ("y",)
2355*da0073e9SAndroid Build Coastguard Worker                ctx.y = keyword_only_inputs["y"]
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker            torch.library.register_autograd(
2358*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2359*da0073e9SAndroid Build Coastguard Worker            )
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
2362*da0073e9SAndroid Build Coastguard Worker            torch.ops._torch_testing.foo(x, y=3.14).sum().backward()
2363*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
2364*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14]))
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2367*da0073e9SAndroid Build Coastguard Worker    def test_register_autograd_defaults(self):
2368*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2369*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker            def foo_impl(w, x=2, *, y=3, z):
2372*da0073e9SAndroid Build Coastguard Worker                return w * x * y * z
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker            called = False
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2379*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2380*da0073e9SAndroid Build Coastguard Worker                called = True
2381*da0073e9SAndroid Build Coastguard Worker                return grad * ctx.c
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, keyword_only_inputs, output):
2384*da0073e9SAndroid Build Coastguard Worker                assert len(inputs) == 2
2385*da0073e9SAndroid Build Coastguard Worker                assert inputs[1] == 2
2386*da0073e9SAndroid Build Coastguard Worker                assert keyword_only_inputs == {"y": 3, "z": 42}
2387*da0073e9SAndroid Build Coastguard Worker                ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1]
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker            torch.library.register_autograd(
2390*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2391*da0073e9SAndroid Build Coastguard Worker            )
2392*da0073e9SAndroid Build Coastguard Worker
2393*da0073e9SAndroid Build Coastguard Worker            w = torch.randn(3, requires_grad=True)
2394*da0073e9SAndroid Build Coastguard Worker            torch.ops._torch_testing.foo(w, z=42).sum().backward()
2395*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
2396*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42))
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2399*da0073e9SAndroid Build Coastguard Worker    def test_manual_schema_error(self):
2400*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"):
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(
2403*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::sin_",
2404*da0073e9SAndroid Build Coastguard Worker                mutates_args=(),
2405*da0073e9SAndroid Build Coastguard Worker                schema="(Tensor(a!) x) -> ()",
2406*da0073e9SAndroid Build Coastguard Worker            )
2407*da0073e9SAndroid Build Coastguard Worker            def sin_(x):
2408*da0073e9SAndroid Build Coastguard Worker                x_np = x.numpy()
2409*da0073e9SAndroid Build Coastguard Worker                np.sin(x_np, out=x_np)
2410*da0073e9SAndroid Build Coastguard Worker
2411*da0073e9SAndroid Build Coastguard Worker    def test_supports_tensorlist(self):
2412*da0073e9SAndroid Build Coastguard Worker        @torch._library.autograd.supports_tensorlist
2413*da0073e9SAndroid Build Coastguard Worker        class Stack(torch.autograd.Function):
2414*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2415*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, xs):
2416*da0073e9SAndroid Build Coastguard Worker                ctx.num_xs = len(xs)
2417*da0073e9SAndroid Build Coastguard Worker                return torch.stack(xs)
2418*da0073e9SAndroid Build Coastguard Worker
2419*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2420*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2421*da0073e9SAndroid Build Coastguard Worker                expected = ([True] * ctx.num_xs,)
2422*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ctx.needs_input_grad, expected)
2423*da0073e9SAndroid Build Coastguard Worker                return list(grad.unbind(0))
2424*da0073e9SAndroid Build Coastguard Worker
2425*da0073e9SAndroid Build Coastguard Worker        # call two applys, do a backward on the first
2426*da0073e9SAndroid Build Coastguard Worker        def t():
2427*da0073e9SAndroid Build Coastguard Worker            return torch.randn([], requires_grad=True)
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker        xs0 = [t(), t(), t()]
2430*da0073e9SAndroid Build Coastguard Worker        xs1 = [t(), t(), t(), t()]
2431*da0073e9SAndroid Build Coastguard Worker        y0 = Stack.apply(xs0)
2432*da0073e9SAndroid Build Coastguard Worker        y1 = Stack.apply(xs1)
2433*da0073e9SAndroid Build Coastguard Worker        grads = torch.autograd.grad(y0.sum(), xs0)
2434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker        # call one apply, do multiple backwards
2437*da0073e9SAndroid Build Coastguard Worker        xs = [t(), t(), t()]
2438*da0073e9SAndroid Build Coastguard Worker        y = Stack.apply(xs)
2439*da0073e9SAndroid Build Coastguard Worker        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2440*da0073e9SAndroid Build Coastguard Worker        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2441*da0073e9SAndroid Build Coastguard Worker        grads = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker        # error: on access forward, backward directly
2445*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"):
2446*da0073e9SAndroid Build Coastguard Worker            Stack.forward(None, xs)
2447*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"):
2448*da0073e9SAndroid Build Coastguard Worker            Stack.backward(None, xs)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker        # the recursive case
2451*da0073e9SAndroid Build Coastguard Worker        @torch._library.autograd.supports_tensorlist
2452*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
2453*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2454*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, xs):
2455*da0073e9SAndroid Build Coastguard Worker                if len(xs) > 1:
2456*da0073e9SAndroid Build Coastguard Worker                    return Foo.apply(xs[1:])
2457*da0073e9SAndroid Build Coastguard Worker                ctx.len_xs = len(xs)
2458*da0073e9SAndroid Build Coastguard Worker                return xs[0].sin()
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2461*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2462*da0073e9SAndroid Build Coastguard Worker                result = [None] * ctx.len_xs
2463*da0073e9SAndroid Build Coastguard Worker                result[-1] = grad.cos()
2464*da0073e9SAndroid Build Coastguard Worker                return result
2465*da0073e9SAndroid Build Coastguard Worker
2466*da0073e9SAndroid Build Coastguard Worker        # should work
2467*da0073e9SAndroid Build Coastguard Worker        result = Foo.apply(xs)
2468*da0073e9SAndroid Build Coastguard Worker        expected = xs[-1].sin()
2469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker        # recursive on backward
2472*da0073e9SAndroid Build Coastguard Worker        @torch._library.autograd.supports_tensorlist
2473*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.autograd.Function):
2474*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2475*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, xs):
2476*da0073e9SAndroid Build Coastguard Worker                return [xs[i] + i for i in range(len(xs))]
2477*da0073e9SAndroid Build Coastguard Worker
2478*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2479*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grads):
2480*da0073e9SAndroid Build Coastguard Worker                f1 = Bar.apply(grads[:2])
2481*da0073e9SAndroid Build Coastguard Worker                f2 = Bar.apply(grads[2:])
2482*da0073e9SAndroid Build Coastguard Worker                return f1 + f2
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)]
2485*da0073e9SAndroid Build Coastguard Worker        ys = Bar.apply(xs)
2486*da0073e9SAndroid Build Coastguard Worker        sum(ys).backward()
2487*da0073e9SAndroid Build Coastguard Worker        result = [xi.grad for xi in xs]
2488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0))
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2491*da0073e9SAndroid Build Coastguard Worker    def test_default_values(self):
2492*da0073e9SAndroid Build Coastguard Worker        defaults = []
2493*da0073e9SAndroid Build Coastguard Worker
2494*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
2495*da0073e9SAndroid Build Coastguard Worker        def f(
2496*da0073e9SAndroid Build Coastguard Worker            x: Tensor,
2497*da0073e9SAndroid Build Coastguard Worker            a: Optional[int] = None,
2498*da0073e9SAndroid Build Coastguard Worker            b: float = 3.14,
2499*da0073e9SAndroid Build Coastguard Worker            c: bool = True,
2500*da0073e9SAndroid Build Coastguard Worker            d: int = 3,
2501*da0073e9SAndroid Build Coastguard Worker            e: str = "foo",
2502*da0073e9SAndroid Build Coastguard Worker            f: torch.dtype = torch.float,
2503*da0073e9SAndroid Build Coastguard Worker            g: torch.dtype = torch.float32,
2504*da0073e9SAndroid Build Coastguard Worker            h: torch.dtype = torch.int,
2505*da0073e9SAndroid Build Coastguard Worker            i: torch.device = torch.device("cpu:0"),
2506*da0073e9SAndroid Build Coastguard Worker            j: torch.device = "cpu",
2507*da0073e9SAndroid Build Coastguard Worker        ) -> Tensor:
2508*da0073e9SAndroid Build Coastguard Worker            defaults.extend([a, b, c, d, e, f, g, h, i, j])
2509*da0073e9SAndroid Build Coastguard Worker            return x.clone()
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2512*da0073e9SAndroid Build Coastguard Worker        f(x)
2513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2514*da0073e9SAndroid Build Coastguard Worker            defaults,
2515*da0073e9SAndroid Build Coastguard Worker            [
2516*da0073e9SAndroid Build Coastguard Worker                None,
2517*da0073e9SAndroid Build Coastguard Worker                3.14,
2518*da0073e9SAndroid Build Coastguard Worker                True,
2519*da0073e9SAndroid Build Coastguard Worker                3,
2520*da0073e9SAndroid Build Coastguard Worker                "foo",
2521*da0073e9SAndroid Build Coastguard Worker                torch.float,
2522*da0073e9SAndroid Build Coastguard Worker                torch.float32,
2523*da0073e9SAndroid Build Coastguard Worker                torch.int,
2524*da0073e9SAndroid Build Coastguard Worker                torch.device("cpu:0"),
2525*da0073e9SAndroid Build Coastguard Worker                "cpu",
2526*da0073e9SAndroid Build Coastguard Worker            ],
2527*da0073e9SAndroid Build Coastguard Worker        )
2528*da0073e9SAndroid Build Coastguard Worker        default_values = [
2529*da0073e9SAndroid Build Coastguard Worker            arg.default_value
2530*da0073e9SAndroid Build Coastguard Worker            for arg in torch.ops._torch_testing.f.default._schema.arguments
2531*da0073e9SAndroid Build Coastguard Worker        ]
2532*da0073e9SAndroid Build Coastguard Worker        # enum values taken from c10/core/ScalarType.h
2533*da0073e9SAndroid Build Coastguard Worker        type_enum = {
2534*da0073e9SAndroid Build Coastguard Worker            "float": 6,
2535*da0073e9SAndroid Build Coastguard Worker            "int": 3,
2536*da0073e9SAndroid Build Coastguard Worker        }
2537*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2538*da0073e9SAndroid Build Coastguard Worker            default_values,
2539*da0073e9SAndroid Build Coastguard Worker            [
2540*da0073e9SAndroid Build Coastguard Worker                None,
2541*da0073e9SAndroid Build Coastguard Worker                None,
2542*da0073e9SAndroid Build Coastguard Worker                3.14,
2543*da0073e9SAndroid Build Coastguard Worker                True,
2544*da0073e9SAndroid Build Coastguard Worker                3,
2545*da0073e9SAndroid Build Coastguard Worker                "foo",
2546*da0073e9SAndroid Build Coastguard Worker                type_enum["float"],
2547*da0073e9SAndroid Build Coastguard Worker                type_enum["float"],
2548*da0073e9SAndroid Build Coastguard Worker                type_enum["int"],
2549*da0073e9SAndroid Build Coastguard Worker                torch.device("cpu:0"),
2550*da0073e9SAndroid Build Coastguard Worker                torch.device("cpu"),
2551*da0073e9SAndroid Build Coastguard Worker            ],
2552*da0073e9SAndroid Build Coastguard Worker        )
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker    def test_mutated_error(self):
2555*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2556*da0073e9SAndroid Build Coastguard Worker            ValueError, r".*{'y'} in mutates_args were not found"
2557*da0073e9SAndroid Build Coastguard Worker        ):
2558*da0073e9SAndroid Build Coastguard Worker
2559*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(
2560*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::numpy_sin_inplace",
2561*da0073e9SAndroid Build Coastguard Worker                mutates_args={"y"},
2562*da0073e9SAndroid Build Coastguard Worker                device_types="cpu",
2563*da0073e9SAndroid Build Coastguard Worker            )
2564*da0073e9SAndroid Build Coastguard Worker            def numpy_sin_inplace(x: Tensor) -> None:
2565*da0073e9SAndroid Build Coastguard Worker                x_np = x.numpy()
2566*da0073e9SAndroid Build Coastguard Worker                np.sin(x_np, out=x_np)
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker    def test_mutated(self):
2569*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
2570*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu"
2571*da0073e9SAndroid Build Coastguard Worker        )
2572*da0073e9SAndroid Build Coastguard Worker        def numpy_sin_inplace(x: Tensor) -> None:
2573*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
2574*da0073e9SAndroid Build Coastguard Worker            np.sin(x_np, out=x_np)
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2577*da0073e9SAndroid Build Coastguard Worker        version = x._version
2578*da0073e9SAndroid Build Coastguard Worker        expected = x.sin()
2579*da0073e9SAndroid Build Coastguard Worker        numpy_sin_inplace(x)
2580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, expected)
2581*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(x._version, version)
2582*da0073e9SAndroid Build Coastguard Worker
2583*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"})
2584*da0073e9SAndroid Build Coastguard Worker        def f(
2585*da0073e9SAndroid Build Coastguard Worker            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2586*da0073e9SAndroid Build Coastguard Worker        ) -> None:
2587*da0073e9SAndroid Build Coastguard Worker            return
2588*da0073e9SAndroid Build Coastguard Worker
2589*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2590*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
2591*da0073e9SAndroid Build Coastguard Worker        z = [torch.randn(3), torch.randn(3)]
2592*da0073e9SAndroid Build Coastguard Worker        w = [torch.randn(3), None, torch.randn(3)]
2593*da0073e9SAndroid Build Coastguard Worker        initial_versions = pytree.tree_map_only(
2594*da0073e9SAndroid Build Coastguard Worker            torch.Tensor, lambda x: x._version, (x, y, z, w)
2595*da0073e9SAndroid Build Coastguard Worker        )
2596*da0073e9SAndroid Build Coastguard Worker        f(x, y, z, w)
2597*da0073e9SAndroid Build Coastguard Worker        new_versions = pytree.tree_map_only(
2598*da0073e9SAndroid Build Coastguard Worker            torch.Tensor, lambda x: x._version, (x, y, z, w)
2599*da0073e9SAndroid Build Coastguard Worker        )
2600*da0073e9SAndroid Build Coastguard Worker
2601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(initial_versions[0], new_versions[0])
2602*da0073e9SAndroid Build Coastguard Worker        initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
2603*da0073e9SAndroid Build Coastguard Worker        new_versions, _ = pytree.tree_flatten(new_versions[1:])
2604*da0073e9SAndroid Build Coastguard Worker        for prev, after in zip(initial_versions, new_versions):
2605*da0073e9SAndroid Build Coastguard Worker            if prev is None and after is None:
2606*da0073e9SAndroid Build Coastguard Worker                continue
2607*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(after, prev)
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker    def test_mutated_unknown(self):
2610*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
2611*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::f", mutates_args="unknown", device_types="cpu"
2612*da0073e9SAndroid Build Coastguard Worker        )
2613*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> None:
2614*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
2615*da0073e9SAndroid Build Coastguard Worker            np.sin(x_np, out=x_np)
2616*da0073e9SAndroid Build Coastguard Worker
2617*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2618*da0073e9SAndroid Build Coastguard Worker        version = x._version
2619*da0073e9SAndroid Build Coastguard Worker        expected = x.sin()
2620*da0073e9SAndroid Build Coastguard Worker        f(x)
2621*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, expected)
2622*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(x._version, version)
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown")
2625*da0073e9SAndroid Build Coastguard Worker        def f2(
2626*da0073e9SAndroid Build Coastguard Worker            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2627*da0073e9SAndroid Build Coastguard Worker        ) -> None:
2628*da0073e9SAndroid Build Coastguard Worker            return
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2631*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
2632*da0073e9SAndroid Build Coastguard Worker        z = [torch.randn(3), torch.randn(3)]
2633*da0073e9SAndroid Build Coastguard Worker        w = [torch.randn(3), None, torch.randn(3)]
2634*da0073e9SAndroid Build Coastguard Worker        initial_versions = pytree.tree_map_only(
2635*da0073e9SAndroid Build Coastguard Worker            torch.Tensor, lambda x: x._version, (x, y, z, w)
2636*da0073e9SAndroid Build Coastguard Worker        )
2637*da0073e9SAndroid Build Coastguard Worker        f2(x, y, z, w)
2638*da0073e9SAndroid Build Coastguard Worker        new_versions = pytree.tree_map_only(
2639*da0073e9SAndroid Build Coastguard Worker            torch.Tensor, lambda x: x._version, (x, y, z, w)
2640*da0073e9SAndroid Build Coastguard Worker        )
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker        initial_versions, _ = pytree.tree_flatten(initial_versions)
2643*da0073e9SAndroid Build Coastguard Worker        new_versions, _ = pytree.tree_flatten(new_versions)
2644*da0073e9SAndroid Build Coastguard Worker        for prev, after in zip(initial_versions, new_versions):
2645*da0073e9SAndroid Build Coastguard Worker            if prev is None and after is None:
2646*da0073e9SAndroid Build Coastguard Worker                continue
2647*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(after, prev)
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "string"):
2650*da0073e9SAndroid Build Coastguard Worker
2651*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::f3", mutates_args="x")
2652*da0073e9SAndroid Build Coastguard Worker            def f3(x: Tensor) -> None:
2653*da0073e9SAndroid Build Coastguard Worker                return
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2656*da0073e9SAndroid Build Coastguard Worker    def test_library_register_torch_dispatch_rule_subclass(self):
2657*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.two_tensor import TwoTensor
2658*da0073e9SAndroid Build Coastguard Worker
2659*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::foo", mutates_args={})
2660*da0073e9SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
2661*da0073e9SAndroid Build Coastguard Worker            return x.sin()
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2664*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
2665*da0073e9SAndroid Build Coastguard Worker        z = TwoTensor(x, y)
2666*da0073e9SAndroid Build Coastguard Worker
2667*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2668*da0073e9SAndroid Build Coastguard Worker            called = 0
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker            def TwoTensor_foo(cls, func, types, args, kwargs):
2671*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2672*da0073e9SAndroid Build Coastguard Worker                assert cls is TwoTensor
2673*da0073e9SAndroid Build Coastguard Worker                called += 1
2674*da0073e9SAndroid Build Coastguard Worker                return x.sin()
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker            m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo)
2677*da0073e9SAndroid Build Coastguard Worker
2678*da0073e9SAndroid Build Coastguard Worker            out = f(z)
2679*da0073e9SAndroid Build Coastguard Worker            out2 = z.cos()
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2684*da0073e9SAndroid Build Coastguard Worker    def test_library_register_torch_dispatch_rule_mode(self):
2685*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.two_tensor import TwoTensorMode
2686*da0073e9SAndroid Build Coastguard Worker
2687*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::foo", mutates_args={})
2688*da0073e9SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
2689*da0073e9SAndroid Build Coastguard Worker            return x.sin()
2690*da0073e9SAndroid Build Coastguard Worker
2691*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2692*da0073e9SAndroid Build Coastguard Worker
2693*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2694*da0073e9SAndroid Build Coastguard Worker            called = 0
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker            def TwoTensor_foo(mode, func, types, args, kwargs):
2697*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2698*da0073e9SAndroid Build Coastguard Worker                called += 1
2699*da0073e9SAndroid Build Coastguard Worker                return x.sin()
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker            m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo)
2702*da0073e9SAndroid Build Coastguard Worker
2703*da0073e9SAndroid Build Coastguard Worker            with TwoTensorMode():
2704*da0073e9SAndroid Build Coastguard Worker                out = f(x)
2705*da0073e9SAndroid Build Coastguard Worker                out2 = x.cos()
2706*da0073e9SAndroid Build Coastguard Worker
2707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2710*da0073e9SAndroid Build Coastguard Worker    @parametrize("idx", [0, 1, 2, 3, 4, 5])
2711*da0073e9SAndroid Build Coastguard Worker    def test_library_register_fake_source(self, idx):
2712*da0073e9SAndroid Build Coastguard Worker        opname = f"source{idx}"
2713*da0073e9SAndroid Build Coastguard Worker        op = getattr(torch.ops._torch_testing, opname).default
2714*da0073e9SAndroid Build Coastguard Worker        entry = torch._library.simple_registry.singleton.find(op._name)
2715*da0073e9SAndroid Build Coastguard Worker        source = entry.fake_impl.kernel.source
2716*da0073e9SAndroid Build Coastguard Worker        assert source is not None
2717*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("custom_op_db.py" in source)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2720*da0073e9SAndroid Build Coastguard Worker    def test_library_register_fake(self):
2721*da0073e9SAndroid Build Coastguard Worker        for mode in ["function", "qualname", "opoverload"]:
2722*da0073e9SAndroid Build Coastguard Worker
2723*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2724*da0073e9SAndroid Build Coastguard Worker            def add(x: Tensor, y: float) -> Tensor:
2725*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy()
2726*da0073e9SAndroid Build Coastguard Worker                out_np = x_np + y
2727*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(out_np).to(x.device)
2728*da0073e9SAndroid Build Coastguard Worker
2729*da0073e9SAndroid Build Coastguard Worker            called = False
2730*da0073e9SAndroid Build Coastguard Worker
2731*da0073e9SAndroid Build Coastguard Worker            if mode == "function":
2732*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_fake(add)
2733*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2734*da0073e9SAndroid Build Coastguard Worker            elif mode == "qualname":
2735*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_fake("_torch_testing::add")
2736*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2737*da0073e9SAndroid Build Coastguard Worker            elif mode == "opoverload":
2738*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_fake(torch.ops._torch_testing.add.default)
2739*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2740*da0073e9SAndroid Build Coastguard Worker            else:
2741*da0073e9SAndroid Build Coastguard Worker                raise AssertionError("should not get here")
2742*da0073e9SAndroid Build Coastguard Worker
2743*da0073e9SAndroid Build Coastguard Worker            @dec
2744*da0073e9SAndroid Build Coastguard Worker            def _(x, y):
2745*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2746*da0073e9SAndroid Build Coastguard Worker                called = True
2747*da0073e9SAndroid Build Coastguard Worker                return torch.empty_like(x)
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker            with torch._subclasses.fake_tensor.FakeTensorMode():
2750*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
2751*da0073e9SAndroid Build Coastguard Worker                y = 3.14
2752*da0073e9SAndroid Build Coastguard Worker                z = add(x, y)
2753*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z.shape, x.shape)
2754*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(called)
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2757*da0073e9SAndroid Build Coastguard Worker    def test_library_register_torch_dispatch(self):
2758*da0073e9SAndroid Build Coastguard Worker        for mode in ["function", "qualname", "opoverload"]:
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker            class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2761*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2762*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
2763*da0073e9SAndroid Build Coastguard Worker
2764*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2765*da0073e9SAndroid Build Coastguard Worker            def add(x: Tensor, y: float) -> Tensor:
2766*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy()
2767*da0073e9SAndroid Build Coastguard Worker                out_np = x_np + y
2768*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(out_np).to(x.device)
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker            called = False
2771*da0073e9SAndroid Build Coastguard Worker
2772*da0073e9SAndroid Build Coastguard Worker            if mode == "function":
2773*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_torch_dispatch(add, MyMode)
2774*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2775*da0073e9SAndroid Build Coastguard Worker            elif mode == "qualname":
2776*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_torch_dispatch(
2777*da0073e9SAndroid Build Coastguard Worker                    "_torch_testing::add", MyMode
2778*da0073e9SAndroid Build Coastguard Worker                )
2779*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2780*da0073e9SAndroid Build Coastguard Worker            elif mode == "opoverload":
2781*da0073e9SAndroid Build Coastguard Worker                dec = torch.library.register_torch_dispatch(
2782*da0073e9SAndroid Build Coastguard Worker                    torch.ops._torch_testing.add.default, MyMode
2783*da0073e9SAndroid Build Coastguard Worker                )
2784*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(dec)
2785*da0073e9SAndroid Build Coastguard Worker            else:
2786*da0073e9SAndroid Build Coastguard Worker                raise AssertionError("should not get here")
2787*da0073e9SAndroid Build Coastguard Worker
2788*da0073e9SAndroid Build Coastguard Worker            @dec
2789*da0073e9SAndroid Build Coastguard Worker            def _(mode, func, types, args, kwargs):
2790*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2791*da0073e9SAndroid Build Coastguard Worker                called = True
2792*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
2793*da0073e9SAndroid Build Coastguard Worker
2794*da0073e9SAndroid Build Coastguard Worker            with MyMode():
2795*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
2796*da0073e9SAndroid Build Coastguard Worker                y = 3.14
2797*da0073e9SAndroid Build Coastguard Worker                z = add(x, y)
2798*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z.shape, x.shape)
2799*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(called)
2800*da0073e9SAndroid Build Coastguard Worker
2801*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2802*da0073e9SAndroid Build Coastguard Worker    def test_library_register_torch_dispatch_low_level(self):
2803*da0073e9SAndroid Build Coastguard Worker        modes = ["qualname", "opoverload"]
2804*da0073e9SAndroid Build Coastguard Worker        calls = ["decorator", "function"]
2805*da0073e9SAndroid Build Coastguard Worker        device_types_options = [("cpu", "cuda"), "cpu", None]
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        for mode, call, device_types in itertools.product(
2808*da0073e9SAndroid Build Coastguard Worker            modes, calls, device_types_options
2809*da0073e9SAndroid Build Coastguard Worker        ):
2810*da0073e9SAndroid Build Coastguard Worker            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2811*da0073e9SAndroid Build Coastguard Worker                lib.define("add10(Tensor x, float y) -> Tensor")
2812*da0073e9SAndroid Build Coastguard Worker
2813*da0073e9SAndroid Build Coastguard Worker                if mode == "qualname":
2814*da0073e9SAndroid Build Coastguard Worker                    op = "_torch_testing::add10"
2815*da0073e9SAndroid Build Coastguard Worker                else:
2816*da0073e9SAndroid Build Coastguard Worker                    assert mode == "opoverload"
2817*da0073e9SAndroid Build Coastguard Worker                    op = torch.ops._torch_testing.add10.default
2818*da0073e9SAndroid Build Coastguard Worker
2819*da0073e9SAndroid Build Coastguard Worker                called = False
2820*da0073e9SAndroid Build Coastguard Worker
2821*da0073e9SAndroid Build Coastguard Worker                class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2822*da0073e9SAndroid Build Coastguard Worker                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2823*da0073e9SAndroid Build Coastguard Worker                        return func(*args, **kwargs)
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker                if call == "decorator":
2826*da0073e9SAndroid Build Coastguard Worker
2827*da0073e9SAndroid Build Coastguard Worker                    @torch.library.register_torch_dispatch(op, MyMode, lib=lib)
2828*da0073e9SAndroid Build Coastguard Worker                    def _(mode, func, types, args, kwargs):
2829*da0073e9SAndroid Build Coastguard Worker                        x, y = args
2830*da0073e9SAndroid Build Coastguard Worker                        nonlocal called
2831*da0073e9SAndroid Build Coastguard Worker                        called = True
2832*da0073e9SAndroid Build Coastguard Worker                        return x + y
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker                else:
2835*da0073e9SAndroid Build Coastguard Worker                    assert call == "function"
2836*da0073e9SAndroid Build Coastguard Worker
2837*da0073e9SAndroid Build Coastguard Worker                    def add_stuff(mode, func, types, args, kwargs):
2838*da0073e9SAndroid Build Coastguard Worker                        x, y = args
2839*da0073e9SAndroid Build Coastguard Worker                        nonlocal called
2840*da0073e9SAndroid Build Coastguard Worker                        called = True
2841*da0073e9SAndroid Build Coastguard Worker                        return x + y
2842*da0073e9SAndroid Build Coastguard Worker
2843*da0073e9SAndroid Build Coastguard Worker                    torch.library.register_torch_dispatch(
2844*da0073e9SAndroid Build Coastguard Worker                        op, MyMode, add_stuff, lib=lib
2845*da0073e9SAndroid Build Coastguard Worker                    )
2846*da0073e9SAndroid Build Coastguard Worker
2847*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
2848*da0073e9SAndroid Build Coastguard Worker                y = 3.14
2849*da0073e9SAndroid Build Coastguard Worker                with MyMode():
2850*da0073e9SAndroid Build Coastguard Worker                    z = torch.ops._torch_testing.add10.default(x, y)
2851*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z, x + y)
2852*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(called)
2853*da0073e9SAndroid Build Coastguard Worker
2854*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2855*da0073e9SAndroid Build Coastguard Worker    def test_library_register_kernel(self):
2856*da0073e9SAndroid Build Coastguard Worker        modes = ["function", "qualname", "opoverload"]
2857*da0073e9SAndroid Build Coastguard Worker        calls = ["decorator", "function"]
2858*da0073e9SAndroid Build Coastguard Worker        device_types_options = ["cpu", None]
2859*da0073e9SAndroid Build Coastguard Worker
2860*da0073e9SAndroid Build Coastguard Worker        for mode, call, device_types in itertools.product(
2861*da0073e9SAndroid Build Coastguard Worker            modes, calls, device_types_options
2862*da0073e9SAndroid Build Coastguard Worker        ):
2863*da0073e9SAndroid Build Coastguard Worker
2864*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(
2865*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::add", mutates_args=(), device_types="cuda"
2866*da0073e9SAndroid Build Coastguard Worker            )
2867*da0073e9SAndroid Build Coastguard Worker            def add(x: Tensor, y: float) -> Tensor:
2868*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy()
2869*da0073e9SAndroid Build Coastguard Worker                out_np = x_np + y
2870*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(out_np).to(x.device)
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker            if mode == "function":
2873*da0073e9SAndroid Build Coastguard Worker                op = add
2874*da0073e9SAndroid Build Coastguard Worker            elif mode == "qualname":
2875*da0073e9SAndroid Build Coastguard Worker                op = "_torch_testing::add"
2876*da0073e9SAndroid Build Coastguard Worker            else:
2877*da0073e9SAndroid Build Coastguard Worker                assert mode == "opoverload"
2878*da0073e9SAndroid Build Coastguard Worker                op = torch.ops._torch_testing.add.default
2879*da0073e9SAndroid Build Coastguard Worker
2880*da0073e9SAndroid Build Coastguard Worker            called = False
2881*da0073e9SAndroid Build Coastguard Worker
2882*da0073e9SAndroid Build Coastguard Worker            if call == "decorator":
2883*da0073e9SAndroid Build Coastguard Worker
2884*da0073e9SAndroid Build Coastguard Worker                @torch.library.register_kernel(op, device_types)
2885*da0073e9SAndroid Build Coastguard Worker                def _(x, y):
2886*da0073e9SAndroid Build Coastguard Worker                    nonlocal called
2887*da0073e9SAndroid Build Coastguard Worker                    called = True
2888*da0073e9SAndroid Build Coastguard Worker                    x_np = x.numpy()
2889*da0073e9SAndroid Build Coastguard Worker                    out_np = x_np + y
2890*da0073e9SAndroid Build Coastguard Worker                    return torch.from_numpy(out_np)
2891*da0073e9SAndroid Build Coastguard Worker
2892*da0073e9SAndroid Build Coastguard Worker            else:
2893*da0073e9SAndroid Build Coastguard Worker                assert call == "function"
2894*da0073e9SAndroid Build Coastguard Worker
2895*da0073e9SAndroid Build Coastguard Worker                def add_cpu(x, y):
2896*da0073e9SAndroid Build Coastguard Worker                    nonlocal called
2897*da0073e9SAndroid Build Coastguard Worker                    called = True
2898*da0073e9SAndroid Build Coastguard Worker                    x_np = x.numpy()
2899*da0073e9SAndroid Build Coastguard Worker                    out_np = x_np + y
2900*da0073e9SAndroid Build Coastguard Worker                    return torch.from_numpy(out_np)
2901*da0073e9SAndroid Build Coastguard Worker
2902*da0073e9SAndroid Build Coastguard Worker                torch.library.register_kernel(op, device_types, add_cpu)
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
2905*da0073e9SAndroid Build Coastguard Worker            y = 3.14
2906*da0073e9SAndroid Build Coastguard Worker            z = add(x, y)
2907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, x + y)
2908*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
2909*da0073e9SAndroid Build Coastguard Worker
2910*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2911*da0073e9SAndroid Build Coastguard Worker    def test_library_register_kernel_low_level(self):
2912*da0073e9SAndroid Build Coastguard Worker        modes = ["qualname", "opoverload"]
2913*da0073e9SAndroid Build Coastguard Worker        calls = ["decorator", "function"]
2914*da0073e9SAndroid Build Coastguard Worker        device_types_options = [("cpu", "cuda"), "cpu", None]
2915*da0073e9SAndroid Build Coastguard Worker
2916*da0073e9SAndroid Build Coastguard Worker        for mode, call, device_types in itertools.product(
2917*da0073e9SAndroid Build Coastguard Worker            modes, calls, device_types_options
2918*da0073e9SAndroid Build Coastguard Worker        ):
2919*da0073e9SAndroid Build Coastguard Worker            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2920*da0073e9SAndroid Build Coastguard Worker                lib.define("add9(Tensor x, float y) -> Tensor")
2921*da0073e9SAndroid Build Coastguard Worker
2922*da0073e9SAndroid Build Coastguard Worker                if mode == "qualname":
2923*da0073e9SAndroid Build Coastguard Worker                    op = "_torch_testing::add9"
2924*da0073e9SAndroid Build Coastguard Worker                else:
2925*da0073e9SAndroid Build Coastguard Worker                    assert mode == "opoverload"
2926*da0073e9SAndroid Build Coastguard Worker                    op = torch.ops._torch_testing.add9.default
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker                called = False
2929*da0073e9SAndroid Build Coastguard Worker
2930*da0073e9SAndroid Build Coastguard Worker                if call == "decorator":
2931*da0073e9SAndroid Build Coastguard Worker
2932*da0073e9SAndroid Build Coastguard Worker                    @torch.library.register_kernel(op, device_types, lib=lib)
2933*da0073e9SAndroid Build Coastguard Worker                    def _(x, y):
2934*da0073e9SAndroid Build Coastguard Worker                        nonlocal called
2935*da0073e9SAndroid Build Coastguard Worker                        called = True
2936*da0073e9SAndroid Build Coastguard Worker                        x_np = x.numpy()
2937*da0073e9SAndroid Build Coastguard Worker                        out_np = x_np + y
2938*da0073e9SAndroid Build Coastguard Worker                        return torch.from_numpy(out_np)
2939*da0073e9SAndroid Build Coastguard Worker
2940*da0073e9SAndroid Build Coastguard Worker                else:
2941*da0073e9SAndroid Build Coastguard Worker                    assert call == "function"
2942*da0073e9SAndroid Build Coastguard Worker
2943*da0073e9SAndroid Build Coastguard Worker                    def add_cpu(x, y):
2944*da0073e9SAndroid Build Coastguard Worker                        nonlocal called
2945*da0073e9SAndroid Build Coastguard Worker                        called = True
2946*da0073e9SAndroid Build Coastguard Worker                        x_np = x.numpy()
2947*da0073e9SAndroid Build Coastguard Worker                        out_np = x_np + y
2948*da0073e9SAndroid Build Coastguard Worker                        return torch.from_numpy(out_np)
2949*da0073e9SAndroid Build Coastguard Worker
2950*da0073e9SAndroid Build Coastguard Worker                    torch.library.register_kernel(op, device_types, add_cpu, lib=lib)
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
2953*da0073e9SAndroid Build Coastguard Worker                y = 3.14
2954*da0073e9SAndroid Build Coastguard Worker                z = torch.ops._torch_testing.add9.default(x, y)
2955*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z, x + y)
2956*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(called)
2957*da0073e9SAndroid Build Coastguard Worker
2958*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2959*da0073e9SAndroid Build Coastguard Worker    def test_library_register_autograd(self):
2960*da0073e9SAndroid Build Coastguard Worker        for mode in ["function", "qualname", "opoverload"]:
2961*da0073e9SAndroid Build Coastguard Worker
2962*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
2963*da0073e9SAndroid Build Coastguard Worker            def numpy_sin(x: Tensor) -> Tensor:
2964*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy()
2965*da0073e9SAndroid Build Coastguard Worker                y_np = np.sin(x_np)
2966*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(y_np).to(device=x.device)
2967*da0073e9SAndroid Build Coastguard Worker
2968*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output) -> Tensor:
2969*da0073e9SAndroid Build Coastguard Worker                (x,) = inputs
2970*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
2971*da0073e9SAndroid Build Coastguard Worker
2972*da0073e9SAndroid Build Coastguard Worker            called = False
2973*da0073e9SAndroid Build Coastguard Worker
2974*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
2975*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2976*da0073e9SAndroid Build Coastguard Worker                called = True
2977*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
2978*da0073e9SAndroid Build Coastguard Worker                return grad * x.cos()
2979*da0073e9SAndroid Build Coastguard Worker
2980*da0073e9SAndroid Build Coastguard Worker            if mode == "function":
2981*da0073e9SAndroid Build Coastguard Worker                torch.library.register_autograd(
2982*da0073e9SAndroid Build Coastguard Worker                    numpy_sin, backward, setup_context=setup_context
2983*da0073e9SAndroid Build Coastguard Worker                )
2984*da0073e9SAndroid Build Coastguard Worker            elif mode == "qualname":
2985*da0073e9SAndroid Build Coastguard Worker                torch.library.register_autograd(
2986*da0073e9SAndroid Build Coastguard Worker                    "mylib::numpy_sin", backward, setup_context=setup_context
2987*da0073e9SAndroid Build Coastguard Worker                )
2988*da0073e9SAndroid Build Coastguard Worker            elif mode == "opoverload":
2989*da0073e9SAndroid Build Coastguard Worker                torch.library.register_autograd(
2990*da0073e9SAndroid Build Coastguard Worker                    torch.ops.mylib.numpy_sin.default,
2991*da0073e9SAndroid Build Coastguard Worker                    backward,
2992*da0073e9SAndroid Build Coastguard Worker                    setup_context=setup_context,
2993*da0073e9SAndroid Build Coastguard Worker                )
2994*da0073e9SAndroid Build Coastguard Worker
2995*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
2996*da0073e9SAndroid Build Coastguard Worker            y = numpy_sin(x)
2997*da0073e9SAndroid Build Coastguard Worker            (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
2998*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
2999*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_x, x.cos())
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3002*da0073e9SAndroid Build Coastguard Worker    def test_library_register_autograd_low_level(self):
3003*da0073e9SAndroid Build Coastguard Worker        for mode in ["qualname", "opoverload"]:
3004*da0073e9SAndroid Build Coastguard Worker            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3005*da0073e9SAndroid Build Coastguard Worker                lib.define("sin5(Tensor x) -> Tensor")
3006*da0073e9SAndroid Build Coastguard Worker
3007*da0073e9SAndroid Build Coastguard Worker                def numpy_sin(x: Tensor) -> Tensor:
3008*da0073e9SAndroid Build Coastguard Worker                    x_np = x.cpu().detach().numpy()
3009*da0073e9SAndroid Build Coastguard Worker                    y_np = np.sin(x_np)
3010*da0073e9SAndroid Build Coastguard Worker                    return torch.from_numpy(y_np).to(device=x.device)
3011*da0073e9SAndroid Build Coastguard Worker
3012*da0073e9SAndroid Build Coastguard Worker                def setup_context(ctx, inputs, output) -> Tensor:
3013*da0073e9SAndroid Build Coastguard Worker                    (x,) = inputs
3014*da0073e9SAndroid Build Coastguard Worker                    ctx.save_for_backward(x)
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker                called = False
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, grad):
3019*da0073e9SAndroid Build Coastguard Worker                    nonlocal called
3020*da0073e9SAndroid Build Coastguard Worker                    called = True
3021*da0073e9SAndroid Build Coastguard Worker                    (x,) = ctx.saved_tensors
3022*da0073e9SAndroid Build Coastguard Worker                    return grad * x.cos()
3023*da0073e9SAndroid Build Coastguard Worker
3024*da0073e9SAndroid Build Coastguard Worker                lib.impl("sin5", numpy_sin, "CPU")
3025*da0073e9SAndroid Build Coastguard Worker
3026*da0073e9SAndroid Build Coastguard Worker                called = False
3027*da0073e9SAndroid Build Coastguard Worker
3028*da0073e9SAndroid Build Coastguard Worker                if mode == "qualname":
3029*da0073e9SAndroid Build Coastguard Worker                    torch.library.register_autograd(
3030*da0073e9SAndroid Build Coastguard Worker                        "_torch_testing::sin5",
3031*da0073e9SAndroid Build Coastguard Worker                        backward,
3032*da0073e9SAndroid Build Coastguard Worker                        setup_context=setup_context,
3033*da0073e9SAndroid Build Coastguard Worker                        lib=lib,
3034*da0073e9SAndroid Build Coastguard Worker                    )
3035*da0073e9SAndroid Build Coastguard Worker                elif mode == "opoverload":
3036*da0073e9SAndroid Build Coastguard Worker                    torch.library.register_autograd(
3037*da0073e9SAndroid Build Coastguard Worker                        torch.ops._torch_testing.sin5.default,
3038*da0073e9SAndroid Build Coastguard Worker                        backward,
3039*da0073e9SAndroid Build Coastguard Worker                        setup_context=setup_context,
3040*da0073e9SAndroid Build Coastguard Worker                        lib=lib,
3041*da0073e9SAndroid Build Coastguard Worker                    )
3042*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, requires_grad=True)
3043*da0073e9SAndroid Build Coastguard Worker                y = torch.ops._torch_testing.sin5(x)
3044*da0073e9SAndroid Build Coastguard Worker                (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
3045*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(called)
3046*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad_x, x.cos())
3047*da0073e9SAndroid Build Coastguard Worker
3048*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3049*da0073e9SAndroid Build Coastguard Worker    def test_fake(self):
3050*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::add", mutates_args=())
3051*da0073e9SAndroid Build Coastguard Worker        def add(x: Tensor, y: float) -> Tensor:
3052*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
3053*da0073e9SAndroid Build Coastguard Worker            out_np = x_np + y
3054*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
3055*da0073e9SAndroid Build Coastguard Worker
3056*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3057*da0073e9SAndroid Build Coastguard Worker        y = 3.14
3058*da0073e9SAndroid Build Coastguard Worker        z = add(x, y)
3059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x + y)
3060*da0073e9SAndroid Build Coastguard Worker
3061*da0073e9SAndroid Build Coastguard Worker        try:
3062*da0073e9SAndroid Build Coastguard Worker            with torch._subclasses.fake_tensor.FakeTensorMode():
3063*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
3064*da0073e9SAndroid Build Coastguard Worker                add(x, y)
3065*da0073e9SAndroid Build Coastguard Worker            raise AssertionError("should not be hit")
3066*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
3067*da0073e9SAndroid Build Coastguard Worker            abstract_impl_error_msg = str(e)
3068*da0073e9SAndroid Build Coastguard Worker        abstract_impl_error_msg = re.sub(
3069*da0073e9SAndroid Build Coastguard Worker            r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
3070*da0073e9SAndroid Build Coastguard Worker        ).replace(". ", ".\n")
3071*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
3072*da0073e9SAndroid Build Coastguard Worker            abstract_impl_error_msg,
3073*da0073e9SAndroid Build Coastguard Worker            """\
3074*da0073e9SAndroid Build Coastguard WorkerThere was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
3075*da0073e9SAndroid Build Coastguard WorkerThis is necessary for torch.compile/export/fx tracing to work.
3076*da0073e9SAndroid Build Coastguard WorkerPlease use `add.register_fake` to add an fake impl.""",
3077*da0073e9SAndroid Build Coastguard Worker        )
3078*da0073e9SAndroid Build Coastguard Worker
3079*da0073e9SAndroid Build Coastguard Worker        if not IS_WINDOWS:
3080*da0073e9SAndroid Build Coastguard Worker
3081*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="eager")
3082*da0073e9SAndroid Build Coastguard Worker            def f(x, y):
3083*da0073e9SAndroid Build Coastguard Worker                return add(x, y)
3084*da0073e9SAndroid Build Coastguard Worker
3085*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
3086*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "no fake impl"):
3087*da0073e9SAndroid Build Coastguard Worker                f(x, y)
3088*da0073e9SAndroid Build Coastguard Worker
3089*da0073e9SAndroid Build Coastguard Worker        abstract_called = False
3090*da0073e9SAndroid Build Coastguard Worker
3091*da0073e9SAndroid Build Coastguard Worker        @add.register_fake
3092*da0073e9SAndroid Build Coastguard Worker        def _(x, y):
3093*da0073e9SAndroid Build Coastguard Worker            nonlocal abstract_called
3094*da0073e9SAndroid Build Coastguard Worker            abstract_called = True
3095*da0073e9SAndroid Build Coastguard Worker            return torch.empty_like(x)
3096*da0073e9SAndroid Build Coastguard Worker
3097*da0073e9SAndroid Build Coastguard Worker        with torch._subclasses.fake_tensor.FakeTensorMode():
3098*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
3099*da0073e9SAndroid Build Coastguard Worker            z = add(x, y)
3100*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.shape, x.shape)
3101*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(abstract_called)
3102*da0073e9SAndroid Build Coastguard Worker
3103*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("recursive dynamo")
3104*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
3105*da0073e9SAndroid Build Coastguard Worker    def test_compile(self):
3106*da0073e9SAndroid Build Coastguard Worker        called_impl = False
3107*da0073e9SAndroid Build Coastguard Worker        called_abstract = False
3108*da0073e9SAndroid Build Coastguard Worker
3109*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::linear", mutates_args=())
3110*da0073e9SAndroid Build Coastguard Worker        def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
3111*da0073e9SAndroid Build Coastguard Worker            nonlocal called_impl
3112*da0073e9SAndroid Build Coastguard Worker            called_impl = True
3113*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
3114*da0073e9SAndroid Build Coastguard Worker            w_np = weight.numpy()
3115*da0073e9SAndroid Build Coastguard Worker            b_np = bias.numpy()
3116*da0073e9SAndroid Build Coastguard Worker            out_np = np.add(x_np @ w_np.T, bias)
3117*da0073e9SAndroid Build Coastguard Worker            return out_np
3118*da0073e9SAndroid Build Coastguard Worker
3119*da0073e9SAndroid Build Coastguard Worker        @custom_linear.register_fake
3120*da0073e9SAndroid Build Coastguard Worker        def _(x, weight, bias):
3121*da0073e9SAndroid Build Coastguard Worker            nonlocal called_abstract
3122*da0073e9SAndroid Build Coastguard Worker            called_abstract = True
3123*da0073e9SAndroid Build Coastguard Worker            assert x.dim() == 2
3124*da0073e9SAndroid Build Coastguard Worker            assert weight.dim() == 2
3125*da0073e9SAndroid Build Coastguard Worker            assert bias.dim() == 1
3126*da0073e9SAndroid Build Coastguard Worker            assert x.shape[1] == weight.shape[1]
3127*da0073e9SAndroid Build Coastguard Worker            assert weight.shape[0] == bias.shape[0]
3128*da0073e9SAndroid Build Coastguard Worker            assert x.device == weight.device
3129*da0073e9SAndroid Build Coastguard Worker            return x.new_empty(x.size(0), weight.size(0))
3130*da0073e9SAndroid Build Coastguard Worker
3131*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3132*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(2, 2)
3133*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(2)
3134*da0073e9SAndroid Build Coastguard Worker        out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
3135*da0073e9SAndroid Build Coastguard Worker            x, weight, bias
3136*da0073e9SAndroid Build Coastguard Worker        )
3137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
3138*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called_impl)
3139*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called_abstract)
3140*da0073e9SAndroid Build Coastguard Worker
3141*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3142*da0073e9SAndroid Build Coastguard Worker    def test_register_autograd_error_cases(self):
3143*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::g", mutates_args=())
3144*da0073e9SAndroid Build Coastguard Worker        def g(x: Tensor) -> Tensor:
3145*da0073e9SAndroid Build Coastguard Worker            return x.sin()
3146*da0073e9SAndroid Build Coastguard Worker
3147*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
3148*da0073e9SAndroid Build Coastguard Worker        y = g(x)
3149*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "no autograd formula"):
3150*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
3151*da0073e9SAndroid Build Coastguard Worker
3152*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3153*da0073e9SAndroid Build Coastguard Worker    def test_replacement(self):
3154*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3155*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3156*da0073e9SAndroid Build Coastguard Worker            return x.sin()
3157*da0073e9SAndroid Build Coastguard Worker
3158*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3159*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3161*da0073e9SAndroid Build Coastguard Worker
3162*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3163*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3164*da0073e9SAndroid Build Coastguard Worker            return x.cos()
3165*da0073e9SAndroid Build Coastguard Worker
3166*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.cos())
3168*da0073e9SAndroid Build Coastguard Worker
3169*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3170*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3171*da0073e9SAndroid Build Coastguard Worker    def test_split_device(self):
3172*da0073e9SAndroid Build Coastguard Worker        cpu_call_count = 0
3173*da0073e9SAndroid Build Coastguard Worker        cuda_call_count = 0
3174*da0073e9SAndroid Build Coastguard Worker
3175*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
3176*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::f", mutates_args=(), device_types="cpu"
3177*da0073e9SAndroid Build Coastguard Worker        )
3178*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3179*da0073e9SAndroid Build Coastguard Worker            nonlocal cpu_call_count
3180*da0073e9SAndroid Build Coastguard Worker            cpu_call_count += 1
3181*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
3182*da0073e9SAndroid Build Coastguard Worker            out_np = np.sin(x_np)
3183*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np)
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker        @f.register_kernel("cuda")
3186*da0073e9SAndroid Build Coastguard Worker        def _(x: Tensor) -> Tensor:
3187*da0073e9SAndroid Build Coastguard Worker            nonlocal cuda_call_count
3188*da0073e9SAndroid Build Coastguard Worker            cuda_call_count += 1
3189*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
3190*da0073e9SAndroid Build Coastguard Worker            out_np = np.sin(x_np)
3191*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
3192*da0073e9SAndroid Build Coastguard Worker
3193*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3194*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_call_count, 1)
3197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cuda_call_count, 0)
3198*da0073e9SAndroid Build Coastguard Worker
3199*da0073e9SAndroid Build Coastguard Worker        x = x.cuda()
3200*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_call_count, 1)
3203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cuda_call_count, 1)
3204*da0073e9SAndroid Build Coastguard Worker
3205*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3206*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3207*da0073e9SAndroid Build Coastguard Worker    def test_multi_types(self):
3208*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
3209*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda")
3210*da0073e9SAndroid Build Coastguard Worker        )
3211*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3212*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
3213*da0073e9SAndroid Build Coastguard Worker            out_np = np.sin(x_np)
3214*da0073e9SAndroid Build Coastguard Worker            return torch.from_numpy(out_np).to(x.device)
3215*da0073e9SAndroid Build Coastguard Worker
3216*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3217*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3219*da0073e9SAndroid Build Coastguard Worker        x = x.cuda()
3220*da0073e9SAndroid Build Coastguard Worker        y = f(x)
3221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3224*da0073e9SAndroid Build Coastguard Worker    def test_overloading(self):
3225*da0073e9SAndroid Build Coastguard Worker        called_f = 0
3226*da0073e9SAndroid Build Coastguard Worker        called_f1 = 0
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3229*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3230*da0073e9SAndroid Build Coastguard Worker            nonlocal called_f
3231*da0073e9SAndroid Build Coastguard Worker            called_f += 1
3232*da0073e9SAndroid Build Coastguard Worker            return x.clone()
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
3235*da0073e9SAndroid Build Coastguard Worker        torch.ops._torch_testing.f(x)
3236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_f, 1)
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f.overload", mutates_args=())
3239*da0073e9SAndroid Build Coastguard Worker        def f1(x: Tensor, y: Tensor) -> Tensor:
3240*da0073e9SAndroid Build Coastguard Worker            nonlocal called_f1
3241*da0073e9SAndroid Build Coastguard Worker            called_f1 += 1
3242*da0073e9SAndroid Build Coastguard Worker            return x.clone()
3243*da0073e9SAndroid Build Coastguard Worker
3244*da0073e9SAndroid Build Coastguard Worker        torch.ops._torch_testing.f(x, x)
3245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_f1, 1)
3246*da0073e9SAndroid Build Coastguard Worker
3247*da0073e9SAndroid Build Coastguard Worker    def test_disallows_output_aliasing(self):
3248*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3249*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3250*da0073e9SAndroid Build Coastguard Worker            return x.view(-1)
3251*da0073e9SAndroid Build Coastguard Worker
3252*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3253*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3254*da0073e9SAndroid Build Coastguard Worker            f(x)
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3257*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3258*da0073e9SAndroid Build Coastguard Worker            return x
3259*da0073e9SAndroid Build Coastguard Worker
3260*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3261*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3262*da0073e9SAndroid Build Coastguard Worker            f(x)
3263*da0073e9SAndroid Build Coastguard Worker
3264*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
3265*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::f", mutates_args={"x"}, device_types="cpu"
3266*da0073e9SAndroid Build Coastguard Worker        )
3267*da0073e9SAndroid Build Coastguard Worker        def numpy_sin_inplace(x: Tensor) -> Tensor:
3268*da0073e9SAndroid Build Coastguard Worker            x_np = x.numpy()
3269*da0073e9SAndroid Build Coastguard Worker            np.sin(x_np, out=x_np)
3270*da0073e9SAndroid Build Coastguard Worker            return x
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3273*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3274*da0073e9SAndroid Build Coastguard Worker            numpy_sin_inplace(x)
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3277*da0073e9SAndroid Build Coastguard Worker    def test_factory_function(self):
3278*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
3279*da0073e9SAndroid Build Coastguard Worker            "_torch_testing::f", mutates_args={}, device_types="cpu"
3280*da0073e9SAndroid Build Coastguard Worker        )
3281*da0073e9SAndroid Build Coastguard Worker        def f(device: torch.device) -> Tensor:
3282*da0073e9SAndroid Build Coastguard Worker            return torch.ones(3)
3283*da0073e9SAndroid Build Coastguard Worker
3284*da0073e9SAndroid Build Coastguard Worker        result = f(device="cpu")
3285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.device, torch.device("cpu"))
3286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.ones(3))
3287*da0073e9SAndroid Build Coastguard Worker
3288*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3289*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "f does not have a kernel registered for cuda"
3290*da0073e9SAndroid Build Coastguard Worker        ):
3291*da0073e9SAndroid Build Coastguard Worker            f("cuda")
3292*da0073e9SAndroid Build Coastguard Worker
3293*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3294*da0073e9SAndroid Build Coastguard Worker            ValueError,
3295*da0073e9SAndroid Build Coastguard Worker            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3296*da0073e9SAndroid Build Coastguard Worker        ):
3297*da0073e9SAndroid Build Coastguard Worker
3298*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op(
3299*da0073e9SAndroid Build Coastguard Worker                "_torch_testing::f2", mutates_args={}, device_types="cpu"
3300*da0073e9SAndroid Build Coastguard Worker            )
3301*da0073e9SAndroid Build Coastguard Worker            def f2() -> Tensor:
3302*da0073e9SAndroid Build Coastguard Worker                return torch.ones(3)
3303*da0073e9SAndroid Build Coastguard Worker
3304*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f3", mutates_args={})
3305*da0073e9SAndroid Build Coastguard Worker        def f3() -> Tensor:
3306*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("NYI")
3307*da0073e9SAndroid Build Coastguard Worker
3308*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3309*da0073e9SAndroid Build Coastguard Worker            ValueError,
3310*da0073e9SAndroid Build Coastguard Worker            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3311*da0073e9SAndroid Build Coastguard Worker        ):
3312*da0073e9SAndroid Build Coastguard Worker
3313*da0073e9SAndroid Build Coastguard Worker            @f3.register_kernel("cpu")
3314*da0073e9SAndroid Build Coastguard Worker            def _():
3315*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(3)
3316*da0073e9SAndroid Build Coastguard Worker
3317*da0073e9SAndroid Build Coastguard Worker            result = f(x)
3318*da0073e9SAndroid Build Coastguard Worker
3319*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("_torch_testing::f4", mutates_args={})
3320*da0073e9SAndroid Build Coastguard Worker        def f4(device: torch.device) -> Tensor:
3321*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("NYI")
3322*da0073e9SAndroid Build Coastguard Worker
3323*da0073e9SAndroid Build Coastguard Worker        @f4.register_kernel("cpu")
3324*da0073e9SAndroid Build Coastguard Worker        def _(device: torch.device):
3325*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(3)
3326*da0073e9SAndroid Build Coastguard Worker
3327*da0073e9SAndroid Build Coastguard Worker        result = f(device="cpu")
3328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.device, torch.device("cpu"))
3329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.ones(3))
3330*da0073e9SAndroid Build Coastguard Worker
3331*da0073e9SAndroid Build Coastguard Worker    def test_library_schema_infer(self):
3332*da0073e9SAndroid Build Coastguard Worker        def foo_impl(x: torch.Tensor) -> torch.Tensor:
3333*da0073e9SAndroid Build Coastguard Worker            return x.sin()
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Worker        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
3336*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
3337*da0073e9SAndroid Build Coastguard Worker
3338*da0073e9SAndroid Build Coastguard Worker        schema = torch.library.infer_schema(foo_impl, mutates_args={})
3339*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
3340*da0073e9SAndroid Build Coastguard Worker
3341*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3342*da0073e9SAndroid Build Coastguard Worker    def test_set_kernel_enabled(self):
3343*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::f", mutates_args=())
3346*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor) -> Tensor:
3347*da0073e9SAndroid Build Coastguard Worker            return x + 1
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x), x + 1)
3350*da0073e9SAndroid Build Coastguard Worker        with self.assertLogs("torch._library.custom_ops") as captured:
3351*da0073e9SAndroid Build Coastguard Worker            with f.set_kernel_enabled("gpu", enabled=False):
3352*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(f(x), x + 1)
3353*da0073e9SAndroid Build Coastguard Worker            self.assertIn(
3354*da0073e9SAndroid Build Coastguard Worker                "no kernel was registered for this device type", captured.output[0]
3355*da0073e9SAndroid Build Coastguard Worker            )
3356*da0073e9SAndroid Build Coastguard Worker
3357*da0073e9SAndroid Build Coastguard Worker        @f.register_kernel("cpu")
3358*da0073e9SAndroid Build Coastguard Worker        def _(x):
3359*da0073e9SAndroid Build Coastguard Worker            return x + 2
3360*da0073e9SAndroid Build Coastguard Worker
3361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x), x + 2)
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker        with self.assertLogs("torch._library.custom_ops") as captured:
3364*da0073e9SAndroid Build Coastguard Worker            with f.set_kernel_enabled("cpu", enabled=True):
3365*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(f(x), x + 2)
3366*da0073e9SAndroid Build Coastguard Worker            self.assertIn("already enabled", captured.output[0])
3367*da0073e9SAndroid Build Coastguard Worker
3368*da0073e9SAndroid Build Coastguard Worker        with f.set_kernel_enabled("cpu", enabled=False):
3369*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x), x + 1)
3370*da0073e9SAndroid Build Coastguard Worker
3371*da0073e9SAndroid Build Coastguard Worker            with self.assertLogs("torch._library.custom_ops") as captured:
3372*da0073e9SAndroid Build Coastguard Worker                with f.set_kernel_enabled("cpu", enabled=False):
3373*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(f(x), x + 1)
3374*da0073e9SAndroid Build Coastguard Worker                self.assertIn("already disabled", captured.output[0])
3375*da0073e9SAndroid Build Coastguard Worker
3376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x), x + 1)
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker        with f.set_kernel_enabled("cpu", enabled=True):
3379*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x), x + 2)
3380*da0073e9SAndroid Build Coastguard Worker
3381*da0073e9SAndroid Build Coastguard Worker        with f.set_kernel_enabled("cpu", enabled=False):
3382*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x), x + 1)
3383*da0073e9SAndroid Build Coastguard Worker
3384*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3385*da0073e9SAndroid Build Coastguard Worker    def test_register_vmap_kwargonly_low_level(self):
3386*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3387*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor x, *, float y) -> Tensor")
3388*da0073e9SAndroid Build Coastguard Worker            called = False
3389*da0073e9SAndroid Build Coastguard Worker
3390*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, *, y):
3391*da0073e9SAndroid Build Coastguard Worker                return x * y
3392*da0073e9SAndroid Build Coastguard Worker
3393*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker            def vmap(info, in_dims, x, *, y):
3396*da0073e9SAndroid Build Coastguard Worker                nonlocal called
3397*da0073e9SAndroid Build Coastguard Worker                called = True
3398*da0073e9SAndroid Build Coastguard Worker                return x * y, 0
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3401*da0073e9SAndroid Build Coastguard Worker
3402*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(3)
3403*da0073e9SAndroid Build Coastguard Worker            result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
3404*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
3405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
3406*da0073e9SAndroid Build Coastguard Worker
3407*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3408*da0073e9SAndroid Build Coastguard Worker    def test_register_vmap_defaults(self):
3409*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3410*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
3411*da0073e9SAndroid Build Coastguard Worker
3412*da0073e9SAndroid Build Coastguard Worker            def foo_impl(w, x=2, *, y=3, z):
3413*da0073e9SAndroid Build Coastguard Worker                return w * x * y * z
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker            called = False
3418*da0073e9SAndroid Build Coastguard Worker
3419*da0073e9SAndroid Build Coastguard Worker            def vmap(info, in_dims, w, x=2, *, y=3, z):
3420*da0073e9SAndroid Build Coastguard Worker                nonlocal called
3421*da0073e9SAndroid Build Coastguard Worker                called = True
3422*da0073e9SAndroid Build Coastguard Worker                return w * x * y * z, 0
3423*da0073e9SAndroid Build Coastguard Worker
3424*da0073e9SAndroid Build Coastguard Worker            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3425*da0073e9SAndroid Build Coastguard Worker
3426*da0073e9SAndroid Build Coastguard Worker            w = torch.ones(3)
3427*da0073e9SAndroid Build Coastguard Worker            result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
3428*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
3429*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, w * 2 * 3 * 42)
3430*da0073e9SAndroid Build Coastguard Worker
3431*da0073e9SAndroid Build Coastguard Worker    def test_layout_constraint_tags(self):
3432*da0073e9SAndroid Build Coastguard Worker        needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order
3433*da0073e9SAndroid Build Coastguard Worker        flexible_layout = torch._C.Tag.flexible_layout
3434*da0073e9SAndroid Build Coastguard Worker        # (tags, the result of the tag inference)
3435*da0073e9SAndroid Build Coastguard Worker        tests = [
3436*da0073e9SAndroid Build Coastguard Worker            ({needs_fixed_stride_order}, needs_fixed_stride_order),
3437*da0073e9SAndroid Build Coastguard Worker            ({flexible_layout}, flexible_layout),
3438*da0073e9SAndroid Build Coastguard Worker            # If no tags are provided, then the following is the default
3439*da0073e9SAndroid Build Coastguard Worker            (set(), flexible_layout),
3440*da0073e9SAndroid Build Coastguard Worker            # If multiple tags are provided, then we use the most constrained tag.
3441*da0073e9SAndroid Build Coastguard Worker            ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order),
3442*da0073e9SAndroid Build Coastguard Worker        ]
3443*da0073e9SAndroid Build Coastguard Worker        from torch._inductor.lowering import get_layout_constraint_tag
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker        for tags, expected in tests:
3446*da0073e9SAndroid Build Coastguard Worker            with torch.library._scoped_library("mylib", "FRAGMENT") as m:
3447*da0073e9SAndroid Build Coastguard Worker                m.define("foobar(Tensor x) -> Tensor", tags=tags)
3448*da0073e9SAndroid Build Coastguard Worker                result = get_layout_constraint_tag(torch.ops.mylib.foobar.default)
3449*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, expected)
3450*da0073e9SAndroid Build Coastguard Worker
3451*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3452*da0073e9SAndroid Build Coastguard Worker    def test_library_register_vmap(self):
3453*da0073e9SAndroid Build Coastguard Worker        for mode in ["function", "qualname", "opoverload", "c_opdef"]:
3454*da0073e9SAndroid Build Coastguard Worker
3455*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("mylib::f", mutates_args=())
3456*da0073e9SAndroid Build Coastguard Worker            def f(x: Tensor, y: Tensor) -> Tensor:
3457*da0073e9SAndroid Build Coastguard Worker                return x * y
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker            called = False
3460*da0073e9SAndroid Build Coastguard Worker
3461*da0073e9SAndroid Build Coastguard Worker            def fvmap(info, in_dims, x, y):
3462*da0073e9SAndroid Build Coastguard Worker                nonlocal called
3463*da0073e9SAndroid Build Coastguard Worker                called = True
3464*da0073e9SAndroid Build Coastguard Worker                x_bdim, y_bdim = in_dims
3465*da0073e9SAndroid Build Coastguard Worker                x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3466*da0073e9SAndroid Build Coastguard Worker                y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3467*da0073e9SAndroid Build Coastguard Worker                result = x * y
3468*da0073e9SAndroid Build Coastguard Worker                result = result.movedim(-1, 0)
3469*da0073e9SAndroid Build Coastguard Worker                return result, 0
3470*da0073e9SAndroid Build Coastguard Worker
3471*da0073e9SAndroid Build Coastguard Worker            if mode == "function":
3472*da0073e9SAndroid Build Coastguard Worker                torch.library.register_vmap(f, fvmap)
3473*da0073e9SAndroid Build Coastguard Worker            elif mode == "qualname":
3474*da0073e9SAndroid Build Coastguard Worker                torch.library.register_vmap("mylib::f", fvmap)
3475*da0073e9SAndroid Build Coastguard Worker            elif mode == "opoverload":
3476*da0073e9SAndroid Build Coastguard Worker                torch.library.register_vmap(torch.ops.mylib.f.default, fvmap)
3477*da0073e9SAndroid Build Coastguard Worker            elif mode == "c_opdef":
3478*da0073e9SAndroid Build Coastguard Worker                f.register_vmap(fvmap)
3479*da0073e9SAndroid Build Coastguard Worker
3480*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 2)
3481*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(2, 2)
3482*da0073e9SAndroid Build Coastguard Worker
3483*da0073e9SAndroid Build Coastguard Worker            result = torch.vmap(f)(x, y)
3484*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
3485*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, x * y)
3486*da0073e9SAndroid Build Coastguard Worker
3487*da0073e9SAndroid Build Coastguard Worker            called = False
3488*da0073e9SAndroid Build Coastguard Worker            result = torch.vmap(f, out_dims=1)(x, y)
3489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, (x * y).T)
3490*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker            called = False
3493*da0073e9SAndroid Build Coastguard Worker            result = torch.vmap(f, in_dims=1)(x, y)
3494*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, (x * y).T)
3495*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(called)
3496*da0073e9SAndroid Build Coastguard Worker
3497*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3498*da0073e9SAndroid Build Coastguard Worker    def test_library_register_vmap_library_decorator(self):
3499*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::f", mutates_args=())
3500*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor, y: Tensor) -> Tensor:
3501*da0073e9SAndroid Build Coastguard Worker            return x * y
3502*da0073e9SAndroid Build Coastguard Worker
3503*da0073e9SAndroid Build Coastguard Worker        called = False
3504*da0073e9SAndroid Build Coastguard Worker
3505*da0073e9SAndroid Build Coastguard Worker        @torch.library.register_vmap("mylib::f")
3506*da0073e9SAndroid Build Coastguard Worker        def fvmap(info, in_dims, x, y):
3507*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3508*da0073e9SAndroid Build Coastguard Worker            called = True
3509*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3510*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3511*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3512*da0073e9SAndroid Build Coastguard Worker            result = x * y
3513*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3514*da0073e9SAndroid Build Coastguard Worker            return result, 0
3515*da0073e9SAndroid Build Coastguard Worker
3516*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3517*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
3518*da0073e9SAndroid Build Coastguard Worker
3519*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3520*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * y)
3522*da0073e9SAndroid Build Coastguard Worker
3523*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3524*da0073e9SAndroid Build Coastguard Worker    def test_library_register_vmap_op_decorator(self):
3525*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::f", mutates_args=())
3526*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor, y: Tensor) -> Tensor:
3527*da0073e9SAndroid Build Coastguard Worker            return x * y
3528*da0073e9SAndroid Build Coastguard Worker
3529*da0073e9SAndroid Build Coastguard Worker        called = False
3530*da0073e9SAndroid Build Coastguard Worker
3531*da0073e9SAndroid Build Coastguard Worker        @f.register_vmap
3532*da0073e9SAndroid Build Coastguard Worker        def fvmap(info, in_dims, x, y):
3533*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3534*da0073e9SAndroid Build Coastguard Worker            called = True
3535*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3536*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3537*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3538*da0073e9SAndroid Build Coastguard Worker            result = x * y
3539*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3540*da0073e9SAndroid Build Coastguard Worker            return result, 0
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3543*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
3544*da0073e9SAndroid Build Coastguard Worker
3545*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3546*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * y)
3548*da0073e9SAndroid Build Coastguard Worker
3549*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3550*da0073e9SAndroid Build Coastguard Worker    def test_library_register_vmap_register_multiple_times(self):
3551*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::f", mutates_args=())
3552*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor, y: Tensor) -> Tensor:
3553*da0073e9SAndroid Build Coastguard Worker            return x * y
3554*da0073e9SAndroid Build Coastguard Worker
3555*da0073e9SAndroid Build Coastguard Worker        called = False
3556*da0073e9SAndroid Build Coastguard Worker
3557*da0073e9SAndroid Build Coastguard Worker        @f.register_vmap
3558*da0073e9SAndroid Build Coastguard Worker        def fvmap(info, in_dims, x, y):
3559*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3560*da0073e9SAndroid Build Coastguard Worker            called = True
3561*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3562*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3563*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3564*da0073e9SAndroid Build Coastguard Worker            result = x * y
3565*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3566*da0073e9SAndroid Build Coastguard Worker            return result, 0
3567*da0073e9SAndroid Build Coastguard Worker
3568*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3569*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
3570*da0073e9SAndroid Build Coastguard Worker
3571*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3572*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * y)
3574*da0073e9SAndroid Build Coastguard Worker        called = False
3575*da0073e9SAndroid Build Coastguard Worker
3576*da0073e9SAndroid Build Coastguard Worker        @f.register_vmap
3577*da0073e9SAndroid Build Coastguard Worker        def fvmap2(info, in_dims, x, y):
3578*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3579*da0073e9SAndroid Build Coastguard Worker            called = True
3580*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3581*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3582*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3583*da0073e9SAndroid Build Coastguard Worker            result = x + y
3584*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3585*da0073e9SAndroid Build Coastguard Worker            return result, 0
3586*da0073e9SAndroid Build Coastguard Worker
3587*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3588*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x + y)
3590*da0073e9SAndroid Build Coastguard Worker
3591*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3592*da0073e9SAndroid Build Coastguard Worker    def test_library_register_vmap_register_multiple_times_2(self):
3593*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::f", mutates_args=())
3594*da0073e9SAndroid Build Coastguard Worker        def f(x: Tensor, y: Tensor) -> Tensor:
3595*da0073e9SAndroid Build Coastguard Worker            return x * y
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker        called = False
3598*da0073e9SAndroid Build Coastguard Worker
3599*da0073e9SAndroid Build Coastguard Worker        @torch.library.register_vmap("mylib::f")
3600*da0073e9SAndroid Build Coastguard Worker        def fvmap(info, in_dims, x, y):
3601*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3602*da0073e9SAndroid Build Coastguard Worker            called = True
3603*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3604*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3605*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3606*da0073e9SAndroid Build Coastguard Worker            result = x * y
3607*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3608*da0073e9SAndroid Build Coastguard Worker            return result, 0
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3611*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3614*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * y)
3616*da0073e9SAndroid Build Coastguard Worker        called = False
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        @torch.library.register_vmap("mylib::f")
3619*da0073e9SAndroid Build Coastguard Worker        def fvmap2(info, in_dims, x, y):
3620*da0073e9SAndroid Build Coastguard Worker            nonlocal called
3621*da0073e9SAndroid Build Coastguard Worker            called = True
3622*da0073e9SAndroid Build Coastguard Worker            x_bdim, y_bdim = in_dims
3623*da0073e9SAndroid Build Coastguard Worker            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3624*da0073e9SAndroid Build Coastguard Worker            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3625*da0073e9SAndroid Build Coastguard Worker            result = x + y
3626*da0073e9SAndroid Build Coastguard Worker            result = result.movedim(-1, 0)
3627*da0073e9SAndroid Build Coastguard Worker            return result, 0
3628*da0073e9SAndroid Build Coastguard Worker
3629*da0073e9SAndroid Build Coastguard Worker        result = torch.vmap(f)(x, y)
3630*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
3631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x + y)
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker
3634*da0073e9SAndroid Build Coastguard Workerclass MiniOpTestOther(CustomOpTestCaseBase):
3635*da0073e9SAndroid Build Coastguard Worker    test_ns = "mini_op_test"
3636*da0073e9SAndroid Build Coastguard Worker
3637*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_again(self):
3638*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1, 2, 0, 0])
3639*da0073e9SAndroid Build Coastguard Worker        y = torch.ops.aten.nonzero.default(x)
3640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, torch.tensor([[1], [2]]))
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker
3643*da0073e9SAndroid Build Coastguard Workeroptests.generate_opcheck_tests(
3644*da0073e9SAndroid Build Coastguard Worker    MiniOpTest,
3645*da0073e9SAndroid Build Coastguard Worker    ["aten", "mini_op_test"],
3646*da0073e9SAndroid Build Coastguard Worker    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3647*da0073e9SAndroid Build Coastguard Worker    additional_decorators={
3648*da0073e9SAndroid Build Coastguard Worker        "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
3649*da0073e9SAndroid Build Coastguard Worker    },
3650*da0073e9SAndroid Build Coastguard Worker    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3651*da0073e9SAndroid Build Coastguard Worker)
3652*da0073e9SAndroid Build Coastguard Worker
3653*da0073e9SAndroid Build Coastguard Workeroptests.generate_opcheck_tests(
3654*da0073e9SAndroid Build Coastguard Worker    MiniOpTestOther,
3655*da0073e9SAndroid Build Coastguard Worker    ["aten", "mini_op_test"],
3656*da0073e9SAndroid Build Coastguard Worker    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3657*da0073e9SAndroid Build Coastguard Worker    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3658*da0073e9SAndroid Build Coastguard Worker)
3659*da0073e9SAndroid Build Coastguard Worker
3660*da0073e9SAndroid Build Coastguard Worker
3661*da0073e9SAndroid Build Coastguard Workerclass TestGenerateOpcheckTests(CustomOpTestCaseBase):
3662*da0073e9SAndroid Build Coastguard Worker    def test_MiniOpTest(self):
3663*da0073e9SAndroid Build Coastguard Worker        for orig_test in ["test_mm", "test_nonzero"]:
3664*da0073e9SAndroid Build Coastguard Worker            for (
3665*da0073e9SAndroid Build Coastguard Worker                test
3666*da0073e9SAndroid Build Coastguard Worker            ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
3667*da0073e9SAndroid Build Coastguard Worker                expected_test = f"{test}__{orig_test}"
3668*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
3669*da0073e9SAndroid Build Coastguard Worker
3670*da0073e9SAndroid Build Coastguard Worker    def test_generate_repro_save_data(self):
3671*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.optests.generate_tests import generate_repro
3672*da0073e9SAndroid Build Coastguard Worker
3673*da0073e9SAndroid Build Coastguard Worker        args = (torch.ones(2, 2),)
3674*da0073e9SAndroid Build Coastguard Worker        kwargs = {"mat2": torch.zeros(2, 2)}
3675*da0073e9SAndroid Build Coastguard Worker        actual = generate_repro(
3676*da0073e9SAndroid Build Coastguard Worker            "test_schema",
3677*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sin.default,
3678*da0073e9SAndroid Build Coastguard Worker            args,
3679*da0073e9SAndroid Build Coastguard Worker            kwargs,
3680*da0073e9SAndroid Build Coastguard Worker            save_data=True,
3681*da0073e9SAndroid Build Coastguard Worker            dry_run=True,
3682*da0073e9SAndroid Build Coastguard Worker        )
3683*da0073e9SAndroid Build Coastguard Worker        actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
3684*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
3685*da0073e9SAndroid Build Coastguard Worker            actual,
3686*da0073e9SAndroid Build Coastguard Worker            """\
3687*da0073e9SAndroid Build Coastguard Worker# =========================================================
3688*da0073e9SAndroid Build Coastguard Worker# BEGIN REPRO SCRIPT
3689*da0073e9SAndroid Build Coastguard Worker# =========================================================
3690*da0073e9SAndroid Build Coastguard Workerimport torch
3691*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.optests import opcheck
3692*da0073e9SAndroid Build Coastguard Worker
3693*da0073e9SAndroid Build Coastguard Worker# Make sure you have loaded the library that contains the op
3694*da0073e9SAndroid Build Coastguard Worker# via an import or torch.ops.load_library(...)
3695*da0073e9SAndroid Build Coastguard Workerop = torch.ops.aten.sin.default
3696*da0073e9SAndroid Build Coastguard Worker
3697*da0073e9SAndroid Build Coastguard Workerargs, kwargs = torch.load("repro.pt")
3698*da0073e9SAndroid Build Coastguard Workeropcheck(op, args, kwargs, test_utils="test_schema")
3699*da0073e9SAndroid Build Coastguard Worker# =========================================================
3700*da0073e9SAndroid Build Coastguard Worker# END REPRO SCRIPT
3701*da0073e9SAndroid Build Coastguard Worker# =========================================================
3702*da0073e9SAndroid Build Coastguard Worker""",
3703*da0073e9SAndroid Build Coastguard Worker        )
3704*da0073e9SAndroid Build Coastguard Worker
3705*da0073e9SAndroid Build Coastguard Worker    def test_generate_repro_no_save_data(self):
3706*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.optests.generate_tests import generate_repro
3707*da0073e9SAndroid Build Coastguard Worker
3708*da0073e9SAndroid Build Coastguard Worker        args = (torch.ones(2, 2),)
3709*da0073e9SAndroid Build Coastguard Worker        kwargs = {"mat2": torch.zeros(2, 2)}
3710*da0073e9SAndroid Build Coastguard Worker        actual = generate_repro(
3711*da0073e9SAndroid Build Coastguard Worker            "test_schema",
3712*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sin.default,
3713*da0073e9SAndroid Build Coastguard Worker            args,
3714*da0073e9SAndroid Build Coastguard Worker            kwargs,
3715*da0073e9SAndroid Build Coastguard Worker            save_data=False,
3716*da0073e9SAndroid Build Coastguard Worker            dry_run=True,
3717*da0073e9SAndroid Build Coastguard Worker        )
3718*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
3719*da0073e9SAndroid Build Coastguard Worker            actual,
3720*da0073e9SAndroid Build Coastguard Worker            """\
3721*da0073e9SAndroid Build Coastguard Worker# =========================================================
3722*da0073e9SAndroid Build Coastguard Worker# BEGIN REPRO SCRIPT
3723*da0073e9SAndroid Build Coastguard Worker# =========================================================
3724*da0073e9SAndroid Build Coastguard Workerimport torch
3725*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.optests import opcheck
3726*da0073e9SAndroid Build Coastguard Worker
3727*da0073e9SAndroid Build Coastguard Worker# Make sure you have loaded the library that contains the op
3728*da0073e9SAndroid Build Coastguard Worker# via an import or torch.ops.load_library(...)
3729*da0073e9SAndroid Build Coastguard Workerop = torch.ops.aten.sin.default
3730*da0073e9SAndroid Build Coastguard Worker
3731*da0073e9SAndroid Build Coastguard Worker# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
3732*da0073e9SAndroid Build Coastguard Worker# we will fill them in same (args, kwargs) as in your test
3733*da0073e9SAndroid Build Coastguard Workerargs = ()  # args to the operator
3734*da0073e9SAndroid Build Coastguard Workerkwargs = {}  # kwargs to the operator
3735*da0073e9SAndroid Build Coastguard Workeropcheck(op, args, kwargs, test_utils="test_schema")
3736*da0073e9SAndroid Build Coastguard Worker# =========================================================
3737*da0073e9SAndroid Build Coastguard Worker# END REPRO SCRIPT
3738*da0073e9SAndroid Build Coastguard Worker# =========================================================
3739*da0073e9SAndroid Build Coastguard Worker""",
3740*da0073e9SAndroid Build Coastguard Worker        )
3741*da0073e9SAndroid Build Coastguard Worker
3742*da0073e9SAndroid Build Coastguard Worker    def test_failures_dict_validation(self):
3743*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.optests.generate_tests import (
3744*da0073e9SAndroid Build Coastguard Worker            FailuresDict,
3745*da0073e9SAndroid Build Coastguard Worker            validate_failures_dict_structure,
3746*da0073e9SAndroid Build Coastguard Worker        )
3747*da0073e9SAndroid Build Coastguard Worker
3748*da0073e9SAndroid Build Coastguard Worker        failures = {
3749*da0073e9SAndroid Build Coastguard Worker            "mini_op_test::incorrect_schema": {
3750*da0073e9SAndroid Build Coastguard Worker                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": {
3751*da0073e9SAndroid Build Coastguard Worker                    "comment": "",
3752*da0073e9SAndroid Build Coastguard Worker                    "status": "success",
3753*da0073e9SAndroid Build Coastguard Worker                }
3754*da0073e9SAndroid Build Coastguard Worker            }
3755*da0073e9SAndroid Build Coastguard Worker        }
3756*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "got status=success"):
3757*da0073e9SAndroid Build Coastguard Worker            validate_failures_dict_structure(
3758*da0073e9SAndroid Build Coastguard Worker                FailuresDict("", failures),
3759*da0073e9SAndroid Build Coastguard Worker                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3760*da0073e9SAndroid Build Coastguard Worker                MiniOpTest,
3761*da0073e9SAndroid Build Coastguard Worker            )
3762*da0073e9SAndroid Build Coastguard Worker
3763*da0073e9SAndroid Build Coastguard Worker        failures = {
3764*da0073e9SAndroid Build Coastguard Worker            "mini_op_test::incorrect_schema": {
3765*da0073e9SAndroid Build Coastguard Worker                "MiniOpTest.test_aot_dispatch__test_delayed_error": {
3766*da0073e9SAndroid Build Coastguard Worker                    "comment": "",
3767*da0073e9SAndroid Build Coastguard Worker                    "status": "xfail",
3768*da0073e9SAndroid Build Coastguard Worker                },
3769*da0073e9SAndroid Build Coastguard Worker            }
3770*da0073e9SAndroid Build Coastguard Worker        }
3771*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
3772*da0073e9SAndroid Build Coastguard Worker            validate_failures_dict_structure(
3773*da0073e9SAndroid Build Coastguard Worker                FailuresDict("", failures),
3774*da0073e9SAndroid Build Coastguard Worker                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3775*da0073e9SAndroid Build Coastguard Worker                MiniOpTest,
3776*da0073e9SAndroid Build Coastguard Worker            )
3777*da0073e9SAndroid Build Coastguard Worker
3778*da0073e9SAndroid Build Coastguard Worker        failures = {
3779*da0073e9SAndroid Build Coastguard Worker            "mini_op_test::incorrect_schema": {
3780*da0073e9SAndroid Build Coastguard Worker                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": {
3781*da0073e9SAndroid Build Coastguard Worker                    "comment": "",
3782*da0073e9SAndroid Build Coastguard Worker                    "status": "xfail",
3783*da0073e9SAndroid Build Coastguard Worker                },
3784*da0073e9SAndroid Build Coastguard Worker            }
3785*da0073e9SAndroid Build Coastguard Worker        }
3786*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
3787*da0073e9SAndroid Build Coastguard Worker            validate_failures_dict_structure(
3788*da0073e9SAndroid Build Coastguard Worker                FailuresDict("", failures),
3789*da0073e9SAndroid Build Coastguard Worker                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3790*da0073e9SAndroid Build Coastguard Worker                MiniOpTest,
3791*da0073e9SAndroid Build Coastguard Worker            )
3792*da0073e9SAndroid Build Coastguard Worker
3793*da0073e9SAndroid Build Coastguard Worker    def test_dont_generate_decorator(self):
3794*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
3795*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
3796*da0073e9SAndroid Build Coastguard Worker
3797*da0073e9SAndroid Build Coastguard Worker    def test_opcheck(self):
3798*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
3799*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "OpOverload"):
3800*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(torch.sin, (x,))
3801*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
3802*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
3803*da0073e9SAndroid Build Coastguard Worker        result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
3804*da0073e9SAndroid Build Coastguard Worker
3805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3806*da0073e9SAndroid Build Coastguard Worker            result,
3807*da0073e9SAndroid Build Coastguard Worker            {
3808*da0073e9SAndroid Build Coastguard Worker                "test_schema": "SUCCESS",
3809*da0073e9SAndroid Build Coastguard Worker                "test_autograd_registration": "SUCCESS",
3810*da0073e9SAndroid Build Coastguard Worker                "test_faketensor": "SUCCESS",
3811*da0073e9SAndroid Build Coastguard Worker                "test_aot_dispatch_dynamic": "SUCCESS",
3812*da0073e9SAndroid Build Coastguard Worker            },
3813*da0073e9SAndroid Build Coastguard Worker        )
3814*da0073e9SAndroid Build Coastguard Worker
3815*da0073e9SAndroid Build Coastguard Worker        result = torch.library.opcheck(
3816*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sin.default, (x,), test_utils="test_schema"
3817*da0073e9SAndroid Build Coastguard Worker        )
3818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, {"test_schema": "SUCCESS"})
3819*da0073e9SAndroid Build Coastguard Worker
3820*da0073e9SAndroid Build Coastguard Worker        result = torch.library.opcheck(
3821*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sin.default,
3822*da0073e9SAndroid Build Coastguard Worker            (x,),
3823*da0073e9SAndroid Build Coastguard Worker            test_utils=["test_schema", "test_faketensor"],
3824*da0073e9SAndroid Build Coastguard Worker        )
3825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3826*da0073e9SAndroid Build Coastguard Worker            result,
3827*da0073e9SAndroid Build Coastguard Worker            {
3828*da0073e9SAndroid Build Coastguard Worker                "test_schema": "SUCCESS",
3829*da0073e9SAndroid Build Coastguard Worker                "test_faketensor": "SUCCESS",
3830*da0073e9SAndroid Build Coastguard Worker            },
3831*da0073e9SAndroid Build Coastguard Worker        )
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker    def test_opcheck_customopdef(self):
3834*da0073e9SAndroid Build Coastguard Worker        sample_inputs = [
3835*da0073e9SAndroid Build Coastguard Worker            (torch.randn(3),),
3836*da0073e9SAndroid Build Coastguard Worker            (torch.randn(3, requires_grad=True),),
3837*da0073e9SAndroid Build Coastguard Worker        ]
3838*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3839*da0073e9SAndroid Build Coastguard Worker            sample_inputs.extend(
3840*da0073e9SAndroid Build Coastguard Worker                [
3841*da0073e9SAndroid Build Coastguard Worker                    (torch.randn(3, device="cuda"),),
3842*da0073e9SAndroid Build Coastguard Worker                    (torch.randn(3, device="cuda", requires_grad=True),),
3843*da0073e9SAndroid Build Coastguard Worker                ]
3844*da0073e9SAndroid Build Coastguard Worker            )
3845*da0073e9SAndroid Build Coastguard Worker        for args in sample_inputs:
3846*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(custom_op_db.numpy_cube, args)
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker    def test_is_inside_opcheck_mode(self):
3849*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(optests.is_inside_opcheck_mode())
3850*da0073e9SAndroid Build Coastguard Worker        with optests.generate_tests.OpCheckMode(
3851*da0073e9SAndroid Build Coastguard Worker            ["foo"], "bar", lambda x: x, None, "baz", "brr"
3852*da0073e9SAndroid Build Coastguard Worker        ):
3853*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(optests.is_inside_opcheck_mode())
3854*da0073e9SAndroid Build Coastguard Worker
3855*da0073e9SAndroid Build Coastguard Worker    def test_opcheck_bad_op(self):
3856*da0073e9SAndroid Build Coastguard Worker        op = op_with_incorrect_schema(self, "foo")
3857*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3858*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "is not defined to alias output"):
3859*da0073e9SAndroid Build Coastguard Worker            torch.library.opcheck(op, (x,))
3860*da0073e9SAndroid Build Coastguard Worker
3861*da0073e9SAndroid Build Coastguard Worker        result = torch.library.opcheck(op, (x,), raise_exception=False)
3862*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(result["test_schema"], RuntimeError))
3863*da0073e9SAndroid Build Coastguard Worker        del result["test_schema"]
3864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3865*da0073e9SAndroid Build Coastguard Worker            result,
3866*da0073e9SAndroid Build Coastguard Worker            {
3867*da0073e9SAndroid Build Coastguard Worker                "test_autograd_registration": "SUCCESS",
3868*da0073e9SAndroid Build Coastguard Worker                "test_faketensor": "SUCCESS",
3869*da0073e9SAndroid Build Coastguard Worker                "test_aot_dispatch_dynamic": "SUCCESS",
3870*da0073e9SAndroid Build Coastguard Worker            },
3871*da0073e9SAndroid Build Coastguard Worker        )
3872*da0073e9SAndroid Build Coastguard Worker
3873*da0073e9SAndroid Build Coastguard Worker    def test_opcheck_does_not_require_extra_deps(self):
3874*da0073e9SAndroid Build Coastguard Worker        # torch.testing._internal.common_utils comes with a lot of additional
3875*da0073e9SAndroid Build Coastguard Worker        # test-time dependencies. Since opcheck is public API, it should be
3876*da0073e9SAndroid Build Coastguard Worker        # usable only with pytorch install-time dependencies.
3877*da0073e9SAndroid Build Coastguard Worker        cmd = [
3878*da0073e9SAndroid Build Coastguard Worker            sys.executable,
3879*da0073e9SAndroid Build Coastguard Worker            "-c",
3880*da0073e9SAndroid Build Coastguard Worker            "import torch; import sys; \
3881*da0073e9SAndroid Build Coastguard Worker               x = torch.randn(3, requires_grad=True); \
3882*da0073e9SAndroid Build Coastguard Worker               torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
3883*da0073e9SAndroid Build Coastguard Worker               assert 'expecttest' not in sys.modules; \
3884*da0073e9SAndroid Build Coastguard Worker               assert 'torch.testing._internal.common_utils' not in sys.modules",
3885*da0073e9SAndroid Build Coastguard Worker        ]
3886*da0073e9SAndroid Build Coastguard Worker        subprocess.check_output(cmd, shell=False)
3887*da0073e9SAndroid Build Coastguard Worker
3888*da0073e9SAndroid Build Coastguard Worker
3889*da0073e9SAndroid Build Coastguard Workerclass TestTypeConversion(TestCase):
3890*da0073e9SAndroid Build Coastguard Worker    """In infer_schema(), we try to suggest a correct type when the type annotation is wrong."""
3891*da0073e9SAndroid Build Coastguard Worker
3892*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3893*da0073e9SAndroid Build Coastguard Worker        self.supported_base_types = [
3894*da0073e9SAndroid Build Coastguard Worker            int,
3895*da0073e9SAndroid Build Coastguard Worker            float,
3896*da0073e9SAndroid Build Coastguard Worker            bool,
3897*da0073e9SAndroid Build Coastguard Worker            str,
3898*da0073e9SAndroid Build Coastguard Worker            torch.device,
3899*da0073e9SAndroid Build Coastguard Worker            torch.Tensor,
3900*da0073e9SAndroid Build Coastguard Worker            torch.dtype,
3901*da0073e9SAndroid Build Coastguard Worker            torch.types.Number,
3902*da0073e9SAndroid Build Coastguard Worker        ]
3903*da0073e9SAndroid Build Coastguard Worker
3904*da0073e9SAndroid Build Coastguard Worker    def test_simple_tuple(self):
3905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(List, tuple_to_list(Tuple))
3906*da0073e9SAndroid Build Coastguard Worker
3907*da0073e9SAndroid Build Coastguard Worker    def test_supported_types(self):
3908*da0073e9SAndroid Build Coastguard Worker        for t in self.supported_base_types:
3909*da0073e9SAndroid Build Coastguard Worker            result_type = tuple_to_list(Tuple[t, t, t])
3910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_type, List[t])
3911*da0073e9SAndroid Build Coastguard Worker
3912*da0073e9SAndroid Build Coastguard Worker            result_type = tuple_to_list(Tuple[t])
3913*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_type, List[t])
3914*da0073e9SAndroid Build Coastguard Worker
3915*da0073e9SAndroid Build Coastguard Worker    def test_optional(self):
3916*da0073e9SAndroid Build Coastguard Worker        for t in self.supported_base_types:
3917*da0073e9SAndroid Build Coastguard Worker            result_type = tuple_to_list(Tuple[t, Optional[t]])
3918*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_type, List[Optional[t]])
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker            result_type = tuple_to_list(Tuple[t, t, Optional[t]])
3921*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_type, List[Optional[t]])
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Worker            result_type = tuple_to_list(Tuple[t, ...])
3924*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_type, List[t])
3925*da0073e9SAndroid Build Coastguard Worker
3926*da0073e9SAndroid Build Coastguard Worker    def test_mixed_types(self):
3927*da0073e9SAndroid Build Coastguard Worker        result_type = tuple_to_list(Tuple[int, float])
3928*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_type, List[typing.Union[int, float]])
3929*da0073e9SAndroid Build Coastguard Worker
3930*da0073e9SAndroid Build Coastguard Worker        result_type = tuple_to_list(Tuple[int, float, str])
3931*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_type, List[typing.Union[int, float, str]])
3932*da0073e9SAndroid Build Coastguard Worker
3933*da0073e9SAndroid Build Coastguard Worker
3934*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda")
3935*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
3936*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCustomOp)
3937*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCustomOpAPI)
3938*da0073e9SAndroid Build Coastguard Worker
3939*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
3940*da0073e9SAndroid Build Coastguard Worker    run_tests()
3941