xref: /aosp_15_r20/external/pytorch/test/test_functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: codegen"]
2
3import unittest
4from contextlib import nullcontext
5
6import torch
7from torch._dispatch.python import (
8    enable_crossref_functionalize,
9    enable_python_dispatcher,
10)
11from torch._subclasses.functional_tensor import (
12    dispatch_functionalize,
13    FunctionalTensor,
14    FunctionalTensorMode,
15)
16from torch.fx.experimental.proxy_tensor import make_fx
17from torch.fx.passes.reinplace import reinplace
18from torch.multiprocessing.reductions import StorageWeakRef
19from torch.testing._internal.common_utils import (
20    IS_WINDOWS,
21    run_tests,
22    skipIfTorchDynamo,
23    TEST_WITH_TORCHDYNAMO,
24    TestCase,
25    xfail_inherited_tests,
26)
27from torch.testing._internal.logging_tensor import capture_logs, LoggingTensor
28from torch.utils import _pytree as pytree
29from torch.utils._pytree import tree_map_only
30
31
32def are_aliased(x, y):
33    x_storage = StorageWeakRef(x.storage())
34    y_storage = StorageWeakRef(y.storage())
35    return x_storage == y_storage
36
37
38# We can unify testing and use functionalize() here instead
39# if/when functorch moves into core.
40# This is basically a crappy version of `functionalize()`.
41def _functionalize(
42    f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False
43):
44    def to_fun(t: torch.Tensor):
45        func_t = torch._to_functional_tensor(t)
46        func_t.requires_grad = t.requires_grad
47        return func_t
48
49    def wrapped(*inputs):
50        ctx = nullcontext()
51        if crossref:
52            ctx = enable_crossref_functionalize()
53        with ctx:
54            inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs)
55            torch._enable_functionalization(reapply_views=reapply_views)
56            try:
57                out = f(*inputs_functional)
58            finally:
59                torch._disable_functionalization()
60            flat_inputs = pytree.tree_leaves(inputs)
61            flat_inputs_functional = pytree.tree_leaves(inputs_functional)
62
63            for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
64                torch._sync(input_functional)
65                inpt_new = torch._from_functional_tensor(input_functional)
66                if inpt_new is not inpt and not skip_input_mutations:
67                    # Existing deficiency in functionalize():
68                    # we don't correctly mutate input metadata (yet?)
69                    if inpt_new.shape == inpt.shape:
70                        inpt.copy_(inpt_new)
71            tree_map_only(torch.Tensor, torch._sync, out)
72            out_unwrapped = tree_map_only(
73                torch.Tensor, torch._from_functional_tensor, out
74            )
75            return out_unwrapped
76
77    return wrapped
78
79
80@unittest.skipIf(
81    TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457"
82)
83class TestFunctionalization(TestCase):
84    crossref = False
85
86    def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False):
87        inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts)
88        traced_f = make_fx(
89            _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)
90        )(*inpts)
91        if run_reinplace:
92            traced_f = reinplace(traced_f, *inpts_clone)
93        return traced_f.code
94
95    def assert_functionalization(
96        self, func, *inpts, reapply_views=False, mutated_input_metadata=False
97    ):
98        clones1 = tree_map_only(torch.Tensor, torch.clone, inpts)
99        clones2 = tree_map_only(torch.Tensor, torch.clone, inpts)
100        clones3 = tree_map_only(torch.Tensor, torch.clone, inpts)
101
102        # Compare outputs (and mutated inputs), with and without functionalization.
103        out_ref = func(*inpts)
104        out_functional = _functionalize(
105            func, reapply_views=reapply_views, crossref=self.crossref
106        )(*clones1)
107
108        # The reinplacing pass is only valid to run with reapply_views=True.
109        functional_func = make_fx(
110            _functionalize(func, reapply_views=True, crossref=self.crossref)
111        )(*clones2)
112        reinplace_func = reinplace(functional_func, *clones2)
113
114        # NOTE: for now, need to pass in fresh inputs here, because make_fx
115        # will directly mutate the inputs that you trace with.
116        # Once this is fixed we can clean this up.
117        out_reinplace = reinplace_func(*clones3)
118
119        # functionalize() deficiency: input metadata mutations aren't propagated properly,
120        # so we just need to skip checks here for the tests that exercise that.
121        if not mutated_input_metadata:
122            flat_inpts = pytree.tree_leaves(inpts)
123            flat_clones1 = pytree.tree_leaves(clones1)
124            flat_clones3 = pytree.tree_leaves(clones3)
125            for inpt, input_clone, input_clone3 in zip(
126                flat_inpts, flat_clones1, flat_clones3
127            ):
128                self.assertEqual(
129                    inpt, input_clone
130                )  # input mutations should still occur
131                self.assertEqual(inpt, input_clone3)
132
133        # Handle tests with multi-tensor outputs
134        if isinstance(out_ref, tuple):
135            out_refs, out_functionals, out_reinplaces = (
136                list(out_ref),
137                list(out_functional),
138                list(out_reinplace),
139            )
140        else:
141            out_refs, out_functionals, out_reinplaces = (
142                [out_ref],
143                [out_functional],
144                [out_reinplace],
145            )
146
147        for out_ref_, out_functional_, out_reinplace_ in zip(
148            out_refs, out_functionals, out_reinplaces
149        ):
150            self.assertEqual(out_ref_, out_functional_)
151            self.assertEqual(out_ref_, out_reinplace_)
152
153    def test_save_for_backwards_segfault(self):
154        inp = torch._to_functional_tensor(
155            LoggingTensor(torch.randn(2, 2))
156        ).requires_grad_(True)
157        inp.exp()
158
159    def test_multiple_views_of_same_base(self):
160        def f(x):
161            y = x.view(-1)
162            z = x.view(-1)
163            x.add_(1)
164            # y should have been updated.
165            y2 = y + 1
166            # z should have been updated too.
167            z2 = z + 1
168            return z2
169
170        self.assert_functionalization(f, torch.ones(4))
171
172    def test_freeze(self):
173        def f(x):
174            y = x.clone()
175            z = y[0]
176            torch._freeze_functional_tensor(y)
177            x.add_(1)
178            self.assertRaises(RuntimeError, lambda: y.add_(1))
179            self.assertRaises(RuntimeError, lambda: z.add_(1))
180            return z
181
182        _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3))
183
184    def test_copy_stride_mismatch(self):
185        def f(x):
186            y = torch.empty_strided((2, 2), (5, 1))
187            y.copy_(x)
188            return y
189
190        r = _functionalize(f, reapply_views=True, crossref=self.crossref)(
191            torch.ones(2, 2)
192        )
193        self.assertEqual(r.stride(), (5, 1))
194
195    def test_set_(self):
196        def f(x):
197            y = torch.ones(2)
198            y.set_(x.storage())
199            return y
200
201        # We should probaby get the crossref test to work,
202        # but fixing it for Storage() objects is annoying.
203        r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
204        self.assertEqual(str(r.device), "cpu")
205
206    def test_advanced_indexing(self):
207        def f():
208            x = torch.zeros(3, 3)
209            idx = torch.tensor([0])
210            val = torch.ones(3, 1)
211            x[:, idx] = val
212            return x
213
214        self.assert_functionalization(f)
215
216    def test_view_clone_view_inplace(self):
217        def f(input):
218            shape = [1, 1024, 128, 128]
219            input_reshaped = input.view(shape)
220            out = input_reshaped.clone()
221            r = out.view(input.shape)
222            r.relu_()
223            return r
224
225        def g(x):
226            loss = f(x).sum()
227            import torch.fx.traceback as fx_traceback
228            from torch._functorch.aot_autograd import (
229                setup_stacktrace_preservation_hooks,
230            )
231
232            setup_stacktrace_preservation_hooks([loss.grad_fn])
233            with fx_traceback.preserve_node_meta():
234                loss.backward()
235            return x.grad
236
237        with torch.autograd.detect_anomaly(check_nan=False):
238            logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
239        self.assertExpectedInline(
240            logs,
241            """\
242
243
244
245def forward(self, arg0_1):
246    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]);  arg0_1 = None
247    clone = torch.ops.aten.clone.default(view_copy);  view_copy = None
248    view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
249    relu = torch.ops.aten.relu.default(view_copy_1);  view_copy_1 = None
250    view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]);  relu = None
251    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]);  view_copy_2 = None
252    view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]);  clone = view_copy_4 = None
253    sum_1 = torch.ops.aten.sum.default(view_copy_3)
254    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format);  sum_1 = None
255    expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]);  ones_like = None
256    view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]);  expand_copy = None
257    new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
258    copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5);  new_empty_strided = view_copy_5 = None
259    view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]);  view_copy_6 = None
260    view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
261    clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
262    threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0);  clone_1 = view_copy_3 = None
263    copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward);  view_copy_7 = threshold_backward = None
264    view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]);  copy_1 = None
265    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]);  view_copy_9 = None
266    view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]);  copy = None
267    detach_copy = torch.ops.aten.detach_copy.default(view_copy_10);  view_copy_10 = detach_copy = None
268    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]);  view_copy_8 = None
269    detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11);  view_copy_11 = None
270    return detach_copy_1
271    """,
272        )  # noqa: B950
273
274    def test_simple(self):
275        def f(x):
276            # simple test: 1 view op, 1 inplace op
277            tmp = torch.ones(4, 2)
278            y = x.view(4, 2)
279            y.add_(tmp)
280            z = x * x
281            return y
282
283        self.assert_functionalization(f, torch.ones(4, 2))
284        logs = self.get_logs(f, torch.ones(4, 2))
285        self.assertExpectedInline(
286            logs,
287            """\
288
289
290
291def forward(self, arg0_1):
292    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
293    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
294    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
295    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
296    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
297    mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1);  mul = None
298    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
299    return view_copy_2
300    """,
301        )
302
303        reinplaced_logs = self.get_logs(
304            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
305        )
306        self.assertExpectedInline(
307            reinplaced_logs,
308            """\
309
310
311
312def forward(self, arg0_1):
313    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
314    view = torch.ops.aten.view.default(arg0_1, [4, 2])
315    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
316    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
317    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
318    mul = torch.ops.aten.mul.Tensor(view_1, view_1);  mul = None
319    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = copy_ = None
320    return view_2
321    """,
322        )
323
324    def test_simple_out(self):
325        def f(x):
326            tmp = torch.ones(4, 2)
327            y = x.view(4, 2)
328            # the out= tensor will get resized, since it has size=0 to start.
329            z = torch.empty(())
330            torch.add(y, tmp, out=z)
331            w = z * z
332            return w
333
334        self.assert_functionalization(f, torch.ones(4, 2))
335        logs = self.get_logs(f, torch.ones(4, 2))
336        self.assertExpectedInline(
337            logs,
338            """\
339
340
341
342def forward(self, arg0_1):
343    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
344    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
345    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False);  empty = None
346    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
347    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
348    return mul
349    """,
350        )
351
352        reinplaced_logs = self.get_logs(
353            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
354        )
355        self.assertExpectedInline(
356            reinplaced_logs,
357            """\
358
359
360
361def forward(self, arg0_1):
362    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
363    view = torch.ops.aten.view.default(arg0_1, [4, 2]);  arg0_1 = None
364    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False);  empty = None
365    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
366    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
367    return mul
368    """,
369        )
370
371    def test_multi_out(self):
372        def f(x):
373            # aminmax.out returns a tuple of tensors.
374            # functionalization should properly handle the tuple.
375            out_min = torch.empty(4)
376            out_max = torch.empty(4)
377            torch.aminmax(x, dim=0, out=(out_max, out_min))
378            return out_max
379
380        self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
381        logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
382        self.assertExpectedInline(
383            logs,
384            """\
385
386
387
388def forward(self, arg0_1):
389    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty = None
390    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty_1 = None
391    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
392    getitem = aminmax[0]
393    getitem_1 = aminmax[1];  aminmax = getitem_1 = None
394    return getitem
395    """,
396        )
397
398        reinplaced_logs = self.get_logs(
399            f,
400            torch.arange(8, dtype=torch.float32),
401            reapply_views=True,
402            run_reinplace=True,
403        )
404        self.assertExpectedInline(
405            reinplaced_logs,
406            """\
407
408
409
410def forward(self, arg0_1):
411    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty = None
412    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty_1 = None
413    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
414    getitem = aminmax[0]
415    getitem_1 = aminmax[1];  aminmax = getitem_1 = None
416    return getitem
417    """,
418        )
419
420    def test_tensor_ctr(self):
421        def f(x):
422            y = torch.tensor((1, 2, 3))
423            z = y.view(-1)
424            z.add_(1)
425            return y
426
427        inpt = torch.arange(3, dtype=torch.float32)
428        self.assert_functionalization(f, inpt)
429
430        logs = self.get_logs(f, inpt)
431        self.assertExpectedInline(
432            logs,
433            """\
434
435
436
437def forward(self, arg0_1):
438    _tensor_constant0 = self._tensor_constant0
439    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
440    view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
441    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
442    view_copy_1 = torch.ops.aten.view_copy.default(add, [3]);  add = None
443    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]);  view_copy_2 = None
444    return view_copy_1
445    """,
446        )
447
448        reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
449        self.assertExpectedInline(
450            reinplaced_logs,
451            """\
452
453
454
455def forward(self, arg0_1):
456    _tensor_constant0 = self._tensor_constant0
457    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
458    view = torch.ops.aten.view.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
459    add = torch.ops.aten.add_.Tensor(view, 1);  add = None
460    view_1 = torch.ops.aten.view.default(view, [3]);  view = None
461    view_2 = torch.ops.aten.view.default(view_1, [-1]);  view_2 = None
462    return view_1
463    """,
464        )
465
466    def test_advanced_indexing_correct_strides(self):
467        def f(a):
468            # This test requires that *_scatter ops are able to return
469            # non-contiguous tensors.
470            b = a.clone()[:, 1]
471            c = torch.ones_like(b, dtype=torch.bool)
472            d = b.masked_fill_(c, 0)
473            return d
474
475        self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
476
477    def test_tensor_list_mixed_functional_nonfunctional(self):
478        nonfunctional_tensor = torch.ones(2, dtype=torch.long)
479
480        def f(x):
481            # simple test: 1 view op, 1 inplace op
482            functional_tensor = torch.ones(2, dtype=torch.long)
483            out = x[functional_tensor, nonfunctional_tensor]
484            return out
485
486        out = f(torch.ones(2, 2))
487        out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(
488            torch.ones(2, 2)
489        )
490        self.assertEqual(out, out_functional)
491
492    def test_inplace_on_non_view(self):
493        def f(x):
494            # test for the case where we functionalize an inplace op on the other tensor - not a view.
495            # This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
496            tmp = torch.ones(4, 2)
497            y = x.view(4, 2)
498            x.add_(tmp)
499            return y
500
501        self.assert_functionalization(f, torch.ones(4, 2))
502        logs = self.get_logs(f, torch.ones(4, 2))
503        self.assertExpectedInline(
504            logs,
505            """\
506
507
508
509def forward(self, arg0_1):
510    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
511    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  view_copy = None
512    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
513    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = copy_ = None
514    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
515    return view_copy_1
516    """,
517        )
518
519        reinplaced_logs = self.get_logs(
520            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
521        )
522        self.assertExpectedInline(
523            reinplaced_logs,
524            """\
525
526
527
528def forward(self, arg0_1):
529    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
530    view = torch.ops.aten.view.default(arg0_1, [4, 2]);  view = None
531    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
532    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = copy_ = None
533    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
534    return view_1
535    """,
536        )
537
538    # Some ops that are mutable are neither inplace nor out= ops.
539    # They also need special handling.
540    def test_mutable_op_not_inplace_or_other(self):
541        def f(x):
542            return torch._fused_moving_avg_obs_fq_helper(
543                x, x, x, x, x, x, x, 1.0, 0, 1, 0
544            )
545
546        logs = self.get_logs(f, torch.ones(1))
547        self.assertExpectedInline(
548            logs,
549            """\
550
551
552
553def forward(self, arg0_1):
554    _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
555    getitem = _fused_moving_avg_obs_fq_helper_functional[0]
556    getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
557    getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2];  getitem_2 = None
558    getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3];  getitem_3 = None
559    getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4];  getitem_4 = None
560    getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5];  _fused_moving_avg_obs_fq_helper_functional = None
561    copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5);  arg0_1 = getitem_5 = copy_ = None
562    return (getitem, getitem_1)
563    """,  # noqa: B950
564        )
565
566    def test_as_strided(self):
567        def f(x):
568            y = x.as_strided((2,), (2,), 1)
569            y.add_(1)
570            return x
571
572        self.assert_functionalization(f, torch.ones(9))
573        logs = self.get_logs(f, torch.ones(9))
574        self.assertExpectedInline(
575            logs,
576            """\
577
578
579
580def forward(self, arg0_1):
581    as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
582    add = torch.ops.aten.add.Tensor(as_strided_copy, 1);  as_strided_copy = None
583    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
584    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1);  as_strided_copy_1 = None
585    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = copy_ = None
586    return as_strided_scatter
587    """,
588        )
589
590        # NB: even with reapply_views=True, we expect to see scatter op
591        reinplaced_logs = self.get_logs(
592            f, torch.ones(2, 2), reapply_views=True, run_reinplace=False
593        )
594        self.assertExpectedInline(
595            reinplaced_logs,
596            """\
597
598
599
600def forward(self, arg0_1):
601    as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1)
602    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
603    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
604    as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1);  as_strided_1 = None
605    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = copy_ = None
606    return as_strided_scatter
607    """,
608        )
609
610    def test_tensor_list_composite(self):
611        def f(x):
612            # Test an op with TensorList input
613            y = torch.block_diag(x, x)
614            return y
615
616        self.assert_functionalization(f, torch.ones(2, 2))
617        logs = self.get_logs(f, torch.ones(2, 2))
618        self.assertExpectedInline(
619            logs,
620            """\
621
622
623
624def forward(self, arg0_1):
625    block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]);  arg0_1 = None
626    return block_diag
627    """,
628        )
629
630    def test_cat(self):
631        def f(x):
632            out = torch.empty(0)
633            torch.cat((x,), out=out)
634            return out
635
636        self.assert_functionalization(f, torch.ones(2, 2))
637        logs = self.get_logs(f, torch.ones(2, 2))
638        self.assertExpectedInline(
639            logs,
640            """\
641
642
643
644def forward(self, arg0_1):
645    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False);  empty = None
646    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
647    return cat
648    """,
649        )
650
651        reinplaced_logs = self.get_logs(
652            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
653        )
654        self.assertExpectedInline(
655            reinplaced_logs,
656            """\
657
658
659
660def forward(self, arg0_1):
661    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False);  empty = None
662    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
663    return cat
664    """,
665        )
666
667    def test_diagonal(self):
668        def f(x):
669            # test: view ops that take a subset of the original tensor (select/diagonal)
670            tmp = torch.ones(2)
671            y = x.clone().diagonal()
672            y.add_(tmp)
673            z = x * x
674            return z
675
676        self.assert_functionalization(f, torch.ones(2, 2))
677        logs = self.get_logs(f, torch.ones(2, 2))
678        self.assertExpectedInline(
679            logs,
680            """\
681
682
683
684def forward(self, arg0_1):
685    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
686    clone = torch.ops.aten.clone.default(arg0_1)
687    diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
688    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
689    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add);  clone = add = None
690    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_scatter = diagonal_copy_1 = None
691    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
692    return mul
693    """,
694        )
695
696        reinplaced_logs = self.get_logs(
697            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
698        )
699        self.assertExpectedInline(
700            reinplaced_logs,
701            """\
702
703
704
705def forward(self, arg0_1):
706    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
707    clone = torch.ops.aten.clone.default(arg0_1)
708    diagonal = torch.ops.aten.diagonal.default(clone)
709    add = torch.ops.aten.add_.Tensor(diagonal, ones);  diagonal = ones = add = None
710    diagonal_1 = torch.ops.aten.diagonal.default(clone);  clone = diagonal_1 = None
711    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
712    return mul
713    """,
714        )
715
716    def test_diagonal_mutated_input(self):
717        def f(x):
718            # simple test: there are pending updates afterwards, which the test syncs manually
719            tmp = torch.ones(2)
720            y = x.diagonal()
721            y.add_(tmp)
722            return x
723
724        x = torch.ones(2, 2)
725        self.assert_functionalization(f, x)
726        logs = self.get_logs(f, torch.ones(2, 2))
727        self.assertExpectedInline(
728            logs,
729            """\
730
731
732
733def forward(self, arg0_1):
734    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
735    diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
736    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
737    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
738    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_copy_1 = None
739    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = copy_ = None
740    return diagonal_scatter
741    """,
742        )
743
744        # NB: even with reapply_views=True, we expect to see scatter op
745        reinplaced_logs = self.get_logs(
746            f, torch.ones(2, 2), reapply_views=True, run_reinplace=False
747        )
748        self.assertExpectedInline(
749            reinplaced_logs,
750            """\
751
752
753
754def forward(self, arg0_1):
755    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
756    diagonal = torch.ops.aten.diagonal.default(arg0_1)
757    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
758    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
759    diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter);  diagonal_1 = None
760    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = copy_ = None
761    return diagonal_scatter
762    """,
763        )
764
765    def test_channels_last_contiguous(self):
766        def f(x):
767            return x.contiguous(memory_format=torch.channels_last)
768            tmp = torch.ones(2)
769            y = x.diagonal()
770            y.add_(tmp)
771            return x
772
773        x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
774        self.assert_functionalization(f, x)
775        logs = self.get_logs(f, x).strip()
776        # There should be no clone in the graph
777        self.assertExpectedInline(
778            logs,
779            """\
780def forward(self, arg0_1):
781    return arg0_1""",
782        )
783
784    def test_split(self):
785        def f(x):
786            # test: view ops that return multiple tensors (split)
787            tmp = torch.ones(2)
788            y1, y2 = x.split(2)
789            y3 = y2.diagonal()
790            y3.add_(tmp)
791            z = x * x
792            return y3
793
794        self.assert_functionalization(f, torch.ones(4, 2))
795        logs = self.get_logs(f, torch.ones(4, 2))
796        self.assertExpectedInline(
797            logs,
798            """\
799
800
801
802def forward(self, arg0_1):
803    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
804    split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
805    getitem = split_copy[0];  getitem = None
806    getitem_1 = split_copy[1];  split_copy = None
807    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1);  getitem_1 = None
808    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
809    split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
810    getitem_2 = split_copy_1[0];  getitem_2 = None
811    getitem_3 = split_copy_1[1];  split_copy_1 = None
812    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
813    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
814    split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
815    getitem_4 = split_copy_2[0];  getitem_4 = None
816    getitem_5 = split_copy_2[1];  split_copy_2 = None
817    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5);  getitem_5 = None
818    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
819    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
820    return diagonal_copy_1
821    """,
822        )  # noqa: B950
823
824        # NB: even with reapply_views=True, we expect to see scatter op
825        reinplaced_logs = self.get_logs(
826            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
827        )
828        self.assertExpectedInline(
829            reinplaced_logs,
830            """\
831
832
833
834def forward(self, arg0_1):
835    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
836    split = torch.ops.aten.split.Tensor(arg0_1, 2)
837    getitem = split[0];  getitem = None
838    getitem_1 = split[1];  split = None
839    diagonal = torch.ops.aten.diagonal.default(getitem_1);  getitem_1 = None
840    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
841    split_1 = torch.ops.aten.split.Tensor(arg0_1, 2)
842    getitem_2 = split_1[0];  getitem_2 = None
843    getitem_3 = split_1[1];  split_1 = None
844    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
845    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
846    split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2)
847    getitem_4 = split_2[0];  getitem_4 = None
848    getitem_5 = split_2[1];  split_2 = None
849    diagonal_1 = torch.ops.aten.diagonal.default(getitem_5);  getitem_5 = None
850    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
851    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
852    return diagonal_1
853    """,
854        )  # noqa: B950
855
856    def test_split_with_sizes(self):
857        def f(x):
858            # test: view ops that return multiple tensors (split_with_sizes)
859            tmp = torch.ones(2)
860            y1, y2 = x.split_with_sizes([2, 2])
861            y3 = y1.diagonal()
862            y3.add_(tmp)
863            z = x * x
864            return y3
865
866        self.assert_functionalization(f, torch.ones(4, 2))
867        logs = self.get_logs(f, torch.ones(4, 2))
868        self.assertExpectedInline(
869            logs,
870            """\
871
872
873
874def forward(self, arg0_1):
875    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
876    split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
877    getitem = split_with_sizes_copy[0]
878    getitem_1 = split_with_sizes_copy[1];  split_with_sizes_copy = getitem_1 = None
879    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem);  getitem = None
880    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
881    split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
882    getitem_2 = split_with_sizes_copy_1[0]
883    getitem_3 = split_with_sizes_copy_1[1];  split_with_sizes_copy_1 = getitem_3 = None
884    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
885    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
886    split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2])
887    getitem_4 = split_with_sizes_copy_2[0]
888    getitem_5 = split_with_sizes_copy_2[1];  split_with_sizes_copy_2 = getitem_5 = None
889    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4);  getitem_4 = None
890    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
891    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
892    return diagonal_copy_1
893    """,
894        )  # noqa: B950
895
896        # NB: even with reapply_views=True, we expect to see scatter op
897        reinplaced_logs = self.get_logs(
898            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
899        )
900        self.assertExpectedInline(
901            reinplaced_logs,
902            """\
903
904
905
906def forward(self, arg0_1):
907    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
908    split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
909    getitem = split_with_sizes[0]
910    getitem_1 = split_with_sizes[1];  split_with_sizes = getitem_1 = None
911    diagonal = torch.ops.aten.diagonal.default(getitem);  getitem = None
912    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
913    split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
914    getitem_2 = split_with_sizes_1[0]
915    getitem_3 = split_with_sizes_1[1];  split_with_sizes_1 = getitem_3 = None
916    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
917    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
918    split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2])
919    getitem_4 = split_with_sizes_2[0]
920    getitem_5 = split_with_sizes_2[1];  split_with_sizes_2 = getitem_5 = None
921    diagonal_1 = torch.ops.aten.diagonal.default(getitem_4);  getitem_4 = None
922    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
923    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
924    return diagonal_1
925    """,
926        )  # noqa: B950
927
928    def test_slice(self):
929        def f(x):
930            tmp = torch.ones(4)
931            x.transpose_(1, 0)
932            y = x[0:2]
933            y.add_(tmp)
934            return x
935
936        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
937        logs = self.get_logs(f, torch.ones(4, 2))
938        self.assertExpectedInline(
939            logs,
940            """\
941
942
943
944def forward(self, arg0_1):
945    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
946    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
947    slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2);  transpose_copy = None
948    add = torch.ops.aten.add.Tensor(slice_copy, ones);  slice_copy = ones = None
949    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
950    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2);  transpose_copy_1 = add = None
951    transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0);  slice_scatter = None
952    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
953    slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2);  transpose_copy_3 = slice_copy_1 = None
954    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
955    return transpose_copy_4
956    """,
957        )  # noqa: B950
958
959        # NB: even with reapply_views=True, we expect to see scatter op
960        reinplaced_logs = self.get_logs(
961            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
962        )
963        self.assertExpectedInline(
964            reinplaced_logs,
965            """\
966
967
968
969def forward(self, arg0_1):
970    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
971    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
972    slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2);  transpose = None
973    add = torch.ops.aten.add.Tensor(slice_1, ones);  slice_1 = ones = None
974    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
975    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2);  transpose_1 = add = None
976    transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0);  slice_scatter = None
977    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
978    slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2);  transpose_3 = slice_2 = None
979    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
980    return transpose_4
981    """,
982        )  # noqa: B950
983
984    def test_view_inplace(self):
985        def f(x):
986            # test: view + inplace op (transpose_)
987            tmp = torch.ones(4)
988            x.transpose_(1, 0)
989            y = x[0]
990            y.add_(tmp)
991            return x
992
993        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
994        logs = self.get_logs(f, torch.ones(4, 2))
995        self.assertExpectedInline(
996            logs,
997            """\
998
999
1000
1001def forward(self, arg0_1):
1002    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1003    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
1004    select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0);  transpose_copy = None
1005    add = torch.ops.aten.add.Tensor(select_copy, ones);  select_copy = ones = None
1006    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
1007    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
1008    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
1009    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
1010    select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0);  transpose_copy_3 = select_copy_1 = None
1011    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
1012    return transpose_copy_4
1013    """,
1014        )  # noqa: B950
1015
1016        # NB: even with reapply_views=True, we expect to see scatter op
1017        reinplaced_logs = self.get_logs(
1018            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
1019        )
1020        self.assertExpectedInline(
1021            reinplaced_logs,
1022            """\
1023
1024
1025
1026def forward(self, arg0_1):
1027    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1028    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
1029    select = torch.ops.aten.select.int(transpose, 0, 0);  transpose = None
1030    add = torch.ops.aten.add.Tensor(select, ones);  select = ones = None
1031    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
1032    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
1033    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
1034    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
1035    select_1 = torch.ops.aten.select.int(transpose_3, 0, 0);  transpose_3 = select_1 = None
1036    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
1037    return transpose_4
1038    """,
1039        )  # noqa: B950
1040
1041    def test_unbind(self):
1042        def f(x):
1043            # test: view + inplace op (transpose_)
1044            tmp = torch.ones(4)
1045            x.transpose_(1, 0)
1046            y, _ = x.unbind(0)
1047            y.add_(tmp)
1048            return x
1049
1050        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
1051        logs = self.get_logs(f, torch.ones(4, 2))
1052        self.assertExpectedInline(
1053            logs,
1054            """\
1055
1056
1057
1058def forward(self, arg0_1):
1059    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1060    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
1061    unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy);  transpose_copy = None
1062    getitem = unbind_copy[0]
1063    getitem_1 = unbind_copy[1];  unbind_copy = getitem_1 = None
1064    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1065    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
1066    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
1067    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
1068    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
1069    unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3);  transpose_copy_3 = None
1070    getitem_2 = unbind_copy_1[0];  getitem_2 = None
1071    getitem_3 = unbind_copy_1[1];  unbind_copy_1 = getitem_3 = None
1072    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
1073    return transpose_copy_4
1074    """,
1075        )  # noqa: B950
1076
1077        # NB: even with reapply_views=True, we expect to see scatter op
1078        reinplaced_logs = self.get_logs(
1079            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
1080        )
1081        self.assertExpectedInline(
1082            reinplaced_logs,
1083            """\
1084
1085
1086
1087def forward(self, arg0_1):
1088    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1089    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
1090    unbind = torch.ops.aten.unbind.int(transpose);  transpose = None
1091    getitem = unbind[0]
1092    getitem_1 = unbind[1];  unbind = getitem_1 = None
1093    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1094    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
1095    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
1096    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
1097    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
1098    unbind_1 = torch.ops.aten.unbind.int(transpose_3);  transpose_3 = None
1099    getitem_2 = unbind_1[0];  getitem_2 = None
1100    getitem_3 = unbind_1[1];  unbind_1 = getitem_3 = None
1101    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
1102    return transpose_4
1103    """,
1104        )  # noqa: B950
1105
1106    def test_optional_tensor_list(self):
1107        def f(x):
1108            # test: an operator that takes in a List[Optional[Tensor]] argument
1109            # (index_put)
1110            y = x.view(8)
1111            indices = torch.arange(4)
1112            values = torch.arange(4, dtype=y.dtype)
1113            y.index_put_((indices,), values, accumulate=False)
1114            return y
1115
1116        self.assert_functionalization(f, torch.ones(4, 2))
1117        logs = self.get_logs(f, torch.ones(4, 2))
1118        self.assertExpectedInline(
1119            logs,
1120            """\
1121
1122
1123
1124def forward(self, arg0_1):
1125    view_copy = torch.ops.aten.view_copy.default(arg0_1, [8])
1126    arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
1127    arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
1128    index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1);  view_copy = arange = arange_1 = None
1129    view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]);  index_put = None
1130    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
1131    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
1132    return view_copy_2
1133    """,
1134        )  # noqa: B950
1135
1136    def test_scalars(self):
1137        def f(x):
1138            # test: the pass can handle scalar inputs properly
1139            tmp = torch.ones(4, 2)
1140            y = x.view(4, 2)
1141            y.add_(1)
1142            z = 2 * y
1143            z.div_(1)
1144            return z
1145
1146        self.assert_functionalization(f, torch.ones(4, 2))
1147        logs = self.get_logs(f, torch.ones(4, 2))
1148        self.assertExpectedInline(
1149            logs,
1150            """\
1151
1152
1153
1154def forward(self, arg0_1):
1155    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False);  ones = None
1156    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
1157    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
1158    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
1159    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
1160    mul = torch.ops.aten.mul.Tensor(view_copy_2, 2);  view_copy_2 = None
1161    div = torch.ops.aten.div.Tensor(mul, 1);  mul = None
1162    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
1163    return div
1164    """,
1165        )
1166
1167    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1168    def test_metadata_change(self):
1169        def f(x):
1170            # ops like ge_() are allowed to change the dtype of the input.
1171            # functionalization should pick up on that.
1172            y = x.clone()
1173            out = y.ge_(0)
1174            return out
1175
1176        self.assert_functionalization(f, torch.ones(4, 2))
1177        logs = self.get_logs(f, torch.ones(4, 2))
1178        self.assertExpectedInline(
1179            logs,
1180            """\
1181
1182
1183
1184def forward(self, arg0_1):
1185    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1186    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
1187    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
1188    return _to_copy
1189    """,
1190        )
1191
1192        reinplaced_logs = self.get_logs(
1193            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
1194        )
1195        self.assertExpectedInline(
1196            reinplaced_logs,
1197            """\
1198
1199
1200
1201def forward(self, arg0_1):
1202    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1203    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
1204    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
1205    return _to_copy
1206    """,
1207        )  # noqa: B950
1208
1209    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1210    def test_metadata_change_out_op(self):
1211        def f(t, y):
1212            out_1 = torch.ones(1)
1213            return torch.add(t, y, out=out_1)
1214
1215        inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
1216        inpt1_func, inpt2_func = (
1217            torch._to_functional_tensor(inpt1),
1218            torch._to_functional_tensor(inpt2),
1219        )
1220
1221        out_ref = f(inpt1, inpt2)
1222        torch._enable_functionalization(reapply_views=True)
1223        try:
1224            out_functional = f(inpt1_func, inpt2_func)
1225        finally:
1226            torch._disable_functionalization()
1227        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
1228
1229    def test_only_one_view(self):
1230        def f(x):
1231            # This tests that we don't have any unnecessary views in the trace.
1232            # If the input wasn't mutated, we don't need to regenerate it,
1233            # so there should be a total of 1 op in the output trace.
1234            return x.view(4, 2)
1235
1236        logs = self.get_logs(f, torch.ones(4, 2))
1237        self.assertExpectedInline(
1238            logs,
1239            """\
1240
1241
1242
1243def forward(self, arg0_1):
1244    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
1245    return view_copy
1246    """,
1247        )
1248
1249    def test_everything(self):
1250        def f(x):
1251            # test: everything
1252            tmp = torch.ones(2, 2)
1253            x2 = x + x
1254            y = x2.view(8)
1255            z0 = y.reshape(2, 4)
1256            z1 = z0.transpose(1, 0)
1257            z1.unsqueeze_(0)
1258            z1.squeeze_()
1259            z2, z3 = z1.split(2)
1260            z2.add_(tmp)
1261            z4 = z0[0] + z2.reshape(4)
1262            return z2
1263
1264        self.assert_functionalization(f, torch.ones(4, 2))
1265        logs = self.get_logs(f, torch.ones(4, 2))
1266        self.assertExpectedInline(
1267            logs,
1268            """\
1269
1270
1271
1272def forward(self, arg0_1):
1273    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1274    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1275    view_copy = torch.ops.aten.view_copy.default(add, [8])
1276    view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]);  view_copy = None
1277    transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
1278    unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0);  transpose_copy = None
1279    squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy);  unsqueeze_copy = None
1280    split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2);  squeeze_copy = None
1281    getitem = split_copy[0]
1282    getitem_1 = split_copy[1];  split_copy = getitem_1 = None
1283    add_1 = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1284    view_copy_2 = torch.ops.aten.view_copy.default(add, [8]);  add = None
1285    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]);  view_copy_2 = None
1286    transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0);  view_copy_3 = None
1287    unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0);  transpose_copy_1 = None
1288    squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1);  unsqueeze_copy_1 = None
1289    slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2);  squeeze_copy_1 = add_1 = None
1290    unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0);  slice_scatter = None
1291    squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0);  unsqueeze_copy_2 = None
1292    transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0);  squeeze_copy_2 = None
1293    view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]);  transpose_copy_2 = None
1294    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]);  view_copy_4 = None
1295    view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1296    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]);  view_copy_6 = None
1297    transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0);  view_copy_7 = None
1298    unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0);  transpose_copy_3 = None
1299    squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3);  unsqueeze_copy_3 = None
1300    split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2);  squeeze_copy_3 = None
1301    getitem_2 = split_copy_1[0]
1302    getitem_3 = split_copy_1[1];  split_copy_1 = getitem_3 = None
1303    select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0);  view_copy_1 = select_copy = None
1304    view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]);  view_copy_8 = None
1305    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1306    view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]);  view_copy_9 = None
1307    select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0);  view_copy_10 = None
1308    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]);  view_copy_5 = None
1309    view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]);  view_copy_11 = None
1310    transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0);  view_copy_12 = None
1311    unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0);  transpose_copy_4 = None
1312    squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4);  unsqueeze_copy_4 = None
1313    split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2);  squeeze_copy_4 = None
1314    getitem_4 = split_copy_2[0]
1315    getitem_5 = split_copy_2[1];  split_copy_2 = getitem_5 = None
1316    view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]);  getitem_4 = None
1317    add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13);  select_copy_1 = view_copy_13 = add_2 = None
1318    return getitem_2
1319    """,
1320        )  # noqa: B950
1321
1322        reinplaced_logs = self.get_logs(
1323            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
1324        )
1325        self.assertExpectedInline(
1326            reinplaced_logs,
1327            """\
1328
1329
1330
1331def forward(self, arg0_1):
1332    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1333    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1334    view = torch.ops.aten.view.default(add, [8])
1335    view_1 = torch.ops.aten.view.default(view, [2, 4]);  view = None
1336    transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
1337    unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0);  transpose = None
1338    squeeze = torch.ops.aten.squeeze.default(unsqueeze);  unsqueeze = None
1339    split = torch.ops.aten.split.Tensor(squeeze, 2);  squeeze = None
1340    getitem = split[0]
1341    getitem_1 = split[1];  split = getitem_1 = None
1342    add_1 = torch.ops.aten.add_.Tensor(getitem, ones);  getitem = ones = add_1 = None
1343    view_2 = torch.ops.aten.view.default(add, [8]);  add = None
1344    view_3 = torch.ops.aten.view.default(view_2, [2, 4]);  view_2 = None
1345    transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0);  view_3 = None
1346    unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0);  transpose_1 = None
1347    squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1);  unsqueeze_1 = None
1348    unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0);  squeeze_1 = None
1349    squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0);  unsqueeze_2 = None
1350    transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0);  squeeze_2 = None
1351    view_4 = torch.ops.aten.view.default(transpose_2, [8]);  transpose_2 = None
1352    view_5 = torch.ops.aten.view.default(view_4, [4, 2]);  view_4 = None
1353    view_6 = torch.ops.aten.view.default(view_5, [8])
1354    view_7 = torch.ops.aten.view.default(view_6, [2, 4]);  view_6 = None
1355    transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0);  view_7 = None
1356    unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0);  transpose_3 = None
1357    squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3);  unsqueeze_3 = None
1358    split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2);  squeeze_3 = None
1359    getitem_2 = split_1[0]
1360    getitem_3 = split_1[1];  split_1 = getitem_3 = None
1361    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = select = None
1362    clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
1363    _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]);  clone = None
1364    view_8 = torch.ops.aten.view.default(view_5, [8]);  view_5 = None
1365    view_9 = torch.ops.aten.view.default(view_8, [2, 4]);  view_8 = None
1366    select_1 = torch.ops.aten.select.int(view_9, 0, 0);  view_9 = None
1367    add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view);  select_1 = _unsafe_view = add_2 = None
1368    return getitem_2
1369    """,
1370        )
1371
1372    def test_reapply_views_simple(self):
1373        def f(x):
1374            tmp = torch.ones(4, 2)
1375            y = x.view(4, 2)
1376            y.add_(tmp)
1377            z = x * x
1378            return y
1379
1380        self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
1381        logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
1382        self.assertExpectedInline(
1383            logs,
1384            """\
1385
1386
1387
1388def forward(self, arg0_1):
1389    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
1390    view = torch.ops.aten.view.default(arg0_1, [4, 2])
1391    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
1392    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
1393    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
1394    mul = torch.ops.aten.mul.Tensor(view_1, view_1);  mul = None
1395    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = copy_ = None
1396    return view_2
1397    """,
1398        )
1399
1400    def test_aliases_maintained_after_pass_when_reapplying_views(self):
1401        def f(x):
1402            tmp = torch.ones(4, 2)
1403            y = x.view(4, 2)
1404            z = x.view(4, 2)
1405            y.add_(tmp)
1406            return y, z
1407
1408        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
1409        torch._enable_functionalization(reapply_views=True)
1410        try:
1411            y, z = f(input_functional)
1412            torch._sync(y)
1413            torch._sync(z)
1414        finally:
1415            torch._disable_functionalization()
1416
1417        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
1418        _y = torch._from_functional_tensor(y)
1419        _z = torch._from_functional_tensor(z)
1420        self.assertTrue(are_aliased(_y, _z))
1421
1422    # copy_() gets its own test, because it used to be special cased in functionalization.
1423    # However, now it works pretty similar to other functional ops
1424    def test_copy_(self):
1425        def f(x):
1426            tmp = torch.zeros(2, 2)
1427            tmp_slice = tmp.diagonal()
1428            y = tmp_slice.copy_(x)
1429            z = y.add_(x)
1430            return z
1431
1432        # Test 1: copy_() with same dtype and shape
1433        # to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
1434        # self.assert_functionalization(f, torch.ones(2))
1435        logs = self.get_logs(f, torch.ones(2))
1436        self.assertExpectedInline(
1437            logs,
1438            """\
1439
1440
1441
1442def forward(self, arg0_1):
1443    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1444    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1445    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1446    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1447    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1448    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1449    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1450    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1451    return diagonal_copy_2
1452    """,
1453        )
1454
1455        reinplaced_logs = self.get_logs(
1456            f, torch.ones(2), reapply_views=True, run_reinplace=True
1457        )
1458        self.assertExpectedInline(
1459            reinplaced_logs,
1460            """\
1461
1462
1463
1464def forward(self, arg0_1):
1465    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1466    diagonal = torch.ops.aten.diagonal.default(zeros)
1467    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1468    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1469    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1470    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1471    return diagonal_2
1472    """,
1473        )
1474
1475        # Test 2: copy_() with same dtype, different shape
1476        self.assert_functionalization(f, torch.ones(1))
1477        logs = self.get_logs(f, torch.ones(1))
1478        self.assertExpectedInline(
1479            logs,
1480            """\
1481
1482
1483
1484def forward(self, arg0_1):
1485    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1486    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1487    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1488    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1489    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1490    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1491    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1492    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1493    return diagonal_copy_2
1494    """,
1495        )
1496
1497        reinplaced_logs = self.get_logs(
1498            f, torch.ones(1), reapply_views=True, run_reinplace=True
1499        )
1500        self.assertExpectedInline(
1501            reinplaced_logs,
1502            """\
1503
1504
1505
1506def forward(self, arg0_1):
1507    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1508    diagonal = torch.ops.aten.diagonal.default(zeros)
1509    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1510    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1511    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1512    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1513    return diagonal_2
1514    """,
1515        )
1516
1517        # Test 3: copy_() with different dtype, same shape
1518        self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
1519        logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
1520        self.assertExpectedInline(
1521            logs,
1522            """\
1523
1524
1525
1526def forward(self, arg0_1):
1527    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1528    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1529    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1530    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1531    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1532    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1533    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1534    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1535    return diagonal_copy_2
1536    """,
1537        )  # noqa: B950
1538
1539        reinplaced_logs = self.get_logs(
1540            f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True
1541        )
1542        self.assertExpectedInline(
1543            reinplaced_logs,
1544            """\
1545
1546
1547
1548def forward(self, arg0_1):
1549    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1550    diagonal = torch.ops.aten.diagonal.default(zeros)
1551    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1552    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1553    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1554    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1555    return diagonal_2
1556    """,
1557        )  # noqa: B950
1558
1559        # Test 4: copy_() with different dtype, different shape
1560        self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
1561        logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
1562        self.assertExpectedInline(
1563            logs,
1564            """\
1565
1566
1567
1568def forward(self, arg0_1):
1569    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1570    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1571    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1572    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1573    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1574    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1575    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1576    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1577    return diagonal_copy_2
1578    """,
1579        )  # noqa: B950
1580
1581        reinplaced_logs = self.get_logs(
1582            f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True
1583        )
1584        self.assertExpectedInline(
1585            reinplaced_logs,
1586            """\
1587
1588
1589
1590def forward(self, arg0_1):
1591    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1592    diagonal = torch.ops.aten.diagonal.default(zeros)
1593    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1594    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1595    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1596    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1597    return diagonal_2
1598    """,
1599        )  # noqa: B950
1600
1601    def test_expand_symint(self):
1602        # Once some existing SymInt bugs are ironed out, we should update
1603        # this test to plumb FakeSymbolicTensors through it
1604        def f(x):
1605            return x.expand(x.size(0), x.size(1))
1606
1607        self.assert_functionalization(f, torch.ones(2, 2))
1608        logs = self.get_logs(f, torch.ones(2, 2))
1609        self.assertExpectedInline(
1610            logs,
1611            """\
1612
1613
1614
1615def forward(self, arg0_1):
1616    expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]);  arg0_1 = None
1617    return expand_copy
1618    """,
1619        )
1620
1621    def test_fill_(self):
1622        def f(x):
1623            y = x + x
1624            z = y.diagonal()
1625            z.fill_(0)
1626            return y
1627
1628        self.assert_functionalization(f, torch.ones(2, 2))
1629        logs = self.get_logs(f, torch.ones(2, 2))
1630        self.assertExpectedInline(
1631            logs,
1632            """\
1633
1634
1635
1636def forward(self, arg0_1):
1637    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1638    diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
1639    fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0);  diagonal_copy = None
1640    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill);  add = fill = None
1641    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_copy_1 = None
1642    return diagonal_scatter
1643    """,
1644        )
1645
1646        reinplaced_logs = self.get_logs(
1647            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
1648        )
1649        self.assertExpectedInline(
1650            reinplaced_logs,
1651            """\
1652
1653
1654
1655def forward(self, arg0_1):
1656    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1657    diagonal = torch.ops.aten.diagonal.default(add)
1658    fill = torch.ops.aten.fill_.Scalar(diagonal, 0);  diagonal = fill = None
1659    diagonal_1 = torch.ops.aten.diagonal.default(add);  diagonal_1 = None
1660    return add
1661    """,
1662        )
1663
1664    def test_resize_smaller(self):
1665        def f(w):
1666            # Resizing to a smaller size doesn't affect storage
1667            x = w + 1
1668            y = x.view(4, 4)
1669            y.resize_(3, 3)
1670            y2 = y.view(-1)
1671            y2.add_(1)
1672            z = y + 1
1673            return z
1674
1675        self.assert_functionalization(f, torch.ones(8, 2))
1676        logs = self.get_logs(f, torch.ones(8, 2))
1677        self.assertExpectedInline(
1678            logs,
1679            """\
1680
1681
1682
1683def forward(self, arg0_1):
1684    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1685    view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
1686    resize = torch.ops.aten.resize.default(view_copy, [3, 3]);  resize = None
1687    as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]);  view_copy = None
1688    view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]);  as_strided_copy = None
1689    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1);  view_copy_1 = None
1690    view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]);  add = None
1691    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]);  as_strided_copy_1 = None
1692    view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]);  add_1 = None
1693    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]);  view_copy_2 = view_copy_3 = None
1694    view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]);  as_strided_scatter = None
1695    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
1696    as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]);  view_copy_5 = None
1697    view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]);  as_strided_copy_2 = view_copy_6 = None
1698    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]);  view_copy_4 = None
1699    as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]);  view_copy_7 = None
1700    add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1);  as_strided_copy_3 = None
1701    return add_2
1702    """,  # noqa: B950
1703        )
1704
1705        reinplaced_logs = self.get_logs(
1706            f, torch.ones(8, 2), reapply_views=True, run_reinplace=True
1707        )
1708        self.assertExpectedInline(
1709            reinplaced_logs,
1710            """\
1711
1712
1713
1714def forward(self, arg0_1):
1715    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1716    view = torch.ops.aten.view.default(add, [4, 4])
1717    resize = torch.ops.aten.resize.default(view, [3, 3]);  resize = None
1718    as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]);  view = None
1719    view_1 = torch.ops.aten.view.default(as_strided, [-1]);  as_strided = None
1720    add_1 = torch.ops.aten.add_.Tensor(view_1, 1);  add_1 = None
1721    view_2 = torch.ops.aten.view.default(add, [4, 4]);  add = None
1722    as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]);  as_strided_1 = None
1723    view_3 = torch.ops.aten.view.default(view_1, [3, 3]);  view_1 = view_3 = None
1724    view_4 = torch.ops.aten.view.default(view_2, [8, 2]);  view_2 = None
1725    view_5 = torch.ops.aten.view.default(view_4, [4, 4])
1726    as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]);  view_5 = None
1727    view_6 = torch.ops.aten.view.default(as_strided_2, [-1]);  as_strided_2 = view_6 = None
1728    view_7 = torch.ops.aten.view.default(view_4, [4, 4]);  view_4 = None
1729    as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]);  view_7 = None
1730    add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1);  add_2 = None
1731    return as_strided_3
1732    """,
1733        )
1734
1735    def test_resize_same_size_diff_rank(self):
1736        def f(x):
1737            y = x.clone()
1738            y.resize_(25, 5)
1739            return y
1740
1741        self.assert_functionalization(f, torch.ones(5, 5, 5))
1742
1743    def test_resize_larger_valid(self):
1744        def f(x):
1745            y = x + 1
1746            # resizing a tensor to a larger size is only currently allowed
1747            # if the tensor-to-resize is not a view / has no outstanding views.
1748            # See Note [resize_() in functionalization pass]
1749            y.resize_(5, 5)
1750            y2 = y.view(25)
1751            # Do a mutation to ensure that aliases of the output of resize_()
1752            # propagate mutations correctly.
1753            # I'm using fill_ specifically because I want to guarantee that
1754            # none of the output has uninitialized memory at the end
1755            # (since these tests compare the data output against a reference impl)
1756            y2.fill_(1)
1757            out = y + 1
1758            return y, out
1759
1760        self.assert_functionalization(f, torch.ones(8, 2))
1761        logs = self.get_logs(f, torch.ones(8, 2))
1762        self.assertExpectedInline(
1763            logs,
1764            """\
1765
1766
1767
1768def forward(self, arg0_1):
1769    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1770    resize = torch.ops.aten.resize.default(add, [5, 5]);  add = None
1771    view_copy = torch.ops.aten.view_copy.default(resize, [25]);  resize = None
1772    fill = torch.ops.aten.fill.Scalar(view_copy, 1);  view_copy = None
1773    view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]);  fill = None
1774    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]);  view_copy_2 = None
1775    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
1776    return (view_copy_1, add_1)
1777    """,
1778        )
1779
1780        reinplaced_logs = self.get_logs(
1781            f, torch.ones(8, 2), reapply_views=True, run_reinplace=True
1782        )
1783        self.assertExpectedInline(
1784            reinplaced_logs,
1785            """\
1786
1787
1788
1789def forward(self, arg0_1):
1790    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1791    resize = torch.ops.aten.resize_.default(add, [5, 5]);  resize = None
1792    view = torch.ops.aten.view.default(add, [25]);  add = None
1793    fill = torch.ops.aten.fill_.Scalar(view, 1);  fill = None
1794    view_1 = torch.ops.aten.view.default(view, [5, 5]);  view = None
1795    view_2 = torch.ops.aten.view.default(view_1, [25]);  view_2 = None
1796    add_1 = torch.ops.aten.add.Tensor(view_1, 1)
1797    return (view_1, add_1)
1798    """,
1799        )
1800
1801    def test_resize_larger_invalid(self):
1802        def f(x):
1803            y = x + 1
1804            z = y.view(4, 4)
1805            # resizing a tensor to a larger size is only currently allowed
1806            # if the tensor-to-resize is not a view / has no outstanding views.
1807            # See Note [resize_() in functionalization pass]
1808            # This should fail
1809            z.resize_(5, 5)
1810            z2 = z.view(25)
1811            z2.fill_(1)
1812            out = z + 1
1813            return y, out
1814
1815        with self.assertRaisesRegex(
1816            RuntimeError,
1817            r"Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass",
1818        ):
1819            self.assert_functionalization(f, torch.ones(8, 2))
1820
1821    def test_nested_functions_propagate_updates(self):
1822        def g(x):
1823            # Create a view of x
1824            y = x[0]
1825            y.add_(1)
1826            # The view, y, gets deallocated at the end of this function
1827
1828        def f(x):
1829            # Calling g(x) should mutate x
1830            g(x)
1831            # We expect x to be synced here, even though the alias created in g() has been deallocated!
1832            y = x + x
1833            return y
1834
1835        self.assert_functionalization(f, torch.ones(2, 2))
1836
1837    def test_mixed_wrappers_valid(self):
1838        def f(x, y):
1839            z = x + y
1840            z.add_(1)
1841            return z
1842
1843        x1_not_functional = LoggingTensor(torch.ones(4))
1844        x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
1845
1846        with capture_logs() as logs:
1847            y = f(x1_not_functional, x2_functional)
1848
1849        # Make sure that functionalization ran the "+" kernel
1850        # with a functional + non-functional tensor, and wrapped the output appropriately.
1851        self.assertExpectedInline(
1852            "\n".join(logs),
1853            """\
1854$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
1855$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""",
1856        )
1857
1858    def test_mixed_wrappers_invalid(self):
1859        x1_not_functional = torch.ones(4)
1860        x2_functional = torch._to_functional_tensor(torch.ones(4))
1861
1862        # When dealing with mixed functional + non functional tensors,
1863        # normal_tensor.add_(functional_tensor) is not valid
1864        # because normal_tensor would need to be "promoted" to a functional tensor.
1865        with self.assertRaises(RuntimeError):
1866            x1_not_functional.add_(x2_functional)
1867
1868    def test_index_mutation_on_non_input(self):
1869        def f(x):
1870            tmp = torch.zeros(10)
1871            tmp[5].fill_(1)
1872            return tmp
1873
1874        self.assert_functionalization(f, torch.ones(2))
1875        logs = self.get_logs(f, torch.ones(2))
1876        self.assertExpectedInline(
1877            logs,
1878            """\
1879
1880
1881
1882def forward(self, arg0_1):
1883    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1884    select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
1885    fill = torch.ops.aten.fill.Scalar(select_copy, 1);  select_copy = None
1886    select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5);  zeros = fill = None
1887    select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5);  select_copy_1 = None
1888    return select_scatter
1889    """,
1890        )  # noqa: B950
1891
1892        reinplaced_logs = self.get_logs(
1893            f, torch.ones(2), reapply_views=True, run_reinplace=True
1894        )
1895        self.assertExpectedInline(
1896            reinplaced_logs,
1897            """\
1898
1899
1900
1901def forward(self, arg0_1):
1902    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1903    select = torch.ops.aten.select.int(zeros, 0, 5)
1904    fill = torch.ops.aten.fill_.Scalar(select, 1);  select = fill = None
1905    select_1 = torch.ops.aten.select.int(zeros, 0, 5);  select_1 = None
1906    return zeros
1907    """,
1908        )
1909
1910    def test_instance_norm(self):
1911        size = 100
1912
1913        def f(x, running_mean, running_var):
1914            with enable_python_dispatcher():
1915                return torch.instance_norm(
1916                    x,
1917                    None,
1918                    None,
1919                    running_mean,
1920                    running_var,
1921                    use_input_stats=True,
1922                    momentum=0.1,
1923                    eps=1e-5,
1924                    cudnn_enabled=False,
1925                )
1926
1927        self.assert_functionalization(
1928            f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)
1929        )
1930        # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
1931        # whereas on other platforms, the alias_copy's are before the view_copy's.
1932        # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
1933        if not IS_WINDOWS:
1934            logs = self.get_logs(
1935                f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)
1936            )
1937            self.assertExpectedInline(
1938                logs,
1939                """\
1940
1941
1942
1943def forward(self, arg0_1, arg1_1, arg2_1):
1944    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1945    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1946    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1947    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
1948    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view_copy = repeat = repeat_1 = None
1949    getitem = _native_batch_norm_legit_functional[0]
1950    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
1951    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
1952    getitem_3 = _native_batch_norm_legit_functional[3]
1953    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1954    alias_copy = torch.ops.aten.alias_copy.default(arg1_1)
1955    view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]);  view_copy_1 = None
1956    view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]);  getitem_3 = None
1957    mean = torch.ops.aten.mean.dim(view_copy_2, [0]);  view_copy_2 = None
1958    copy = torch.ops.aten.copy.default(alias_copy, mean);  alias_copy = mean = None
1959    alias_copy_1 = torch.ops.aten.alias_copy.default(copy);  copy = None
1960    alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1);  alias_copy_2 = None
1961    alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
1962    view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]);  view_copy_3 = None
1963    view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]);  getitem_4 = None
1964    mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]);  view_copy_4 = None
1965    copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1);  alias_copy_3 = mean_1 = None
1966    alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1);  copy_1 = None
1967    alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4);  alias_copy_5 = None
1968    view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]);  getitem = None
1969    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1);  arg1_1 = alias_copy_1 = copy_ = None
1970    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4);  arg2_1 = alias_copy_4 = copy__1 = None
1971    return view_copy_5
1972    """,  # noqa: B950
1973            )
1974
1975            reinplaced_logs = self.get_logs(
1976                f,
1977                torch.randn(20, size, 35, 45),
1978                torch.zeros(size),
1979                torch.ones(size),
1980                reapply_views=True,
1981                run_reinplace=True,
1982            )
1983            self.assertExpectedInline(
1984                reinplaced_logs,
1985                """\
1986
1987
1988
1989def forward(self, arg0_1, arg1_1, arg2_1):
1990    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1991    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1992    view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1993    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
1994    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view = repeat = repeat_1 = None
1995    getitem = _native_batch_norm_legit_functional[0]
1996    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
1997    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
1998    getitem_3 = _native_batch_norm_legit_functional[3]
1999    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2000    alias = torch.ops.aten.alias.default(arg1_1)
2001    view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]);  view_1 = None
2002    view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]);  getitem_3 = None
2003    mean = torch.ops.aten.mean.dim(view_2, [0]);  view_2 = None
2004    copy = torch.ops.aten.copy.default(alias, mean);  alias = mean = None
2005    alias_1 = torch.ops.aten.alias.default(copy);  copy = None
2006    alias_2 = torch.ops.aten.alias.default(alias_1);  alias_2 = None
2007    alias_3 = torch.ops.aten.alias.default(arg2_1)
2008    view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]);  view_3 = None
2009    view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]);  getitem_4 = None
2010    mean_1 = torch.ops.aten.mean.dim(view_4, [0]);  view_4 = None
2011    copy_1 = torch.ops.aten.copy.default(alias_3, mean_1);  alias_3 = mean_1 = None
2012    alias_4 = torch.ops.aten.alias.default(copy_1);  copy_1 = None
2013    alias_5 = torch.ops.aten.alias.default(alias_4);  alias_5 = None
2014    view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]);  getitem = None
2015    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1);  arg1_1 = alias_1 = copy_ = None
2016    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4);  arg2_1 = alias_4 = copy__1 = None
2017    return view_5
2018    """,  # noqa: B950
2019            )
2020
2021    def test_mutation_overlapping_mem(self):
2022        def fn(x):
2023            # x: (1, 5)
2024            t1 = torch.add(x, x)
2025            t2 = t1.unfold(1, 3, 2)
2026            t3 = t2.abs_()
2027            return t3
2028
2029        with self.assertRaisesRegex(
2030            RuntimeError,
2031            r"encountered a tensor being mutated that has internal overlap",
2032        ):
2033            x = torch.ones(1, 5)
2034            out = _functionalize(fn, reapply_views=True, crossref=False)(x)
2035
2036    def test_batch_norm(self):
2037        def f(x, running_mean, running_var):
2038            with enable_python_dispatcher():
2039                return torch.batch_norm(
2040                    x, None, None, running_mean, running_var, True, 0.1, 1e-5, False
2041                )
2042
2043        self.assert_functionalization(
2044            f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)
2045        )
2046        logs = self.get_logs(
2047            f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)
2048        )
2049        self.assertExpectedInline(
2050            logs,
2051            """\
2052
2053
2054
2055def forward(self, arg0_1, arg1_1, arg2_1):
2056    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
2057    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
2058    getitem = _native_batch_norm_legit_functional[0]
2059    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
2060    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
2061    getitem_3 = _native_batch_norm_legit_functional[3]
2062    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2063    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = copy_ = None
2064    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = copy__1 = None
2065    return getitem
2066    """,  # noqa: B950
2067        )
2068
2069        reinplaced_logs = self.get_logs(
2070            f,
2071            torch.randn(20, 100, 35, 45),
2072            torch.zeros(100),
2073            torch.ones(100),
2074            reapply_views=True,
2075            run_reinplace=True,
2076        )
2077        self.assertExpectedInline(
2078            reinplaced_logs,
2079            """\
2080
2081
2082
2083def forward(self, arg0_1, arg1_1, arg2_1):
2084    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
2085    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
2086    getitem = _native_batch_norm_legit_functional[0]
2087    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
2088    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
2089    getitem_3 = _native_batch_norm_legit_functional[3]
2090    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2091    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = copy_ = None
2092    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = copy__1 = None
2093    return getitem
2094    """,  # noqa: B950
2095        )
2096
2097    # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
2098    def test_python_functionalization(self):
2099        def f(x):
2100            x_view = x.view(-1)
2101            x.mul_(2)
2102            return x_view + 1
2103
2104        def f_functionalized(x):
2105            # Note [Disabling Functionalize TLS Above Python Functionalization]
2106            # This UX is pretty annoying (although python functionalization's main customer is AOTAutograd,
2107            # and is not really advertised as a user API).
2108            # We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode.
2109            # Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper.
2110            # Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default,
2111            # our FunctionalTensor will inherit the same keyset.
2112            # We don't have an easy way of directly mutating a tensor's keyset from python,
2113            # so globally disabling functionalization here is easier.
2114            maybe_disable = torch._C._ExcludeDispatchKeyGuard(
2115                torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
2116            )
2117            with maybe_disable, FunctionalTensorMode():
2118                x_wrapped = FunctionalTensor.to_functional(x)
2119                out_wrapped = f(x_wrapped)
2120            out_unwrapped = out_wrapped.elem
2121            torch._sync(out_unwrapped)
2122            return torch._from_functional_tensor(out_unwrapped)
2123
2124        # Make a non-leaf
2125        x = torch.randn(2, requires_grad=True) + 1
2126        fx_g = make_fx(f_functionalized)(x)
2127        # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
2128        # DCE pass that will remove nodes like this later on.
2129        self.assertExpectedInline(
2130            fx_g.code.strip(),
2131            """\
2132def forward(self, x_1):
2133    view = torch.ops.aten.view.default(x_1, [-1]);  view = None
2134    mul = torch.ops.aten.mul.Tensor(x_1, 2);  x_1 = None
2135    view_1 = torch.ops.aten.view.default(mul, [-1]);  view_1 = None
2136    view_2 = torch.ops.aten.view.default(mul, [-1]);  mul = None
2137    add = torch.ops.aten.add.Tensor(view_2, 1);  view_2 = None
2138    return add""",
2139        )
2140
2141    def test_python_functionalization_zero_tensor(self):
2142        def f(x):
2143            y = torch.ops.aten._efficientzerotensor([4])
2144            out = x + y
2145            out.mul_(2)
2146            return out
2147
2148        x = torch.randn(4)
2149        out_ref = f(x)
2150        out_test = dispatch_functionalize(f)(x)
2151        out_test_cpp = _functionalize(
2152            f, reapply_views=True, crossref=False, skip_input_mutations=True
2153        )(x)
2154        self.assertEqual(out_ref, out_test)
2155        self.assertEqual(out_ref, out_test_cpp)
2156        fx_g = make_fx(dispatch_functionalize(f))(x)
2157        fx_g_cpp = make_fx(
2158            _functionalize(
2159                f, reapply_views=True, crossref=False, skip_input_mutations=True
2160            )
2161        )(x)
2162        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2163
2164    def test_python_functionalization_is_conj(self):
2165        def f(x):
2166            out = x.conj()
2167            return out, out.is_conj()
2168
2169        x = torch.randn(4, dtype=torch.complex64)
2170        out_ref = f(x)
2171        out_test = dispatch_functionalize(f)(x)
2172        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
2173        self.assertEqual(out_ref[0], out_test[0])
2174        self.assertEqual(out_ref[1], out_test[1])
2175        self.assertEqual(out_ref[0], out_test_cpp[0])
2176        self.assertEqual(out_ref[1], out_test_cpp[1])
2177
2178    def test_python_functionalization_is_neg(self):
2179        def f(x):
2180            out = x.neg()
2181            return out, out.is_neg()
2182
2183        x = torch.randn(4, dtype=torch.complex64)
2184        out_ref = f(x)
2185        out_test = dispatch_functionalize(f)(x)
2186        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
2187        self.assertEqual(out_ref[0], out_test[0])
2188        self.assertEqual(out_ref[1], out_test[1])
2189        self.assertEqual(out_ref[0], out_test_cpp[0])
2190        self.assertEqual(out_ref[1], out_test_cpp[1])
2191
2192    def test_python_functionalization_conj(self):
2193        def f(x):
2194            y = x.clone().conj()
2195            y.mul_(2)
2196            return torch.view_as_real(y.resolve_conj())
2197
2198        x = torch.randn(4, dtype=torch.complex64)
2199        out_ref = f(x)
2200        out_test = dispatch_functionalize(f)(x)
2201        out_test_cpp = _functionalize(
2202            f, reapply_views=True, crossref=False, skip_input_mutations=True
2203        )(x)
2204        self.assertEqual(out_ref, out_test)
2205        self.assertEqual(out_test, out_test_cpp)
2206        fx_g = make_fx(dispatch_functionalize(f))(x)
2207        fx_g_cpp = make_fx(
2208            _functionalize(
2209                f, reapply_views=True, crossref=False, skip_input_mutations=True
2210            )
2211        )(x)
2212        self.assertExpectedInline(
2213            fx_g.code.strip(),
2214            """\
2215def forward(self, arg0_1):
2216    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
2217    _conj = torch.ops.aten._conj.default(clone);  clone = None
2218    clone_1 = torch.ops.aten.clone.default(_conj)
2219    mul = torch.ops.aten.mul.Tensor(clone_1, 2);  clone_1 = None
2220    clone_2 = torch.ops.aten.clone.default(_conj);  _conj = None
2221    copy = torch.ops.aten.copy.default(clone_2, mul);  clone_2 = mul = None
2222    _conj_1 = torch.ops.aten._conj.default(copy);  copy = None
2223    _conj_2 = torch.ops.aten._conj.default(_conj_1);  _conj_1 = None
2224    clone_3 = torch.ops.aten.clone.default(_conj_2);  _conj_2 = None
2225    view_as_real = torch.ops.aten.view_as_real.default(clone_3);  clone_3 = None
2226    return view_as_real""",
2227        )
2228        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2229
2230    def test_python_functionalization_neg(self):
2231        def f(x):
2232            y = x._neg_view()
2233            z = y.resolve_neg()
2234            return z + 1
2235
2236        x = torch.randn(4)
2237        out_ref = f(x)
2238        out_test = dispatch_functionalize(f)(x)
2239        out_test_cpp = _functionalize(
2240            f, reapply_views=True, crossref=False, skip_input_mutations=True
2241        )(x)
2242        self.assertEqual(out_ref, out_test)
2243        self.assertEqual(out_ref, out_test_cpp)
2244        fx_g = make_fx(dispatch_functionalize(f))(x)
2245        fx_g_cpp = make_fx(
2246            _functionalize(
2247                f, reapply_views=True, crossref=False, skip_input_mutations=True
2248            )
2249        )(x)
2250        self.assertExpectedInline(
2251            fx_g.code.strip(),
2252            """\
2253def forward(self, arg0_1):
2254    _neg_view = torch.ops.aten._neg_view.default(arg0_1);  arg0_1 = None
2255    clone = torch.ops.aten.clone.default(_neg_view);  _neg_view = None
2256    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
2257    return add""",
2258        )
2259        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2260
2261    def test_python_functionalization_lift_fresh_storage(self):
2262        unlifted = torch.tensor([0.0])
2263
2264        maybe_disable = torch._C._ExcludeDispatchKeyGuard(
2265            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
2266        )
2267        with maybe_disable, FunctionalTensorMode():
2268            lifted = torch.ops.aten.lift_fresh.default(unlifted)
2269
2270        self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage())
2271
2272    def test_python_functionalization_lift_fresh(self):
2273        def f(x):
2274            tmp = torch.tensor([0.0])
2275            return tmp + x
2276
2277        x = torch.randn(4)
2278        out_ref = f(x)
2279        out_test = dispatch_functionalize(f)(x)
2280        out_test_cpp = _functionalize(
2281            f, reapply_views=True, crossref=False, skip_input_mutations=True
2282        )(x)
2283        self.assertEqual(out_ref, out_test)
2284        self.assertEqual(out_ref, out_test_cpp)
2285        fx_g = make_fx(dispatch_functionalize(f))(x)
2286        fx_g_cpp = make_fx(
2287            _functionalize(
2288                f, reapply_views=True, crossref=False, skip_input_mutations=True
2289            )
2290        )(x)
2291        self.assertExpectedInline(
2292            fx_g.code.strip(),
2293            """\
2294def forward(self, arg0_1):
2295    _tensor_constant0 = self._tensor_constant0
2296    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
2297    add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1);  lift_fresh_copy = arg0_1 = None
2298    return add""",
2299        )
2300        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2301
2302
2303@xfail_inherited_tests(
2304    [
2305        "test_as_strided",
2306        "test_copy_",
2307        "test_diagonal",
2308        "test_diagonal_mutated_input",
2309        "test_everything",
2310        "test_fill_",
2311        "test_slice",
2312        "test_split",
2313        "test_split_with_sizes",
2314        "test_unbind",
2315        "test_view_clone_view_inplace",
2316        "test_view_inplace",
2317    ]
2318)
2319@unittest.skipIf(
2320    TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well"
2321)
2322class TestCrossRefFunctionalization(TestFunctionalization):
2323    crossref = True
2324
2325
2326if __name__ == "__main__":
2327    run_tests()
2328