xref: /aosp_15_r20/external/pytorch/test/test_python_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: __torch_dispatch__"]
2
3import logging
4import sys
5import tempfile
6import unittest
7from copy import deepcopy
8
9import torch
10import torch._dynamo
11from torch import SymInt
12from torch._C import DispatchKey, DispatchKeySet
13from torch._custom_op.functional import register_functional_op
14from torch._subclasses.fake_tensor import FakeTensorMode
15from torch.cuda.jiterator import _create_jit_fn
16from torch.fx.experimental.proxy_tensor import make_fx
17from torch.fx.experimental.symbolic_shapes import ShapeEnv
18from torch.library import _scoped_library, fallthrough_kernel, impl, Library
19from torch.multiprocessing.reductions import StorageWeakRef
20from torch.testing._internal.common_device_type import (
21    instantiate_device_type_tests,
22    ops,
23)
24from torch.testing._internal.common_methods_invocations import op_db
25from torch.testing._internal.common_utils import (
26    first_sample,
27    IS_WINDOWS,
28    run_tests,
29    TEST_WITH_ROCM,
30    TestCase,
31)
32from torch.testing._internal.custom_op_db import custom_op_db
33from torch.testing._internal.logging_tensor import (
34    capture_logs,
35    capture_logs_with_logging_tensor_mode,
36    log_input,
37    LoggingTensor,
38    LoggingTensorMode,
39    LoggingTensorReentrant,
40)
41from torch.testing._internal.two_tensor import TwoTensor
42from torch.utils import _pytree as pytree
43from torch.utils._mode_utils import all_same_mode, no_dispatch
44from torch.utils._python_dispatch import (
45    _get_current_dispatch_mode,
46    _get_current_dispatch_mode_stack,
47    is_in_torch_dispatch_mode,
48    TorchDispatchMode,
49)
50from torch.utils._pytree import tree_map, tree_map_only
51
52
53# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
54def _identity(x):
55    return x
56
57
58class TestDispatcherPythonBindings(TestCase):
59    def test_call_boxed(self) -> None:
60        sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
61        x = torch.randn(3)
62        y = torch._C._dispatch_call_boxed(sin, x)
63        self.assertEqual(y, x.sin())
64
65
66class TestPythonRegistration(TestCase):
67    test_ns = "_test_python_registration"
68
69    def tearDown(self):
70        if hasattr(torch.ops, self.test_ns):
71            del torch.ops._test_python_registration
72
73    def test_fallback(self) -> None:
74        test_key = "TESTING_ONLY_GenericMode"
75        test_keyset = torch._C.DispatchKeySet(test_key)
76        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
77        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
78
79        with _scoped_library("_", "IMPL") as my_lib:
80            expected_op = None
81            expected_args = None
82            expected_kwargs = None
83            # Use this out shape to make sure the result from our fallback
84            # is what is returned to the user
85            out_shape = None
86
87            def my_fallback(op, *args, **kwargs):
88                # Disable our handler during checks and generating the output
89                with torch._C._ForceDispatchKeyGuard(
90                    include_to_set, exclude_to_set | test_keyset
91                ):
92                    self.assertIs(op, expected_op)
93                    self.assertEqual(args, expected_args)
94                    self.assertEqual(kwargs, expected_kwargs)
95                    # Return something specific
96                    return torch.empty(out_shape)
97
98            my_lib.fallback(my_fallback, test_key)
99
100            a, b = torch.rand(2), torch.rand(2)
101
102            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
103                # Check a factory function
104                expected_op = torch.ops.aten.empty.memory_format
105                expected_args = ((2, 2),)
106                # Extra kwargs to bypass issues with default args in factory functions
107                expected_kwargs = {
108                    "dtype": torch.float64,
109                    "pin_memory": False,
110                    "device": torch.device("cpu"),
111                }
112                out_shape = (3,)
113                out = torch.empty(*expected_args, **expected_kwargs)
114                self.assertEqual(out.size(), out_shape)
115
116                # Check a regular function
117                expected_op = torch.ops.aten.add.Tensor
118                expected_args = (a, b)
119                expected_kwargs = {}
120                out_shape = (4,)
121                out = a + b
122                self.assertEqual(out.size(), out_shape)
123
124    def test_fallback_keyset(self) -> None:
125        test_key_first = "TESTING_ONLY_GenericMode"
126        test_key_second = "TESTING_ONLY_GenericWrapper"
127        test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
128            test_key_second
129        )
130        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
131        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
132
133        with _scoped_library("_", "IMPL") as my_lib:
134            first_called = False
135            second_called = False
136
137            def first_fallback(keyset, op, *args, **kwargs):
138                nonlocal first_called
139                if second_called:
140                    # Recursive call
141                    first_called = True
142                    with torch._C._ForceDispatchKeyGuard(
143                        include_to_set, exclude_to_set | test_keyset
144                    ):
145                        return op(*args, **kwargs)
146                else:
147                    # Redispatch down
148                    keyset = keyset.remove(test_key_first)
149                    return op.redispatch(keyset, *args, **kwargs)
150
151            def second_fallback(op, *args, **kwargs):
152                nonlocal second_called
153                # Set to avoid infinite recursion
154                second_called = True
155                # New dispatcher call should hit the first callback again
156                self.assertFalse(first_called)
157                a, b = args
158                # Make a substraction here instead of add !
159                c = a - b
160                self.assertTrue(first_called)
161                return c
162
163            my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
164            my_lib.fallback(second_fallback, test_key_second)
165
166            a, b = torch.rand(2), torch.rand(2)
167            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
168                c = a + b
169
170            self.assertEqual(c, a - b)
171            self.assertTrue(first_called)
172            self.assertTrue(second_called)
173
174    def test_fallback_fallthrough(self) -> None:
175        test_key_first = "TESTING_ONLY_GenericMode"
176        test_key_second = "TESTING_ONLY_GenericWrapper"
177        test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
178            test_key_second
179        )
180        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
181        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
182
183        with _scoped_library("_", "IMPL") as my_lib:
184            is_called = False
185
186            def my_fallback(op, *args, **kwargs):
187                nonlocal is_called
188                is_called = True
189                with torch._C._ForceDispatchKeyGuard(
190                    include_to_set, exclude_to_set | test_keyset
191                ):
192                    return op(*args, **kwargs)
193
194            my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
195            my_lib.fallback(my_fallback, test_key_second)
196
197            a, b = torch.rand(2), torch.rand(2)
198            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
199                c = a + b
200
201            self.assertEqual(c, a + b)
202            self.assertTrue(is_called)
203
204    def test_override_aten_ops_with_multiple_libraries(self) -> None:
205        x = torch.tensor([1, 2])
206        with _scoped_library("aten", "IMPL") as my_lib2:
207            with _scoped_library("aten", "IMPL") as my_lib1:
208                # Example 1
209                def my_neg(*args, **kwargs):
210                    return args[0]._neg_view()
211
212                # Now we are secretly making the operator a view op so autograd needs to know how
213                # to handle it
214                my_lib1.impl("neg", my_neg, "AutogradCPU")
215
216                self.assertTrue(torch.neg(x).is_neg())
217
218                # RuntimeError: impl("aten::neg", ...):
219                # Explicitly provided namespace (aten) in operator name does not match ...
220                with self.assertRaisesRegex(
221                    RuntimeError, "operator name does not match namespace"
222                ):
223                    with _scoped_library("foo", "DEF") as my_lib3:
224                        my_lib3.define("neg(Tensor self) -> Tensor")
225                        my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
226
227                # Example 2
228                def my_mul(*args, **kwargs):
229                    return torch.zeros_like(args[0])
230
231                # torch.ops.aten.mul.Tensor
232                my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
233
234                y = torch._efficientzerotensor(2)
235                self.assertFalse(torch.mul(x, y)._is_zerotensor())
236
237                # Assert that a user can't override the behavior of a (ns, op, dispatch_key)
238                # combination if someone overridden the behavior for the same before them
239                with self.assertRaisesRegex(
240                    RuntimeError, "already a kernel registered from python"
241                ):
242                    my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
243
244            # Validate that lib2 is not affected by removing lib1
245            self.assertFalse(torch.mul(x, y)._is_zerotensor())
246
247        # Validate that the old behavior is restored for neg and mul
248        self.assertFalse(torch.neg(x).is_neg())
249        self.assertTrue(torch.mul(x, y)._is_zerotensor())
250
251    def test_error_if_fn_not_callable(self):
252        with self.assertRaisesRegex(
253            TypeError, "Input function is required to be a callable"
254        ):
255            with _scoped_library("aten", "IMPL") as my_lib:
256                my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
257
258    def test_finalizer(self):
259        impls_refcnt = sys.getrefcount(torch.library._impls)
260        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
261        lib.define("foo123(Tensor x) -> Tensor")
262
263        # 1 for `lib`, 1 for sys.getrefcount
264        self.assertEqual(sys.getrefcount(lib), 2)
265        # We gained an additional reference that gets cleared when the finalizer runs
266        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
267        # 1 for `lib`
268        # 1 for the finalizer
269        # 1 for sys.getrefcount
270        self.assertEqual(sys.getrefcount(lib._op_impls), 3)
271
272        def foo123(x):
273            pass
274
275        lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
276        key = f"{self.test_ns}/foo123/CPU"
277        self.assertTrue(key in torch.library._impls)
278
279        saved_op_impls = lib._op_impls
280
281        # del will definitely work if the following passes
282        self.assertEqual(sys.getrefcount(lib), 2)
283        del lib
284
285        # 1 for saved_op_impls
286        # 1 for sys.getrefcount
287        # This function should be the last user of lib._op_impls:
288        # - lib should not have a reference anymore (it was del'ed)
289        # - lib's finalizer should not have a reference anymore
290        self.assertEqual(sys.getrefcount(saved_op_impls), 2)
291
292        self.assertTrue(key not in torch.library._impls)
293
294        # lib's finalizer should not have a reference anymore
295        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
296
297    def test_override_cpu_sum(self) -> None:
298        # Example 1
299        run = [False]
300
301        def my_sum(*args, **kwargs):
302            run[0] = True
303            return args[0].clone()
304
305        with _scoped_library("aten", "IMPL") as my_lib1:
306            my_lib1.impl("aten::sum", my_sum, "CPU")
307            x = torch.tensor([1, 2])
308            self.assertEqual(torch.sum(x), x)
309            self.assertTrue(run[0])
310        # Validate that the old behavior is restored for sum
311        self.assertEqual(torch.sum(x), torch.tensor(3))
312
313    def test_override_cuda_with_jiterator(self) -> None:
314        def override_where_cuda() -> None:
315            # Example 1: Invert the behavior of where's condition input
316            not_where_code_string = """
317            template <typename T> T inverted_where(bool cond, T a, T b){
318                return !cond ? a : b;
319            }
320            """
321            jitted_where = _create_jit_fn(not_where_code_string)
322
323            CALLED = [False]
324
325            def inverted_where(*args, **kwargs):
326                CALLED[0] = True
327                return jitted_where(*args, **kwargs)
328
329            # overriding where's cuda kernel with Jiterator generated kernel
330            with _scoped_library("aten", "IMPL") as my_lib:
331                my_lib.impl("aten::where.self", inverted_where, "CUDA")
332
333                device = "cuda"
334                cond = torch.tensor(
335                    [True, True, False], device=device, dtype=torch.bool
336                )
337                x = torch.tensor([1, 2, 3], device=device)
338                y = torch.tensor([-1, -2, -3], device=device)
339
340                self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
341                self.assertTrue(CALLED[0])
342
343            # behavior restored after deregistration
344            self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
345
346        def override_gelu_cuda() -> None:
347            # Example 2: Use relu to approximate gelu for faster compute
348            fastest_gelu_code_string = """
349            template <typename T> T fast_gelu(T a){
350                return a > 0 ? a : 0;
351            }
352            """
353            jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
354
355            CALLED = [False]
356
357            def fast_gelu(*args, **kwargs):
358                CALLED[0] = True
359                return jitted_gelu(*args, **kwargs)
360
361            # overriding gelu's cuda kernel with Jiterator generated relu kernel
362            with _scoped_library("aten", "IMPL") as my_lib:
363                my_lib.impl("aten::gelu", fast_gelu, "CUDA")
364
365                x = torch.rand([3, 3], device="cuda", dtype=torch.float)
366                self.assertEqual(
367                    torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
368                )
369                self.assertTrue(CALLED[0])
370
371            # behavior restored after deregistration
372            self.assertNotEqual(
373                torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
374            )
375
376        def override_exp_cuda() -> None:
377            # Example 3: Preventing exp from exploding for float16
378            clipped_exp_code_string = """
379            template <typename T> T clipped_exp(T a){
380                return a > T(10.0) ? T(22026.4657948) : exp(a);
381            }
382            """
383            jitted_exp = _create_jit_fn(clipped_exp_code_string)
384
385            CALLED = [False]
386
387            def clipped_exp(*args, **kwargs):
388                CALLED[0] = True
389                return jitted_exp(*args, **kwargs)
390
391            # overriding exp's cuda kernel with clipped_exp kernel
392            with _scoped_library("aten", "IMPL") as my_lib:
393                my_lib.impl("aten::exp", clipped_exp, "CUDA")
394
395                x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
396                self.assertEqual(
397                    torch.exp(x),
398                    torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
399                )
400                self.assertTrue(CALLED[0])
401
402            # behavior restored after deregistration
403            self.assertEqual(
404                torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
405            )
406
407        def override_add_cuda() -> None:
408            # Example 4: simulate a hardware bug, where the adder is always off by 1
409            buggy_add_code_string = """
410            template <typename T> T buggy_add(T a, T b){
411                return a + b + T(1);
412            }
413            """
414            jitted_add = _create_jit_fn(buggy_add_code_string)
415
416            CALLED = [False]
417
418            def buggy_add(*args, **kwargs):
419                CALLED[0] = True
420                return jitted_add(*args, **kwargs)
421
422            with _scoped_library("aten", "IMPL") as my_lib:
423                my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")
424
425                x_cpu = torch.rand([3, 3], device="cpu")
426                y_cpu = torch.rand([3], device="cpu")
427
428                x_cuda = x_cpu.cuda()
429                y_cuda = y_cpu.cuda()
430
431                self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
432                self.assertTrue(CALLED[0])
433
434            # behavior restored after deregistration
435            self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
436
437        if torch.cuda.is_available() and not TEST_WITH_ROCM:
438            override_where_cuda()
439            override_gelu_cuda()
440            override_exp_cuda()
441            override_add_cuda()
442
443    def test_extend_library_with_dispatch_key_arg(self):
444        def my_sum(*args, **kwargs):
445            return args[0].clone()
446
447        with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
448            # RuntimeError: Explicitly provided dispatch key (Conjugate) is
449            # inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
450            with self.assertRaisesRegex(
451                RuntimeError, "inconsistent with the dispatch key"
452            ):
453                my_lib1.impl("sum", my_sum, "Conjugate")
454            my_lib1.impl("aten::sum", my_sum)
455            x = torch.tensor([1, 2])
456            self.assertEqual(torch.sum(x), x)
457
458    def test_create_new_library(self) -> None:
459        with _scoped_library(self.test_ns, "DEF") as my_lib1:
460            my_lib1.define("sum(Tensor self) -> Tensor")
461
462            # Example 1
463            @torch.library.impl(my_lib1, "sum", "CPU")
464            def my_sum(*args, **kwargs):
465                return args[0].clone()
466
467            x = torch.tensor([1, 2])
468            op = getattr(torch.ops, self.test_ns).sum
469            self.assertEqual(op(x), x)
470
471            with _scoped_library(self.test_ns, "IMPL") as my_lib2:
472                # Example 2
473                @torch.library.impl(my_lib2, op.default, "ZeroTensor")
474                def my_sum_zt(*args, **kwargs):
475                    if args[0]._is_zerotensor():
476                        return torch._efficientzerotensor(args[0].shape)
477                    else:
478                        return args[0].clone()
479
480                y = torch._efficientzerotensor(3)
481                self.assertTrue(op(y)._is_zerotensor())
482                self.assertEqual(op(x), x)
483
484    def test_create_new_library_fragment_no_existing(self):
485        with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
486            my_lib.define("sum2(Tensor self) -> Tensor")
487
488            @torch.library.impl(my_lib, "sum2", "CPU")
489            def my_sum(*args, **kwargs):
490                return args[0]
491
492            x = torch.tensor([1, 2])
493            self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
494
495    def test_create_new_library_fragment_with_existing(self):
496        with _scoped_library(self.test_ns, "DEF") as my_lib1:
497            # Create a fragment
498            with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
499                my_lib2.define("sum4(Tensor self) -> Tensor")
500
501                @torch.library.impl(my_lib2, "sum4", "CPU")
502                def my_sum4(*args, **kwargs):
503                    return args[0]
504
505                x = torch.tensor([1, 2])
506                self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
507
508                # Create another fragment
509                with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
510                    my_lib3.define("sum3(Tensor self) -> Tensor")
511
512                    @torch.library.impl(my_lib3, "sum3", "CPU")
513                    def my_sum3(*args, **kwargs):
514                        return args[0]
515
516                    x = torch.tensor([1, 2])
517                    self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
518
519    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
520    def test_alias_analysis(self):
521        def test_helper(alias_analysis=""):
522            my_lib1 = Library(self.test_ns, "DEF")  # noqa: TOR901
523
524            called = [0]
525
526            @torch.library.define(
527                my_lib1, "_op() -> None", alias_analysis=alias_analysis
528            )
529            def _op(*args, **kwargs):
530                called[0] += 1
531
532            @torch.jit.script
533            def _test():
534                torch.ops._test_python_registration._op()
535
536            assert "_test_python_registration::_op" in str(_test.graph)
537
538        with self.assertRaises(AssertionError):
539            test_helper("")  # alias_analysis="FROM_SCHEMA"
540
541        test_helper("CONSERVATIVE")
542
543    def test_error_for_unsupported_ns_or_kind(self) -> None:
544        with self.assertRaisesRegex(ValueError, "Unsupported kind"):
545            my_lib1 = Library("myns", "BLA")  # noqa: TOR901
546
547        for kind in ("DEF", "FRAGMENT"):
548            with self.assertRaisesRegex(ValueError, "reserved namespace"):
549                my_lib1 = Library("prim", kind)  # noqa: TOR901
550
551    def test_returning_symint(self) -> None:
552        shape_env = ShapeEnv()
553        fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
554
555        ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
556
557        s0, s1 = ft.shape
558
559        with _scoped_library(self.test_ns, "DEF") as tlib:
560            tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
561
562            @impl(tlib, "sqsum", "CompositeExplicitAutograd")
563            def sqsum(a: SymInt, b: SymInt):
564                return a * a + b * b
565
566            out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
567            out_val = shape_env.evaluate_expr(out.node.expr)
568        self.assertEqual(out_val, 13)
569
570    def test_register_functional_op_error_cases(self):
571        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
572            with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
573                register_functional_op(lib, "abs", torch.ops.aten.abs_)
574            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
575                register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
576            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
577                register_functional_op(lib, "abs", torch.ops.aten.abs.out)
578
579            schemas = [
580                "foo(Tensor x, Tensor(a!)[] y) -> ()",
581                "foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)",
582                "foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))",
583            ]
584
585        for schema in schemas:
586            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
587                lib.define(schema)
588                with self.assertRaisesRegex(RuntimeError, "NYI"):
589                    register_functional_op(
590                        lib,
591                        "foo_functional",
592                        getattr(torch.ops, self.test_ns).foo.default,
593                    )
594
595    def _check_is_functional_variant(self, mutable_op, functional_op, args):
596        # functional op should not mutate
597        cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
598        functional_result = functional_op(*cloned_args)
599        self.assertEqual(cloned_args, args)
600
601        # check functional_result includes mutable_result
602        mutable_result = mutable_op(*cloned_args)
603        if mutable_result is None:
604            flat_mutable_result = []
605        else:
606            flat_mutable_result = pytree.tree_leaves(mutable_result)
607        flat_functional_result = pytree.tree_leaves(functional_result)
608        assert len(flat_functional_result) > len(flat_mutable_result)
609        self.assertEqual(
610            flat_functional_result[: len(flat_mutable_result)], flat_mutable_result
611        )
612
613        # check rest of functional_result is the mutated args
614        mutated_args = [
615            maybe_mutated_arg
616            for maybe_mutated_arg, arg in zip(cloned_args, args)
617            if not (
618                maybe_mutated_arg is not None
619                and arg is not None
620                and torch.allclose(maybe_mutated_arg, arg)
621            )
622        ]
623        self.assertEqual(
624            flat_functional_result[len(flat_mutable_result) :], mutated_args
625        )
626
627        # check that functionalization kernel was indeed registered
628        def fn(*args):
629            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
630            mutable_op(*cloned_args)
631            return cloned_args
632
633        gm = make_fx(torch.func.functionalize(fn))(*args)
634        has_functional_op = False
635        for node in gm.graph.nodes:
636            self.assertFalse(node.target is mutable_op)
637            if node.target is functional_op:
638                has_functional_op = True
639        self.assertTrue(has_functional_op)
640
641    def test_register_functional_op_no_returns(self):
642        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
643            lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()")
644
645            def foo_impl(x, y, z, w):
646                y.fill_(3.14)
647                w.fill_(2.71)
648
649            lib.impl("foo", foo_impl, "CPU")
650            register_functional_op(
651                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
652            )
653            x = torch.randn([])
654            y = torch.randn([])
655            z = torch.randn([])
656            w = torch.randn([])
657            self._check_is_functional_variant(
658                getattr(torch.ops, self.test_ns).foo.default,
659                getattr(torch.ops, self.test_ns).foo_functional.default,
660                (x, y, z, w),
661            )
662
663    def test_register_functional_op_with_optional(self):
664        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
665            lib.define(
666                "foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()"
667            )
668
669            def foo_impl(x, y, z, w):
670                y.fill_(3.14)
671                z.fill_(2.71)
672                if w is not None:
673                    w.fill_(1.618)
674
675            lib.impl("foo", foo_impl, "CPU")
676            register_functional_op(
677                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
678            )
679            x = torch.randn([])
680            y = torch.randn([])
681            z = torch.randn([])
682            w = torch.randn([])
683            self._check_is_functional_variant(
684                getattr(torch.ops, self.test_ns).foo.default,
685                getattr(torch.ops, self.test_ns).foo_functional.default,
686                (x, y, z, w),
687            )
688            self._check_is_functional_variant(
689                getattr(torch.ops, self.test_ns).foo.default,
690                getattr(torch.ops, self.test_ns).foo_functional.default,
691                (x, y, z, None),
692            )
693
694    def test_register_functional_op_one_return(self):
695        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
696            lib.define(
697                "foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor"
698            )
699
700            def foo_impl(x, y, z, w):
701                y.fill_(3.14)
702                w.fill_(2.71)
703                z.fill_(0.99)
704                return x.clone()
705
706            lib.impl("foo", foo_impl, "CPU")
707            register_functional_op(
708                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
709            )
710            x = torch.randn([])
711            y = torch.randn([])
712            z = torch.randn([])
713            w = torch.randn([])
714            self._check_is_functional_variant(
715                getattr(torch.ops, self.test_ns).foo.default,
716                getattr(torch.ops, self.test_ns).foo_functional.default,
717                (x, y, z, w),
718            )
719
720    def test_register_functional_op_multiple_returns(self):
721        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
722            lib.define(
723                "foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)"
724            )
725
726            def foo_impl(x, y, z, w):
727                y.fill_(3.14)
728                w.fill_(2.71)
729                return x.clone(), z.clone()
730
731            lib.impl("foo", foo_impl, "CPU")
732            register_functional_op(
733                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
734            )
735
736            x = torch.randn([])
737            y = torch.randn([])
738            z = torch.randn([])
739            w = torch.randn([])
740            self._check_is_functional_variant(
741                getattr(torch.ops, self.test_ns).foo.default,
742                getattr(torch.ops, self.test_ns).foo_functional.default,
743                (x, y, z, w),
744            )
745
746    def test_register_fallthrough(self):
747        with _scoped_library("aten", "IMPL") as my_lib:
748            my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
749
750            a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
751            b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
752            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
753                # dtype for mm should be float32 since we registered a fallthrough
754                self.assertEqual(torch.mm(a, b).dtype, torch.float32)
755                # ops that don't have a fallthrough registered should not be affected
756                self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
757
758        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
759            # default behavior should have been restored
760            self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
761
762
763class TestPythonDispatch(TestCase):
764    def test_basic(self) -> None:
765        with capture_logs() as logs:
766            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
767            log_input("x", x)
768            y = x * x
769            saved_x = y.grad_fn._saved_self
770            grad_y = LoggingTensor(torch.tensor([1.0]))
771            log_input("grad_y", grad_y)
772            (g,) = torch.autograd.grad((y,), (x,), (grad_y,))
773
774        self.assertEqual(g.elem, torch.tensor([6.0]))
775        with torch.no_grad():
776            self.assertEqual(saved_x, x)
777            self.assertEqual(saved_x._version, x._version)
778            x.add_(2)
779            self.assertEqual(saved_x, x)
780            # TODO: figure out why broken
781            # self.assertEqual(saved_x._version, x._version)
782        self.assertExpectedInline(
783            "\n".join(logs),
784            """\
785$0: f32[1] = input('x')
786$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
787$2: f32[1] = input('grad_y')
788$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
789$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
790$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
791        )
792
793    def test_out(self) -> None:
794        with capture_logs() as logs:
795            x = LoggingTensor(torch.ones(1))
796            y = LoggingTensor(torch.zeros(1))
797            log_input("x", x)
798            log_input("y", y)
799            torch.abs(x, out=y)
800
801        self.assertEqual(y.elem, torch.ones(1))
802        # TODO: arguably this shouldn't pass and we should complain
803        # that out isn't a kwarg
804        self.assertExpectedInline(
805            "\n".join(logs),
806            """\
807$0: f32[1] = input('x')
808$1: f32[1] = input('y')
809$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
810        )
811
812    def test_kwarg_only(self) -> None:
813        with capture_logs() as logs:
814            x = LoggingTensor(torch.ones(1))
815            y = LoggingTensor(torch.ones(1, 1))
816            z = LoggingTensor(torch.ones(1))
817            log_input("x", x)
818            log_input("y", y)
819            log_input("z", z)
820            torch.addmv(x, y, z)
821            torch.addmv(x, y, z, beta=1)
822            torch.addmv(x, y, z, beta=2)
823            torch.addmv(x, y, z, alpha=2)
824            torch.addmv(x, y, z, beta=2, alpha=2)
825
826        # The expectation is that beta/alpha don't show up when they're
827        # defaulted.  This is even if the user explicitly specified it.
828        self.assertExpectedInline(
829            "\n".join(logs),
830            """\
831$0: f32[1] = input('x')
832$1: f32[1, 1] = input('y')
833$2: f32[1] = input('z')
834$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
835$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
836$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
837$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
838$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
839        )
840
841    def test_kwarg_only_and_positional_default(self) -> None:
842        with capture_logs() as logs:
843            x = LoggingTensor(torch.ones(1))
844            log_input("x", x)
845            torch.ops.aten._foobar(x)
846            torch.ops.aten._foobar(x, False)
847            torch.ops.aten._foobar(x, arg3=False)
848            torch.ops.aten._foobar(x, False, arg3=False)
849
850        # What we are testing here is that we omit arg2
851        # if it is defaulted, even if a kwarg is set
852        self.assertExpectedInline(
853            "\n".join(logs),
854            """\
855$0: f32[1] = input('x')
856$1: f32[1] = torch._ops.aten._foobar.default($0)
857$2: f32[1] = torch._ops.aten._foobar.default($0, False)
858$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
859$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
860        )
861
862    def test_produce_real_type(self) -> None:
863        with capture_logs() as logs:
864            x = LoggingTensor(torch.ones(2, 2))
865            log_input("x", x)
866            x.to(dtype=torch.double)  # non-optional dtype
867            torch.cumprod(x, 0, dtype=torch.double)  # optional dtype
868            x[:, 1].contiguous(
869                memory_format=torch.contiguous_format
870            )  # optional memory format
871            # There doesn't appear to be any layout signatures which are
872            # triggerable using tensor subclasses (need to use a mode)
873
874        self.assertExpectedInline(
875            "\n".join(logs),
876            """\
877$0: f32[2, 2] = input('x')
878$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
879$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
880$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
881$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
882$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
883        )
884
885    def test_optional_tensor_list(self) -> None:
886        def weird(xs):
887            print("woof")
888            return torch.empty(())
889
890        with _scoped_library("my_lib", "DEF") as my_lib:
891            my_lib.define("weird(Tensor?[] self) -> Tensor")
892            my_lib.impl("weird", weird, "CPU")
893            with capture_logs() as logs:
894                x = LoggingTensor(torch.ones(2, 2))
895                log_input("x", x)
896                torch.ops.my_lib.weird.default([None, x])
897
898        self.assertExpectedInline(
899            "\n".join(logs),
900            """\
901$0: f32[2, 2] = input('x')
902$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
903        )
904
905    def test_list_ret(self) -> None:
906        # test all sequence types are permissible returns
907        for list_type in (list, tuple):
908
909            class A(torch.Tensor):
910                @staticmethod
911                def __new__(cls, elem):
912                    return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
913
914                @classmethod
915                def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
916                    if func.overloadpacket == torch.ops.aten.split:
917                        with no_dispatch():
918                            return list_type(torch.split(*args))
919                    else:
920                        raise AssertionError(f"unrecognized func: {func}")
921
922            self.assertEqual(
923                torch.split(A(torch.tensor([0, 1])), 2),
924                torch.split(torch.tensor([0, 1]), 2),
925            )
926
927    def test_invalid_ret(self) -> None:
928        # test invalid return gets reasonable error message
929        class A(torch.Tensor):
930            @staticmethod
931            def __new__(cls, elem):
932                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
933
934            @classmethod
935            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
936                return "arf"
937
938        # Wobbles depending on NDEBUG mode of pybind11
939        self.assertRaisesRegex(
940            RuntimeError,
941            "Unable to cast",
942            lambda: A(torch.zeros(1)).neg(),
943        )
944        self.assertRaisesRegex(
945            RuntimeError,
946            "Unable to cast",
947            lambda: A(torch.zeros(1)).detach(),
948        )
949
950    def test_detach_appears_twice_when_called_once(self) -> None:
951        with capture_logs() as logs:
952            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
953            log_input("x", x)
954            x.detach()
955        # FIXME: We actually want this to emit a single detach. However,
956        # it currently emits two, for reasons unclear to us. Leaving
957        # this test here to make sure we don't regress even further (it
958        # would be bad if calling .detach() once emits 3+ detaches).
959        self.assertExpectedInline(
960            "\n".join(logs),
961            """\
962$0: f32[1] = input('x')
963$1: f32[1] = torch._ops.aten.detach.default($0)
964$2: f32[1] = torch._ops.aten.detach.default($1)""",
965        )
966
967    def test_storage(self) -> None:
968        # For now, just make sure it doesn't crash.  Ideally, we should
969        # return some virtual storage that is safe to work with
970        x = LoggingTensor(torch.ones(1))
971        storage = x.untyped_storage()
972        self.assertRaises(RuntimeError, lambda: storage.data_ptr())
973
974    def test_make_wrapper_subclass_noalloc(self) -> None:
975        # This is ludicrously big (8TB) and this should pass because wrapper
976        # subclasses don't allocate
977        torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
978
979    def test_version(self) -> None:
980        x = LoggingTensor(torch.ones(1))
981        prev_vc = x._version
982        x.detach().add_(2)
983        cur_vc = x._version
984        self.assertNotEqual(prev_vc, cur_vc)
985        x.data.add_(2)
986        self.assertEqual(cur_vc, x._version)
987
988    def test_subclass_priority(self) -> None:
989        class ErrorA(RuntimeError):
990            pass
991
992        class ErrorB(RuntimeError):
993            pass
994
995        # The big tests for code coverage are test_precedence_semantics in
996        # test_overrides.py; this is just to make sure it is wired up at all
997        # correctly for __torch_dispatch__
998        class A(torch.Tensor):
999            @staticmethod
1000            def __new__(cls, elem):
1001                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1002
1003            @classmethod
1004            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1005                raise ErrorA
1006
1007        class B(A):
1008            @staticmethod
1009            def __new__(cls, elem):
1010                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1011
1012            @classmethod
1013            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1014                raise ErrorB
1015
1016        self.assertRaises(
1017            ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
1018        )
1019        self.assertRaises(
1020            ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
1021        )
1022        self.assertRaises(
1023            ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
1024        )
1025        self.assertRaises(
1026            ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
1027        )
1028
1029    def test_format(self) -> None:
1030        x = LoggingTensor(torch.ones(1))
1031        s1 = str(x)
1032        s2 = repr(x)
1033        s3 = f"{x}"
1034        self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
1035        self.assertEqual(s1, s2)
1036        self.assertEqual(s1, s3)
1037
1038    def test_custom_autograd(self) -> None:
1039        escape = [None]
1040
1041        class Square(torch.autograd.Function):
1042            @staticmethod
1043            def forward(ctx, x):
1044                y = x**2
1045                ctx.save_for_backward(x)
1046                return y
1047
1048            @staticmethod
1049            def backward(ctx, grad_output):
1050                assert isinstance(grad_output, LoggingTensor)
1051                (x,) = ctx.saved_tensors
1052                assert isinstance(x, LoggingTensor)
1053                escape[0] = x
1054                return grad_output * 2 * x
1055
1056        with capture_logs() as logs:
1057            x = LoggingTensor(torch.ones(1), requires_grad=True)
1058            log_input("x", x)
1059            x.grad = LoggingTensor(torch.zeros(1))
1060            log_input("x.grad", x.grad)
1061            y = Square.apply(x)
1062            grad_output = LoggingTensor(torch.ones(1))
1063            log_input("grad_output", grad_output)
1064            y.backward(grad_output)
1065
1066        with torch.no_grad():
1067            self.assertEqual(escape[0], x)
1068            self.assertEqual(escape[0]._version, x._version)
1069            # TODO: figure out why x.requires_grad = False doesn't
1070            # trigger an error for LoggingTensor
1071            x.add_(2)
1072            self.assertEqual(escape[0], x)
1073            # TODO: figure out why this is broken
1074            # self.assertEqual(escape[0]._version, x._version)
1075
1076        self.assertExpectedInline(
1077            "\n".join(logs),
1078            """\
1079$0: f32[1] = input('x')
1080$1: f32[1] = input('x.grad')
1081$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
1082$3: f32[1] = input('grad_output')
1083$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
1084$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
1085$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
1086        )
1087
1088    def test_subclass_creation(self):
1089        # Make sure these statements runs without error
1090        # In particular checking that when internal detach returns
1091        # subclasses, these are cleanly overwritten.
1092        class Foo(torch.Tensor):
1093            pass
1094
1095        err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
1096        with self.assertRaisesRegex(RuntimeError, err_msg):
1097            a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
1098        with self.assertRaisesRegex(RuntimeError, err_msg):
1099            b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
1100        with self.assertRaisesRegex(RuntimeError, err_msg):
1101            Foo(LoggingTensor(torch.rand(2)))
1102
1103        with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
1104            torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
1105
1106    def test_new_ones(self) -> None:
1107        class MyTensor(torch.Tensor):
1108            @classmethod
1109            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1110                return MyTensor(3)
1111
1112        self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
1113
1114    def test_like(self) -> None:
1115        class MyTensor(torch.Tensor):
1116            @classmethod
1117            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1118                return MyTensor(3)
1119
1120        for f in ["empty", "ones", "rand", "randn", "zeros"]:
1121            f_name = f + "_like"
1122            self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
1123
1124        self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
1125        self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
1126
1127    def test_make_fx_with_subclass(self) -> None:
1128        def f(x, y):
1129            # Returns (TwoTensor, Tensor)
1130            return x * y, y + y
1131
1132        x_a = torch.zeros(4)
1133        x_b = torch.zeros(4)
1134        y = torch.ones(4)
1135
1136        # make_fx() is not responsible for unwrapping tensor subclass inputs,
1137        # so we do it manually here.
1138        # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
1139        # convention as f(*args). Unwrapping tensor subclass inputs can potentially change
1140        # the number of input args to the graph, breaking that assumption
1141        def f_to_trace(x_a, x_b, y):
1142            x = TwoTensor(x_a, x_b)
1143            out1, out2 = f(x, y)
1144            out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
1145            return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
1146
1147        fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
1148        self.assertExpectedInline(
1149            fx_g.code,
1150            """\
1151
1152
1153
1154def forward(self, x_a_1, x_b_1, y_1):
1155    mul = torch.ops.aten.mul.Tensor(x_a_1, y_1);  x_a_1 = None
1156    mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1);  x_b_1 = None
1157    add = torch.ops.aten.add.Tensor(y_1, y_1);  y_1 = None
1158    return (mul, mul_1, add)
1159    """,
1160        )
1161
1162    # See https://github.com/pytorch/pytorch/issues/117794
1163    def test_return_and_correct_aliasing_gives_correct_stride(self):
1164        t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
1165        x = torch.randn(2, 2)
1166        # slicing should result in the same stride for TwoTensor as a dense tensor would give
1167        self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
1168
1169    def test_make_wrapper_subclass_propagates_metadata(self) -> None:
1170        class WrapperTensor(torch.Tensor):
1171            elem: torch.Tensor
1172
1173            __slots__ = ["elem"]
1174
1175            @staticmethod
1176            def __new__(cls, elem, *args, **kwargs):
1177                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1178                    cls,
1179                    elem.size(),
1180                    dtype=elem.dtype,
1181                    layout=elem.layout,
1182                    device=elem.device,
1183                    requires_grad=elem.requires_grad,
1184                    strides=elem.stride(),
1185                    storage_offset=elem.storage_offset(),
1186                )
1187                r.elem = elem
1188                return r
1189
1190            @classmethod
1191            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1192                raise RuntimeError("NYI")
1193
1194        # non-contiguous strides, non-zero storage offset
1195        x = torch.randn(4, 6).t().diagonal(offset=2)
1196        y = WrapperTensor(x)
1197        self.assertEqual(y.size(), x.size())
1198        self.assertEqual(y.stride(), x.stride())
1199        self.assertEqual(y.storage_offset(), x.storage_offset())
1200
1201    def test_wrapper_subclass_serializes(self) -> None:
1202        with tempfile.TemporaryFile() as f:
1203            # purposefully use int64 to test non-default dtype
1204            x = LoggingTensor(torch.randperm(3))
1205            torch.save(x, f)
1206            f.seek(0)
1207            with torch.serialization.safe_globals([LoggingTensor]):
1208                x_loaded = torch.load(f)
1209            self.assertTrue(type(x_loaded) is type(x))
1210            self.assertEqual(x, x_loaded)
1211            self.assertEqual(x.elem, x_loaded.elem)
1212            self.assertFalse(x is x_loaded)
1213
1214    def test_deepcopy_wrapper_subclass(self) -> None:
1215        # purposefully use int64 to test non-default dtype
1216        x = LoggingTensor(torch.randperm(3))
1217        x_copy = deepcopy(x)
1218        self.assertTrue(type(x_copy) is type(x))
1219        self.assertEqual(x, x_copy)
1220        self.assertEqual(x.elem, x_copy.elem)
1221        self.assertFalse(x is x_copy)
1222
1223    def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(
1224        self,
1225    ) -> None:
1226        class MyWrapperTensor(torch.Tensor):
1227            elem: torch.Tensor
1228
1229            __slots__ = ["elem"]
1230
1231            @staticmethod
1232            def __new__(cls, elem, *args, **kwargs):
1233                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1234                    cls,
1235                    elem.size(),
1236                    dtype=elem.dtype,
1237                    layout=elem.layout,
1238                    device=elem.device,
1239                    requires_grad=elem.requires_grad,
1240                    strides=elem.stride(),
1241                    storage_offset=elem.storage_offset(),
1242                )
1243                r.elem = elem
1244                return r
1245
1246            @classmethod
1247            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1248                if func.overloadpacket.__name__ == "clone":
1249                    # Return a plain tensor from clone().
1250                    return args[0].elem.clone()
1251                raise RuntimeError("NYI")
1252
1253            # NB: The default Tensor.__torch_function__ implementation called for deepcopy
1254            # disables __torch_function__ by the time we get to clone(), so there is no need to
1255            # explicitly disable __torch_function__ for this subclass.
1256
1257        x = MyWrapperTensor(torch.randn(3))
1258        with self.assertRaisesRegex(
1259            RuntimeError,
1260            "for which cloning returns another instance of the same subclass",
1261        ):
1262            x_copy = deepcopy(x)
1263
1264    def test_deepcopy_non_wrapper_subclass(self) -> None:
1265        # Ensure correct error is thrown for common error cases.
1266        class SubTensorError1(torch.Tensor):
1267            # Default implementation of new_empty() returns a plain tensor.
1268            pass
1269
1270        class SubTensorError2(torch.Tensor):
1271            # new_empty() incorrectly returns a different type (i.e. a plain tensor).
1272            def new_empty(self, shape):
1273                return torch.Tensor(shape)
1274
1275        for error_cls in [SubTensorError1, SubTensorError2]:
1276            x = error_cls(3)
1277            with self.assertRaisesRegex(
1278                RuntimeError,
1279                "for which that function returns another instance of the same subclass",
1280            ):
1281                x_copy = deepcopy(x)
1282
1283        # Ensure a correctly implemented new_empty() causes deepcopy() to work.
1284        class SubTensorSuccess(torch.Tensor):
1285            def new_empty(self, shape):
1286                return type(self)(shape)
1287
1288        x = SubTensorSuccess(3)
1289        x_copy = deepcopy(x)
1290        self.assertIs(type(x_copy), type(x))
1291
1292    def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
1293        class ExtraKeysTensor(torch.Tensor):
1294            @staticmethod
1295            def __new__(cls, elem, *args, **kwargs):
1296                # NB: only the non-kwarg overload of _make_wrapper_subclass supports
1297                #     extra dispatch keys. We probably want to unify the two APIs
1298                #     in the future.
1299                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1300                    cls,
1301                    elem.size(),
1302                    elem.stride(),
1303                    elem.storage_offset(),
1304                    torch.contiguous_format,
1305                    elem.dtype,
1306                    elem.layout,
1307                    elem.device,
1308                    False,
1309                    False,
1310                    None,
1311                    False,
1312                    False,
1313                    DispatchKeySet(DispatchKey.NestedTensor),
1314                )
1315                return r
1316
1317            @classmethod
1318            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1319                pass
1320
1321        x = ExtraKeysTensor(torch.randn(3))
1322        self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
1323        self.assertFalse(
1324            torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor)
1325        )
1326
1327    def test_wrapper_subclass_multiprocessing_preserves_dtype(self):
1328        # a and b have dtype of int64, which is purposefully different from the default
1329        # assumed by _make_wrapper_subclass().
1330        a = torch.randperm(5)
1331        b = torch.randperm(5)
1332        data = TwoTensor(a, b)
1333        expected_dtype = data.dtype
1334
1335        loader = torch.utils.data.DataLoader(
1336            [data, data],
1337            batch_size=2,
1338            num_workers=2,
1339            collate_fn=_identity,
1340        )
1341        for batch in loader:
1342            self.assertEqual(batch[0].dtype, expected_dtype)
1343
1344    def test_index_put_where_only_index_is_subclass(self) -> None:
1345        called_funcs = []
1346
1347        class MyTensor(torch.Tensor):
1348            elem: torch.Tensor
1349            __slots__ = ["elem"]
1350
1351            @staticmethod
1352            def __new__(cls, elem, *args, **kwargs):
1353                r = torch.Tensor._make_wrapper_subclass(
1354                    cls,
1355                    elem.size(),
1356                    dtype=elem.dtype,
1357                    layout=elem.layout,
1358                    device=elem.device,
1359                    requires_grad=elem.requires_grad,
1360                )
1361                r.elem = elem
1362                return r
1363
1364            @classmethod
1365            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1366                called_funcs.append(func)
1367                return MyTensor(torch.tensor(3))
1368
1369        x = torch.randn(3, 3)
1370        idxs = (MyTensor(torch.tensor(0)),)
1371        v = torch.randn(1)
1372        res = x.index_put_(idxs, v)
1373        self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
1374
1375    def test_torch_dispatch_mode_basic(self) -> None:
1376        with capture_logs(is_mode=True) as logs:
1377            with LoggingTensorMode():
1378                torch.empty([])
1379        self.assertExpectedInline(
1380            "\n".join(logs),
1381            """\
1382$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
1383        )
1384
1385    def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
1386        x = torch.randn([])
1387        y = torch.randn([])
1388        with capture_logs(is_mode=True) as logs:
1389            with LoggingTensorMode():
1390                x + y
1391        self.assertExpectedInline(
1392            "\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)"""
1393        )
1394
1395    def test_nested_push_logging_tensor_mode(self):
1396        x = torch.randn([])
1397        y = torch.randn([])
1398        with capture_logs(is_mode=True) as logs:
1399            with LoggingTensorMode():
1400                with LoggingTensorMode():
1401                    torch.empty([])
1402                    x + y
1403
1404        self.assertExpectedInline(
1405            "\n".join(logs),
1406            """\
1407$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1408$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1409$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
1410$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1411        )
1412
1413    def test_capture_logs_with_torch_dispatch_mode(self):
1414        x = torch.randn([])
1415        y = torch.randn([])
1416        with capture_logs_with_logging_tensor_mode() as logs:
1417            torch.empty([])
1418            x + y
1419        self.assertExpectedInline(
1420            "\n".join(logs),
1421            """\
1422$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1423$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1424        )
1425
1426        x = torch.randn([])
1427        y = torch.randn([])
1428
1429        with capture_logs_with_logging_tensor_mode() as logs1:
1430            with capture_logs_with_logging_tensor_mode() as logs2:
1431                torch.empty([])
1432                x + y
1433
1434        self.assertExpectedInline(
1435            "\n".join(logs2),
1436            """\
1437$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1438$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1439$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
1440$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1441        )
1442
1443        self.assertEqual(logs1, logs2)
1444
1445    def test_torch_dispatch_mode_subclass_priority(self) -> None:
1446        class ErrorA(RuntimeError):
1447            pass
1448
1449        class ErrorB(RuntimeError):
1450            pass
1451
1452        class A(torch.Tensor):
1453            @staticmethod
1454            def __new__(cls, elem):
1455                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1456
1457            @classmethod
1458            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1459                with AMode():
1460                    raise ErrorA
1461
1462        class B(A):
1463            @staticmethod
1464            def __new__(cls, elem):
1465                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1466
1467            @classmethod
1468            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1469                with BMode():
1470                    func(*args, **kwargs)
1471
1472        class AMode(TorchDispatchMode):
1473            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1474                raise ErrorA
1475
1476        class BMode(TorchDispatchMode):
1477            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1478                raise ErrorB
1479
1480        a = A(torch.empty(1))
1481        b = B(torch.empty(1))
1482        with self.assertRaises(ErrorA):
1483            a + a
1484        with self.assertRaises(ErrorB):
1485            a + b
1486
1487        # B has precedence over A due to the subclass relationship yet
1488        # modes take precedence over arguments
1489        with self.assertRaises(ErrorA):
1490            with AMode():
1491                b + b
1492        with self.assertRaises(ErrorB):
1493            with BMode():
1494                a + a
1495        with self.assertRaises(ErrorB):
1496            with BMode():
1497                a + b
1498
1499    def test_mode_with_make_subclass(self):
1500        class SubTensor(torch.Tensor):
1501            @staticmethod
1502            def __new__(cls, elem):
1503                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1504
1505        class BasicMode(TorchDispatchMode):
1506            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1507                return func(*args, **kwargs)
1508
1509        x = torch.randn(3)
1510        with BasicMode():
1511            y = SubTensor(x)
1512        self.assertIsInstance(y, SubTensor)
1513
1514    def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
1515        with capture_logs(is_mode=True) as logs1:
1516            with LoggingTensorMode():
1517                torch.ones([2, 3])
1518                with no_dispatch():
1519                    torch.ones([2, 3])
1520        with capture_logs(is_mode=True) as logs2:
1521            with LoggingTensorMode():
1522                torch.ones([2, 3])
1523        self.assertEqual(logs1, logs2)
1524
1525    def test_shallow_copy_and_detach(self) -> None:
1526        seen = set()
1527        test_case = self
1528
1529        class TestMode(TorchDispatchMode):
1530            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1531                tree_map_only(
1532                    torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs)
1533                )
1534                if kwargs is None:
1535                    kwargs = {}
1536                r = func(*args, **kwargs)
1537                tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
1538                return r
1539
1540        with TestMode():
1541            x = torch.randn(3, requires_grad=True)
1542            loss = (x * x).sum()
1543            loss.backward()
1544
1545    def test_exception_handling(self):
1546        class A(torch.Tensor):
1547            @staticmethod
1548            def __new__(cls, elem):
1549                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1550
1551        class AMode(TorchDispatchMode):
1552            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1553                if func.__name__ == "randn.default":
1554                    raise RuntimeError
1555                return A(torch.zeros(()))
1556
1557        with AMode():
1558            try:
1559                torch.randn(())
1560            except RuntimeError:
1561                pass
1562            self.assertTrue(isinstance(torch.zeros(()), A))
1563
1564    def test_with_mode_created_separately(self):
1565        class ErrorA(RuntimeError):
1566            pass
1567
1568        class A(TorchDispatchMode):
1569            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1570                raise ErrorA
1571
1572        x = A()
1573        with self.assertRaises(ErrorA):
1574            with x:
1575                torch.empty([])
1576
1577    def test_with_nested_modes(self):
1578        class ErrorA(RuntimeError):
1579            def __init__(self, msg):
1580                super().__init__(msg)
1581
1582        class A(TorchDispatchMode):
1583            def __init__(self, msg):
1584                self.msg = msg
1585
1586            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1587                raise ErrorA(self.msg)
1588
1589        with self.assertRaisesRegex(ErrorA, "layer2"):
1590            with A("layer1"):
1591                with A("layer2"):
1592                    torch.empty([])
1593
1594    def test_make_subclass_with_modes(self):
1595        class ModeTensor(torch.Tensor):
1596            def __new__(cls, elem, mode):
1597                r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1598                r.elem = elem
1599                r.mode = mode
1600                return r
1601
1602            @classmethod
1603            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1604                raise NotImplementedError("Shouldn't be here")
1605
1606        class Mode(TorchDispatchMode):
1607            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1608                def unwrap(e):
1609                    if isinstance(e, ModeTensor):
1610                        return e.elem
1611                    else:
1612                        return e
1613
1614                def wrap(t):
1615                    if isinstance(t, torch.Tensor):
1616                        return ModeTensor(t, self)
1617                    else:
1618                        return t
1619
1620                return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
1621
1622        class BasicMode(TorchDispatchMode):
1623            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1624                return func(*args, **kwargs)
1625
1626        x = torch.tensor(4.0)
1627        with Mode():
1628            y = x + x
1629            z = y + y
1630        self.assertIsInstance(y, ModeTensor)
1631        self.assertIsInstance(z, ModeTensor)
1632
1633        with Mode():
1634            with BasicMode():  # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
1635                y = x + x
1636                z = y + y
1637        self.assertIsInstance(y, ModeTensor)
1638        self.assertIsInstance(z, ModeTensor)
1639
1640        assert self.assertRaisesRegex(
1641            RuntimeError,
1642            "subclass Mode but.* associated to a python object of type Mode",
1643        )
1644
1645    def test_notimplemented_mode(self):
1646        sub_count = 0
1647
1648        class PoliteMode(TorchDispatchMode):
1649            def __init__(self) -> None:
1650                self.pre_count = 0
1651                self.post_count = 0
1652
1653            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1654                self.pre_count += 1
1655                if any(t is not torch.Tensor for t in types):
1656                    return NotImplemented
1657                self.post_count += 1
1658                return func(*args, **kwargs)
1659
1660        class SubTensor(torch.Tensor):
1661            def __new__(cls, elem):
1662                r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
1663                r.elem = elem
1664                return r
1665
1666            @classmethod
1667            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1668                nonlocal sub_count
1669                sub_count += 1
1670
1671                def unwrap(t):
1672                    if isinstance(t, SubTensor):
1673                        return t.elem
1674                    else:
1675                        return t
1676
1677                return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1678
1679        a = SubTensor(torch.randn(2))
1680        with PoliteMode() as mode:
1681            a.abs()
1682
1683        self.assertEqual(mode.pre_count, 2)
1684        self.assertEqual(mode.post_count, 1)
1685        self.assertEqual(sub_count, 1)
1686
1687        # make sure this doesn't error
1688        with PoliteMode():
1689            with PoliteMode():
1690                a.abs()
1691
1692    def test_nesting_same_mode(self):
1693        # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
1694
1695        with capture_logs(is_mode=True) as logs:
1696            with LoggingTensorMode() as reenabled:
1697                with reenabled:
1698                    torch.empty([])
1699            self.assertExpectedInline(
1700                "\n".join(logs),
1701                """\
1702$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1703$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
1704            )
1705
1706    def test_error_using_class_method_on_mode(self):
1707        class A(TorchDispatchMode):
1708            @classmethod
1709            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1710                return func(args, kwargs)
1711
1712        x = torch.tensor(5.0)
1713        with self.assertRaisesRegex(
1714            RuntimeError, "classmethod is not supported, please make it a plain method"
1715        ):
1716            with A():
1717                x + x
1718
1719    def test_get_cur_mode(self):
1720        class A(TorchDispatchMode):
1721            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1722                pass
1723
1724        self.assertEqual(_get_current_dispatch_mode(), None)
1725
1726        with A() as mode1:
1727            self.assertEqual(_get_current_dispatch_mode(), mode1)
1728
1729        with mode1:
1730            with A() as mode2:
1731                self.assertEqual(_get_current_dispatch_mode(), mode2)
1732
1733    def test_get_mode_stack(self):
1734        class A(TorchDispatchMode):
1735            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1736                pass
1737
1738        self.assertEqual(_get_current_dispatch_mode_stack(), [])
1739
1740        with A() as mode1:
1741            self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
1742
1743        with mode1:
1744            with A() as mode2:
1745                self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
1746
1747    def test_all_same_mode(self):
1748        x = LoggingTensorMode()
1749        y = LoggingTensorMode()
1750        self.assertTrue(all_same_mode([x, x, x]))
1751        self.assertFalse(all_same_mode([x, None]))
1752        self.assertFalse(all_same_mode([x, y]))
1753
1754    def test_mode_detection(self):
1755        class InfraMode(TorchDispatchMode):
1756            @classmethod
1757            def is_infra_mode(cls):
1758                return True
1759
1760        class NonInfraMode(TorchDispatchMode):
1761            pass
1762
1763        with InfraMode():
1764            self.assertTrue(is_in_torch_dispatch_mode())
1765            self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1766            with NonInfraMode():
1767                self.assertTrue(is_in_torch_dispatch_mode())
1768                self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
1769                with InfraMode():
1770                    self.assertTrue(is_in_torch_dispatch_mode())
1771                    self.assertTrue(
1772                        is_in_torch_dispatch_mode(include_infra_modes=False)
1773                    )
1774
1775                self.assertTrue(is_in_torch_dispatch_mode())
1776                self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
1777            self.assertTrue(is_in_torch_dispatch_mode())
1778            self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1779
1780        self.assertFalse(is_in_torch_dispatch_mode())
1781        self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1782
1783    def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
1784        x = LoggingTensor(torch.tensor([2.0, 3.0]))
1785        with self.assertRaisesRegex(
1786            RuntimeError, "is not supported for tensor subclasses."
1787        ):
1788            x.tolist()
1789        with self.assertRaisesRegex(
1790            RuntimeError, "is not supported for tensor subclasses."
1791        ):
1792            x.numpy()
1793        with self.assertRaises(AssertionError):
1794            self.assertEqual(x, None)
1795
1796    def test_record_stream(self) -> None:
1797        class TestMode(TorchDispatchMode):
1798            def __init__(self, testcase):
1799                self.testcase = testcase
1800
1801            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1802                self.testcase.assertEqual(func.name(), "aten::record_stream")
1803                self.testcase.assertIsInstance(args[0], torch.Tensor)
1804                self.testcase.assertIsInstance(args[1], torch.Stream)
1805                self.testcase.assertEqual(args[1].stream_id, 1)
1806                self.testcase.assertEqual(args[1].device_index, 2)
1807                self.testcase.assertEqual(args[1].device_type, 3)
1808
1809        t = torch.tensor(5.0)
1810        s = torch.Stream(stream_id=1, device_index=2, device_type=3)
1811        with TestMode(self):
1812            t.record_stream(s)
1813
1814    def test_return_stream(self) -> None:
1815        with _scoped_library("test_return_stream", "DEF") as l_def:
1816            l_def.define("return_stream(Tensor self) -> Stream")
1817            with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
1818                l_impl.impl(
1819                    "return_stream",
1820                    lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2),
1821                )
1822
1823                class TestMode(TorchDispatchMode):
1824                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1825                        return torch.Stream(stream_id=1, device_index=2, device_type=3)
1826
1827                t = torch.tensor(5.0)
1828                s = torch.ops.test_return_stream.return_stream(t)
1829                self.assertIsInstance(s, torch.Stream)
1830                self.assertEqual(s.stream_id, 0)
1831                self.assertEqual(s.device_index, 1)
1832                self.assertEqual(s.device_type, 2)
1833
1834                with TestMode():
1835                    s = torch.ops.test_return_stream.return_stream(t)
1836                self.assertIsInstance(s, torch.Stream)
1837                self.assertEqual(s.stream_id, 1)
1838                self.assertEqual(s.device_index, 2)
1839                self.assertEqual(s.device_type, 3)
1840
1841    def test_subclass_autograd_device_check(self) -> None:
1842        class NonWrapperSubclass(torch.Tensor):
1843            elem: torch.Tensor
1844
1845            __slots__ = ["elem"]
1846
1847            @staticmethod
1848            def __new__(cls, elem, *args, **kwargs):
1849                # Wrong device here!
1850                r = torch.Tensor._make_subclass(
1851                    cls, elem.to("meta"), elem.requires_grad
1852                )
1853                # ...the real tensor is held as an element on the tensor.
1854                r.elem = elem
1855                return r
1856
1857            @classmethod
1858            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1859                def unwrap(e):
1860                    return e.elem if isinstance(e, NonWrapperSubclass) else e
1861
1862                def wrap(e):
1863                    return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
1864
1865                rs = tree_map(
1866                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1867                )
1868                logging.getLogger("NonWrapperSubclass").info(
1869                    f"{func.__module__}.{func.__name__}",  # noqa: G004
1870                    args,
1871                    kwargs,
1872                    rs,
1873                )
1874                return rs
1875
1876        x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
1877        y = torch.randn(2, requires_grad=True)
1878        z = x * y
1879        self.assertIsInstance(z, NonWrapperSubclass)
1880        z.sum().backward(torch.tensor(1))
1881        self.assertEqual(x.grad, y)
1882        self.assertEqual(y.grad, x)
1883
1884    def test_none_wrapping(self):
1885        # A Tensor subclass that returns None when doing add
1886        # See LoggingTensor above for more details on the subclass
1887        class SubclassWithNone(torch.Tensor):
1888            @staticmethod
1889            def __new__(cls, elem, *args, **kwargs):
1890                r = torch.Tensor._make_wrapper_subclass(
1891                    cls,
1892                    elem.size(),
1893                    dtype=elem.dtype,
1894                    layout=elem.layout,
1895                    device=elem.device,
1896                    requires_grad=elem.requires_grad,
1897                )
1898                r.elem = elem
1899                return r
1900
1901            @classmethod
1902            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1903                def unwrap(e):
1904                    return e.elem if isinstance(e, SubclassWithNone) else e
1905
1906                def wrap(e):
1907                    return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
1908
1909                rs = tree_map(
1910                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1911                )
1912                if func.overloadpacket.__name__ == "add":
1913                    return None
1914                else:
1915                    return rs
1916
1917        x = SubclassWithNone(torch.rand(2))
1918        # Make sure both run without error
1919        self.assertIsInstance(x * 2, SubclassWithNone)
1920        self.assertIsNone(x + 2)
1921
1922        x.requires_grad_()
1923        out = x.acos().sum()
1924
1925        # The backward of acos does add then rsqrt so here we make sure that the
1926        # undefined Tensor generated by the user code is nicely handled.
1927        # If acos formula changes in the future, this can be replaced by any other
1928        # function that does add then something in the backward in a composite way
1929        with self.assertRaisesRegex(RuntimeError, "but got None"):
1930            out.backward()
1931
1932    def test_storage_can_be_converted_to_python_object(self):
1933        s = torch.Storage()
1934        z = LoggingTensor(torch.empty([]))
1935        z.set_(s)
1936
1937    def test_autograd_in_attr(self):
1938        # We want the wrapped Tensor to require gradients!
1939        true_t = torch.rand(2, requires_grad=True)
1940        t = LoggingTensorReentrant(true_t)
1941
1942        out = t + 2
1943
1944        self.assertFalse(out.requires_grad)
1945        self.assertIsNone(out.grad_fn)
1946
1947        self.assertTrue(out.elem.requires_grad)
1948        self.assertIsNotNone(out.elem.grad_fn)
1949
1950        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1951            out.sum().backward()
1952
1953        out.elem.sum().backward()
1954
1955        self.assertIsNone(t.grad)
1956        self.assertIsNotNone(t.elem.grad)
1957
1958    def test_dispatch_super_call(self):
1959        called = []
1960
1961        class SubTensor(torch.Tensor):
1962            @staticmethod
1963            def __new__(cls, elem):
1964                return torch.Tensor._make_subclass(cls, elem)
1965
1966            @classmethod
1967            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1968                called.append(func)
1969                return super().__torch_dispatch__(func, types, args, kwargs)
1970
1971        x = torch.randn(2)
1972        y = torch.randn(2)
1973        self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
1974        self.assertEqual(called, [torch.ops.aten.add.Tensor])
1975
1976    def test_dispatch_super_call_list_arg(self):
1977        called = []
1978
1979        class SubTensorWithListArg(torch.Tensor):
1980            @staticmethod
1981            def __new__(cls, elem):
1982                return torch.Tensor._make_subclass(cls, elem)
1983
1984            @classmethod
1985            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1986                called.append(func)
1987                return super().__torch_dispatch__(func, types, list(args), kwargs)
1988
1989        x = torch.randn(2)
1990        self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
1991        self.assertEqual(called, [torch.ops.aten.neg.default])
1992
1993    def test_dispatch_super_dont_autograd(self):
1994        called = []
1995
1996        class SubTensor(torch.Tensor):
1997            @staticmethod
1998            def __new__(cls, elem):
1999                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
2000
2001            @classmethod
2002            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2003                called.append(func)
2004                # This argument still requires grad because it was passed
2005                # through directly...
2006                self.assertTrue(args[0].requires_grad)
2007                r = super().__torch_dispatch__(func, types, args, kwargs)
2008                # But the output better not require grad, because that means
2009                # you did autograd again in torch dispatch (oops)
2010                self.assertFalse(r.requires_grad)
2011                return r
2012
2013        x = SubTensor(torch.randn(2, requires_grad=True))
2014        x.neg()
2015        self.assertEqual(called, [torch.ops.aten.neg.default])
2016
2017    def test_set_data(self):
2018        called = 0
2019
2020        class SubTensor(torch.Tensor):
2021            @classmethod
2022            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2023                nonlocal called
2024                called += 1
2025                return super().__torch_dispatch__(func, types, args, kwargs)
2026
2027        x = SubTensor(torch.empty(2))
2028        x.data
2029        self.assertEqual(called, 1)
2030        x.data = torch.empty(2)
2031        self.assertEqual(called, 1)
2032        x.data
2033        self.assertEqual(called, 2)
2034        self.assertIs(type(x), SubTensor)
2035        x.set_(torch.empty(2))
2036        self.assertEqual(called, 3)
2037        x.data
2038        self.assertEqual(called, 4)
2039        self.assertIs(type(x), SubTensor)
2040
2041    def test_construct_int_tensor(self):
2042        class SubTensor(torch.Tensor):
2043            pass
2044
2045        # should not fail
2046        SubTensor(torch.zeros(2, dtype=torch.int))
2047
2048    def test_multiple_ops_subclass(self):
2049        # This is a Direct Subclass, don't do that!
2050        class MySubclass(torch.Tensor):
2051            @staticmethod
2052            def __new__(cls, elem):
2053                r = torch.Tensor._make_subclass(cls, elem)
2054                return r
2055
2056            @classmethod
2057            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2058                with no_dispatch():
2059                    return func(*args, **kwargs)
2060
2061        x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
2062        y = x.conj()
2063        # Details of the bug that this tests for:
2064        # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
2065        # There are a few calls to the dispatcher that are going to happen here:
2066        #  - call_exp: User calling exp on y
2067        #    - PythonTLSSnapshot: records the TLS on entry and redispatch
2068        #    - AutogradCPU: no input requires grad, so does nothing and redispatch
2069        #    - Conjugate: no special implementation for exp: use the fallback that
2070        #                 first clone the Tensor (to materialize the conj) then redispatch
2071        #      - call_clone: conjugate fallback calling clone on y
2072        #        - PythonTLSSnapshot: records the TLS on entry and redispatch
2073        #        - (AutogradCPU: skipped as autograd added itself to the exclude set above)
2074        #        - Conjugate: special implementation for clone: just skip this key
2075        #        - Python: Reset the TLS based on the snapshot above and call the user implementation (this
2076        #                  actually calls into the dispatcher again but since we disable both our keys
2077        #                  before, not detailed here)
2078        #        - exit Python: restore the TLS and exit
2079        #        - exit Conjugate: nothing was inplace so just exit
2080        #        - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
2081        #    - Python: Reset the TLS again based on the snapshot. <- this used to fail
2082        #    - More steps....
2083        y.exp()
2084
2085    @staticmethod
2086    def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
2087        if use_wrapper_subclass:
2088            kwargs["device"] = data.device
2089            kwargs["dtype"] = data.dtype
2090            kwargs["layout"] = data.layout
2091            kwargs["requires_grad"] = True
2092            return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)  # type: ignore[attr-defined]
2093        else:
2094            return torch.Tensor._make_subclass(cls, data, True, **kwargs)
2095
2096    def test_is_contiguous_slow_path(self):
2097        data = torch.randn(3, 3)
2098        contiguous_data = data.clone()
2099        not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
2100
2101        for use_wrapper_subclass in [True, False]:
2102
2103            class ExampleTensor1(torch.Tensor):
2104                @staticmethod
2105                def __new__(cls, data, wrapper):
2106                    return TestPythonDispatch.subclass_helper(
2107                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2108                    )
2109
2110                @classmethod
2111                def __torch_dispatch__(cls, func, types, args, kwargs):
2112                    return NotImplemented
2113
2114            class ExampleTensor2(torch.Tensor):
2115                @staticmethod
2116                def __new__(cls, data, wrapper):
2117                    return TestPythonDispatch.subclass_helper(
2118                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2119                    )
2120
2121                @classmethod
2122                def __torch_dispatch__(cls, func, types, args, kwargs):
2123                    if func.overloadpacket == torch.ops.aten.is_contiguous:
2124                        return contiguous_data.is_contiguous()
2125                    return NotImplemented
2126
2127            class ExampleTensor3(torch.Tensor):
2128                @staticmethod
2129                def __new__(cls, data, wrapper):
2130                    return TestPythonDispatch.subclass_helper(
2131                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2132                    )
2133
2134                @classmethod
2135                def __torch_dispatch__(cls, func, types, args, kwargs):
2136                    if func.overloadpacket == torch.ops.aten.is_contiguous:
2137                        return not_contiguous_data.is_contiguous()
2138                    return NotImplemented
2139
2140            err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
2141            e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
2142            with self.assertRaisesRegex(TypeError, err_msg):
2143                e.is_contiguous()
2144            with self.assertRaisesRegex(TypeError, err_msg):
2145                e.contiguous()
2146
2147            e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
2148            self.assertEqual(e.is_contiguous(), True)
2149            e.contiguous()  # this will just return the original TensorImpl since is_contiguous = True
2150
2151            err_msg = "Multiple dispatch failed for"
2152            e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
2153            self.assertEqual(e.is_contiguous(), False)
2154            with self.assertRaisesRegex(TypeError, err_msg):
2155                e.contiguous()
2156
2157    def test_fancy_strides(self):
2158        calls = []
2159
2160        class ExampleTensor(torch.Tensor):
2161            @staticmethod
2162            def __new__(cls, data):
2163                return TestPythonDispatch.subclass_helper(
2164                    cls, data, False, dispatch_sizes_strides_policy="strides"
2165                )
2166
2167            @classmethod
2168            def __torch_dispatch__(cls, func, types, args, kwargs):
2169                if func in [
2170                    torch.ops.aten.is_contiguous.default,
2171                    torch.ops.aten.is_contiguous.memory_format,
2172                    torch.ops.aten.is_strides_like_format.default,
2173                    torch.ops.aten.is_non_overlapping_and_dense.default,
2174                    torch.ops.aten.stride.default,
2175                ]:
2176                    calls.append((func, list(args)[1:]))
2177                    return None
2178                with no_dispatch():
2179                    return func(*args, **kwargs)
2180
2181        e = ExampleTensor(torch.randn(2, 2))
2182        self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
2183        self.assertEqual(
2184            calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])]
2185        )
2186        calls.clear()
2187        self.assertFalse(
2188            torch.ops.aten.is_strides_like_format.default(e, torch.channels_last)
2189        )
2190        self.assertEqual(
2191            calls,
2192            [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])],
2193        )
2194        calls.clear()
2195        self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
2196        self.assertEqual(
2197            calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])]
2198        )
2199
2200    def test_device_slowpath(self):
2201        for use_wrapper_subclass in [True]:
2202
2203            class ExampleTensor1(torch.Tensor):
2204                @staticmethod
2205                def __new__(cls, data, wrapper):
2206                    return TestPythonDispatch.subclass_helper(
2207                        cls, data, wrapper, dispatch_device=True
2208                    )
2209
2210                @classmethod
2211                def __torch_dispatch__(cls, func, types, args, kwargs):
2212                    return NotImplemented
2213
2214            class ExampleTensor2(torch.Tensor):
2215                @staticmethod
2216                def __new__(cls, data, wrapper):
2217                    return TestPythonDispatch.subclass_helper(
2218                        cls, data, wrapper, dispatch_device=True
2219                    )
2220
2221                @classmethod
2222                def __torch_dispatch__(cls, func, types, args, kwargs):
2223                    if func.overloadpacket == torch.ops.prim.device:
2224                        return torch.device("meta")
2225                    return NotImplemented
2226
2227            class ExampleTensor3(torch.Tensor):
2228                @staticmethod
2229                def __new__(cls, data, wrapper):
2230                    return TestPythonDispatch.subclass_helper(
2231                        cls, data, wrapper, dispatch_device=True
2232                    )
2233
2234                @classmethod
2235                def __torch_dispatch__(cls, func, types, args, kwargs):
2236                    if func.overloadpacket == torch.ops.prim.device:
2237                        return torch.device("meta")
2238                    return NotImplemented
2239
2240            err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
2241            with self.assertRaisesRegex(TypeError, err_msg):
2242                e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
2243                e.device()
2244
2245            ten = torch.rand([1])
2246            e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
2247            self.assertEqual(e.device.type, "meta")
2248            self.assertEqual(ten.type_as(e).device.type, "meta")
2249
2250            e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
2251            self.assertEqual(e.device.type, "meta")
2252            self.assertEqual(ten.type_as(e).device.type, "meta")
2253
2254    def test_dim_slowpath(self):
2255        data = torch.randn(3, 3)
2256
2257        for use_wrapper_subclass in [True, False]:
2258
2259            class DimNotImplementedTensor(torch.Tensor):
2260                @staticmethod
2261                def __new__(cls, data, wrapper):
2262                    return TestPythonDispatch.subclass_helper(
2263                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2264                    )
2265
2266                @classmethod
2267                def __torch_dispatch__(cls, func, types, args, kwargs):
2268                    return NotImplemented
2269
2270            class DimImplementedTensor(torch.Tensor):
2271                @staticmethod
2272                def __new__(cls, data, wrapper):
2273                    return TestPythonDispatch.subclass_helper(
2274                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2275                    )
2276
2277                @classmethod
2278                def __torch_dispatch__(cls, func, types, args, kwargs):
2279                    if func.overloadpacket == torch.ops.aten.dim:
2280                        return data.dim()
2281                    return NotImplemented
2282
2283            err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
2284            e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
2285            with self.assertRaisesRegex(TypeError, err_msg):
2286                e.dim()
2287
2288            t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
2289            self.assertEqual(t.dim(), 2)
2290
2291    def test_maybe_tuple_bug(self):
2292        class T(torch.Tensor):
2293            @classmethod
2294            def __torch_function__(cls, *args, **kwargs):
2295                pass
2296
2297        a = torch.rand(3)
2298
2299        a[[T(), T()]]
2300
2301    def test_standard_is_not_subclass(self):
2302        # https://github.com/pytorch/pytorch/issues/79079
2303        self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
2304
2305    def test_sym_sizes_strides_slow_path(self):
2306        class TestTensor(torch.Tensor):
2307            @staticmethod
2308            def __new__(cls, *args, **kwargs):
2309                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
2310                    cls, (0,), dispatch_sizes_strides_policy="sizes"
2311                )
2312                return r
2313
2314            @classmethod
2315            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2316                if func in (
2317                    torch.ops.aten.sym_size.default,
2318                    torch.ops.aten.sym_stride.default,
2319                ):
2320                    from torch._dynamo.source import ConstantSource
2321                    from torch.fx.experimental.symbolic_shapes import (
2322                        DimDynamic,
2323                        ShapeEnv,
2324                    )
2325
2326                    shape_env = ShapeEnv()
2327                    si = shape_env.create_symintnode(
2328                        shape_env.create_symbol(
2329                            123,
2330                            source=ConstantSource("abc"),
2331                            dynamic_dim=DimDynamic.DUCK,
2332                            constraint_dim=None,
2333                        ),
2334                        hint=123,
2335                    )
2336                    return (si,)
2337
2338        t = TestTensor()
2339        si = t.size()[0]
2340        self.assertIsInstance(si, torch.SymInt)
2341        si = t.stride()[0]
2342        self.assertIsInstance(si, torch.SymInt)
2343
2344    def test_strides_slow_path(self):
2345        for use_wrapper_subclass in [True, False]:
2346
2347            class StridesNotImplemented(torch.Tensor):
2348                @staticmethod
2349                def __new__(cls, data, wrapper):
2350                    return TestPythonDispatch.subclass_helper(
2351                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2352                    )
2353
2354                @classmethod
2355                def __torch_dispatch__(cls, func, types, args, kwargs):
2356                    return NotImplemented
2357
2358            class StridesCustomReturn(torch.Tensor):
2359                @staticmethod
2360                def __new__(cls, data, wrapper):
2361                    return TestPythonDispatch.subclass_helper(
2362                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2363                    )
2364
2365                @classmethod
2366                def __torch_dispatch__(cls, func, types, args, kwargs):
2367                    if func == torch.ops.aten.sym_stride.default:
2368                        return (4, 2)
2369                    return NotImplemented
2370
2371            class StridesDefaultReturn(torch.Tensor):
2372                @staticmethod
2373                def __new__(cls, data, wrapper):
2374                    return TestPythonDispatch.subclass_helper(
2375                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2376                    )
2377
2378                @classmethod
2379                def __torch_dispatch__(cls, func, types, args, kwargs):
2380                    if func == torch.ops.aten.sym_stride.default:
2381                        return None
2382                    return NotImplemented
2383
2384            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
2385            e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2386            with self.assertRaisesRegex(TypeError, err_msg):
2387                e.stride()
2388
2389            e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2390            self.assertEqual(e.stride(), (4, 2))
2391
2392            e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
2393            self.assertEqual(e.stride(), (2, 1))
2394
2395    def test_sizes_slow_path(self):
2396        for use_wrapper_subclass in [True, False]:
2397            data = torch.randn(6, 2)
2398
2399            class SizesNotImplemented(torch.Tensor):
2400                @staticmethod
2401                def __new__(cls, data, wrapper):
2402                    return TestPythonDispatch.subclass_helper(
2403                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2404                    )
2405
2406                @classmethod
2407                def __torch_dispatch__(cls, func, types, args, kwargs):
2408                    if func.overloadpacket == torch.ops.aten.dim:
2409                        return data.dim()
2410                    return NotImplemented
2411
2412            class SizesCustomReturn(torch.Tensor):
2413                @staticmethod
2414                def __new__(cls, data, wrapper):
2415                    return TestPythonDispatch.subclass_helper(
2416                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2417                    )
2418
2419                @classmethod
2420                def __torch_dispatch__(cls, func, types, args, kwargs):
2421                    if func.overloadpacket == torch.ops.aten.dim:
2422                        return data.dim()
2423                    if func.overloadpacket == torch.ops.aten.sym_size:
2424                        return (5, 3)
2425                    return NotImplemented
2426
2427            class SizesDefaultReturn(torch.Tensor):
2428                @staticmethod
2429                def __new__(cls, data, wrapper):
2430                    return TestPythonDispatch.subclass_helper(
2431                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2432                    )
2433
2434                @classmethod
2435                def __torch_dispatch__(cls, func, types, args, kwargs):
2436                    if func.overloadpacket == torch.ops.aten.dim:
2437                        return data.dim()
2438                    if func.overloadpacket == torch.ops.aten.sym_size:
2439                        return None
2440                    return NotImplemented
2441
2442            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
2443            e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2444            with self.assertRaisesRegex(TypeError, err_msg):
2445                e.size()
2446
2447            e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2448            self.assertEqual(e.size(), (5, 3))
2449
2450            e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
2451            self.assertEqual(e.size(), (4, 2))
2452
2453    def test_custom_size_policy_dynamic_shapes(self):
2454        data = torch.randn(6, 2)
2455
2456        class CustomSizeDynamicShapesTensor(torch.Tensor):
2457            @staticmethod
2458            def __new__(cls, inner):
2459                return torch.Tensor._make_wrapper_subclass(
2460                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
2461                    # Calling the overload that has kwargs causes us to go down the first overload path,
2462                    # which will **always** specialize sizes.
2463                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
2464                    cls,
2465                    inner.size(),
2466                    inner.stride(),
2467                    None,
2468                    None,
2469                    inner.dtype,
2470                    inner.layout,
2471                    inner.device,
2472                    False,
2473                    inner.requires_grad,
2474                    "sizes",
2475                )
2476
2477            def __init__(self, inner):
2478                self.inner = inner
2479
2480            @classmethod
2481            def __torch_dispatch__(cls, func, types, args, kwargs):
2482                if func == torch.ops.aten.sym_size.default:
2483                    return args[0].inner.shape
2484                if func == torch.ops.aten.sym_stride.default:
2485                    return args[0].inner.shape
2486                return NotImplemented
2487
2488        x = torch.ones(2, 2)
2489
2490        def trace_fn(x):
2491            x_wrapper = CustomSizeDynamicShapesTensor(x)
2492            return x_wrapper.size(), x_wrapper.stride()
2493
2494        fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
2495        self.assertExpectedInline(
2496            fx_g.code.strip(),
2497            """\
2498def forward(self, x_1):
2499    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
2500    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1);  x_1 = None
2501    return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""",
2502        )
2503
2504    def test_data_ptr_respects_numel_slow_path(self):
2505        data = torch.randn(6, 2)
2506
2507        class NumelDefaultReturn(torch.Tensor):
2508            @staticmethod
2509            def __new__(cls, data, wrapper):
2510                return TestPythonDispatch.subclass_helper(
2511                    cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2512                )
2513
2514            @classmethod
2515            def __torch_dispatch__(cls, func, types, args, kwargs):
2516                if func.overloadpacket == torch.ops.aten.dim:
2517                    return data.dim()
2518                if func.overloadpacket == torch.ops.aten.numel:
2519                    numel_called[0] = True
2520                    return None
2521                return NotImplemented
2522
2523        for use_wrapper_subclass in (False, True):
2524            numel_called = [False]
2525            e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
2526            e.data_ptr()
2527            self.assertTrue(numel_called[0])
2528
2529    def test_layout_slow_path(self):
2530        for use_wrapper_subclass in [True, False]:
2531            data = torch.randn(6, 2)
2532
2533            class LayoutNotImplemented(torch.Tensor):
2534                @staticmethod
2535                def __new__(cls, data, wrapper):
2536                    return TestPythonDispatch.subclass_helper(
2537                        cls, data, wrapper, dispatch_layout=True
2538                    )
2539
2540                @classmethod
2541                def __torch_dispatch__(cls, func, types, args, kwargs):
2542                    return NotImplemented
2543
2544            class LayoutCustomReturn(torch.Tensor):
2545                @staticmethod
2546                def __new__(cls, data, wrapper):
2547                    return TestPythonDispatch.subclass_helper(
2548                        cls, data, wrapper, dispatch_layout=True
2549                    )
2550
2551                @classmethod
2552                def __torch_dispatch__(cls, func, types, args, kwargs):
2553                    if func.overloadpacket == torch.ops.prim.layout:
2554                        return torch.sparse_csr
2555                    return NotImplemented
2556
2557            class LayoutDefaultReturn(torch.Tensor):
2558                @staticmethod
2559                def __new__(cls, data, wrapper):
2560                    return TestPythonDispatch.subclass_helper(
2561                        cls, data, wrapper, dispatch_layout=True
2562                    )
2563
2564                @classmethod
2565                def __torch_dispatch__(cls, func, types, args, kwargs):
2566                    if func.overloadpacket == torch.ops.prim.layout:
2567                        return data.layout
2568                    return NotImplemented
2569
2570            err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
2571            e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2572            with self.assertRaisesRegex(TypeError, err_msg):
2573                e.layout
2574
2575            e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2576            self.assertEqual(e.layout, torch.sparse_csr)
2577
2578            e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
2579            self.assertEqual(e.layout, torch.strided)
2580
2581
2582class TestPythonDispatcher(TestCase):
2583    def test_basic(self):
2584        x = torch.randn(2, requires_grad=True)
2585        r = torch._C._EnablePythonDispatcher()
2586        torch.add(x, x)
2587
2588    def test_lstsq(self):
2589        a = torch.randn(4, 3)
2590        b = torch.rand(4, 3)
2591        expected_shape = torch.linalg.lstsq(a, b).solution.shape
2592        r = torch._C._EnablePythonDispatcher()
2593        python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
2594        self.assertEqual(expected_shape, python_disp_shape)
2595
2596
2597class TestWrapperSubclassAliasing(TestCase):
2598    def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
2599        def to_subclass(t: torch.Tensor):
2600            return TwoTensor(t, t.clone())
2601
2602        result_ref = op(*args, **kwargs)
2603
2604        args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
2605        kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
2606
2607        result_test = op(*args_subclass, **kwargs_subclass)
2608
2609        args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
2610        args_ref_flat_tensors = [
2611            x for x in args_ref_flat if isinstance(x, torch.Tensor)
2612        ]
2613
2614        args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
2615        args_test_flat_tensors = [
2616            x for x in args_test_flat if isinstance(x, torch.Tensor)
2617        ]
2618
2619        result_ref_flat = pytree.tree_leaves(result_ref)
2620        result_ref_flat_tensors = [
2621            x for x in result_ref_flat if isinstance(x, torch.Tensor)
2622        ]
2623
2624        result_test_flat = pytree.tree_leaves(result_test)
2625        result_test_flat_tensors = [
2626            x for x in result_test_flat if isinstance(x, torch.Tensor)
2627        ]
2628
2629        for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
2630            for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
2631                out_is_inpt = o_ref is a_ref
2632                if out_is_inpt:
2633                    self.assertTrue(o_test is a_test)
2634
2635                out_aliases_inpt = StorageWeakRef(
2636                    o_ref.untyped_storage()
2637                ) == StorageWeakRef(a_ref.untyped_storage())
2638                if out_aliases_inpt:
2639                    self.assertTrue(
2640                        StorageWeakRef(o_test.untyped_storage())
2641                        == StorageWeakRef(a_test.untyped_storage())
2642                    )
2643                else:
2644                    self.assertFalse(
2645                        StorageWeakRef(o_test.untyped_storage())
2646                        == StorageWeakRef(a_test.untyped_storage())
2647                    )
2648
2649    # This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
2650    # a util for wrapper subclasses to promise correct aliasing behavior.
2651    # It's probably overkill to test every OpInfo,
2652    # so I picked a sampling of ops with representative schemas.
2653    @ops(
2654        [
2655            op
2656            for op in op_db
2657            if op.name
2658            in [
2659                "mul",  # out-of-place
2660                "cat",  # out-of-place (TensorList input)
2661                "index",  # out-of-place (Optional TensorList input)
2662                "mul_",  # inplace
2663                "view",  # view
2664                "t_",  # inplace-view
2665                "split",  # view (multi-return)
2666                "native_batch_norm",  # mutable op (returns outputs and mutates some inputs)
2667            ]
2668        ],
2669        allowed_dtypes=(torch.float,),
2670    )
2671    def test_wrapper_subclass_aliasing(self, device, dtype, op):
2672        samples = op.sample_inputs(device, dtype)
2673        sample = first_sample(self, samples)
2674        args = (sample.input, *sample.args)
2675        kwargs = sample.kwargs
2676        self._test_wrapper_subclass_aliasing(op, args, kwargs)
2677
2678    @ops(custom_op_db, allowed_dtypes=(torch.float,))
2679    def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
2680        samples = op.sample_inputs(device, dtype)
2681        sample = first_sample(self, samples)
2682        args = (sample.input, *sample.args)
2683        kwargs = sample.kwargs
2684        self._test_wrapper_subclass_aliasing(op, args, kwargs)
2685
2686    def test_wrapper_subclass_aliasing_conv2d(self, device):
2687        args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
2688        kwargs = {}
2689        # conv2d has a default arg 'int[2] strides=0',
2690        # which torchscript expands into 'int[2] strides=[0, 0]'
2691        # Make sure that _return_and_correct_aliasing can handle this case
2692        # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
2693        with torch.inference_mode():
2694            self._test_wrapper_subclass_aliasing(
2695                torch.ops.aten.conv2d.default, args, kwargs
2696            )
2697
2698    def test_wrapper_subclass_aliasing_out_op(self, device):
2699        # Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
2700        args = (torch.ones(4), torch.ones(4))
2701        kwargs = {"out": torch.empty(4)}
2702        self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
2703
2704
2705instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
2706
2707if __name__ == "__main__":
2708    run_tests()
2709