xref: /aosp_15_r20/external/pytorch/test/test_custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: custom-operators"]
2
3import collections
4import itertools
5import os
6import re
7import subprocess
8import sys
9import typing
10import unittest
11from typing import *  # noqa: F403
12
13import numpy as np
14
15import torch._custom_ops as custom_ops
16import torch.testing._internal.optests as optests
17import torch.utils._pytree as pytree
18import torch.utils.cpp_extension
19from functorch import make_fx
20from torch import Tensor
21from torch._custom_op.impl import CustomOp, infer_schema
22from torch._library.infer_schema import tuple_to_list
23from torch._utils_internal import get_file_path_2
24from torch.testing._internal import custom_op_db
25from torch.testing._internal.common_cuda import TEST_CUDA
26from torch.testing._internal.common_device_type import (
27    instantiate_device_type_tests,
28    OpDTypes,
29    ops,
30)
31from torch.testing._internal.common_utils import (
32    instantiate_parametrized_tests,
33    IS_WINDOWS,
34    parametrize,
35    run_tests,
36    skipIfTorchDynamo,
37    subtest,
38    TestCase,
39)
40from torch.testing._internal.custom_op_db import numpy_nonzero
41
42
43# Shadowed by `torch.testing._internal.common_utils.custom_op`
44from torch._custom_op.impl import custom_op  # usort: skip
45
46
47def requires_compile(fun):
48    fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun)
49    return fun
50
51
52class CustomOpTestCaseBase(TestCase):
53    test_ns = "_test_custom_op"
54
55    def setUp(self):
56        super().setUp()
57        self.libraries = []
58
59    def tearDown(self):
60        super().tearDown()
61        import torch._custom_op
62
63        keys = list(torch._custom_op.impl.global_registry.keys())
64        for key in keys:
65            if not key.startswith(f"{self.test_ns}::"):
66                continue
67            torch._custom_op.impl.global_registry[key]._destroy()
68        if hasattr(torch.ops, self.test_ns):
69            delattr(torch.ops, self.test_ns)
70        for lib in self.libraries:
71            lib._destroy()
72        del self.libraries
73
74    def ns(self):
75        return getattr(torch.ops, self.test_ns)
76
77    def lib(self):
78        result = torch.library.Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
79        self.libraries.append(result)
80        return result
81
82    def get_op(self, qualname):
83        return torch._custom_op.impl.get_op(qualname)
84
85
86@requires_compile
87class TestCustomOpTesting(CustomOpTestCaseBase):
88    @parametrize("check_gradients", (False, "auto"))
89    @parametrize("dynamic", (True, False))
90    def test_aot_autograd_check_degenerate_cases(
91        self, device, dynamic, check_gradients
92    ):
93        def simple(x):
94            return x.clone()
95
96        # Should not raise
97        x = torch.randn(3, device=device)
98        optests.aot_autograd_check(
99            simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
100        )
101
102        def outputs_dont_require_grad(x):
103            return x.detach()
104
105        # Should not raise
106        y = torch.randn(3, device=device, requires_grad=True)
107        optests.aot_autograd_check(
108            simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
109        )
110
111        def no_outputs(x):
112            return x.detach()
113
114        # Should not raise
115        x = torch.randn(3, device=device, requires_grad=True)
116        y = torch.randn(3, device=device, requires_grad=False)
117        optests.aot_autograd_check(
118            no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
119        )
120        optests.aot_autograd_check(
121            no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
122        )
123
124    def test_incorrect_schema_mutation(self, device):
125        lib = self.lib()
126        lib.define("foo(Tensor x) -> Tensor")
127        op = self.ns().foo.default
128
129        class Foo(torch.autograd.Function):
130            @staticmethod
131            def forward(ctx, x):
132                guard = torch._C._AutoDispatchBelowAutograd()
133                try:
134                    return op(x)
135                finally:
136                    del guard
137
138            @staticmethod
139            def backward(ctx, gx):
140                return gx
141
142        def foo_impl(x):
143            x.sin_()
144            return x.clone()
145
146        lib.impl("foo", Foo.apply, "Autograd")
147        lib.impl("foo", foo_impl, "CPU")
148        lib.impl("foo", foo_impl, "CUDA")
149
150        x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
151        with self.assertRaisesRegex(
152            optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
153        ):
154            torch.library.opcheck(op, (x,), {})
155
156    def test_incorrect_schema_view(self, device):
157        lib = self.lib()
158        lib.define("foo(Tensor x) -> Tensor")
159        op = self.ns().foo.default
160
161        class Foo(torch.autograd.Function):
162            @staticmethod
163            def forward(ctx, x):
164                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
165                with torch._C._AutoDispatchBelowAutograd():
166                    with torch._C._ExcludeDispatchKeyGuard(
167                        torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
168                    ):
169                        return op(x)
170
171            @staticmethod
172            def backward(ctx, gx):
173                return gx
174
175        def foo_impl(x):
176            return x.view_as(x)
177
178        def foo_meta(x):
179            return x.view_as(x)
180
181        lib.impl("foo", Foo.apply, "Autograd")
182        lib.impl("foo", foo_impl, "CPU")
183        lib.impl("foo", foo_meta, "Meta")
184
185        x = torch.tensor(3.14159 / 3, requires_grad=True)
186        with self.assertRaisesRegex(
187            optests.OpCheckError,
188            "Argument x is not defined to alias output but was aliasing",
189        ):
190            torch.library.opcheck(op, (x,), {})
191
192    def test_missing_abstract_impl(self, device):
193        lib = self.lib()
194        lib.define("foo(Tensor x) -> Tensor")
195        op = self.ns().foo.default
196
197        class Foo(torch.autograd.Function):
198            @staticmethod
199            def forward(ctx, x):
200                with torch._C._AutoDispatchBelowAutograd():
201                    return op(x)
202
203            @staticmethod
204            def backward(ctx, gx):
205                return 2 * gx
206
207        def foo_impl(x):
208            return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
209
210        lib.impl("foo", Foo.apply, "Autograd")
211        lib.impl("foo", foo_impl, "CPU")
212        lib.impl("foo", foo_impl, "CUDA")
213
214        x = torch.tensor([0, 1.0], requires_grad=True)
215        with self.assertRaisesRegex(
216            optests.OpCheckError,
217            "_test_custom_op.foo.default",
218        ):
219            torch.library.opcheck(op, (x,), {})
220
221    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
222    def test_incorrect_abstract_impl(self, device):
223        lib = self.lib()
224        lib.define("foo(Tensor x) -> Tensor")
225        op = self.ns().foo.default
226
227        class Foo(torch.autograd.Function):
228            @staticmethod
229            def forward(ctx, x):
230                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
231                guard = torch._C._AutoDispatchBelowAutograd()
232                guard2 = torch._C.ExcludeDispatchKeyGuard(
233                    torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
234                )
235                try:
236                    return op(x)
237                finally:
238                    del guard
239                    del guard2
240
241            @staticmethod
242            def backward(ctx, gx):
243                return gx
244
245        def foo_impl(x):
246            return x**2
247
248        def foo_meta(x):
249            return x.unsqueeze(1) ** 2
250
251        lib.impl("foo", Foo.apply, "Autograd")
252        lib.impl("foo", foo_impl, "CPU")
253        lib.impl("foo", foo_impl, "CUDA")
254        lib.impl("foo", foo_meta, "Meta")
255
256        x = torch.tensor([0, 1.0], requires_grad=True)
257        with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
258            torch.library.opcheck(op, (x,), {})
259
260    def test_missing_functionalization(self, device):
261        lib = self.lib()
262        lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
263        op = self.ns().foo.default
264
265        class Foo(torch.autograd.Function):
266            @staticmethod
267            def forward(ctx, x):
268                ctx.mark_dirty(x)
269                with torch._C._AutoDispatchBelowAutograd():
270                    return op(x)
271
272            @staticmethod
273            def backward(ctx, gx):
274                return gx
275
276        def foo_impl(x):
277            return x.sin_()
278
279        def foo_meta(x):
280            return x
281
282        lib.impl("foo", Foo.apply, "Autograd")
283        lib.impl("foo", foo_impl, "CPU")
284        lib.impl("foo", foo_impl, "CUDA")
285        lib.impl("foo", foo_meta, "Meta")
286
287        x = torch.tensor([0, 1.0])
288        y = x.clone()
289        with self.assertRaisesRegex(
290            optests.OpCheckError,
291            "We only support functionalizing operators whose outputs do not have alias annotations",
292        ):
293            torch.library.opcheck(op, (y,), {})
294
295    def test_autograd_registered_at_backend(self, device):
296        lib = self.lib()
297        lib.define("foo(Tensor x) -> Tensor")
298        op = self.ns().foo.default
299
300        class Foo(torch.autograd.Function):
301            @staticmethod
302            def forward(ctx, x):
303                return x.clone()
304
305            @staticmethod
306            def backward(ctx, gx):
307                return gx * 0.5
308
309        lib.impl("foo", Foo.apply, "CPU")
310        lib.impl("foo", Foo.apply, "CUDA")
311        lib.impl("foo", lambda x: x.clone(), "Meta")
312
313        x = torch.randn([], requires_grad=True)
314
315        with self.assertRaisesRegex(
316            torch.testing._internal.optests.OpCheckError,
317            "does not have an autograd kernel",
318        ):
319            torch.library.opcheck(op, (x,), {})
320
321        # I'm not sure why this is necessary
322        del lib
323
324    def test_global_state_mutation(self, device):
325        lib = self.lib()
326        lib.define("foo(Tensor x) -> Tensor")
327        op = self.ns().foo.default
328
329        class Foo(torch.autograd.Function):
330            invoked = 0
331
332            @staticmethod
333            def forward(ctx, x):
334                Foo.invoked += 1
335                return x.clone() * Foo.invoked
336
337            @staticmethod
338            def backward(ctx, gx):
339                return gx
340
341        lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
342
343        x = torch.tensor(3.14159 / 3, requires_grad=True)
344        with self.assertRaisesRegex(
345            optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
346        ):
347            torch.library.opcheck(op, (x,), {})
348
349    @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
350    def test_opcheck_opinfo(self, device, dtype, op):
351        for sample_input in op.sample_inputs(
352            device, dtype, requires_grad=op.supports_autograd
353        ):
354            args = [sample_input.input] + list(sample_input.args)
355            kwargs = sample_input.kwargs
356            torch.library.opcheck(op.op, args, kwargs)
357
358    def test_opcheck_fails_basic(self, device):
359        @custom_op(f"{self.test_ns}::foo")
360        def foo(x: torch.Tensor) -> torch.Tensor: ...
361
362        @foo.impl(["cpu", "cuda"])
363        def foo_impl(x):
364            return x.sum()
365
366        x = torch.randn(3, device=device, requires_grad=True)
367        # Triggers the CustomOp autograd NYI error
368        with self.assertRaisesRegex(
369            optests.OpCheckError, "Autograd has not been implemented for operator"
370        ):
371            torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
372
373    def test_autograd_registration_check_autograd_kernel(self, device):
374        lib = self.lib()
375        lib.define("foo(Tensor x) -> Tensor")
376        op = self.ns().foo.default
377
378        class Foo(torch.autograd.Function):
379            @staticmethod
380            def forward(ctx, x):
381                with torch._C._AutoDispatchBelowAutograd():
382                    return op(x)
383
384            @staticmethod
385            def backward(ctx, gx):
386                return gx
387
388        def foo_impl(x):
389            return x.sin()
390
391        lib.impl("foo", Foo.apply, "Autograd")
392        lib.impl("foo", foo_impl, "CPU")
393        lib.impl("foo", foo_impl, "CUDA")
394
395        x = torch.randn(3, requires_grad=True, device=device)
396        # Should not raise
397        optests.autograd_registration_check(op, (x,), {})
398
399    def test_autograd_registration_check_compositeimplicitautograd(self, device):
400        lib = self.lib()
401        lib.define("foo(Tensor x) -> Tensor")
402        op = self.ns().foo.default
403
404        def foo_impl(x):
405            return x.sin().cos()
406
407        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
408
409        x = torch.randn(3, requires_grad=True, device=device)
410        # Should not raise
411        optests.autograd_registration_check(op, (x,), {})
412
413    def test_autograd_registration_check_incorrect_composite(self, device):
414        lib = self.lib()
415        lib.define("foo(Tensor x) -> Tensor")
416        op = self.ns().foo.default
417
418        def foo_impl(x):
419            return x.sin().cos()
420
421        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
422
423        x = torch.randn(3, requires_grad=True, device=device)
424        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
425            optests.autograd_registration_check(op, (x,), {})
426
427    def test_autograd_registration_check_incorrect(self, device):
428        lib = self.lib()
429        lib.define("foo(Tensor x) -> Tensor")
430        op = self.ns().foo.default
431
432        class Foo(torch.autograd.Function):
433            @staticmethod
434            def forward(ctx, x):
435                return torch.sin(x)
436
437            @staticmethod
438            def backward(ctx, gx):
439                return gx
440
441        lib.impl("foo", Foo.apply, "CPU")
442        lib.impl("foo", Foo.apply, "CUDA")
443
444        x = torch.randn(3, requires_grad=True, device=device)
445        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
446            optests.autograd_registration_check(op, (x,), {})
447
448    def test_assert_raises_regex(self, device):
449        from torch.testing._internal.optests.aot_autograd import assert_raises_regex
450
451        with assert_raises_regex(RuntimeError, "c"):
452            raise RuntimeError("abcd")
453        with assert_raises_regex(RuntimeError, "c.*"):
454            raise RuntimeError("abcd")
455        with self.assertRaisesRegex(AssertionError, "instead got"):
456            with assert_raises_regex(RuntimeError, "c.*"):
457                raise ValueError("abcd")
458        with self.assertRaisesRegex(AssertionError, "Expected exception"):
459            with assert_raises_regex(RuntimeError, "c.*"):
460                pass
461        with self.assertRaisesRegex(AssertionError, "to match regex"):
462            with assert_raises_regex(RuntimeError, "f"):
463                raise RuntimeError("abcd")
464
465
466class TestCustomOp(CustomOpTestCaseBase):
467    test_ns = "_test_custom_op"
468
469    @requires_compile
470    def test_functionalize_error(self):
471        with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:
472            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
473
474            def foo(x):
475                return x.sin_()
476
477            lib.impl("foo", foo, "CompositeExplicitAutograd")
478            foo_op = self.get_op(f"{self.test_ns}::foo")
479
480            lib.define("bar(Tensor(a) x) -> Tensor(a)")
481
482            def bar(x):
483                return x.view(-1)
484
485            lib.impl("bar", bar, "CompositeExplicitAutograd")
486            bar_op = self.get_op(f"{self.test_ns}::bar")
487
488            msg = r".*We only support functionalizing operators whose outputs do not have alias annotations"
489
490            x = torch.randn(3)
491
492            @torch.compile(backend="aot_eager", fullgraph=True)
493            def f(x):
494                return foo_op(x)
495
496            @torch.compile(backend="aot_eager", fullgraph=True)
497            def g(x):
498                return bar_op(x)
499
500            with self.assertRaisesRegex(RuntimeError, msg):
501                f(x)
502            with self.assertRaisesRegex(RuntimeError, msg):
503                g(x)
504
505    def test_invalid_schemas(self):
506        # function schmea validation goes through torchgen, so this is just a
507        # basic test.
508        with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
509            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
510
511    def test_invalid_qualname(self):
512        with self.assertRaisesRegex(ValueError, "overload"):
513            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
514
515    def test_name_must_match(self):
516        with self.assertRaisesRegex(ValueError, "to have name"):
517
518            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
519            def baz(x: Tensor) -> Tensor:
520                raise NotImplementedError
521
522    def test_unsupported_schemas(self):
523        with self.assertRaisesRegex(ValueError, "only supports functional"):
524            custom_ops.custom_op(
525                f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
526            )(foo)
527        with self.assertRaisesRegex(ValueError, "only supports functional"):
528            custom_ops.custom_op(
529                f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
530            )(foo)
531        with self.assertRaisesRegex(ValueError, "only supports functional"):
532            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
533                foo
534            )
535        with self.assertRaisesRegex(ValueError, "self"):
536            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
537                foo
538            )
539
540    # Tests for the older custom_op API
541    def test_schema_matches_signature(self):
542        with self.assertRaisesRegex(ValueError, "signature to match"):
543
544            @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
545            def blah(x):
546                pass
547
548        with self.assertRaisesRegex(ValueError, "signature to match"):
549
550            @custom_op(
551                f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
552            )
553            def blah2(x, y):
554                pass
555
556        with self.assertRaisesRegex(ValueError, "signature to match"):
557
558            @custom_op(
559                f"{TestCustomOp.test_ns}::blah3",
560                "(Tensor x, *, Tensor w, Tensor z) -> Tensor",
561            )
562            def blah3(x, *, y, z):
563                pass
564
565        with self.assertRaisesRegex(ValueError, "signature to match"):
566
567            @custom_op(
568                f"{TestCustomOp.test_ns}::blah4",
569                "(Tensor x, *, Tensor z, Tensor y) -> Tensor",
570            )
571            def blah4(x, *, y, z):
572                pass
573
574        with self.assertRaisesRegex(ValueError, "not supported"):
575
576            @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
577            def blah5(*args):
578                pass
579
580        with self.assertRaisesRegex(ValueError, "not supported"):
581
582            @custom_op(
583                f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
584            )
585            def blah6(**kwargs):
586                pass
587
588        with self.assertRaisesRegex(ValueError, "default arguments"):
589
590            @custom_op(
591                f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
592            )
593            def blah7(x=1, *, y):
594                pass
595
596        with self.assertRaisesRegex(ValueError, "default arguments"):
597
598            @custom_op(
599                f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
600            )
601            def blah8(x, *, y=1):
602                pass
603
604        # kwonly-arg works
605        @custom_op(
606            f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
607        )
608        def blah9(x, *, y):
609            pass
610
611    def test_infer_schema_no_return(self):
612        with self.assertRaisesRegex(
613            ValueError, "No return type annotation was provided. Please add one."
614        ):
615
616            @torch.library.custom_op("mylib::foo", mutates_args={})
617            def foo(x: torch.Tensor, y: int):
618                return x * y
619
620    def test_infer_schema_supported(self):
621        def a(x: Tensor) -> Tensor:
622            return torch.empty([])
623
624        self.assertExpectedInline(
625            infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
626        )
627
628        def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
629            return torch.empty([])
630
631        self.assertExpectedInline(
632            infer_schema(kwonly1, mutates_args=()),
633            """(Tensor x, *, SymInt y, float z) -> Tensor""",
634        )
635
636        def kwonly2(*, y: Tensor) -> Tensor:
637            return torch.empty([])
638
639        self.assertExpectedInline(
640            infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
641        )
642
643        def b(
644            x: Tensor,
645            y: int,
646            z: bool,
647            a: float,
648            b: torch.dtype,
649            c: torch.device,
650            d: torch.types.Number,
651        ) -> Tuple[Tensor, int, float, bool]:
652            return torch.empty([]), 1, 0.1, True
653
654        self.assertExpectedInline(
655            infer_schema(b, mutates_args=()),
656            """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
657        )
658
659        def c(
660            x: Tensor,
661            y: Sequence[Tensor],
662            z: Optional[Tensor],
663            w: Sequence[Optional[Tensor]],
664        ) -> List[Tensor]:
665            return [torch.empty([])]
666
667        self.assertExpectedInline(
668            infer_schema(c, mutates_args=()),
669            """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
670        )
671
672        def d(x: Tensor) -> Tuple[List[Tensor], Tensor]:
673            return [torch.empty([])], torch.empty([])
674
675        self.assertExpectedInline(
676            infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
677        )
678
679        def e() -> Tensor:
680            return torch.empty([])
681
682        self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
683
684        def f(x: Tensor) -> None:
685            pass
686
687        self.assertExpectedInline(
688            infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
689        )
690
691        def g(
692            x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
693        ) -> None:
694            pass
695
696        self.assertExpectedInline(
697            infer_schema(g, mutates_args=()),
698            """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
699        )
700
701        self.assertExpectedInline(
702            infer_schema(g, mutates_args={"x", "w", "z"}),
703            """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
704        )
705
706        self.assertExpectedInline(
707            infer_schema(g, mutates_args="unknown"),
708            """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
709        )
710
711        def h(
712            x: Tensor,
713            a: Optional[int] = None,
714            b: float = 3.14,
715            c: bool = True,
716            d: int = 3,
717            e: str = "foo",
718            f: torch.dtype = torch.float,
719            g: torch.dtype = torch.float32,
720            h: torch.dtype = torch.int,
721            i: torch.device = torch.device("cpu:0"),
722            j: torch.device = "cpu",
723        ) -> None:
724            pass
725
726        self.assertExpectedInline(
727            infer_schema(h, mutates_args=()),
728            (
729                """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
730                """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
731            ),
732        )
733
734        def foo_impl(x: torch.Tensor) -> torch.Tensor:
735            return x.sin()
736
737        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
738        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
739
740    def test_infer_schema_unsupported(self):
741        with self.assertRaisesRegex(ValueError, "varargs"):
742
743            def foo(*args):
744                raise NotImplementedError
745
746            infer_schema(foo, mutates_args=())
747
748        with self.assertRaisesRegex(ValueError, "varkwargs"):
749
750            def foo(**kwargs):
751                raise NotImplementedError
752
753            infer_schema(foo, mutates_args=())
754
755        with self.assertRaisesRegex(ValueError, "must have a type annotation"):
756
757            def foo(x):
758                raise NotImplementedError
759
760            infer_schema(foo, mutates_args=())
761
762        with self.assertRaisesRegex(ValueError, "unsupported"):
763
764            def foo(x: Tensor) -> Tuple[Tensor, ...]:
765                raise NotImplementedError
766
767            infer_schema(foo, mutates_args=())
768
769        with self.assertRaisesRegex(ValueError, "can be mutated"):
770
771            def foo(x: Tensor, y: int) -> Tensor:
772                raise NotImplementedError
773
774            infer_schema(foo, mutates_args={"y"})
775
776    def _generate_examples(self, typ):
777        if typ is int:
778            return [17]
779        if typ is float:
780            return [3.14]
781        if typ is bool:
782            return [True]
783        if typ is str:
784            return ["foo"]
785        if typ is torch.dtype:
786            return [torch.float32]
787        if typ is torch.device:
788            return [torch.device("cpu")]
789        if typ == torch.types.Number:
790            return [2.718]
791        if typ is torch.Tensor:
792            return [torch.tensor(3)]
793        if typ == Optional[torch.types.Number]:
794            return [None, 2.718]
795        origin = typing.get_origin(typ)
796        if origin is Union:
797            args = typing.get_args(typ)
798            assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
799            elt = args[0] if args[1] is type(None) else args[1]
800            return self._generate_examples(elt) + [None]
801        if origin is list:
802            args = typing.get_args(typ)
803            assert len(args) == 1
804            elt = args[0]
805            return [
806                self._generate_examples(elt),
807                self._generate_examples(elt),
808                self._generate_examples(elt),
809            ]
810        if origin is collections.abc.Sequence:
811            args = typing.get_args(typ)
812            assert len(args) == 1
813            examples = self._generate_examples(args[0])
814            return list(itertools.product(examples, examples)) + []
815        raise NotImplementedError(
816            f"testrunner cannot generate instanstance of type {typ}"
817        )
818
819    def test_supported_return_types_single_return(self):
820        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
821            for example in self._generate_examples(typ):
822                try:
823
824                    @custom_ops.custom_op(f"{self.test_ns}::foo")
825                    def foo(x: Tensor) -> typ:
826                        raise NotImplementedError
827
828                    @custom_ops.impl(f"{self.test_ns}::foo")
829                    def foo_impl(x: Tensor) -> typ:
830                        return example
831
832                    op = self.get_op(f"{self.test_ns}::foo")
833                    result = op(torch.randn([]))
834                    self.assertEqual(result, example, msg=f"{typ} {example}")
835                finally:
836                    custom_ops._destroy(f"{self.test_ns}::foo")
837
838    def test_supported_return_types_multi_return(self):
839        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
840            for example in self._generate_examples(typ):
841                try:
842
843                    @custom_ops.custom_op(f"{self.test_ns}::foo")
844                    def foo(x: Tensor) -> Tuple[typ, typ]:
845                        raise NotImplementedError
846
847                    @custom_ops.impl(f"{self.test_ns}::foo")
848                    def foo_impl(x: Tensor) -> Tuple[typ, typ]:
849                        return (example, example)
850
851                    op = self.get_op(f"{self.test_ns}::foo")
852                    result = op(torch.randn([]))
853                    expected = (example, example)
854                    self.assertEqual(result, expected, msg=f"{typ} {example}")
855                finally:
856                    custom_ops._destroy(f"{self.test_ns}::foo")
857
858    def test_supported_param_types(self):
859        for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES:
860
861            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862            def foo(x: Tensor, y: typ) -> Tensor:
863                raise NotImplementedError
864
865            yeet = None
866
867            @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
868            def foo_cpu(x, y):
869                nonlocal yeet
870                yeet = y
871                return x.clone()
872
873            try:
874                for example in self._generate_examples(typ):
875                    op = self.get_op(f"{self.test_ns}::foo")
876                    op(torch.randn([]), example)
877                    self.assertEqual(yeet, example, msg=f"{typ} {example}")
878                    yeet = None
879            finally:
880                custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
881
882    def test_sequences(self):
883        # Sequence[int] gets automagically turned into int[] in the schema.
884        # This test checks that we actually do support arbitrary sequence types.
885        class MySequence(collections.abc.Sequence):
886            def __init__(self) -> None:
887                self._container = [1, 2, 3]
888
889            def __getitem__(self, idx):
890                return self._container[idx]
891
892            def __len__(self):
893                return len(self._container)
894
895        @custom_ops.custom_op(f"{self.test_ns}::foo")
896        def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
897            raise NotImplementedError
898
899        called = 0
900
901        @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
902        def foo_cpu(x, sizes):
903            nonlocal called
904            called += 1
905            # Dispatcher will normalize the sequence type into a List
906            self.assertEqual(sizes, [1, 2, 3])
907            return x.clone()
908
909        x = torch.randn([])
910        seq = MySequence()
911        op = self.get_op(f"{self.test_ns}::foo")
912        op(x, seq)
913        self.assertEqual(called, 1)
914
915    def test_unsupported_param_types(self):
916        # Not comprehensive (it doesn't need to be), just a check that our mechanism works
917        with self.assertRaisesRegex(ValueError, "unsupported type"):
918
919            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
920            def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
921                raise NotImplementedError
922
923            del foo
924
925        with self.assertRaisesRegex(ValueError, "unsupported type"):
926            # int[N] in Dispatcher is a bit wild, so we don't try to support it.
927            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
928            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
929                raise NotImplementedError
930
931            del foo
932
933        with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"):
934            # test that we propose a correct and supported type.
935            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
936            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
937                raise NotImplementedError
938
939            del foo
940
941        with self.assertRaises(ValueError) as cm:
942
943            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
944            def foo(x: Tensor, y: Tuple[int, float]) -> Tensor:
945                raise NotImplementedError
946
947            del foo
948
949            self.assertNotIn("example", str(cm.exception), "")
950
951        with self.assertRaisesRegex(ValueError, "unsupported type"):
952
953            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
954            def foo(x: Tensor, y: Callable) -> Tensor:
955                raise NotImplementedError
956
957            del foo
958
959    def test_supported_schemas(self):
960        # All of these should already be tested by PyTorch codegen
961        # (we share the same mechanism), but here's a sanity check.
962        schemas = [
963            "(Tensor x) -> Tensor",
964            "(Tensor x) -> Tensor y",
965            "(Tensor[] x) -> Tensor y",
966            "(Tensor x) -> (Tensor, Tensor)",
967            "(Tensor x) -> (Tensor y, Tensor z)",
968            "(Tensor x) -> (Tensor y, Tensor z)",
969        ]
970        other_schemas = [
971            "(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
972            "(Tensor x, Tensor w) -> (Tensor, Tensor)",
973            "(Tensor x, Tensor w) -> Tensor",
974            "(Tensor? x, Tensor w) -> Tensor",
975            "(Tensor? x, Tensor[] w) -> Tensor",
976            "(Tensor x, int[] w) -> Tensor",
977            "(Tensor x, SymInt[] w) -> Tensor",
978            "(Tensor x, Scalar w) -> Tensor",
979            "(Tensor x, float w) -> Tensor",
980            "(Tensor x, float? w) -> Tensor",
981            "(Tensor x, bool[] w) -> Tensor",
982        ]
983
984        for schema in schemas:
985            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
986            custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
987        for schema in other_schemas:
988            custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
989            custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
990
991    def test_reserved_ns(self):
992        from torch._custom_op.impl import RESERVED_NS
993
994        for ns in RESERVED_NS:
995            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
996                custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
997
998            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
999
1000                @custom_ops.custom_op(f"{ns}::foo2")
1001                def foo2(x: torch.Tensor) -> torch.Tensor:
1002                    raise NotImplementedError
1003
1004    def test_private_ctor(self):
1005        with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
1006            CustomOp(None, None, None, None, None)
1007
1008    def test_lifetime(self):
1009        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1010        def foo(x: torch.Tensor) -> torch.Tensor:
1011            raise NotImplementedError
1012
1013        custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
1014
1015        # We can't define an op multiple times,
1016        with self.assertRaisesRegex(RuntimeError, "multiple times"):
1017
1018            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1019            def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1020                raise NotImplementedError
1021
1022        # Unless we delete the original op.
1023        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1024
1025        # Smoke test
1026        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1027        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1028            raise NotImplementedError
1029
1030        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1031
1032    def test_autograd_notimplemented(self):
1033        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1034        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1035            raise NotImplementedError
1036
1037        x = torch.randn(3, requires_grad=True)
1038        op = self.get_op(f"{self.test_ns}::foo")
1039        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1040            op(x)
1041        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1042        del foo
1043
1044        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1045        def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
1046            raise NotImplementedError
1047
1048        x = torch.randn(3, requires_grad=True)
1049        y = torch.randn(3)
1050        op = self.get_op(f"{self.test_ns}::foo")
1051        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1052            op([y, x])
1053        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1054        del foo
1055
1056        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1057        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1058            raise NotImplementedError
1059
1060        x = torch.randn(3, requires_grad=True)
1061        y = torch.randn(3)
1062        op = self.get_op(f"{self.test_ns}::foo")
1063        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1064            op(y, x)
1065        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1066
1067    def test_autograd_notimplemented_gradmode(self):
1068        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1069        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1070            raise NotImplementedError
1071
1072        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1073        def foo_impl(x, y):
1074            return x * y
1075
1076        x = torch.randn(3, requires_grad=True)
1077        y = torch.randn(3)
1078        op = self.get_op(f"{self.test_ns}::foo")
1079        with torch.no_grad():
1080            # Shouldn't raise, because we are in no_grad
1081            op(y, x)
1082
1083    def test_impl_cpu(self):
1084        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1085        def foo(x: torch.Tensor) -> torch.Tensor:
1086            raise NotImplementedError
1087
1088        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1089        def foo_cpu(x):
1090            return x.sin()
1091
1092        x = torch.randn(3)
1093        op = self.get_op(f"{self.test_ns}::foo")
1094        result = op(x)
1095        self.assertEqual(result, foo_cpu(x))
1096
1097    def test_impl_invalid_devices(self):
1098        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1099        def foo(x: torch.Tensor) -> torch.Tensor:
1100            raise NotImplementedError
1101
1102        def foo_impl(x):
1103            return x.sin()
1104
1105        from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
1106
1107        for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
1108            # Smoke test: should not raise error
1109            custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
1110                foo_impl
1111            )
1112
1113        # Not supported by this API: we can either support them in the future
1114        # or provide some other CustomOp.def_* function. This depends on how
1115        # common the use cases are.
1116        for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
1117            with self.assertRaisesRegex(ValueError, "we only support device_type"):
1118                custom_ops.impl(
1119                    f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
1120                )(foo_impl)
1121
1122    def test_backward_partially_registered(self):
1123        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1124        def foo(x: torch.Tensor) -> torch.Tensor:
1125            raise NotImplementedError
1126
1127        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1128        def foo_impl(x):
1129            return x.sin()
1130
1131        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1132        def foo_backward(ctx, saved, grad):
1133            return grad * saved.cos()
1134
1135        x = torch.randn([], requires_grad=True)
1136        op = self.get_op(f"{self.test_ns}::foo")
1137        with self.assertRaisesRegex(
1138            RuntimeError, "unable to find a 'save_for_backward'"
1139        ):
1140            y = op(x)
1141            y.backward()
1142
1143    def test_save_for_backward_inputs_are_namedtuple(self):
1144        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1145        def foo(x: torch.Tensor) -> torch.Tensor:
1146            raise NotImplementedError
1147
1148        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1149        def foo_impl(x):
1150            return x.sin()
1151
1152        hit = 0
1153
1154        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1155        def foo_save_for_backward(inputs, output):
1156            nonlocal hit
1157            hit += 1
1158            self.assertTrue(isinstance(inputs, tuple))
1159            self.assertEqual(list(inputs._asdict().keys()), ["x"])
1160            return inputs.x
1161
1162        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1163        def foo_backward(ctx, saved, grad):
1164            return {"x": grad * saved.cos()}
1165
1166        x = torch.randn([], requires_grad=True)
1167        op = self.get_op(f"{self.test_ns}::foo")
1168        y = op(x)
1169        self.assertEqual(hit, 1)
1170        y.backward()
1171        self.assertEqual(hit, 1)
1172
1173    def test_backward_returns_dict(self):
1174        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1175        def foo(x: torch.Tensor) -> torch.Tensor:
1176            raise NotImplementedError
1177
1178        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1179        def foo_impl(x):
1180            return x.sin()
1181
1182        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1183        def foo_save_for_backward(inputs, output):
1184            return inputs.x
1185
1186        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1187        def foo_backward(ctx, saved, grad):
1188            return grad * saved.cos()
1189
1190        x = torch.randn([], requires_grad=True)
1191        op = self.get_op(f"{self.test_ns}::foo")
1192        y = op(x)
1193        with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1194            y.backward()
1195
1196    def test_backward_dict_invalid_keys(self):
1197        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1198        def foo(x: torch.Tensor) -> torch.Tensor:
1199            raise NotImplementedError
1200
1201        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1202        def foo_impl(x):
1203            return x.sin()
1204
1205        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1206        def foo_save_for_backward(inputs, output):
1207            return inputs.x
1208
1209        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1210        def foo_backward(ctx, saved, grad):
1211            return {"x": grad * saved.cos(), "y": None}
1212
1213        x = torch.randn([], requires_grad=True)
1214        op = self.get_op(f"{self.test_ns}::foo")
1215        y = op(x)
1216        with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1217            y.backward()
1218
1219    def test_backward_dict_grad_for_nontensor(self):
1220        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1221        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1222            raise NotImplementedError
1223
1224        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1225        def foo_impl(x, dim):
1226            return x.sin()
1227
1228        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1229        def foo_save_for_backward(inputs, output):
1230            return inputs.x
1231
1232        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1233        def foo_backward(ctx, saved, grad):
1234            return {"x": grad * saved.cos(), "dim": None}
1235
1236        x = torch.randn([], requires_grad=True)
1237        op = self.get_op(f"{self.test_ns}::foo")
1238        y = op(x, 32)
1239        with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1240            y.backward()
1241
1242    def test_backward_dict_requires_keys_for_input_tensors(self):
1243        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1244        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1245            raise NotImplementedError
1246
1247        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1248        def foo_impl(x, y):
1249            return x.sin()
1250
1251        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1252        def foo_save_for_backward(inputs, output):
1253            return inputs.x
1254
1255        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1256        def foo_backward(ctx, saved, grad):
1257            return {"x": grad * saved.cos()}
1258
1259        x = torch.randn([], requires_grad=True)
1260        op = self.get_op(f"{self.test_ns}::foo")
1261        y = op(x, x)
1262        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1263            y.backward()
1264
1265    def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1266        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1267        def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1268            raise NotImplementedError
1269
1270        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1271        def foo_impl(x, y):
1272            return x.sin()
1273
1274        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1275        def foo_save_for_backward(inputs, output):
1276            return inputs.x
1277
1278        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1279        def foo_backward(ctx, saved, grad):
1280            return {"x": grad * saved.cos()}
1281
1282        x = torch.randn([], requires_grad=True)
1283        op = self.get_op(f"{self.test_ns}::foo")
1284        y = op(x, None)
1285        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1286            y.backward()
1287
1288    def test_backward_grads_are_tensor_or_none(self):
1289        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1290        def foo(x: torch.Tensor) -> torch.Tensor:
1291            raise NotImplementedError
1292
1293        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1294        def foo_impl(x):
1295            return x.sin()
1296
1297        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1298        def foo_save_for_backward(inputs, output):
1299            return inputs.x
1300
1301        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1302        def foo_backward(ctx, saved, grad):
1303            return {"x": (grad * saved.cos(),)}
1304
1305        x = torch.randn([], requires_grad=True)
1306        op = self.get_op(f"{self.test_ns}::foo")
1307        y = op(x)
1308        with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1309            y.backward()
1310
1311    def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1312        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1313        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1314            raise NotImplementedError
1315
1316        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1317        def foo_impl(xs):
1318            return xs[0].sin()
1319
1320        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1321        def foo_save_for_backward(inputs, output):
1322            return inputs.xs[0]
1323
1324        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1325        def foo_backward(ctx, saved, grad):
1326            return {"xs": [grad * saved.cos(), None]}
1327
1328        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1329        op = self.get_op(f"{self.test_ns}::foo")
1330        y = op(xs)
1331        with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1332            y.backward()
1333
1334    def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1335        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1336        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1337            raise NotImplementedError
1338
1339        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1340        def foo_impl(xs):
1341            return xs[0].sin()
1342
1343        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1344        def foo_save_for_backward(inputs, output):
1345            return inputs.xs[0]
1346
1347        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1348        def foo_backward(ctx, saved, grad):
1349            return {"xs": [grad * saved.cos(), None, (None,)]}
1350
1351        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1352        op = self.get_op(f"{self.test_ns}::foo")
1353        y = op(xs)
1354        with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1355            y.backward()
1356
1357    def test_backward_tensorlist_input_requires_list_grads(self):
1358        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1359        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1360            raise NotImplementedError
1361
1362        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1363        def foo_impl(xs):
1364            return xs[0].sin()
1365
1366        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1367        def foo_save_for_backward(inputs, output):
1368            return inputs.xs[0]
1369
1370        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1371        def foo_backward(ctx, saved, grad):
1372            return {"xs": None}
1373
1374        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1375        op = self.get_op(f"{self.test_ns}::foo")
1376        y = op(xs)
1377        with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1378            y.backward()
1379
1380    def test_backward_output_differentiability_type(self):
1381        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1382        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1383            raise NotImplementedError
1384
1385        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1386
1387            @custom_ops.impl_backward(
1388                f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1389            )
1390            def foo_backward(ctx, saved, grad):
1391                return {"xs": None}
1392
1393    def test_backward_output_differentiability_numel(self):
1394        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1395        def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1396            raise NotImplementedError
1397
1398        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1399
1400            @custom_ops.impl_backward(
1401                f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1402            )
1403            def foo_backward(ctx, saved, grad):
1404                return {"xs": None}
1405
1406    def test_backward_output_differentiability_tensorlist(self):
1407        @custom_ops.custom_op(f"{self.test_ns}::foo")
1408        def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1409            raise NotImplementedError
1410
1411        @custom_ops.impl(f"{self.test_ns}::foo")
1412        def foo_impl(x):
1413            return [x.clone(), x.clone()], x.clone()
1414
1415        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1416        def foo_save_for_backward(inputs, output):
1417            return []
1418
1419        @custom_ops.impl_backward(
1420            f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1421        )
1422        def foo_backward(ctx, saved, grad_lst, grad):
1423            return {"x": grad}
1424
1425        op = self.get_op(f"{self.test_ns}::foo")
1426        x = torch.randn(3, requires_grad=True)
1427        [a, b], c = op(x)
1428        self.assertFalse(a.requires_grad)
1429        self.assertFalse(b.requires_grad)
1430        self.assertTrue(c.requires_grad)
1431
1432    def test_backward_output_differentiability_non_tensor(self):
1433        @custom_ops.custom_op(f"{self.test_ns}::foo")
1434        def foo(x: Tensor) -> Tuple[Tensor, int]:
1435            raise NotImplementedError
1436
1437        @custom_ops.impl(f"{self.test_ns}::foo")
1438        def foo_impl(x):
1439            return x.clone(), 3
1440
1441        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1442        def foo_save_for_backward(inputs, output):
1443            return []
1444
1445        @custom_ops.impl_backward(
1446            f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1447        )
1448        def foo_backward(ctx, saved, grad0, grad1):
1449            return {"x": grad0}
1450
1451        op = self.get_op(f"{self.test_ns}::foo")
1452        x = torch.randn(3, requires_grad=True)
1453        with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1454            op(x)
1455
1456    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1457    def test_impl_separate(self):
1458        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1459        def foo(x: torch.Tensor) -> torch.Tensor:
1460            raise NotImplementedError
1461
1462        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1463        def foo_cpu(x):
1464            return x.sin()
1465
1466        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1467        def foo_cuda(x):
1468            return x.cos()
1469
1470        x = torch.randn(3)
1471        op = self.get_op(f"{self.test_ns}::foo")
1472        result = op(x)
1473        self.assertEqual(result, foo_cpu(x))
1474
1475        x_cuda = x.cuda()
1476        op = self.get_op(f"{self.test_ns}::foo")
1477        result = op(x_cuda)
1478        self.assertEqual(result, foo_cuda(x_cuda))
1479
1480    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1481    def test_impl_multiple(self):
1482        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1483        def foo(x: torch.Tensor) -> torch.Tensor:
1484            raise NotImplementedError
1485
1486        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1487        def foo_impl(x):
1488            return x.cos()
1489
1490        op = self.get_op(f"{self.test_ns}::foo")
1491        x = torch.randn(3)
1492        result = op(x)
1493        self.assertEqual(result, foo_impl(x))
1494
1495        x_cuda = x.cuda()
1496        result = op(x_cuda)
1497        self.assertEqual(result, foo_impl(x_cuda))
1498
1499    def test_impl_abstract_overload(self):
1500        lib = self.lib()
1501        lib.define("sin.blah(Tensor x) -> Tensor")
1502
1503        torch.library.impl_abstract(
1504            f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1505        )
1506
1507        op = self.ns().sin.blah
1508        x = torch.randn(3, device="meta")
1509        op(x)
1510
1511    def test_impl_meta(self):
1512        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1513        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1514            raise NotImplementedError
1515
1516        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1517        def foo_meta(x, dim):
1518            output_shape = list(x.shape)
1519            del output_shape[dim]
1520            return x.new_empty(output_shape)
1521
1522        x = torch.randn(2, 3, device="meta")
1523        op = self.get_op(f"{self.test_ns}::foo")
1524        result = op(x, 1)
1525        self.assertEqual(result.shape, foo_meta(x, 1).shape)
1526
1527    def test_duplicate_impl(self):
1528        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1529        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1530            raise NotImplementedError
1531
1532        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1533        def foo_meta(x, dim):
1534            output_shape = list(x.shape)
1535            del output_shape[dim]
1536            return x.new_empty(output_shape)
1537
1538        with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1539
1540            @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1541            def foo_meta2(x, dim):
1542                output_shape = list(x.shape)
1543                del output_shape[dim]
1544                return x.new_empty(output_shape)
1545
1546    def test_new_data_dependent_symint(self):
1547        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1548        def foo(x: torch.Tensor) -> torch.Tensor:
1549            raise NotImplementedError
1550
1551        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1552        def foo_meta(x):
1553            ctx = torch.library.get_ctx()
1554            r = ctx.new_dynamic_size(min=1)
1555            with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1556                ctx.new_dynamic_size(min=-1)
1557            with self.assertRaisesRegex(ValueError, "SymInt"):
1558                ctx.new_dynamic_size(max=x.numel())
1559            # NB: You must return dynamic sizes!
1560            return x.new_empty(r)
1561
1562        x = torch.randn(2, 3, device="cpu")
1563        op = self.get_op(f"{self.test_ns}::foo")
1564        make_fx(op, tracing_mode="symbolic")(x)
1565
1566    def test_meta_for_data_dependent_shape_operation(self):
1567        x = torch.randn(10, device="meta")
1568        with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1569            numpy_nonzero(x)
1570
1571    def test_basic_make_fx(self):
1572        # More serious tests are in our CustomOp opinfo db,
1573        # this one is just a sanity check.
1574        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1575        def foo(x: torch.Tensor) -> torch.Tensor:
1576            raise NotImplementedError
1577
1578        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1579        def foo_meta(x):
1580            return x.sum()
1581
1582        x = torch.randn(3)
1583        op = self.get_op(f"{self.test_ns}::foo")
1584        gm = make_fx(op, tracing_mode="symbolic")(x)
1585        self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1586
1587    def test_not_implemented_error(self):
1588        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1589        def foo(x: torch.Tensor) -> torch.Tensor:
1590            raise NotImplementedError
1591
1592        x = torch.randn(3)
1593        op = self.get_op(f"{self.test_ns}::foo")
1594        with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1595            op(x)
1596
1597        x = torch.randn(3, device="meta")
1598        with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"):
1599            op(x)
1600
1601        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1602        def bar(sizes: Sequence[int]) -> torch.Tensor:
1603            raise NotImplementedError
1604
1605        op = self.get_op(f"{self.test_ns}::bar")
1606        with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1607            op((1, 2, 3))
1608
1609    def test_data_dependent_basic(self):
1610        x = torch.randn(5, 5)
1611        gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
1612        self.assertTrue("nonzero" in gm.code)
1613
1614    def test_data_dependent_fake_tracing(self):
1615        x = torch.randn(5, 5)
1616        # We've updated to attempt to use unbacked symints even for fake
1617        # tracing
1618        make_fx(numpy_nonzero, tracing_mode="fake")(x)
1619
1620    def test_symints(self):
1621        def f(x):
1622            return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1623
1624        x = torch.randn(2, 3, 4)
1625        gm = make_fx(f, tracing_mode="symbolic")(x)
1626        result = gm(x)
1627        self.assertEqual(result, f(x))
1628        self.assertExpectedInline(
1629            gm.code.strip(),
1630            """\
1631def forward(self, x_1):
1632    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1633    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1634    sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1635    numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1636    return numpy_view_copy""",  # noqa: B950
1637        )
1638
1639    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1640    def test_data_dependent_compile(self):
1641        import torch._dynamo.testing
1642        from torch._dynamo.utils import counters
1643
1644        counters.clear()
1645        cnt = torch._dynamo.testing.CompileCounter()
1646
1647        @torch.compile(backend=cnt)
1648        def f(x):
1649            return numpy_nonzero(x.clone()).clone()
1650
1651        f(torch.randn(10))
1652
1653        self.assertEqual(len(counters["graph_break"]), 1)
1654        self.assertEqual(next(iter(counters["graph_break"].values())), 1)
1655        self.assertExpectedInline(
1656            next(iter(counters["graph_break"].keys())).replace(";", "\n"),
1657            """\
1658dynamic shape operator: _torch_testing.numpy_nonzero.default
1659 to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
1660        )
1661
1662    # pre-existing problem: torch.compile(dynamic=True) will, by default,
1663    # graph break on data-dependent operations. Eventually we'll make it so
1664    # that it never graph breaks on data-dependent operations.
1665    @unittest.expectedFailure
1666    def test_data_dependent_nms_dynamic_compile(self):
1667        import torch._dynamo.testing
1668        from torch._dynamo.utils import counters
1669
1670        counters.clear()
1671        cnt = torch._dynamo.testing.CompileCounter()
1672
1673        @torch.compile(backend=cnt, dynamic=True)
1674        def f(x, s, i):
1675            return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1676
1677        f(torch.randn(20, 4), torch.randn(20), 0.1)
1678
1679        self.assertEqual(len(counters["graph_break"]), 0)
1680
1681    def test_impl_on_existing_op(self):
1682        lib = self.lib()
1683        lib.define("foo(Tensor x) -> Tensor")
1684        qualname = f"{self.test_ns}::foo"
1685
1686        @torch._custom_ops.impl(qualname)
1687        def foo_impl(x):
1688            return x.sin()
1689
1690        op = self.get_op(qualname)
1691        x = torch.randn(3)
1692        result = op(x)
1693        self.assertEqual(result, x.sin())
1694
1695    @parametrize(
1696        "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1697    )
1698    def test_impl_on_existing_op_with_cpu_registration(self, key):
1699        lib = self.lib()
1700        lib.define("foo(Tensor x) -> Tensor")
1701        qualname = f"{self.test_ns}::foo"
1702
1703        def foo_impl(x):
1704            return x.sin()
1705
1706        lib.impl("foo", foo_impl, key)
1707        op = self.get_op(qualname)
1708
1709        with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1710            custom_ops.impl(qualname, func=foo_impl)
1711
1712    def test_abstract_impl_on_existing_op(self):
1713        lib = self.lib()
1714        lib.define("foo(Tensor x) -> Tensor")
1715        qualname = f"{self.test_ns}::foo"
1716
1717        @torch.library.impl_abstract(qualname, lib=self.lib())
1718        def foo_impl(x):
1719            return x.sin()
1720
1721        op = self.get_op(qualname)
1722        with torch._subclasses.FakeTensorMode():
1723            x = torch.randn(3)
1724            result = op(x)
1725            self.assertEqual(result.shape, x.shape)
1726            self.assertEqual(result.stride(), x.stride())
1727
1728    def test_abstract_impl_on_existing_op_with_meta(self):
1729        lib = self.lib()
1730        lib.define("foo(Tensor x) -> Tensor")
1731        qualname = f"{self.test_ns}::foo"
1732
1733        def foo_impl(x):
1734            return x.sin()
1735
1736        lib.impl("foo", foo_impl, "Meta")
1737        op = self.get_op(qualname)
1738
1739        with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1740            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1741
1742    def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1743        lib = self.lib()
1744        lib.define("foo(Tensor x) -> Tensor")
1745        qualname = f"{self.test_ns}::foo"
1746
1747        def foo_impl(x):
1748            return x.sin()
1749
1750        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1751        op = self.get_op(qualname)
1752
1753        with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1754            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1755
1756    def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1757        lib = self.lib()
1758        lib.define("foo(Tensor x) -> Tensor")
1759        qualname = f"{self.test_ns}::foo"
1760
1761        def foo_impl(x):
1762            return x.sin()
1763
1764        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1765        op = self.get_op(qualname)
1766
1767        torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1768        with torch._subclasses.FakeTensorMode():
1769            x = torch.randn(10)
1770            result = op(x)
1771            self.assertEqual(result.shape, ())
1772
1773    def _test_backward_impl_raises(self, qualname, err_regex):
1774        with self.assertRaisesRegex(RuntimeError, err_regex):
1775
1776            @custom_ops.impl_save_for_backward(qualname)
1777            def foo2(x):
1778                return
1779
1780        with self.assertRaisesRegex(RuntimeError, err_regex):
1781
1782            @custom_ops.impl_backward(qualname)
1783            def foo3(x):
1784                return
1785
1786    def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1787        lib = self.lib()
1788        lib.define("foo(Tensor(a) x) -> Tensor(a)")
1789        qualname = f"{self.test_ns}::foo"
1790        self._test_backward_impl_raises(qualname, "operator that returns views")
1791
1792    def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1793        lib = self.lib()
1794        lib.define("foo(Tensor(a!) x) -> Tensor")
1795        qualname = f"{self.test_ns}::foo"
1796        self._test_backward_impl_raises(qualname, "non-functional")
1797
1798    def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1799        lib = self.lib()
1800        lib.define("foo(Tensor x) -> ()")
1801        qualname = f"{self.test_ns}::foo"
1802        self._test_backward_impl_raises(qualname, "no returns")
1803
1804    def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1805        lib = self.lib()
1806        lib.define("foo(Tensor x) -> Tensor")
1807        qualname = f"{self.test_ns}::foo"
1808        lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1809        self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1810
1811    @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1812    def test_backward_impl_on_existing_op_with_key(self, key):
1813        lib = self.lib()
1814        lib.define("foo(Tensor x) -> Tensor")
1815        qualname = f"{self.test_ns}::foo"
1816        lib.impl("foo", lambda x: x.sin().cos(), key)
1817        self._test_backward_impl_raises(qualname, key)
1818
1819    def test_is_functional_schema(self):
1820        tests = {
1821            "foo(Tensor x) -> Tensor": True,
1822            "foo(Tensor(a) x) -> Tensor": True,
1823            "foo(Tensor(a!) x) -> Tensor": False,
1824            "foo(Tensor(a) x) -> Tensor(a)": False,
1825            "foo(Tensor x) -> ()": False,
1826        }
1827        for schema_str, expected in tests.items():
1828            res = torch._library.utils.is_functional_schema(schema_str)
1829            self.assertEqual(res, expected)
1830
1831            from torchgen.model import FunctionSchema
1832
1833            schema = FunctionSchema.parse(schema_str)
1834            res = torch._library.utils.is_functional_schema(schema)
1835            self.assertEqual(res, expected)
1836
1837            schema = torch._C.parse_schema(schema_str)
1838            res = torch._library.utils.is_functional_schema(schema)
1839            self.assertEqual(res, expected)
1840
1841    def test_incorrect_schema_types(self):
1842        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1843            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1844                lib.define("foo12(Tensor a) -> asdfasdf")
1845            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1846                lib.define("foo12(asdf a) -> Tensor")
1847            with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
1848                lib.define("foo12(int64_t a) -> Tensor")
1849            with self.assertRaisesRegex(RuntimeError, "Use `float`"):
1850                lib.define("foo12(double a) -> Tensor")
1851
1852    def test_is_tensorlist_like_type(self):
1853        tensorlists = [
1854            # Tensor[]
1855            torch.ops.aten.where.default._schema.returns[0].type,
1856            # Tensor?[]
1857            torch.ops.aten.index.Tensor._schema.arguments[1].type,
1858            # Tensor[]?
1859            torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type,
1860            # Tensor?[]?
1861            torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type,
1862        ]
1863        non_tensorlists = [
1864            # Tensor
1865            torch.ops.aten.sin.default._schema.arguments[0].type,
1866            # IntList
1867            torch.ops.aten.sum.dim_IntList._schema.arguments[1].type,
1868        ]
1869        for a in tensorlists:
1870            self.assertTrue(torch._library.utils.is_tensorlist_like_type(a))
1871        for a in non_tensorlists:
1872            self.assertFalse(torch._library.utils.is_tensorlist_like_type(a))
1873
1874    def test_backward_impl_on_existing_op(self):
1875        lib = self.lib()
1876        lib.define("foo(Tensor x) -> Tensor")
1877        qualname = f"{self.test_ns}::foo"
1878
1879        @custom_ops.impl(qualname)
1880        def foo_impl(x):
1881            with torch.no_grad():
1882                return x.sin()
1883
1884        @custom_ops.impl_save_for_backward(qualname)
1885        def foo_save_for_backward(inputs, output):
1886            return inputs.x
1887
1888        @custom_ops.impl_backward(qualname)
1889        def foo_backward(ctx, saved, grad_out):
1890            return {"x": grad_out * saved.cos()}
1891
1892        op = self.get_op(qualname)
1893        x = torch.randn([], requires_grad=True)
1894        y = op(x)
1895        (gx,) = torch.autograd.grad(y, x)
1896        self.assertEqual(gx, x.cos())
1897
1898    @parametrize(
1899        "tags",
1900        [
1901            subtest(torch.Tag.pointwise, "single"),
1902            subtest((torch.Tag.pointwise,), "tuple"),
1903            subtest([torch.Tag.pointwise], "list"),
1904        ],
1905    )
1906    def test_define_with_tags(self, tags):
1907        lib = self.lib()
1908        tags = (torch.Tag.pointwise,)
1909        torch.library.define(
1910            f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1911        )
1912        actual = self.ns().foo.default.tags
1913        self.assertTrue(isinstance(actual, list))
1914        self.assertEqual(actual, list(tags))
1915
1916    def test_builtin_aten_ops_are_pt2_compliant(self):
1917        for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1918            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1919
1920    def test_builtin_torchscript_ops(self):
1921        for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1922            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1923
1924    def test_autogen_aten_ops_are_pt2_compliant(self):
1925        for op in [torch.ops.aten.fill.Tensor_out]:
1926            self.assertIn(torch.Tag.generated, op.tags)
1927            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1928
1929    def test_resolve_packet(self):
1930        x = torch.randn(3)
1931        result = torch._C._jit_resolve_packet("aten::sum", x)
1932        self.assertEqual(result, "default")
1933
1934        result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1935        self.assertEqual(result, "dim_IntList")
1936
1937        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1938            result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1939
1940    def test_define_bad_schema(self):
1941        lib = self.lib()
1942        with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1943            torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1944
1945    def test_define_and_impl(self):
1946        lib = self.lib()
1947        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1948
1949        @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1950        def f(x):
1951            return torch.from_numpy(np.sin(x.numpy()))
1952
1953        x = torch.randn(3)
1954        y = self.ns().foo(x)
1955        assert torch.allclose(y, x.sin())
1956
1957    def test_define_validation(self):
1958        with self.assertRaisesRegex(ValueError, "namespace"):
1959            torch.library.define("foo", "(Tensor x) -> Tensor")
1960
1961    def test_legacy_define(self):
1962        lib = self.lib()
1963
1964        @torch.library.define(lib, "foo(Tensor x) -> Tensor")
1965        def f(x):
1966            return torch.from_numpy(np.sin(x.numpy()))
1967
1968        x = torch.randn(3)
1969        y = self.ns().foo(x)
1970        assert torch.allclose(y, x.sin())
1971
1972    def test_impl_function(self):
1973        lib = self.lib()
1974        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1975
1976        def f(x):
1977            return torch.from_numpy(np.sin(x.numpy()))
1978
1979        torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1980        x = torch.randn(3)
1981        y = self.ns().foo(x)
1982        assert torch.allclose(y, x.sin())
1983
1984    def test_legacy_impl(self):
1985        lib = self.lib()
1986        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1987
1988        @torch.library.impl(lib, "foo", "CPU")
1989        def f(x):
1990            return torch.from_numpy(np.sin(x.numpy()))
1991
1992        x = torch.randn(3)
1993        y = self.ns().foo(x)
1994        assert torch.allclose(y, x.sin())
1995
1996    def test_defined_in_python(self):
1997        self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1998        self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
1999
2000        lib = self.lib()
2001        torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2002        ns = self.ns()
2003        self.assertTrue(ns.foo.default._defined_in_python)
2004
2005        torch.library.define(
2006            "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
2007        )
2008        self.assertTrue(ns.bar.overload._defined_in_python)
2009
2010    def _test_impl_device(self, name, types, device):
2011        lib = self.lib()
2012        torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
2013
2014        @torch.library.impl(f"{self.test_ns}::{name}", types)
2015        def f(x):
2016            x_np = x.cpu().numpy()
2017            y = torch.from_numpy(np.sin(x_np))
2018            return y.to(device=x.device)
2019
2020        x = torch.randn(3, device=device)
2021        y = getattr(self.ns(), name)(x)
2022        assert torch.allclose(y, x.sin())
2023
2024    def test_impl_device_cpu(self):
2025        self._test_impl_device("foo1", "default", "cpu")
2026        self._test_impl_device("foo2", ["cpu"], "cpu")
2027        self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
2028
2029    @unittest.skipIf(not TEST_CUDA, "requires cuda")
2030    def test_impl_device_cuda(self):
2031        self._test_impl_device("foo4", "default", "cuda")
2032        self._test_impl_device("foo5", ["cuda"], "cuda")
2033        self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
2034
2035    def test_impl_device_function(self):
2036        lib = self.lib()
2037        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2038
2039        def f(x):
2040            x_np = x.cpu().numpy()
2041            y = torch.from_numpy(np.sin(x_np))
2042            return y.to(device=x.device)
2043
2044        torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
2045        x = torch.randn(3)
2046        y = self.ns().foo(x)
2047        assert torch.allclose(y, x.sin())
2048
2049    def test_impl_device_invalid(self):
2050        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
2051            torch.library.impl("blah::blah", "somethingsomething")
2052
2053    def test_autograd_function_backed_op(self):
2054        cpp_source = """
2055struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2056  static constexpr bool is_traceable = true;
2057
2058  static torch::Tensor forward(
2059      torch::autograd::AutogradContext* ctx,
2060      const torch::Tensor& x) {
2061    return x;
2062  }
2063
2064  static torch::autograd::variable_list backward(
2065      torch::autograd::AutogradContext *ctx,
2066      torch::autograd::variable_list grad_output) {
2067    return grad_output;
2068  }
2069};
2070
2071torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
2072  return CustomOpAutogradFunction::apply(x);
2073}
2074
2075TORCH_LIBRARY(mylib, m) {
2076    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2077}
2078        """
2079
2080        module = torch.utils.cpp_extension.load_inline(
2081            name="mylib",
2082            cpp_sources=cpp_source,
2083            functions="custom_op_backed_by_autograd_fn",
2084            verbose=True,
2085        )
2086
2087        x = torch.ones(2, 2, requires_grad=True)
2088        temp = x.clone().detach()
2089        out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
2090        loss = out.sum()
2091        loss.backward()
2092        self.assertEqual(x.grad, temp)
2093
2094
2095def op_with_incorrect_schema(testcase, name):
2096    lib = testcase.lib()
2097    lib.define(f"{name}(Tensor x) -> Tensor")
2098    qualname = f"{testcase.test_ns}::{name}"
2099    lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
2100    return testcase.get_op(qualname)
2101
2102
2103class MiniOpTest(CustomOpTestCaseBase):
2104    test_ns = "mini_op_test"
2105
2106    def _init_op_delayed_backward_error(self):
2107        name = "delayed_error"
2108        qualname = f"{self.test_ns}::{name}"
2109        lib = self.lib()
2110        lib.define(f"{name}(Tensor x) -> Tensor")
2111        lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
2112        op = self.get_op(qualname)
2113
2114        class Op(torch.autograd.Function):
2115            @staticmethod
2116            def forward(ctx, x):
2117                with torch._C._AutoDispatchBelowAutograd():
2118                    return op(x)
2119
2120            @staticmethod
2121            def backward(ctx, grad):
2122                raise NotImplementedError
2123
2124        def autograd_impl(x):
2125            return Op.apply(x)
2126
2127        lib.impl(name, autograd_impl, "Autograd")
2128        return op
2129
2130    def _init_op_with_no_abstract_impl(self):
2131        name = "no_abstract"
2132        qualname = f"{self.test_ns}::{name}"
2133        lib = self.lib()
2134        lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
2135        lib.impl(name, lambda x: x.clone(), "CPU")
2136        return torch._library.utils.lookup_op(qualname)
2137
2138    def setUp(self):
2139        super().setUp()
2140        self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
2141        self._op_delayed_backward_error = self._init_op_delayed_backward_error()
2142
2143    @optests.dontGenerateOpCheckTests("Testing this API")
2144    def test_dont_generate(self):
2145        op = op_with_incorrect_schema(self, "incorrect_schema")
2146        x = torch.randn(3)
2147        op(x)
2148
2149    def test_mm(self):
2150        x = torch.randn(2, 3, requires_grad=True)
2151        y = torch.randn(3, 5)
2152        result = torch.ops.aten.mm.default(x, y)
2153        self.assertEqual(result, x @ y)
2154
2155    def test_mm_meta(self):
2156        x = torch.randn(2, 3, requires_grad=True, device="meta")
2157        y = torch.randn(3, 5, device="meta")
2158        result = torch.ops.aten.mm.default(x, y)
2159        self.assertEqual(result.shape, (x @ y).shape)
2160
2161    def test_mm_fake(self):
2162        with torch._subclasses.fake_tensor.FakeTensorMode():
2163            x = torch.randn(2, 3, requires_grad=True, device="cpu")
2164            y = torch.randn(3, 5, device="cpu")
2165            result = torch.ops.aten.mm.default(x, y)
2166            self.assertEqual(result.shape, (x @ y).shape)
2167
2168    def test_mm_errors(self):
2169        x = torch.randn(2, 3, requires_grad=True)
2170        y = torch.randn(4, 5)
2171        with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
2172            result = torch.ops.aten.mm.default(x, y)
2173
2174    def test_nonzero(self):
2175        x = torch.tensor([0, 1, 2, 0, 0])
2176        y = torch.ops.aten.nonzero.default(x)
2177        self.assertEqual(y, torch.tensor([[1], [2]]))
2178
2179    def test_inplace(self):
2180        x = torch.randn(3)
2181        x_clone = x.clone()
2182        y = torch.ops.aten.sin_(x)
2183        self.assertEqual(x, x_clone.sin())
2184
2185    def test_incorrect_schema(self):
2186        op = op_with_incorrect_schema(self, "incorrect_schema")
2187        x = torch.randn(3)
2188        op(x)
2189
2190    def test_no_abstract(self):
2191        op = self._op_with_no_abstract_impl
2192        x = torch.randn(3)
2193        op(x)
2194
2195    def test_delayed_error(self):
2196        op = self._op_delayed_backward_error
2197        x = torch.randn([], requires_grad=True)
2198        y = op(x)
2199        with self.assertRaises(NotImplementedError):
2200            y.sum().backward()
2201
2202    def test_delayed_error_no_requires_grad(self):
2203        op = self._op_delayed_backward_error
2204        x = torch.randn([])
2205        y = op(x)
2206
2207
2208class TestCustomOpAPI(TestCase):
2209    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2210    def test_basic(self):
2211        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2212        def add(x: Tensor, y: float) -> Tensor:
2213            x_np = x.numpy(force=True)
2214            out_np = x_np + y
2215            return torch.from_numpy(out_np).to(x.device)
2216
2217        x = torch.randn(3)
2218        y = 3.14
2219        z = add(x, y)
2220        self.assertEqual(z, x + y)
2221
2222        cpu_called = False
2223
2224        @add.register_kernel("cpu")
2225        def _(x, y):
2226            nonlocal cpu_called
2227            cpu_called = True
2228            x_np = x.numpy()
2229            out_np = x_np + y
2230            return torch.from_numpy(out_np)
2231
2232        z = add(x, y)
2233        self.assertEqual(z, x + y)
2234        self.assertTrue(cpu_called)
2235
2236    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2237    def test_no_grad_skips_autograd(self):
2238        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2239        def add(x: Tensor, y: float) -> Tensor:
2240            x_np = x.numpy(force=True)
2241            out_np = x_np + y
2242            return torch.from_numpy(out_np).to(x.device)
2243
2244        called = 0
2245
2246        def setup_context(ctx, inputs, output):
2247            nonlocal called
2248            called += 1
2249
2250        def backward(ctx, grad):
2251            raise AssertionError("should not be reached")
2252
2253        add.register_autograd(backward, setup_context=setup_context)
2254
2255        x = torch.randn(3, requires_grad=True)
2256        with torch.no_grad():
2257            y = add(x, 2.0)
2258        self.assertEqual(called, 0)
2259        self.assertEqual(y, x + 2.0)
2260
2261        x.requires_grad_(False)
2262        y = add(x, 2.0)
2263        self.assertEqual(called, 0)
2264        self.assertEqual(y, x + 2.0)
2265
2266        x = torch.randn(3, requires_grad=True)
2267        y = add(x, 2.0)
2268        self.assertEqual(called, 1)
2269        self.assertEqual(y, x + 2.0)
2270
2271    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2272    def test_manual_schema(self):
2273        @torch.library.custom_op(
2274            "_torch_testing::add",
2275            mutates_args=(),
2276            schema="(Tensor x, float y) -> Tensor",
2277        )
2278        def add(x, y):
2279            x_np = x.numpy(force=True)
2280            out_np = x_np + y
2281            return torch.from_numpy(out_np).to(x.device)
2282
2283        x = torch.randn(3)
2284        y = 3.14
2285        z = add(x, y)
2286        self.assertEqual(z, x + y)
2287
2288        @torch.library.custom_op(
2289            "_torch_testing::sin_",
2290            mutates_args=["x"],
2291            schema="(Tensor(a!) x) -> ()",
2292        )
2293        def sin_(x):
2294            x_np = x.numpy()
2295            np.sin(x_np, out=x_np)
2296
2297        x = torch.randn(3)
2298        expected = x.sin()
2299        sin_(x)
2300        self.assertEqual(x, expected)
2301
2302    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2303    def test_kwarg_only_tensors(self):
2304        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2305
2306            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2307            def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor:
2308                pass
2309
2310        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2311
2312            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2313            def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor:
2314                pass
2315
2316        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2317
2318            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2319            def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor:
2320                pass
2321
2322        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2323            lib.define("foo(Tensor x, *, Tensor y) -> Tensor")
2324            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2325                torch.library.register_autograd(
2326                    "_torch_testing::foo",
2327                    lambda grad: grad,
2328                    setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
2329                )
2330
2331            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2332                torch.library.register_vmap(
2333                    "_torch_testing::foo",
2334                    lambda info, in_dims, x, *, y: (x, 0),
2335                )
2336
2337    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2338    def test_register_autograd_kwargonly_low_level(self):
2339        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2340            lib.define("foo(Tensor x, *, float y) -> Tensor")
2341            called = False
2342
2343            def foo_impl(x, *, y):
2344                return x * y
2345
2346            lib.impl("foo", foo_impl, "CPU")
2347
2348            def backward(ctx, grad):
2349                nonlocal called
2350                called = True
2351                return grad * ctx.y
2352
2353            def setup_context(ctx, inputs, keyword_only_inputs, output):
2354                assert tuple(keyword_only_inputs.keys()) == ("y",)
2355                ctx.y = keyword_only_inputs["y"]
2356
2357            torch.library.register_autograd(
2358                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2359            )
2360
2361            x = torch.randn(3, requires_grad=True)
2362            torch.ops._torch_testing.foo(x, y=3.14).sum().backward()
2363            self.assertTrue(called)
2364            self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14]))
2365
2366    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2367    def test_register_autograd_defaults(self):
2368        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2369            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
2370
2371            def foo_impl(w, x=2, *, y=3, z):
2372                return w * x * y * z
2373
2374            lib.impl("foo", foo_impl, "CPU")
2375
2376            called = False
2377
2378            def backward(ctx, grad):
2379                nonlocal called
2380                called = True
2381                return grad * ctx.c
2382
2383            def setup_context(ctx, inputs, keyword_only_inputs, output):
2384                assert len(inputs) == 2
2385                assert inputs[1] == 2
2386                assert keyword_only_inputs == {"y": 3, "z": 42}
2387                ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1]
2388
2389            torch.library.register_autograd(
2390                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2391            )
2392
2393            w = torch.randn(3, requires_grad=True)
2394            torch.ops._torch_testing.foo(w, z=42).sum().backward()
2395            self.assertTrue(called)
2396            self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42))
2397
2398    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2399    def test_manual_schema_error(self):
2400        with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"):
2401
2402            @torch.library.custom_op(
2403                "_torch_testing::sin_",
2404                mutates_args=(),
2405                schema="(Tensor(a!) x) -> ()",
2406            )
2407            def sin_(x):
2408                x_np = x.numpy()
2409                np.sin(x_np, out=x_np)
2410
2411    def test_supports_tensorlist(self):
2412        @torch._library.autograd.supports_tensorlist
2413        class Stack(torch.autograd.Function):
2414            @staticmethod
2415            def forward(ctx, xs):
2416                ctx.num_xs = len(xs)
2417                return torch.stack(xs)
2418
2419            @staticmethod
2420            def backward(ctx, grad):
2421                expected = ([True] * ctx.num_xs,)
2422                self.assertEqual(ctx.needs_input_grad, expected)
2423                return list(grad.unbind(0))
2424
2425        # call two applys, do a backward on the first
2426        def t():
2427            return torch.randn([], requires_grad=True)
2428
2429        xs0 = [t(), t(), t()]
2430        xs1 = [t(), t(), t(), t()]
2431        y0 = Stack.apply(xs0)
2432        y1 = Stack.apply(xs1)
2433        grads = torch.autograd.grad(y0.sum(), xs0)
2434        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2435
2436        # call one apply, do multiple backwards
2437        xs = [t(), t(), t()]
2438        y = Stack.apply(xs)
2439        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2440        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2441        grads = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2442        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2443
2444        # error: on access forward, backward directly
2445        with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"):
2446            Stack.forward(None, xs)
2447        with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"):
2448            Stack.backward(None, xs)
2449
2450        # the recursive case
2451        @torch._library.autograd.supports_tensorlist
2452        class Foo(torch.autograd.Function):
2453            @staticmethod
2454            def forward(ctx, xs):
2455                if len(xs) > 1:
2456                    return Foo.apply(xs[1:])
2457                ctx.len_xs = len(xs)
2458                return xs[0].sin()
2459
2460            @staticmethod
2461            def backward(ctx, grad):
2462                result = [None] * ctx.len_xs
2463                result[-1] = grad.cos()
2464                return result
2465
2466        # should work
2467        result = Foo.apply(xs)
2468        expected = xs[-1].sin()
2469        self.assertEqual(result, expected)
2470
2471        # recursive on backward
2472        @torch._library.autograd.supports_tensorlist
2473        class Bar(torch.autograd.Function):
2474            @staticmethod
2475            def forward(ctx, xs):
2476                return [xs[i] + i for i in range(len(xs))]
2477
2478            @staticmethod
2479            def backward(ctx, grads):
2480                f1 = Bar.apply(grads[:2])
2481                f2 = Bar.apply(grads[2:])
2482                return f1 + f2
2483
2484        xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)]
2485        ys = Bar.apply(xs)
2486        sum(ys).backward()
2487        result = [xi.grad for xi in xs]
2488        self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0))
2489
2490    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2491    def test_default_values(self):
2492        defaults = []
2493
2494        @torch.library.custom_op("_torch_testing::f", mutates_args=())
2495        def f(
2496            x: Tensor,
2497            a: Optional[int] = None,
2498            b: float = 3.14,
2499            c: bool = True,
2500            d: int = 3,
2501            e: str = "foo",
2502            f: torch.dtype = torch.float,
2503            g: torch.dtype = torch.float32,
2504            h: torch.dtype = torch.int,
2505            i: torch.device = torch.device("cpu:0"),
2506            j: torch.device = "cpu",
2507        ) -> Tensor:
2508            defaults.extend([a, b, c, d, e, f, g, h, i, j])
2509            return x.clone()
2510
2511        x = torch.randn(3)
2512        f(x)
2513        self.assertEqual(
2514            defaults,
2515            [
2516                None,
2517                3.14,
2518                True,
2519                3,
2520                "foo",
2521                torch.float,
2522                torch.float32,
2523                torch.int,
2524                torch.device("cpu:0"),
2525                "cpu",
2526            ],
2527        )
2528        default_values = [
2529            arg.default_value
2530            for arg in torch.ops._torch_testing.f.default._schema.arguments
2531        ]
2532        # enum values taken from c10/core/ScalarType.h
2533        type_enum = {
2534            "float": 6,
2535            "int": 3,
2536        }
2537        self.assertEqual(
2538            default_values,
2539            [
2540                None,
2541                None,
2542                3.14,
2543                True,
2544                3,
2545                "foo",
2546                type_enum["float"],
2547                type_enum["float"],
2548                type_enum["int"],
2549                torch.device("cpu:0"),
2550                torch.device("cpu"),
2551            ],
2552        )
2553
2554    def test_mutated_error(self):
2555        with self.assertRaisesRegex(
2556            ValueError, r".*{'y'} in mutates_args were not found"
2557        ):
2558
2559            @torch.library.custom_op(
2560                "_torch_testing::numpy_sin_inplace",
2561                mutates_args={"y"},
2562                device_types="cpu",
2563            )
2564            def numpy_sin_inplace(x: Tensor) -> None:
2565                x_np = x.numpy()
2566                np.sin(x_np, out=x_np)
2567
2568    def test_mutated(self):
2569        @torch.library.custom_op(
2570            "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu"
2571        )
2572        def numpy_sin_inplace(x: Tensor) -> None:
2573            x_np = x.numpy()
2574            np.sin(x_np, out=x_np)
2575
2576        x = torch.randn(3)
2577        version = x._version
2578        expected = x.sin()
2579        numpy_sin_inplace(x)
2580        self.assertEqual(x, expected)
2581        self.assertGreater(x._version, version)
2582
2583        @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"})
2584        def f(
2585            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2586        ) -> None:
2587            return
2588
2589        x = torch.randn(3)
2590        y = torch.randn(3)
2591        z = [torch.randn(3), torch.randn(3)]
2592        w = [torch.randn(3), None, torch.randn(3)]
2593        initial_versions = pytree.tree_map_only(
2594            torch.Tensor, lambda x: x._version, (x, y, z, w)
2595        )
2596        f(x, y, z, w)
2597        new_versions = pytree.tree_map_only(
2598            torch.Tensor, lambda x: x._version, (x, y, z, w)
2599        )
2600
2601        self.assertEqual(initial_versions[0], new_versions[0])
2602        initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
2603        new_versions, _ = pytree.tree_flatten(new_versions[1:])
2604        for prev, after in zip(initial_versions, new_versions):
2605            if prev is None and after is None:
2606                continue
2607            self.assertGreater(after, prev)
2608
2609    def test_mutated_unknown(self):
2610        @torch.library.custom_op(
2611            "_torch_testing::f", mutates_args="unknown", device_types="cpu"
2612        )
2613        def f(x: Tensor) -> None:
2614            x_np = x.numpy()
2615            np.sin(x_np, out=x_np)
2616
2617        x = torch.randn(3)
2618        version = x._version
2619        expected = x.sin()
2620        f(x)
2621        self.assertEqual(x, expected)
2622        self.assertGreater(x._version, version)
2623
2624        @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown")
2625        def f2(
2626            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2627        ) -> None:
2628            return
2629
2630        x = torch.randn(3)
2631        y = torch.randn(3)
2632        z = [torch.randn(3), torch.randn(3)]
2633        w = [torch.randn(3), None, torch.randn(3)]
2634        initial_versions = pytree.tree_map_only(
2635            torch.Tensor, lambda x: x._version, (x, y, z, w)
2636        )
2637        f2(x, y, z, w)
2638        new_versions = pytree.tree_map_only(
2639            torch.Tensor, lambda x: x._version, (x, y, z, w)
2640        )
2641
2642        initial_versions, _ = pytree.tree_flatten(initial_versions)
2643        new_versions, _ = pytree.tree_flatten(new_versions)
2644        for prev, after in zip(initial_versions, new_versions):
2645            if prev is None and after is None:
2646                continue
2647            self.assertGreater(after, prev)
2648
2649        with self.assertRaisesRegex(ValueError, "string"):
2650
2651            @torch.library.custom_op("_torch_testing::f3", mutates_args="x")
2652            def f3(x: Tensor) -> None:
2653                return
2654
2655    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2656    def test_library_register_torch_dispatch_rule_subclass(self):
2657        from torch.testing._internal.two_tensor import TwoTensor
2658
2659        @torch.library.custom_op("mylib::foo", mutates_args={})
2660        def f(x: torch.Tensor) -> torch.Tensor:
2661            return x.sin()
2662
2663        x = torch.randn(3)
2664        y = torch.randn(3)
2665        z = TwoTensor(x, y)
2666
2667        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2668            called = 0
2669
2670            def TwoTensor_foo(cls, func, types, args, kwargs):
2671                nonlocal called
2672                assert cls is TwoTensor
2673                called += 1
2674                return x.sin()
2675
2676            m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo)
2677
2678            out = f(z)
2679            out2 = z.cos()
2680
2681        self.assertEqual(called, 1)
2682
2683    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2684    def test_library_register_torch_dispatch_rule_mode(self):
2685        from torch.testing._internal.two_tensor import TwoTensorMode
2686
2687        @torch.library.custom_op("mylib::foo", mutates_args={})
2688        def f(x: torch.Tensor) -> torch.Tensor:
2689            return x.sin()
2690
2691        x = torch.randn(3)
2692
2693        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2694            called = 0
2695
2696            def TwoTensor_foo(mode, func, types, args, kwargs):
2697                nonlocal called
2698                called += 1
2699                return x.sin()
2700
2701            m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo)
2702
2703            with TwoTensorMode():
2704                out = f(x)
2705                out2 = x.cos()
2706
2707        self.assertEqual(called, 1)
2708
2709    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2710    @parametrize("idx", [0, 1, 2, 3, 4, 5])
2711    def test_library_register_fake_source(self, idx):
2712        opname = f"source{idx}"
2713        op = getattr(torch.ops._torch_testing, opname).default
2714        entry = torch._library.simple_registry.singleton.find(op._name)
2715        source = entry.fake_impl.kernel.source
2716        assert source is not None
2717        self.assertTrue("custom_op_db.py" in source)
2718
2719    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2720    def test_library_register_fake(self):
2721        for mode in ["function", "qualname", "opoverload"]:
2722
2723            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2724            def add(x: Tensor, y: float) -> Tensor:
2725                x_np = x.cpu().numpy()
2726                out_np = x_np + y
2727                return torch.from_numpy(out_np).to(x.device)
2728
2729            called = False
2730
2731            if mode == "function":
2732                dec = torch.library.register_fake(add)
2733                self.assertIsNotNone(dec)
2734            elif mode == "qualname":
2735                dec = torch.library.register_fake("_torch_testing::add")
2736                self.assertIsNotNone(dec)
2737            elif mode == "opoverload":
2738                dec = torch.library.register_fake(torch.ops._torch_testing.add.default)
2739                self.assertIsNotNone(dec)
2740            else:
2741                raise AssertionError("should not get here")
2742
2743            @dec
2744            def _(x, y):
2745                nonlocal called
2746                called = True
2747                return torch.empty_like(x)
2748
2749            with torch._subclasses.fake_tensor.FakeTensorMode():
2750                x = torch.randn(3)
2751                y = 3.14
2752                z = add(x, y)
2753                self.assertEqual(z.shape, x.shape)
2754                self.assertTrue(called)
2755
2756    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2757    def test_library_register_torch_dispatch(self):
2758        for mode in ["function", "qualname", "opoverload"]:
2759
2760            class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2761                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2762                    return func(*args, **kwargs)
2763
2764            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2765            def add(x: Tensor, y: float) -> Tensor:
2766                x_np = x.cpu().numpy()
2767                out_np = x_np + y
2768                return torch.from_numpy(out_np).to(x.device)
2769
2770            called = False
2771
2772            if mode == "function":
2773                dec = torch.library.register_torch_dispatch(add, MyMode)
2774                self.assertIsNotNone(dec)
2775            elif mode == "qualname":
2776                dec = torch.library.register_torch_dispatch(
2777                    "_torch_testing::add", MyMode
2778                )
2779                self.assertIsNotNone(dec)
2780            elif mode == "opoverload":
2781                dec = torch.library.register_torch_dispatch(
2782                    torch.ops._torch_testing.add.default, MyMode
2783                )
2784                self.assertIsNotNone(dec)
2785            else:
2786                raise AssertionError("should not get here")
2787
2788            @dec
2789            def _(mode, func, types, args, kwargs):
2790                nonlocal called
2791                called = True
2792                return func(*args, **kwargs)
2793
2794            with MyMode():
2795                x = torch.randn(3)
2796                y = 3.14
2797                z = add(x, y)
2798                self.assertEqual(z.shape, x.shape)
2799                self.assertTrue(called)
2800
2801    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2802    def test_library_register_torch_dispatch_low_level(self):
2803        modes = ["qualname", "opoverload"]
2804        calls = ["decorator", "function"]
2805        device_types_options = [("cpu", "cuda"), "cpu", None]
2806
2807        for mode, call, device_types in itertools.product(
2808            modes, calls, device_types_options
2809        ):
2810            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2811                lib.define("add10(Tensor x, float y) -> Tensor")
2812
2813                if mode == "qualname":
2814                    op = "_torch_testing::add10"
2815                else:
2816                    assert mode == "opoverload"
2817                    op = torch.ops._torch_testing.add10.default
2818
2819                called = False
2820
2821                class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2822                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2823                        return func(*args, **kwargs)
2824
2825                if call == "decorator":
2826
2827                    @torch.library.register_torch_dispatch(op, MyMode, lib=lib)
2828                    def _(mode, func, types, args, kwargs):
2829                        x, y = args
2830                        nonlocal called
2831                        called = True
2832                        return x + y
2833
2834                else:
2835                    assert call == "function"
2836
2837                    def add_stuff(mode, func, types, args, kwargs):
2838                        x, y = args
2839                        nonlocal called
2840                        called = True
2841                        return x + y
2842
2843                    torch.library.register_torch_dispatch(
2844                        op, MyMode, add_stuff, lib=lib
2845                    )
2846
2847                x = torch.randn(3)
2848                y = 3.14
2849                with MyMode():
2850                    z = torch.ops._torch_testing.add10.default(x, y)
2851                self.assertEqual(z, x + y)
2852                self.assertTrue(called)
2853
2854    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2855    def test_library_register_kernel(self):
2856        modes = ["function", "qualname", "opoverload"]
2857        calls = ["decorator", "function"]
2858        device_types_options = ["cpu", None]
2859
2860        for mode, call, device_types in itertools.product(
2861            modes, calls, device_types_options
2862        ):
2863
2864            @torch.library.custom_op(
2865                "_torch_testing::add", mutates_args=(), device_types="cuda"
2866            )
2867            def add(x: Tensor, y: float) -> Tensor:
2868                x_np = x.cpu().numpy()
2869                out_np = x_np + y
2870                return torch.from_numpy(out_np).to(x.device)
2871
2872            if mode == "function":
2873                op = add
2874            elif mode == "qualname":
2875                op = "_torch_testing::add"
2876            else:
2877                assert mode == "opoverload"
2878                op = torch.ops._torch_testing.add.default
2879
2880            called = False
2881
2882            if call == "decorator":
2883
2884                @torch.library.register_kernel(op, device_types)
2885                def _(x, y):
2886                    nonlocal called
2887                    called = True
2888                    x_np = x.numpy()
2889                    out_np = x_np + y
2890                    return torch.from_numpy(out_np)
2891
2892            else:
2893                assert call == "function"
2894
2895                def add_cpu(x, y):
2896                    nonlocal called
2897                    called = True
2898                    x_np = x.numpy()
2899                    out_np = x_np + y
2900                    return torch.from_numpy(out_np)
2901
2902                torch.library.register_kernel(op, device_types, add_cpu)
2903
2904            x = torch.randn(3)
2905            y = 3.14
2906            z = add(x, y)
2907            self.assertEqual(z, x + y)
2908            self.assertTrue(called)
2909
2910    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2911    def test_library_register_kernel_low_level(self):
2912        modes = ["qualname", "opoverload"]
2913        calls = ["decorator", "function"]
2914        device_types_options = [("cpu", "cuda"), "cpu", None]
2915
2916        for mode, call, device_types in itertools.product(
2917            modes, calls, device_types_options
2918        ):
2919            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2920                lib.define("add9(Tensor x, float y) -> Tensor")
2921
2922                if mode == "qualname":
2923                    op = "_torch_testing::add9"
2924                else:
2925                    assert mode == "opoverload"
2926                    op = torch.ops._torch_testing.add9.default
2927
2928                called = False
2929
2930                if call == "decorator":
2931
2932                    @torch.library.register_kernel(op, device_types, lib=lib)
2933                    def _(x, y):
2934                        nonlocal called
2935                        called = True
2936                        x_np = x.numpy()
2937                        out_np = x_np + y
2938                        return torch.from_numpy(out_np)
2939
2940                else:
2941                    assert call == "function"
2942
2943                    def add_cpu(x, y):
2944                        nonlocal called
2945                        called = True
2946                        x_np = x.numpy()
2947                        out_np = x_np + y
2948                        return torch.from_numpy(out_np)
2949
2950                    torch.library.register_kernel(op, device_types, add_cpu, lib=lib)
2951
2952                x = torch.randn(3)
2953                y = 3.14
2954                z = torch.ops._torch_testing.add9.default(x, y)
2955                self.assertEqual(z, x + y)
2956                self.assertTrue(called)
2957
2958    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2959    def test_library_register_autograd(self):
2960        for mode in ["function", "qualname", "opoverload"]:
2961
2962            @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
2963            def numpy_sin(x: Tensor) -> Tensor:
2964                x_np = x.cpu().numpy()
2965                y_np = np.sin(x_np)
2966                return torch.from_numpy(y_np).to(device=x.device)
2967
2968            def setup_context(ctx, inputs, output) -> Tensor:
2969                (x,) = inputs
2970                ctx.save_for_backward(x)
2971
2972            called = False
2973
2974            def backward(ctx, grad):
2975                nonlocal called
2976                called = True
2977                (x,) = ctx.saved_tensors
2978                return grad * x.cos()
2979
2980            if mode == "function":
2981                torch.library.register_autograd(
2982                    numpy_sin, backward, setup_context=setup_context
2983                )
2984            elif mode == "qualname":
2985                torch.library.register_autograd(
2986                    "mylib::numpy_sin", backward, setup_context=setup_context
2987                )
2988            elif mode == "opoverload":
2989                torch.library.register_autograd(
2990                    torch.ops.mylib.numpy_sin.default,
2991                    backward,
2992                    setup_context=setup_context,
2993                )
2994
2995            x = torch.randn(3, requires_grad=True)
2996            y = numpy_sin(x)
2997            (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
2998            self.assertTrue(called)
2999            self.assertEqual(grad_x, x.cos())
3000
3001    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3002    def test_library_register_autograd_low_level(self):
3003        for mode in ["qualname", "opoverload"]:
3004            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3005                lib.define("sin5(Tensor x) -> Tensor")
3006
3007                def numpy_sin(x: Tensor) -> Tensor:
3008                    x_np = x.cpu().detach().numpy()
3009                    y_np = np.sin(x_np)
3010                    return torch.from_numpy(y_np).to(device=x.device)
3011
3012                def setup_context(ctx, inputs, output) -> Tensor:
3013                    (x,) = inputs
3014                    ctx.save_for_backward(x)
3015
3016                called = False
3017
3018                def backward(ctx, grad):
3019                    nonlocal called
3020                    called = True
3021                    (x,) = ctx.saved_tensors
3022                    return grad * x.cos()
3023
3024                lib.impl("sin5", numpy_sin, "CPU")
3025
3026                called = False
3027
3028                if mode == "qualname":
3029                    torch.library.register_autograd(
3030                        "_torch_testing::sin5",
3031                        backward,
3032                        setup_context=setup_context,
3033                        lib=lib,
3034                    )
3035                elif mode == "opoverload":
3036                    torch.library.register_autograd(
3037                        torch.ops._torch_testing.sin5.default,
3038                        backward,
3039                        setup_context=setup_context,
3040                        lib=lib,
3041                    )
3042                x = torch.randn(3, requires_grad=True)
3043                y = torch.ops._torch_testing.sin5(x)
3044                (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
3045                self.assertTrue(called)
3046                self.assertEqual(grad_x, x.cos())
3047
3048    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3049    def test_fake(self):
3050        @torch.library.custom_op("_torch_testing::add", mutates_args=())
3051        def add(x: Tensor, y: float) -> Tensor:
3052            x_np = x.cpu().numpy()
3053            out_np = x_np + y
3054            return torch.from_numpy(out_np).to(x.device)
3055
3056        x = torch.randn(3)
3057        y = 3.14
3058        z = add(x, y)
3059        self.assertEqual(z, x + y)
3060
3061        try:
3062            with torch._subclasses.fake_tensor.FakeTensorMode():
3063                x = torch.randn(3)
3064                add(x, y)
3065            raise AssertionError("should not be hit")
3066        except RuntimeError as e:
3067            abstract_impl_error_msg = str(e)
3068        abstract_impl_error_msg = re.sub(
3069            r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
3070        ).replace(". ", ".\n")
3071        self.assertExpectedInline(
3072            abstract_impl_error_msg,
3073            """\
3074There was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
3075This is necessary for torch.compile/export/fx tracing to work.
3076Please use `add.register_fake` to add an fake impl.""",
3077        )
3078
3079        if not IS_WINDOWS:
3080
3081            @torch.compile(backend="eager")
3082            def f(x, y):
3083                return add(x, y)
3084
3085            x = torch.randn(3)
3086            with self.assertRaisesRegex(RuntimeError, "no fake impl"):
3087                f(x, y)
3088
3089        abstract_called = False
3090
3091        @add.register_fake
3092        def _(x, y):
3093            nonlocal abstract_called
3094            abstract_called = True
3095            return torch.empty_like(x)
3096
3097        with torch._subclasses.fake_tensor.FakeTensorMode():
3098            x = torch.randn(3)
3099            z = add(x, y)
3100            self.assertEqual(z.shape, x.shape)
3101            self.assertTrue(abstract_called)
3102
3103    @skipIfTorchDynamo("recursive dynamo")
3104    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
3105    def test_compile(self):
3106        called_impl = False
3107        called_abstract = False
3108
3109        @torch.library.custom_op("_torch_testing::linear", mutates_args=())
3110        def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
3111            nonlocal called_impl
3112            called_impl = True
3113            x_np = x.numpy()
3114            w_np = weight.numpy()
3115            b_np = bias.numpy()
3116            out_np = np.add(x_np @ w_np.T, bias)
3117            return out_np
3118
3119        @custom_linear.register_fake
3120        def _(x, weight, bias):
3121            nonlocal called_abstract
3122            called_abstract = True
3123            assert x.dim() == 2
3124            assert weight.dim() == 2
3125            assert bias.dim() == 1
3126            assert x.shape[1] == weight.shape[1]
3127            assert weight.shape[0] == bias.shape[0]
3128            assert x.device == weight.device
3129            return x.new_empty(x.size(0), weight.size(0))
3130
3131        x = torch.randn(2, 2)
3132        weight = torch.randn(2, 2)
3133        bias = torch.randn(2)
3134        out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
3135            x, weight, bias
3136        )
3137        self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
3138        self.assertTrue(called_impl)
3139        self.assertTrue(called_abstract)
3140
3141    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3142    def test_register_autograd_error_cases(self):
3143        @torch.library.custom_op("_torch_testing::g", mutates_args=())
3144        def g(x: Tensor) -> Tensor:
3145            return x.sin()
3146
3147        x = torch.randn(3, requires_grad=True)
3148        y = g(x)
3149        with self.assertRaisesRegex(RuntimeError, "no autograd formula"):
3150            y.sum().backward()
3151
3152    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3153    def test_replacement(self):
3154        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3155        def f(x: Tensor) -> Tensor:
3156            return x.sin()
3157
3158        x = torch.randn(3)
3159        y = f(x)
3160        self.assertEqual(y, x.sin())
3161
3162        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3163        def f(x: Tensor) -> Tensor:
3164            return x.cos()
3165
3166        y = f(x)
3167        self.assertEqual(y, x.cos())
3168
3169    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3170    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3171    def test_split_device(self):
3172        cpu_call_count = 0
3173        cuda_call_count = 0
3174
3175        @torch.library.custom_op(
3176            "_torch_testing::f", mutates_args=(), device_types="cpu"
3177        )
3178        def f(x: Tensor) -> Tensor:
3179            nonlocal cpu_call_count
3180            cpu_call_count += 1
3181            x_np = x.numpy()
3182            out_np = np.sin(x_np)
3183            return torch.from_numpy(out_np)
3184
3185        @f.register_kernel("cuda")
3186        def _(x: Tensor) -> Tensor:
3187            nonlocal cuda_call_count
3188            cuda_call_count += 1
3189            x_np = x.cpu().numpy()
3190            out_np = np.sin(x_np)
3191            return torch.from_numpy(out_np).to(x.device)
3192
3193        x = torch.randn(3)
3194        y = f(x)
3195        self.assertEqual(y, x.sin())
3196        self.assertEqual(cpu_call_count, 1)
3197        self.assertEqual(cuda_call_count, 0)
3198
3199        x = x.cuda()
3200        y = f(x)
3201        self.assertEqual(y, x.sin())
3202        self.assertEqual(cpu_call_count, 1)
3203        self.assertEqual(cuda_call_count, 1)
3204
3205    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3206    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3207    def test_multi_types(self):
3208        @torch.library.custom_op(
3209            "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda")
3210        )
3211        def f(x: Tensor) -> Tensor:
3212            x_np = x.cpu().numpy()
3213            out_np = np.sin(x_np)
3214            return torch.from_numpy(out_np).to(x.device)
3215
3216        x = torch.randn(3)
3217        y = f(x)
3218        self.assertEqual(y, x.sin())
3219        x = x.cuda()
3220        y = f(x)
3221        self.assertEqual(y, x.sin())
3222
3223    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3224    def test_overloading(self):
3225        called_f = 0
3226        called_f1 = 0
3227
3228        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3229        def f(x: Tensor) -> Tensor:
3230            nonlocal called_f
3231            called_f += 1
3232            return x.clone()
3233
3234        x = torch.randn(2, 3)
3235        torch.ops._torch_testing.f(x)
3236        self.assertEqual(called_f, 1)
3237
3238        @torch.library.custom_op("_torch_testing::f.overload", mutates_args=())
3239        def f1(x: Tensor, y: Tensor) -> Tensor:
3240            nonlocal called_f1
3241            called_f1 += 1
3242            return x.clone()
3243
3244        torch.ops._torch_testing.f(x, x)
3245        self.assertEqual(called_f1, 1)
3246
3247    def test_disallows_output_aliasing(self):
3248        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3249        def f(x: Tensor) -> Tensor:
3250            return x.view(-1)
3251
3252        x = torch.randn(3)
3253        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3254            f(x)
3255
3256        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3257        def f(x: Tensor) -> Tensor:
3258            return x
3259
3260        x = torch.randn(3)
3261        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3262            f(x)
3263
3264        @torch.library.custom_op(
3265            "_torch_testing::f", mutates_args={"x"}, device_types="cpu"
3266        )
3267        def numpy_sin_inplace(x: Tensor) -> Tensor:
3268            x_np = x.numpy()
3269            np.sin(x_np, out=x_np)
3270            return x
3271
3272        x = torch.randn(3)
3273        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3274            numpy_sin_inplace(x)
3275
3276    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3277    def test_factory_function(self):
3278        @torch.library.custom_op(
3279            "_torch_testing::f", mutates_args={}, device_types="cpu"
3280        )
3281        def f(device: torch.device) -> Tensor:
3282            return torch.ones(3)
3283
3284        result = f(device="cpu")
3285        self.assertEqual(result.device, torch.device("cpu"))
3286        self.assertEqual(result, torch.ones(3))
3287
3288        with self.assertRaisesRegex(
3289            RuntimeError, "f does not have a kernel registered for cuda"
3290        ):
3291            f("cuda")
3292
3293        with self.assertRaisesRegex(
3294            ValueError,
3295            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3296        ):
3297
3298            @torch.library.custom_op(
3299                "_torch_testing::f2", mutates_args={}, device_types="cpu"
3300            )
3301            def f2() -> Tensor:
3302                return torch.ones(3)
3303
3304        @torch.library.custom_op("_torch_testing::f3", mutates_args={})
3305        def f3() -> Tensor:
3306            raise NotImplementedError("NYI")
3307
3308        with self.assertRaisesRegex(
3309            ValueError,
3310            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3311        ):
3312
3313            @f3.register_kernel("cpu")
3314            def _():
3315                return torch.zeros(3)
3316
3317            result = f(x)
3318
3319        @torch.library.custom_op("_torch_testing::f4", mutates_args={})
3320        def f4(device: torch.device) -> Tensor:
3321            raise NotImplementedError("NYI")
3322
3323        @f4.register_kernel("cpu")
3324        def _(device: torch.device):
3325            return torch.zeros(3)
3326
3327        result = f(device="cpu")
3328        self.assertEqual(result.device, torch.device("cpu"))
3329        self.assertEqual(result, torch.ones(3))
3330
3331    def test_library_schema_infer(self):
3332        def foo_impl(x: torch.Tensor) -> torch.Tensor:
3333            return x.sin()
3334
3335        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
3336        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
3337
3338        schema = torch.library.infer_schema(foo_impl, mutates_args={})
3339        self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
3340
3341    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3342    def test_set_kernel_enabled(self):
3343        x = torch.ones(1)
3344
3345        @torch.library.custom_op("mylib::f", mutates_args=())
3346        def f(x: Tensor) -> Tensor:
3347            return x + 1
3348
3349        self.assertEqual(f(x), x + 1)
3350        with self.assertLogs("torch._library.custom_ops") as captured:
3351            with f.set_kernel_enabled("gpu", enabled=False):
3352                self.assertEqual(f(x), x + 1)
3353            self.assertIn(
3354                "no kernel was registered for this device type", captured.output[0]
3355            )
3356
3357        @f.register_kernel("cpu")
3358        def _(x):
3359            return x + 2
3360
3361        self.assertEqual(f(x), x + 2)
3362
3363        with self.assertLogs("torch._library.custom_ops") as captured:
3364            with f.set_kernel_enabled("cpu", enabled=True):
3365                self.assertEqual(f(x), x + 2)
3366            self.assertIn("already enabled", captured.output[0])
3367
3368        with f.set_kernel_enabled("cpu", enabled=False):
3369            self.assertEqual(f(x), x + 1)
3370
3371            with self.assertLogs("torch._library.custom_ops") as captured:
3372                with f.set_kernel_enabled("cpu", enabled=False):
3373                    self.assertEqual(f(x), x + 1)
3374                self.assertIn("already disabled", captured.output[0])
3375
3376            self.assertEqual(f(x), x + 1)
3377
3378        with f.set_kernel_enabled("cpu", enabled=True):
3379            self.assertEqual(f(x), x + 2)
3380
3381        with f.set_kernel_enabled("cpu", enabled=False):
3382            self.assertEqual(f(x), x + 1)
3383
3384    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3385    def test_register_vmap_kwargonly_low_level(self):
3386        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3387            lib.define("foo(Tensor x, *, float y) -> Tensor")
3388            called = False
3389
3390            def foo_impl(x, *, y):
3391                return x * y
3392
3393            lib.impl("foo", foo_impl, "CPU")
3394
3395            def vmap(info, in_dims, x, *, y):
3396                nonlocal called
3397                called = True
3398                return x * y, 0
3399
3400            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3401
3402            x = torch.ones(3)
3403            result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
3404            self.assertTrue(called)
3405            self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
3406
3407    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3408    def test_register_vmap_defaults(self):
3409        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3410            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
3411
3412            def foo_impl(w, x=2, *, y=3, z):
3413                return w * x * y * z
3414
3415            lib.impl("foo", foo_impl, "CPU")
3416
3417            called = False
3418
3419            def vmap(info, in_dims, w, x=2, *, y=3, z):
3420                nonlocal called
3421                called = True
3422                return w * x * y * z, 0
3423
3424            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3425
3426            w = torch.ones(3)
3427            result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
3428            self.assertTrue(called)
3429            self.assertEqual(result, w * 2 * 3 * 42)
3430
3431    def test_layout_constraint_tags(self):
3432        needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order
3433        flexible_layout = torch._C.Tag.flexible_layout
3434        # (tags, the result of the tag inference)
3435        tests = [
3436            ({needs_fixed_stride_order}, needs_fixed_stride_order),
3437            ({flexible_layout}, flexible_layout),
3438            # If no tags are provided, then the following is the default
3439            (set(), flexible_layout),
3440            # If multiple tags are provided, then we use the most constrained tag.
3441            ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order),
3442        ]
3443        from torch._inductor.lowering import get_layout_constraint_tag
3444
3445        for tags, expected in tests:
3446            with torch.library._scoped_library("mylib", "FRAGMENT") as m:
3447                m.define("foobar(Tensor x) -> Tensor", tags=tags)
3448                result = get_layout_constraint_tag(torch.ops.mylib.foobar.default)
3449                self.assertEqual(result, expected)
3450
3451    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3452    def test_library_register_vmap(self):
3453        for mode in ["function", "qualname", "opoverload", "c_opdef"]:
3454
3455            @torch.library.custom_op("mylib::f", mutates_args=())
3456            def f(x: Tensor, y: Tensor) -> Tensor:
3457                return x * y
3458
3459            called = False
3460
3461            def fvmap(info, in_dims, x, y):
3462                nonlocal called
3463                called = True
3464                x_bdim, y_bdim = in_dims
3465                x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3466                y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3467                result = x * y
3468                result = result.movedim(-1, 0)
3469                return result, 0
3470
3471            if mode == "function":
3472                torch.library.register_vmap(f, fvmap)
3473            elif mode == "qualname":
3474                torch.library.register_vmap("mylib::f", fvmap)
3475            elif mode == "opoverload":
3476                torch.library.register_vmap(torch.ops.mylib.f.default, fvmap)
3477            elif mode == "c_opdef":
3478                f.register_vmap(fvmap)
3479
3480            x = torch.randn(2, 2)
3481            y = torch.randn(2, 2)
3482
3483            result = torch.vmap(f)(x, y)
3484            self.assertTrue(called)
3485            self.assertEqual(result, x * y)
3486
3487            called = False
3488            result = torch.vmap(f, out_dims=1)(x, y)
3489            self.assertEqual(result, (x * y).T)
3490            self.assertTrue(called)
3491
3492            called = False
3493            result = torch.vmap(f, in_dims=1)(x, y)
3494            self.assertEqual(result, (x * y).T)
3495            self.assertTrue(called)
3496
3497    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3498    def test_library_register_vmap_library_decorator(self):
3499        @torch.library.custom_op("mylib::f", mutates_args=())
3500        def f(x: Tensor, y: Tensor) -> Tensor:
3501            return x * y
3502
3503        called = False
3504
3505        @torch.library.register_vmap("mylib::f")
3506        def fvmap(info, in_dims, x, y):
3507            nonlocal called
3508            called = True
3509            x_bdim, y_bdim = in_dims
3510            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3511            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3512            result = x * y
3513            result = result.movedim(-1, 0)
3514            return result, 0
3515
3516        x = torch.randn(2, 2)
3517        y = torch.randn(2, 2)
3518
3519        result = torch.vmap(f)(x, y)
3520        self.assertTrue(called)
3521        self.assertEqual(result, x * y)
3522
3523    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3524    def test_library_register_vmap_op_decorator(self):
3525        @torch.library.custom_op("mylib::f", mutates_args=())
3526        def f(x: Tensor, y: Tensor) -> Tensor:
3527            return x * y
3528
3529        called = False
3530
3531        @f.register_vmap
3532        def fvmap(info, in_dims, x, y):
3533            nonlocal called
3534            called = True
3535            x_bdim, y_bdim = in_dims
3536            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3537            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3538            result = x * y
3539            result = result.movedim(-1, 0)
3540            return result, 0
3541
3542        x = torch.randn(2, 2)
3543        y = torch.randn(2, 2)
3544
3545        result = torch.vmap(f)(x, y)
3546        self.assertTrue(called)
3547        self.assertEqual(result, x * y)
3548
3549    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3550    def test_library_register_vmap_register_multiple_times(self):
3551        @torch.library.custom_op("mylib::f", mutates_args=())
3552        def f(x: Tensor, y: Tensor) -> Tensor:
3553            return x * y
3554
3555        called = False
3556
3557        @f.register_vmap
3558        def fvmap(info, in_dims, x, y):
3559            nonlocal called
3560            called = True
3561            x_bdim, y_bdim = in_dims
3562            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3563            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3564            result = x * y
3565            result = result.movedim(-1, 0)
3566            return result, 0
3567
3568        x = torch.randn(2, 2)
3569        y = torch.randn(2, 2)
3570
3571        result = torch.vmap(f)(x, y)
3572        self.assertTrue(called)
3573        self.assertEqual(result, x * y)
3574        called = False
3575
3576        @f.register_vmap
3577        def fvmap2(info, in_dims, x, y):
3578            nonlocal called
3579            called = True
3580            x_bdim, y_bdim = in_dims
3581            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3582            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3583            result = x + y
3584            result = result.movedim(-1, 0)
3585            return result, 0
3586
3587        result = torch.vmap(f)(x, y)
3588        self.assertTrue(called)
3589        self.assertEqual(result, x + y)
3590
3591    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3592    def test_library_register_vmap_register_multiple_times_2(self):
3593        @torch.library.custom_op("mylib::f", mutates_args=())
3594        def f(x: Tensor, y: Tensor) -> Tensor:
3595            return x * y
3596
3597        called = False
3598
3599        @torch.library.register_vmap("mylib::f")
3600        def fvmap(info, in_dims, x, y):
3601            nonlocal called
3602            called = True
3603            x_bdim, y_bdim = in_dims
3604            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3605            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3606            result = x * y
3607            result = result.movedim(-1, 0)
3608            return result, 0
3609
3610        x = torch.randn(2, 2)
3611        y = torch.randn(2, 2)
3612
3613        result = torch.vmap(f)(x, y)
3614        self.assertTrue(called)
3615        self.assertEqual(result, x * y)
3616        called = False
3617
3618        @torch.library.register_vmap("mylib::f")
3619        def fvmap2(info, in_dims, x, y):
3620            nonlocal called
3621            called = True
3622            x_bdim, y_bdim = in_dims
3623            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3624            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3625            result = x + y
3626            result = result.movedim(-1, 0)
3627            return result, 0
3628
3629        result = torch.vmap(f)(x, y)
3630        self.assertTrue(called)
3631        self.assertEqual(result, x + y)
3632
3633
3634class MiniOpTestOther(CustomOpTestCaseBase):
3635    test_ns = "mini_op_test"
3636
3637    def test_nonzero_again(self):
3638        x = torch.tensor([0, 1, 2, 0, 0])
3639        y = torch.ops.aten.nonzero.default(x)
3640        self.assertEqual(y, torch.tensor([[1], [2]]))
3641
3642
3643optests.generate_opcheck_tests(
3644    MiniOpTest,
3645    ["aten", "mini_op_test"],
3646    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3647    additional_decorators={
3648        "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
3649    },
3650    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3651)
3652
3653optests.generate_opcheck_tests(
3654    MiniOpTestOther,
3655    ["aten", "mini_op_test"],
3656    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3657    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3658)
3659
3660
3661class TestGenerateOpcheckTests(CustomOpTestCaseBase):
3662    def test_MiniOpTest(self):
3663        for orig_test in ["test_mm", "test_nonzero"]:
3664            for (
3665                test
3666            ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
3667                expected_test = f"{test}__{orig_test}"
3668                self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
3669
3670    def test_generate_repro_save_data(self):
3671        from torch.testing._internal.optests.generate_tests import generate_repro
3672
3673        args = (torch.ones(2, 2),)
3674        kwargs = {"mat2": torch.zeros(2, 2)}
3675        actual = generate_repro(
3676            "test_schema",
3677            torch.ops.aten.sin.default,
3678            args,
3679            kwargs,
3680            save_data=True,
3681            dry_run=True,
3682        )
3683        actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
3684        self.assertExpectedInline(
3685            actual,
3686            """\
3687# =========================================================
3688# BEGIN REPRO SCRIPT
3689# =========================================================
3690import torch
3691from torch.testing._internal.optests import opcheck
3692
3693# Make sure you have loaded the library that contains the op
3694# via an import or torch.ops.load_library(...)
3695op = torch.ops.aten.sin.default
3696
3697args, kwargs = torch.load("repro.pt")
3698opcheck(op, args, kwargs, test_utils="test_schema")
3699# =========================================================
3700# END REPRO SCRIPT
3701# =========================================================
3702""",
3703        )
3704
3705    def test_generate_repro_no_save_data(self):
3706        from torch.testing._internal.optests.generate_tests import generate_repro
3707
3708        args = (torch.ones(2, 2),)
3709        kwargs = {"mat2": torch.zeros(2, 2)}
3710        actual = generate_repro(
3711            "test_schema",
3712            torch.ops.aten.sin.default,
3713            args,
3714            kwargs,
3715            save_data=False,
3716            dry_run=True,
3717        )
3718        self.assertExpectedInline(
3719            actual,
3720            """\
3721# =========================================================
3722# BEGIN REPRO SCRIPT
3723# =========================================================
3724import torch
3725from torch.testing._internal.optests import opcheck
3726
3727# Make sure you have loaded the library that contains the op
3728# via an import or torch.ops.load_library(...)
3729op = torch.ops.aten.sin.default
3730
3731# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
3732# we will fill them in same (args, kwargs) as in your test
3733args = ()  # args to the operator
3734kwargs = {}  # kwargs to the operator
3735opcheck(op, args, kwargs, test_utils="test_schema")
3736# =========================================================
3737# END REPRO SCRIPT
3738# =========================================================
3739""",
3740        )
3741
3742    def test_failures_dict_validation(self):
3743        from torch.testing._internal.optests.generate_tests import (
3744            FailuresDict,
3745            validate_failures_dict_structure,
3746        )
3747
3748        failures = {
3749            "mini_op_test::incorrect_schema": {
3750                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": {
3751                    "comment": "",
3752                    "status": "success",
3753                }
3754            }
3755        }
3756        with self.assertRaisesRegex(RuntimeError, "got status=success"):
3757            validate_failures_dict_structure(
3758                FailuresDict("", failures),
3759                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3760                MiniOpTest,
3761            )
3762
3763        failures = {
3764            "mini_op_test::incorrect_schema": {
3765                "MiniOpTest.test_aot_dispatch__test_delayed_error": {
3766                    "comment": "",
3767                    "status": "xfail",
3768                },
3769            }
3770        }
3771        with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
3772            validate_failures_dict_structure(
3773                FailuresDict("", failures),
3774                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3775                MiniOpTest,
3776            )
3777
3778        failures = {
3779            "mini_op_test::incorrect_schema": {
3780                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": {
3781                    "comment": "",
3782                    "status": "xfail",
3783                },
3784            }
3785        }
3786        with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
3787            validate_failures_dict_structure(
3788                FailuresDict("", failures),
3789                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3790                MiniOpTest,
3791            )
3792
3793    def test_dont_generate_decorator(self):
3794        self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
3795        self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
3796
3797    def test_opcheck(self):
3798        x = torch.randn(3, requires_grad=True)
3799        with self.assertRaisesRegex(ValueError, "OpOverload"):
3800            torch.library.opcheck(torch.sin, (x,))
3801        with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
3802            torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
3803        result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
3804
3805        self.assertEqual(
3806            result,
3807            {
3808                "test_schema": "SUCCESS",
3809                "test_autograd_registration": "SUCCESS",
3810                "test_faketensor": "SUCCESS",
3811                "test_aot_dispatch_dynamic": "SUCCESS",
3812            },
3813        )
3814
3815        result = torch.library.opcheck(
3816            torch.ops.aten.sin.default, (x,), test_utils="test_schema"
3817        )
3818        self.assertEqual(result, {"test_schema": "SUCCESS"})
3819
3820        result = torch.library.opcheck(
3821            torch.ops.aten.sin.default,
3822            (x,),
3823            test_utils=["test_schema", "test_faketensor"],
3824        )
3825        self.assertEqual(
3826            result,
3827            {
3828                "test_schema": "SUCCESS",
3829                "test_faketensor": "SUCCESS",
3830            },
3831        )
3832
3833    def test_opcheck_customopdef(self):
3834        sample_inputs = [
3835            (torch.randn(3),),
3836            (torch.randn(3, requires_grad=True),),
3837        ]
3838        if torch.cuda.is_available():
3839            sample_inputs.extend(
3840                [
3841                    (torch.randn(3, device="cuda"),),
3842                    (torch.randn(3, device="cuda", requires_grad=True),),
3843                ]
3844            )
3845        for args in sample_inputs:
3846            torch.library.opcheck(custom_op_db.numpy_cube, args)
3847
3848    def test_is_inside_opcheck_mode(self):
3849        self.assertFalse(optests.is_inside_opcheck_mode())
3850        with optests.generate_tests.OpCheckMode(
3851            ["foo"], "bar", lambda x: x, None, "baz", "brr"
3852        ):
3853            self.assertTrue(optests.is_inside_opcheck_mode())
3854
3855    def test_opcheck_bad_op(self):
3856        op = op_with_incorrect_schema(self, "foo")
3857        x = torch.randn(3)
3858        with self.assertRaisesRegex(Exception, "is not defined to alias output"):
3859            torch.library.opcheck(op, (x,))
3860
3861        result = torch.library.opcheck(op, (x,), raise_exception=False)
3862        self.assertTrue(isinstance(result["test_schema"], RuntimeError))
3863        del result["test_schema"]
3864        self.assertEqual(
3865            result,
3866            {
3867                "test_autograd_registration": "SUCCESS",
3868                "test_faketensor": "SUCCESS",
3869                "test_aot_dispatch_dynamic": "SUCCESS",
3870            },
3871        )
3872
3873    def test_opcheck_does_not_require_extra_deps(self):
3874        # torch.testing._internal.common_utils comes with a lot of additional
3875        # test-time dependencies. Since opcheck is public API, it should be
3876        # usable only with pytorch install-time dependencies.
3877        cmd = [
3878            sys.executable,
3879            "-c",
3880            "import torch; import sys; \
3881               x = torch.randn(3, requires_grad=True); \
3882               torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
3883               assert 'expecttest' not in sys.modules; \
3884               assert 'torch.testing._internal.common_utils' not in sys.modules",
3885        ]
3886        subprocess.check_output(cmd, shell=False)
3887
3888
3889class TestTypeConversion(TestCase):
3890    """In infer_schema(), we try to suggest a correct type when the type annotation is wrong."""
3891
3892    def setUp(self):
3893        self.supported_base_types = [
3894            int,
3895            float,
3896            bool,
3897            str,
3898            torch.device,
3899            torch.Tensor,
3900            torch.dtype,
3901            torch.types.Number,
3902        ]
3903
3904    def test_simple_tuple(self):
3905        self.assertEqual(List, tuple_to_list(Tuple))
3906
3907    def test_supported_types(self):
3908        for t in self.supported_base_types:
3909            result_type = tuple_to_list(Tuple[t, t, t])
3910            self.assertEqual(result_type, List[t])
3911
3912            result_type = tuple_to_list(Tuple[t])
3913            self.assertEqual(result_type, List[t])
3914
3915    def test_optional(self):
3916        for t in self.supported_base_types:
3917            result_type = tuple_to_list(Tuple[t, Optional[t]])
3918            self.assertEqual(result_type, List[Optional[t]])
3919
3920            result_type = tuple_to_list(Tuple[t, t, Optional[t]])
3921            self.assertEqual(result_type, List[Optional[t]])
3922
3923            result_type = tuple_to_list(Tuple[t, ...])
3924            self.assertEqual(result_type, List[t])
3925
3926    def test_mixed_types(self):
3927        result_type = tuple_to_list(Tuple[int, float])
3928        self.assertEqual(result_type, List[typing.Union[int, float]])
3929
3930        result_type = tuple_to_list(Tuple[int, float, str])
3931        self.assertEqual(result_type, List[typing.Union[int, float, str]])
3932
3933
3934only_for = ("cpu", "cuda")
3935instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
3936instantiate_parametrized_tests(TestCustomOp)
3937instantiate_parametrized_tests(TestCustomOpAPI)
3938
3939if __name__ == "__main__":
3940    run_tests()
3941