xref: /aosp_15_r20/external/pytorch/test/test_functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: codegen"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import (
8*da0073e9SAndroid Build Coastguard Worker    enable_crossref_functionalize,
9*da0073e9SAndroid Build Coastguard Worker    enable_python_dispatcher,
10*da0073e9SAndroid Build Coastguard Worker)
11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.functional_tensor import (
12*da0073e9SAndroid Build Coastguard Worker    dispatch_functionalize,
13*da0073e9SAndroid Build Coastguard Worker    FunctionalTensor,
14*da0073e9SAndroid Build Coastguard Worker    FunctionalTensorMode,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx
17*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.reinplace import reinplace
18*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing.reductions import StorageWeakRef
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
20*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
21*da0073e9SAndroid Build Coastguard Worker    run_tests,
22*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
23*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
24*da0073e9SAndroid Build Coastguard Worker    TestCase,
25*da0073e9SAndroid Build Coastguard Worker    xfail_inherited_tests,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import capture_logs, LoggingTensor
28*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
29*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map_only
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef are_aliased(x, y):
33*da0073e9SAndroid Build Coastguard Worker    x_storage = StorageWeakRef(x.storage())
34*da0073e9SAndroid Build Coastguard Worker    y_storage = StorageWeakRef(y.storage())
35*da0073e9SAndroid Build Coastguard Worker    return x_storage == y_storage
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker# We can unify testing and use functionalize() here instead
39*da0073e9SAndroid Build Coastguard Worker# if/when functorch moves into core.
40*da0073e9SAndroid Build Coastguard Worker# This is basically a crappy version of `functionalize()`.
41*da0073e9SAndroid Build Coastguard Workerdef _functionalize(
42*da0073e9SAndroid Build Coastguard Worker    f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False
43*da0073e9SAndroid Build Coastguard Worker):
44*da0073e9SAndroid Build Coastguard Worker    def to_fun(t: torch.Tensor):
45*da0073e9SAndroid Build Coastguard Worker        func_t = torch._to_functional_tensor(t)
46*da0073e9SAndroid Build Coastguard Worker        func_t.requires_grad = t.requires_grad
47*da0073e9SAndroid Build Coastguard Worker        return func_t
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def wrapped(*inputs):
50*da0073e9SAndroid Build Coastguard Worker        ctx = nullcontext()
51*da0073e9SAndroid Build Coastguard Worker        if crossref:
52*da0073e9SAndroid Build Coastguard Worker            ctx = enable_crossref_functionalize()
53*da0073e9SAndroid Build Coastguard Worker        with ctx:
54*da0073e9SAndroid Build Coastguard Worker            inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs)
55*da0073e9SAndroid Build Coastguard Worker            torch._enable_functionalization(reapply_views=reapply_views)
56*da0073e9SAndroid Build Coastguard Worker            try:
57*da0073e9SAndroid Build Coastguard Worker                out = f(*inputs_functional)
58*da0073e9SAndroid Build Coastguard Worker            finally:
59*da0073e9SAndroid Build Coastguard Worker                torch._disable_functionalization()
60*da0073e9SAndroid Build Coastguard Worker            flat_inputs = pytree.tree_leaves(inputs)
61*da0073e9SAndroid Build Coastguard Worker            flat_inputs_functional = pytree.tree_leaves(inputs_functional)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker            for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
64*da0073e9SAndroid Build Coastguard Worker                torch._sync(input_functional)
65*da0073e9SAndroid Build Coastguard Worker                inpt_new = torch._from_functional_tensor(input_functional)
66*da0073e9SAndroid Build Coastguard Worker                if inpt_new is not inpt and not skip_input_mutations:
67*da0073e9SAndroid Build Coastguard Worker                    # Existing deficiency in functionalize():
68*da0073e9SAndroid Build Coastguard Worker                    # we don't correctly mutate input metadata (yet?)
69*da0073e9SAndroid Build Coastguard Worker                    if inpt_new.shape == inpt.shape:
70*da0073e9SAndroid Build Coastguard Worker                        inpt.copy_(inpt_new)
71*da0073e9SAndroid Build Coastguard Worker            tree_map_only(torch.Tensor, torch._sync, out)
72*da0073e9SAndroid Build Coastguard Worker            out_unwrapped = tree_map_only(
73*da0073e9SAndroid Build Coastguard Worker                torch.Tensor, torch._from_functional_tensor, out
74*da0073e9SAndroid Build Coastguard Worker            )
75*da0073e9SAndroid Build Coastguard Worker            return out_unwrapped
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    return wrapped
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
81*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457"
82*da0073e9SAndroid Build Coastguard Worker)
83*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalization(TestCase):
84*da0073e9SAndroid Build Coastguard Worker    crossref = False
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False):
87*da0073e9SAndroid Build Coastguard Worker        inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts)
88*da0073e9SAndroid Build Coastguard Worker        traced_f = make_fx(
89*da0073e9SAndroid Build Coastguard Worker            _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)
90*da0073e9SAndroid Build Coastguard Worker        )(*inpts)
91*da0073e9SAndroid Build Coastguard Worker        if run_reinplace:
92*da0073e9SAndroid Build Coastguard Worker            traced_f = reinplace(traced_f, *inpts_clone)
93*da0073e9SAndroid Build Coastguard Worker        return traced_f.code
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def assert_functionalization(
96*da0073e9SAndroid Build Coastguard Worker        self, func, *inpts, reapply_views=False, mutated_input_metadata=False
97*da0073e9SAndroid Build Coastguard Worker    ):
98*da0073e9SAndroid Build Coastguard Worker        clones1 = tree_map_only(torch.Tensor, torch.clone, inpts)
99*da0073e9SAndroid Build Coastguard Worker        clones2 = tree_map_only(torch.Tensor, torch.clone, inpts)
100*da0073e9SAndroid Build Coastguard Worker        clones3 = tree_map_only(torch.Tensor, torch.clone, inpts)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        # Compare outputs (and mutated inputs), with and without functionalization.
103*da0073e9SAndroid Build Coastguard Worker        out_ref = func(*inpts)
104*da0073e9SAndroid Build Coastguard Worker        out_functional = _functionalize(
105*da0073e9SAndroid Build Coastguard Worker            func, reapply_views=reapply_views, crossref=self.crossref
106*da0073e9SAndroid Build Coastguard Worker        )(*clones1)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        # The reinplacing pass is only valid to run with reapply_views=True.
109*da0073e9SAndroid Build Coastguard Worker        functional_func = make_fx(
110*da0073e9SAndroid Build Coastguard Worker            _functionalize(func, reapply_views=True, crossref=self.crossref)
111*da0073e9SAndroid Build Coastguard Worker        )(*clones2)
112*da0073e9SAndroid Build Coastguard Worker        reinplace_func = reinplace(functional_func, *clones2)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        # NOTE: for now, need to pass in fresh inputs here, because make_fx
115*da0073e9SAndroid Build Coastguard Worker        # will directly mutate the inputs that you trace with.
116*da0073e9SAndroid Build Coastguard Worker        # Once this is fixed we can clean this up.
117*da0073e9SAndroid Build Coastguard Worker        out_reinplace = reinplace_func(*clones3)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        # functionalize() deficiency: input metadata mutations aren't propagated properly,
120*da0073e9SAndroid Build Coastguard Worker        # so we just need to skip checks here for the tests that exercise that.
121*da0073e9SAndroid Build Coastguard Worker        if not mutated_input_metadata:
122*da0073e9SAndroid Build Coastguard Worker            flat_inpts = pytree.tree_leaves(inpts)
123*da0073e9SAndroid Build Coastguard Worker            flat_clones1 = pytree.tree_leaves(clones1)
124*da0073e9SAndroid Build Coastguard Worker            flat_clones3 = pytree.tree_leaves(clones3)
125*da0073e9SAndroid Build Coastguard Worker            for inpt, input_clone, input_clone3 in zip(
126*da0073e9SAndroid Build Coastguard Worker                flat_inpts, flat_clones1, flat_clones3
127*da0073e9SAndroid Build Coastguard Worker            ):
128*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
129*da0073e9SAndroid Build Coastguard Worker                    inpt, input_clone
130*da0073e9SAndroid Build Coastguard Worker                )  # input mutations should still occur
131*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inpt, input_clone3)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        # Handle tests with multi-tensor outputs
134*da0073e9SAndroid Build Coastguard Worker        if isinstance(out_ref, tuple):
135*da0073e9SAndroid Build Coastguard Worker            out_refs, out_functionals, out_reinplaces = (
136*da0073e9SAndroid Build Coastguard Worker                list(out_ref),
137*da0073e9SAndroid Build Coastguard Worker                list(out_functional),
138*da0073e9SAndroid Build Coastguard Worker                list(out_reinplace),
139*da0073e9SAndroid Build Coastguard Worker            )
140*da0073e9SAndroid Build Coastguard Worker        else:
141*da0073e9SAndroid Build Coastguard Worker            out_refs, out_functionals, out_reinplaces = (
142*da0073e9SAndroid Build Coastguard Worker                [out_ref],
143*da0073e9SAndroid Build Coastguard Worker                [out_functional],
144*da0073e9SAndroid Build Coastguard Worker                [out_reinplace],
145*da0073e9SAndroid Build Coastguard Worker            )
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        for out_ref_, out_functional_, out_reinplace_ in zip(
148*da0073e9SAndroid Build Coastguard Worker            out_refs, out_functionals, out_reinplaces
149*da0073e9SAndroid Build Coastguard Worker        ):
150*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_ref_, out_functional_)
151*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_ref_, out_reinplace_)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def test_save_for_backwards_segfault(self):
154*da0073e9SAndroid Build Coastguard Worker        inp = torch._to_functional_tensor(
155*da0073e9SAndroid Build Coastguard Worker            LoggingTensor(torch.randn(2, 2))
156*da0073e9SAndroid Build Coastguard Worker        ).requires_grad_(True)
157*da0073e9SAndroid Build Coastguard Worker        inp.exp()
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def test_multiple_views_of_same_base(self):
160*da0073e9SAndroid Build Coastguard Worker        def f(x):
161*da0073e9SAndroid Build Coastguard Worker            y = x.view(-1)
162*da0073e9SAndroid Build Coastguard Worker            z = x.view(-1)
163*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
164*da0073e9SAndroid Build Coastguard Worker            # y should have been updated.
165*da0073e9SAndroid Build Coastguard Worker            y2 = y + 1
166*da0073e9SAndroid Build Coastguard Worker            # z should have been updated too.
167*da0073e9SAndroid Build Coastguard Worker            z2 = z + 1
168*da0073e9SAndroid Build Coastguard Worker            return z2
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4))
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def test_freeze(self):
173*da0073e9SAndroid Build Coastguard Worker        def f(x):
174*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
175*da0073e9SAndroid Build Coastguard Worker            z = y[0]
176*da0073e9SAndroid Build Coastguard Worker            torch._freeze_functional_tensor(y)
177*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
178*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: y.add_(1))
179*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: z.add_(1))
180*da0073e9SAndroid Build Coastguard Worker            return z
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3))
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker    def test_copy_stride_mismatch(self):
185*da0073e9SAndroid Build Coastguard Worker        def f(x):
186*da0073e9SAndroid Build Coastguard Worker            y = torch.empty_strided((2, 2), (5, 1))
187*da0073e9SAndroid Build Coastguard Worker            y.copy_(x)
188*da0073e9SAndroid Build Coastguard Worker            return y
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        r = _functionalize(f, reapply_views=True, crossref=self.crossref)(
191*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2)
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.stride(), (5, 1))
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_set_(self):
196*da0073e9SAndroid Build Coastguard Worker        def f(x):
197*da0073e9SAndroid Build Coastguard Worker            y = torch.ones(2)
198*da0073e9SAndroid Build Coastguard Worker            y.set_(x.storage())
199*da0073e9SAndroid Build Coastguard Worker            return y
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        # We should probaby get the crossref test to work,
202*da0073e9SAndroid Build Coastguard Worker        # but fixing it for Storage() objects is annoying.
203*da0073e9SAndroid Build Coastguard Worker        r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(r.device), "cpu")
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing(self):
207*da0073e9SAndroid Build Coastguard Worker        def f():
208*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(3, 3)
209*da0073e9SAndroid Build Coastguard Worker            idx = torch.tensor([0])
210*da0073e9SAndroid Build Coastguard Worker            val = torch.ones(3, 1)
211*da0073e9SAndroid Build Coastguard Worker            x[:, idx] = val
212*da0073e9SAndroid Build Coastguard Worker            return x
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def test_view_clone_view_inplace(self):
217*da0073e9SAndroid Build Coastguard Worker        def f(input):
218*da0073e9SAndroid Build Coastguard Worker            shape = [1, 1024, 128, 128]
219*da0073e9SAndroid Build Coastguard Worker            input_reshaped = input.view(shape)
220*da0073e9SAndroid Build Coastguard Worker            out = input_reshaped.clone()
221*da0073e9SAndroid Build Coastguard Worker            r = out.view(input.shape)
222*da0073e9SAndroid Build Coastguard Worker            r.relu_()
223*da0073e9SAndroid Build Coastguard Worker            return r
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        def g(x):
226*da0073e9SAndroid Build Coastguard Worker            loss = f(x).sum()
227*da0073e9SAndroid Build Coastguard Worker            import torch.fx.traceback as fx_traceback
228*da0073e9SAndroid Build Coastguard Worker            from torch._functorch.aot_autograd import (
229*da0073e9SAndroid Build Coastguard Worker                setup_stacktrace_preservation_hooks,
230*da0073e9SAndroid Build Coastguard Worker            )
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            setup_stacktrace_preservation_hooks([loss.grad_fn])
233*da0073e9SAndroid Build Coastguard Worker            with fx_traceback.preserve_node_meta():
234*da0073e9SAndroid Build Coastguard Worker                loss.backward()
235*da0073e9SAndroid Build Coastguard Worker            return x.grad
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.detect_anomaly(check_nan=False):
238*da0073e9SAndroid Build Coastguard Worker            logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
239*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
240*da0073e9SAndroid Build Coastguard Worker            logs,
241*da0073e9SAndroid Build Coastguard Worker            """\
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
246*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]);  arg0_1 = None
247*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(view_copy);  view_copy = None
248*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
249*da0073e9SAndroid Build Coastguard Worker    relu = torch.ops.aten.relu.default(view_copy_1);  view_copy_1 = None
250*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]);  relu = None
251*da0073e9SAndroid Build Coastguard Worker    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]);  view_copy_2 = None
252*da0073e9SAndroid Build Coastguard Worker    view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]);  clone = view_copy_4 = None
253*da0073e9SAndroid Build Coastguard Worker    sum_1 = torch.ops.aten.sum.default(view_copy_3)
254*da0073e9SAndroid Build Coastguard Worker    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format);  sum_1 = None
255*da0073e9SAndroid Build Coastguard Worker    expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]);  ones_like = None
256*da0073e9SAndroid Build Coastguard Worker    view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]);  expand_copy = None
257*da0073e9SAndroid Build Coastguard Worker    new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
258*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5);  new_empty_strided = view_copy_5 = None
259*da0073e9SAndroid Build Coastguard Worker    view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]);  view_copy_6 = None
260*da0073e9SAndroid Build Coastguard Worker    view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
261*da0073e9SAndroid Build Coastguard Worker    clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
262*da0073e9SAndroid Build Coastguard Worker    threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0);  clone_1 = view_copy_3 = None
263*da0073e9SAndroid Build Coastguard Worker    copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward);  view_copy_7 = threshold_backward = None
264*da0073e9SAndroid Build Coastguard Worker    view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]);  copy_1 = None
265*da0073e9SAndroid Build Coastguard Worker    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]);  view_copy_9 = None
266*da0073e9SAndroid Build Coastguard Worker    view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]);  copy = None
267*da0073e9SAndroid Build Coastguard Worker    detach_copy = torch.ops.aten.detach_copy.default(view_copy_10);  view_copy_10 = detach_copy = None
268*da0073e9SAndroid Build Coastguard Worker    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]);  view_copy_8 = None
269*da0073e9SAndroid Build Coastguard Worker    detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11);  view_copy_11 = None
270*da0073e9SAndroid Build Coastguard Worker    return detach_copy_1
271*da0073e9SAndroid Build Coastguard Worker    """,
272*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    def test_simple(self):
275*da0073e9SAndroid Build Coastguard Worker        def f(x):
276*da0073e9SAndroid Build Coastguard Worker            # simple test: 1 view op, 1 inplace op
277*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
278*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
279*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
280*da0073e9SAndroid Build Coastguard Worker            z = x * x
281*da0073e9SAndroid Build Coastguard Worker            return y
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
284*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
285*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
286*da0073e9SAndroid Build Coastguard Worker            logs,
287*da0073e9SAndroid Build Coastguard Worker            """\
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
292*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
293*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
294*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
295*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
296*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
297*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1);  mul = None
298*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
299*da0073e9SAndroid Build Coastguard Worker    return view_copy_2
300*da0073e9SAndroid Build Coastguard Worker    """,
301*da0073e9SAndroid Build Coastguard Worker        )
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
304*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
305*da0073e9SAndroid Build Coastguard Worker        )
306*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
307*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
308*da0073e9SAndroid Build Coastguard Worker            """\
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
313*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
314*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(arg0_1, [4, 2])
315*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
316*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
317*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
318*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(view_1, view_1);  mul = None
319*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = copy_ = None
320*da0073e9SAndroid Build Coastguard Worker    return view_2
321*da0073e9SAndroid Build Coastguard Worker    """,
322*da0073e9SAndroid Build Coastguard Worker        )
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def test_simple_out(self):
325*da0073e9SAndroid Build Coastguard Worker        def f(x):
326*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
327*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
328*da0073e9SAndroid Build Coastguard Worker            # the out= tensor will get resized, since it has size=0 to start.
329*da0073e9SAndroid Build Coastguard Worker            z = torch.empty(())
330*da0073e9SAndroid Build Coastguard Worker            torch.add(y, tmp, out=z)
331*da0073e9SAndroid Build Coastguard Worker            w = z * z
332*da0073e9SAndroid Build Coastguard Worker            return w
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
335*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
336*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
337*da0073e9SAndroid Build Coastguard Worker            logs,
338*da0073e9SAndroid Build Coastguard Worker            """\
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
343*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
344*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
345*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False);  empty = None
346*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
347*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
348*da0073e9SAndroid Build Coastguard Worker    return mul
349*da0073e9SAndroid Build Coastguard Worker    """,
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
353*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
354*da0073e9SAndroid Build Coastguard Worker        )
355*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
356*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
357*da0073e9SAndroid Build Coastguard Worker            """\
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
362*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
363*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(arg0_1, [4, 2]);  arg0_1 = None
364*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False);  empty = None
365*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
366*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
367*da0073e9SAndroid Build Coastguard Worker    return mul
368*da0073e9SAndroid Build Coastguard Worker    """,
369*da0073e9SAndroid Build Coastguard Worker        )
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    def test_multi_out(self):
372*da0073e9SAndroid Build Coastguard Worker        def f(x):
373*da0073e9SAndroid Build Coastguard Worker            # aminmax.out returns a tuple of tensors.
374*da0073e9SAndroid Build Coastguard Worker            # functionalization should properly handle the tuple.
375*da0073e9SAndroid Build Coastguard Worker            out_min = torch.empty(4)
376*da0073e9SAndroid Build Coastguard Worker            out_max = torch.empty(4)
377*da0073e9SAndroid Build Coastguard Worker            torch.aminmax(x, dim=0, out=(out_max, out_min))
378*da0073e9SAndroid Build Coastguard Worker            return out_max
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
381*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
382*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
383*da0073e9SAndroid Build Coastguard Worker            logs,
384*da0073e9SAndroid Build Coastguard Worker            """\
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
389*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty = None
390*da0073e9SAndroid Build Coastguard Worker    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty_1 = None
391*da0073e9SAndroid Build Coastguard Worker    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
392*da0073e9SAndroid Build Coastguard Worker    getitem = aminmax[0]
393*da0073e9SAndroid Build Coastguard Worker    getitem_1 = aminmax[1];  aminmax = getitem_1 = None
394*da0073e9SAndroid Build Coastguard Worker    return getitem
395*da0073e9SAndroid Build Coastguard Worker    """,
396*da0073e9SAndroid Build Coastguard Worker        )
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
399*da0073e9SAndroid Build Coastguard Worker            f,
400*da0073e9SAndroid Build Coastguard Worker            torch.arange(8, dtype=torch.float32),
401*da0073e9SAndroid Build Coastguard Worker            reapply_views=True,
402*da0073e9SAndroid Build Coastguard Worker            run_reinplace=True,
403*da0073e9SAndroid Build Coastguard Worker        )
404*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
405*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
406*da0073e9SAndroid Build Coastguard Worker            """\
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
411*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty = None
412*da0073e9SAndroid Build Coastguard Worker    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False);  empty_1 = None
413*da0073e9SAndroid Build Coastguard Worker    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
414*da0073e9SAndroid Build Coastguard Worker    getitem = aminmax[0]
415*da0073e9SAndroid Build Coastguard Worker    getitem_1 = aminmax[1];  aminmax = getitem_1 = None
416*da0073e9SAndroid Build Coastguard Worker    return getitem
417*da0073e9SAndroid Build Coastguard Worker    """,
418*da0073e9SAndroid Build Coastguard Worker        )
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def test_tensor_ctr(self):
421*da0073e9SAndroid Build Coastguard Worker        def f(x):
422*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor((1, 2, 3))
423*da0073e9SAndroid Build Coastguard Worker            z = y.view(-1)
424*da0073e9SAndroid Build Coastguard Worker            z.add_(1)
425*da0073e9SAndroid Build Coastguard Worker            return y
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        inpt = torch.arange(3, dtype=torch.float32)
428*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, inpt)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, inpt)
431*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
432*da0073e9SAndroid Build Coastguard Worker            logs,
433*da0073e9SAndroid Build Coastguard Worker            """\
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
438*da0073e9SAndroid Build Coastguard Worker    _tensor_constant0 = self._tensor_constant0
439*da0073e9SAndroid Build Coastguard Worker    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
440*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
441*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
442*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(add, [3]);  add = None
443*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]);  view_copy_2 = None
444*da0073e9SAndroid Build Coastguard Worker    return view_copy_1
445*da0073e9SAndroid Build Coastguard Worker    """,
446*da0073e9SAndroid Build Coastguard Worker        )
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
449*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
450*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
451*da0073e9SAndroid Build Coastguard Worker            """\
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
456*da0073e9SAndroid Build Coastguard Worker    _tensor_constant0 = self._tensor_constant0
457*da0073e9SAndroid Build Coastguard Worker    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
458*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
459*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(view, 1);  add = None
460*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(view, [3]);  view = None
461*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(view_1, [-1]);  view_2 = None
462*da0073e9SAndroid Build Coastguard Worker    return view_1
463*da0073e9SAndroid Build Coastguard Worker    """,
464*da0073e9SAndroid Build Coastguard Worker        )
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_correct_strides(self):
467*da0073e9SAndroid Build Coastguard Worker        def f(a):
468*da0073e9SAndroid Build Coastguard Worker            # This test requires that *_scatter ops are able to return
469*da0073e9SAndroid Build Coastguard Worker            # non-contiguous tensors.
470*da0073e9SAndroid Build Coastguard Worker            b = a.clone()[:, 1]
471*da0073e9SAndroid Build Coastguard Worker            c = torch.ones_like(b, dtype=torch.bool)
472*da0073e9SAndroid Build Coastguard Worker            d = b.masked_fill_(c, 0)
473*da0073e9SAndroid Build Coastguard Worker            return d
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    def test_tensor_list_mixed_functional_nonfunctional(self):
478*da0073e9SAndroid Build Coastguard Worker        nonfunctional_tensor = torch.ones(2, dtype=torch.long)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        def f(x):
481*da0073e9SAndroid Build Coastguard Worker            # simple test: 1 view op, 1 inplace op
482*da0073e9SAndroid Build Coastguard Worker            functional_tensor = torch.ones(2, dtype=torch.long)
483*da0073e9SAndroid Build Coastguard Worker            out = x[functional_tensor, nonfunctional_tensor]
484*da0073e9SAndroid Build Coastguard Worker            return out
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker        out = f(torch.ones(2, 2))
487*da0073e9SAndroid Build Coastguard Worker        out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(
488*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2)
489*da0073e9SAndroid Build Coastguard Worker        )
490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_functional)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_non_view(self):
493*da0073e9SAndroid Build Coastguard Worker        def f(x):
494*da0073e9SAndroid Build Coastguard Worker            # test for the case where we functionalize an inplace op on the other tensor - not a view.
495*da0073e9SAndroid Build Coastguard Worker            # This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
496*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
497*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
498*da0073e9SAndroid Build Coastguard Worker            x.add_(tmp)
499*da0073e9SAndroid Build Coastguard Worker            return y
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
502*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
503*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
504*da0073e9SAndroid Build Coastguard Worker            logs,
505*da0073e9SAndroid Build Coastguard Worker            """\
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
510*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
511*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  view_copy = None
512*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
513*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = copy_ = None
514*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
515*da0073e9SAndroid Build Coastguard Worker    return view_copy_1
516*da0073e9SAndroid Build Coastguard Worker    """,
517*da0073e9SAndroid Build Coastguard Worker        )
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
520*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
521*da0073e9SAndroid Build Coastguard Worker        )
522*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
523*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
524*da0073e9SAndroid Build Coastguard Worker            """\
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
529*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
530*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(arg0_1, [4, 2]);  view = None
531*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
532*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = copy_ = None
533*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
534*da0073e9SAndroid Build Coastguard Worker    return view_1
535*da0073e9SAndroid Build Coastguard Worker    """,
536*da0073e9SAndroid Build Coastguard Worker        )
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    # Some ops that are mutable are neither inplace nor out= ops.
539*da0073e9SAndroid Build Coastguard Worker    # They also need special handling.
540*da0073e9SAndroid Build Coastguard Worker    def test_mutable_op_not_inplace_or_other(self):
541*da0073e9SAndroid Build Coastguard Worker        def f(x):
542*da0073e9SAndroid Build Coastguard Worker            return torch._fused_moving_avg_obs_fq_helper(
543*da0073e9SAndroid Build Coastguard Worker                x, x, x, x, x, x, x, 1.0, 0, 1, 0
544*da0073e9SAndroid Build Coastguard Worker            )
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(1))
547*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
548*da0073e9SAndroid Build Coastguard Worker            logs,
549*da0073e9SAndroid Build Coastguard Worker            """\
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
554*da0073e9SAndroid Build Coastguard Worker    _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
555*da0073e9SAndroid Build Coastguard Worker    getitem = _fused_moving_avg_obs_fq_helper_functional[0]
556*da0073e9SAndroid Build Coastguard Worker    getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
557*da0073e9SAndroid Build Coastguard Worker    getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2];  getitem_2 = None
558*da0073e9SAndroid Build Coastguard Worker    getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3];  getitem_3 = None
559*da0073e9SAndroid Build Coastguard Worker    getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4];  getitem_4 = None
560*da0073e9SAndroid Build Coastguard Worker    getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5];  _fused_moving_avg_obs_fq_helper_functional = None
561*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5);  arg0_1 = getitem_5 = copy_ = None
562*da0073e9SAndroid Build Coastguard Worker    return (getitem, getitem_1)
563*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
564*da0073e9SAndroid Build Coastguard Worker        )
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    def test_as_strided(self):
567*da0073e9SAndroid Build Coastguard Worker        def f(x):
568*da0073e9SAndroid Build Coastguard Worker            y = x.as_strided((2,), (2,), 1)
569*da0073e9SAndroid Build Coastguard Worker            y.add_(1)
570*da0073e9SAndroid Build Coastguard Worker            return x
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(9))
573*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(9))
574*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
575*da0073e9SAndroid Build Coastguard Worker            logs,
576*da0073e9SAndroid Build Coastguard Worker            """\
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
581*da0073e9SAndroid Build Coastguard Worker    as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
582*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(as_strided_copy, 1);  as_strided_copy = None
583*da0073e9SAndroid Build Coastguard Worker    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
584*da0073e9SAndroid Build Coastguard Worker    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1);  as_strided_copy_1 = None
585*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = copy_ = None
586*da0073e9SAndroid Build Coastguard Worker    return as_strided_scatter
587*da0073e9SAndroid Build Coastguard Worker    """,
588*da0073e9SAndroid Build Coastguard Worker        )
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
591*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
592*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=False
593*da0073e9SAndroid Build Coastguard Worker        )
594*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
595*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
596*da0073e9SAndroid Build Coastguard Worker            """\
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
601*da0073e9SAndroid Build Coastguard Worker    as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1)
602*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
603*da0073e9SAndroid Build Coastguard Worker    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
604*da0073e9SAndroid Build Coastguard Worker    as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1);  as_strided_1 = None
605*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = copy_ = None
606*da0073e9SAndroid Build Coastguard Worker    return as_strided_scatter
607*da0073e9SAndroid Build Coastguard Worker    """,
608*da0073e9SAndroid Build Coastguard Worker        )
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    def test_tensor_list_composite(self):
611*da0073e9SAndroid Build Coastguard Worker        def f(x):
612*da0073e9SAndroid Build Coastguard Worker            # Test an op with TensorList input
613*da0073e9SAndroid Build Coastguard Worker            y = torch.block_diag(x, x)
614*da0073e9SAndroid Build Coastguard Worker            return y
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
617*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
618*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
619*da0073e9SAndroid Build Coastguard Worker            logs,
620*da0073e9SAndroid Build Coastguard Worker            """\
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
625*da0073e9SAndroid Build Coastguard Worker    block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]);  arg0_1 = None
626*da0073e9SAndroid Build Coastguard Worker    return block_diag
627*da0073e9SAndroid Build Coastguard Worker    """,
628*da0073e9SAndroid Build Coastguard Worker        )
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
631*da0073e9SAndroid Build Coastguard Worker        def f(x):
632*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0)
633*da0073e9SAndroid Build Coastguard Worker            torch.cat((x,), out=out)
634*da0073e9SAndroid Build Coastguard Worker            return out
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
637*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
638*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
639*da0073e9SAndroid Build Coastguard Worker            logs,
640*da0073e9SAndroid Build Coastguard Worker            """\
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
645*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False);  empty = None
646*da0073e9SAndroid Build Coastguard Worker    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
647*da0073e9SAndroid Build Coastguard Worker    return cat
648*da0073e9SAndroid Build Coastguard Worker    """,
649*da0073e9SAndroid Build Coastguard Worker        )
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
652*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
653*da0073e9SAndroid Build Coastguard Worker        )
654*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
655*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
656*da0073e9SAndroid Build Coastguard Worker            """\
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
661*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False);  empty = None
662*da0073e9SAndroid Build Coastguard Worker    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
663*da0073e9SAndroid Build Coastguard Worker    return cat
664*da0073e9SAndroid Build Coastguard Worker    """,
665*da0073e9SAndroid Build Coastguard Worker        )
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    def test_diagonal(self):
668*da0073e9SAndroid Build Coastguard Worker        def f(x):
669*da0073e9SAndroid Build Coastguard Worker            # test: view ops that take a subset of the original tensor (select/diagonal)
670*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2)
671*da0073e9SAndroid Build Coastguard Worker            y = x.clone().diagonal()
672*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
673*da0073e9SAndroid Build Coastguard Worker            z = x * x
674*da0073e9SAndroid Build Coastguard Worker            return z
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
677*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
678*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
679*da0073e9SAndroid Build Coastguard Worker            logs,
680*da0073e9SAndroid Build Coastguard Worker            """\
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
685*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
686*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(arg0_1)
687*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
688*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
689*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add);  clone = add = None
690*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_scatter = diagonal_copy_1 = None
691*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
692*da0073e9SAndroid Build Coastguard Worker    return mul
693*da0073e9SAndroid Build Coastguard Worker    """,
694*da0073e9SAndroid Build Coastguard Worker        )
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
697*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
698*da0073e9SAndroid Build Coastguard Worker        )
699*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
700*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
701*da0073e9SAndroid Build Coastguard Worker            """\
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
706*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
707*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(arg0_1)
708*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(clone)
709*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(diagonal, ones);  diagonal = ones = add = None
710*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(clone);  clone = diagonal_1 = None
711*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
712*da0073e9SAndroid Build Coastguard Worker    return mul
713*da0073e9SAndroid Build Coastguard Worker    """,
714*da0073e9SAndroid Build Coastguard Worker        )
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    def test_diagonal_mutated_input(self):
717*da0073e9SAndroid Build Coastguard Worker        def f(x):
718*da0073e9SAndroid Build Coastguard Worker            # simple test: there are pending updates afterwards, which the test syncs manually
719*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2)
720*da0073e9SAndroid Build Coastguard Worker            y = x.diagonal()
721*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
722*da0073e9SAndroid Build Coastguard Worker            return x
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
725*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, x)
726*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
727*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
728*da0073e9SAndroid Build Coastguard Worker            logs,
729*da0073e9SAndroid Build Coastguard Worker            """\
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
734*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
735*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
736*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
737*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
738*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_copy_1 = None
739*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = copy_ = None
740*da0073e9SAndroid Build Coastguard Worker    return diagonal_scatter
741*da0073e9SAndroid Build Coastguard Worker    """,
742*da0073e9SAndroid Build Coastguard Worker        )
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
745*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
746*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=False
747*da0073e9SAndroid Build Coastguard Worker        )
748*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
749*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
750*da0073e9SAndroid Build Coastguard Worker            """\
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
755*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
756*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(arg0_1)
757*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
758*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
759*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter);  diagonal_1 = None
760*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = copy_ = None
761*da0073e9SAndroid Build Coastguard Worker    return diagonal_scatter
762*da0073e9SAndroid Build Coastguard Worker    """,
763*da0073e9SAndroid Build Coastguard Worker        )
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker    def test_channels_last_contiguous(self):
766*da0073e9SAndroid Build Coastguard Worker        def f(x):
767*da0073e9SAndroid Build Coastguard Worker            return x.contiguous(memory_format=torch.channels_last)
768*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2)
769*da0073e9SAndroid Build Coastguard Worker            y = x.diagonal()
770*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
771*da0073e9SAndroid Build Coastguard Worker            return x
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
774*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, x)
775*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, x).strip()
776*da0073e9SAndroid Build Coastguard Worker        # There should be no clone in the graph
777*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
778*da0073e9SAndroid Build Coastguard Worker            logs,
779*da0073e9SAndroid Build Coastguard Worker            """\
780*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
781*da0073e9SAndroid Build Coastguard Worker    return arg0_1""",
782*da0073e9SAndroid Build Coastguard Worker        )
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker    def test_split(self):
785*da0073e9SAndroid Build Coastguard Worker        def f(x):
786*da0073e9SAndroid Build Coastguard Worker            # test: view ops that return multiple tensors (split)
787*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2)
788*da0073e9SAndroid Build Coastguard Worker            y1, y2 = x.split(2)
789*da0073e9SAndroid Build Coastguard Worker            y3 = y2.diagonal()
790*da0073e9SAndroid Build Coastguard Worker            y3.add_(tmp)
791*da0073e9SAndroid Build Coastguard Worker            z = x * x
792*da0073e9SAndroid Build Coastguard Worker            return y3
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
795*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
796*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
797*da0073e9SAndroid Build Coastguard Worker            logs,
798*da0073e9SAndroid Build Coastguard Worker            """\
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
803*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
804*da0073e9SAndroid Build Coastguard Worker    split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
805*da0073e9SAndroid Build Coastguard Worker    getitem = split_copy[0];  getitem = None
806*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split_copy[1];  split_copy = None
807*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1);  getitem_1 = None
808*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
809*da0073e9SAndroid Build Coastguard Worker    split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
810*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_copy_1[0];  getitem_2 = None
811*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_copy_1[1];  split_copy_1 = None
812*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
813*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
814*da0073e9SAndroid Build Coastguard Worker    split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
815*da0073e9SAndroid Build Coastguard Worker    getitem_4 = split_copy_2[0];  getitem_4 = None
816*da0073e9SAndroid Build Coastguard Worker    getitem_5 = split_copy_2[1];  split_copy_2 = None
817*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5);  getitem_5 = None
818*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
819*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
820*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_1
821*da0073e9SAndroid Build Coastguard Worker    """,
822*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
825*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
826*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
827*da0073e9SAndroid Build Coastguard Worker        )
828*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
829*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
830*da0073e9SAndroid Build Coastguard Worker            """\
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
835*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
836*da0073e9SAndroid Build Coastguard Worker    split = torch.ops.aten.split.Tensor(arg0_1, 2)
837*da0073e9SAndroid Build Coastguard Worker    getitem = split[0];  getitem = None
838*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split[1];  split = None
839*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(getitem_1);  getitem_1 = None
840*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
841*da0073e9SAndroid Build Coastguard Worker    split_1 = torch.ops.aten.split.Tensor(arg0_1, 2)
842*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_1[0];  getitem_2 = None
843*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_1[1];  split_1 = None
844*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
845*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
846*da0073e9SAndroid Build Coastguard Worker    split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2)
847*da0073e9SAndroid Build Coastguard Worker    getitem_4 = split_2[0];  getitem_4 = None
848*da0073e9SAndroid Build Coastguard Worker    getitem_5 = split_2[1];  split_2 = None
849*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(getitem_5);  getitem_5 = None
850*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
851*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
852*da0073e9SAndroid Build Coastguard Worker    return diagonal_1
853*da0073e9SAndroid Build Coastguard Worker    """,
854*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def test_split_with_sizes(self):
857*da0073e9SAndroid Build Coastguard Worker        def f(x):
858*da0073e9SAndroid Build Coastguard Worker            # test: view ops that return multiple tensors (split_with_sizes)
859*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2)
860*da0073e9SAndroid Build Coastguard Worker            y1, y2 = x.split_with_sizes([2, 2])
861*da0073e9SAndroid Build Coastguard Worker            y3 = y1.diagonal()
862*da0073e9SAndroid Build Coastguard Worker            y3.add_(tmp)
863*da0073e9SAndroid Build Coastguard Worker            z = x * x
864*da0073e9SAndroid Build Coastguard Worker            return y3
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
867*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
868*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
869*da0073e9SAndroid Build Coastguard Worker            logs,
870*da0073e9SAndroid Build Coastguard Worker            """\
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
875*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
876*da0073e9SAndroid Build Coastguard Worker    split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
877*da0073e9SAndroid Build Coastguard Worker    getitem = split_with_sizes_copy[0]
878*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split_with_sizes_copy[1];  split_with_sizes_copy = getitem_1 = None
879*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem);  getitem = None
880*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
881*da0073e9SAndroid Build Coastguard Worker    split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
882*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_with_sizes_copy_1[0]
883*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_with_sizes_copy_1[1];  split_with_sizes_copy_1 = getitem_3 = None
884*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
885*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
886*da0073e9SAndroid Build Coastguard Worker    split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2])
887*da0073e9SAndroid Build Coastguard Worker    getitem_4 = split_with_sizes_copy_2[0]
888*da0073e9SAndroid Build Coastguard Worker    getitem_5 = split_with_sizes_copy_2[1];  split_with_sizes_copy_2 = getitem_5 = None
889*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4);  getitem_4 = None
890*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
891*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
892*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_1
893*da0073e9SAndroid Build Coastguard Worker    """,
894*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
897*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
898*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
899*da0073e9SAndroid Build Coastguard Worker        )
900*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
901*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
902*da0073e9SAndroid Build Coastguard Worker            """\
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
907*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
908*da0073e9SAndroid Build Coastguard Worker    split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
909*da0073e9SAndroid Build Coastguard Worker    getitem = split_with_sizes[0]
910*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split_with_sizes[1];  split_with_sizes = getitem_1 = None
911*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(getitem);  getitem = None
912*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
913*da0073e9SAndroid Build Coastguard Worker    split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
914*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_with_sizes_1[0]
915*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_with_sizes_1[1];  split_with_sizes_1 = getitem_3 = None
916*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
917*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
918*da0073e9SAndroid Build Coastguard Worker    split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2])
919*da0073e9SAndroid Build Coastguard Worker    getitem_4 = split_with_sizes_2[0]
920*da0073e9SAndroid Build Coastguard Worker    getitem_5 = split_with_sizes_2[1];  split_with_sizes_2 = getitem_5 = None
921*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(getitem_4);  getitem_4 = None
922*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter);  mul = None
923*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = copy_ = None
924*da0073e9SAndroid Build Coastguard Worker    return diagonal_1
925*da0073e9SAndroid Build Coastguard Worker    """,
926*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
929*da0073e9SAndroid Build Coastguard Worker        def f(x):
930*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4)
931*da0073e9SAndroid Build Coastguard Worker            x.transpose_(1, 0)
932*da0073e9SAndroid Build Coastguard Worker            y = x[0:2]
933*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
934*da0073e9SAndroid Build Coastguard Worker            return x
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
937*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
938*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
939*da0073e9SAndroid Build Coastguard Worker            logs,
940*da0073e9SAndroid Build Coastguard Worker            """\
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
945*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
946*da0073e9SAndroid Build Coastguard Worker    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
947*da0073e9SAndroid Build Coastguard Worker    slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2);  transpose_copy = None
948*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(slice_copy, ones);  slice_copy = ones = None
949*da0073e9SAndroid Build Coastguard Worker    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
950*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2);  transpose_copy_1 = add = None
951*da0073e9SAndroid Build Coastguard Worker    transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0);  slice_scatter = None
952*da0073e9SAndroid Build Coastguard Worker    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
953*da0073e9SAndroid Build Coastguard Worker    slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2);  transpose_copy_3 = slice_copy_1 = None
954*da0073e9SAndroid Build Coastguard Worker    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
955*da0073e9SAndroid Build Coastguard Worker    return transpose_copy_4
956*da0073e9SAndroid Build Coastguard Worker    """,
957*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
960*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
961*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
962*da0073e9SAndroid Build Coastguard Worker        )
963*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
964*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
965*da0073e9SAndroid Build Coastguard Worker            """\
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
970*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
971*da0073e9SAndroid Build Coastguard Worker    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
972*da0073e9SAndroid Build Coastguard Worker    slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2);  transpose = None
973*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(slice_1, ones);  slice_1 = ones = None
974*da0073e9SAndroid Build Coastguard Worker    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
975*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2);  transpose_1 = add = None
976*da0073e9SAndroid Build Coastguard Worker    transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0);  slice_scatter = None
977*da0073e9SAndroid Build Coastguard Worker    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
978*da0073e9SAndroid Build Coastguard Worker    slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2);  transpose_3 = slice_2 = None
979*da0073e9SAndroid Build Coastguard Worker    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
980*da0073e9SAndroid Build Coastguard Worker    return transpose_4
981*da0073e9SAndroid Build Coastguard Worker    """,
982*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker    def test_view_inplace(self):
985*da0073e9SAndroid Build Coastguard Worker        def f(x):
986*da0073e9SAndroid Build Coastguard Worker            # test: view + inplace op (transpose_)
987*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4)
988*da0073e9SAndroid Build Coastguard Worker            x.transpose_(1, 0)
989*da0073e9SAndroid Build Coastguard Worker            y = x[0]
990*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
991*da0073e9SAndroid Build Coastguard Worker            return x
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
994*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
995*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
996*da0073e9SAndroid Build Coastguard Worker            logs,
997*da0073e9SAndroid Build Coastguard Worker            """\
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1002*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1003*da0073e9SAndroid Build Coastguard Worker    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
1004*da0073e9SAndroid Build Coastguard Worker    select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0);  transpose_copy = None
1005*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(select_copy, ones);  select_copy = ones = None
1006*da0073e9SAndroid Build Coastguard Worker    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
1007*da0073e9SAndroid Build Coastguard Worker    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
1008*da0073e9SAndroid Build Coastguard Worker    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
1009*da0073e9SAndroid Build Coastguard Worker    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
1010*da0073e9SAndroid Build Coastguard Worker    select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0);  transpose_copy_3 = select_copy_1 = None
1011*da0073e9SAndroid Build Coastguard Worker    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
1012*da0073e9SAndroid Build Coastguard Worker    return transpose_copy_4
1013*da0073e9SAndroid Build Coastguard Worker    """,
1014*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
1017*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1018*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
1019*da0073e9SAndroid Build Coastguard Worker        )
1020*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1021*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1022*da0073e9SAndroid Build Coastguard Worker            """\
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1027*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1028*da0073e9SAndroid Build Coastguard Worker    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
1029*da0073e9SAndroid Build Coastguard Worker    select = torch.ops.aten.select.int(transpose, 0, 0);  transpose = None
1030*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(select, ones);  select = ones = None
1031*da0073e9SAndroid Build Coastguard Worker    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
1032*da0073e9SAndroid Build Coastguard Worker    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
1033*da0073e9SAndroid Build Coastguard Worker    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
1034*da0073e9SAndroid Build Coastguard Worker    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
1035*da0073e9SAndroid Build Coastguard Worker    select_1 = torch.ops.aten.select.int(transpose_3, 0, 0);  transpose_3 = select_1 = None
1036*da0073e9SAndroid Build Coastguard Worker    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
1037*da0073e9SAndroid Build Coastguard Worker    return transpose_4
1038*da0073e9SAndroid Build Coastguard Worker    """,
1039*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
1042*da0073e9SAndroid Build Coastguard Worker        def f(x):
1043*da0073e9SAndroid Build Coastguard Worker            # test: view + inplace op (transpose_)
1044*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4)
1045*da0073e9SAndroid Build Coastguard Worker            x.transpose_(1, 0)
1046*da0073e9SAndroid Build Coastguard Worker            y, _ = x.unbind(0)
1047*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
1048*da0073e9SAndroid Build Coastguard Worker            return x
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
1051*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1052*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1053*da0073e9SAndroid Build Coastguard Worker            logs,
1054*da0073e9SAndroid Build Coastguard Worker            """\
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1059*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1060*da0073e9SAndroid Build Coastguard Worker    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
1061*da0073e9SAndroid Build Coastguard Worker    unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy);  transpose_copy = None
1062*da0073e9SAndroid Build Coastguard Worker    getitem = unbind_copy[0]
1063*da0073e9SAndroid Build Coastguard Worker    getitem_1 = unbind_copy[1];  unbind_copy = getitem_1 = None
1064*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1065*da0073e9SAndroid Build Coastguard Worker    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
1066*da0073e9SAndroid Build Coastguard Worker    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
1067*da0073e9SAndroid Build Coastguard Worker    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
1068*da0073e9SAndroid Build Coastguard Worker    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
1069*da0073e9SAndroid Build Coastguard Worker    unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3);  transpose_copy_3 = None
1070*da0073e9SAndroid Build Coastguard Worker    getitem_2 = unbind_copy_1[0];  getitem_2 = None
1071*da0073e9SAndroid Build Coastguard Worker    getitem_3 = unbind_copy_1[1];  unbind_copy_1 = getitem_3 = None
1072*da0073e9SAndroid Build Coastguard Worker    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
1073*da0073e9SAndroid Build Coastguard Worker    return transpose_copy_4
1074*da0073e9SAndroid Build Coastguard Worker    """,
1075*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1076*da0073e9SAndroid Build Coastguard Worker
1077*da0073e9SAndroid Build Coastguard Worker        # NB: even with reapply_views=True, we expect to see scatter op
1078*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1079*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=False
1080*da0073e9SAndroid Build Coastguard Worker        )
1081*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1082*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1083*da0073e9SAndroid Build Coastguard Worker            """\
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1088*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
1089*da0073e9SAndroid Build Coastguard Worker    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
1090*da0073e9SAndroid Build Coastguard Worker    unbind = torch.ops.aten.unbind.int(transpose);  transpose = None
1091*da0073e9SAndroid Build Coastguard Worker    getitem = unbind[0]
1092*da0073e9SAndroid Build Coastguard Worker    getitem_1 = unbind[1];  unbind = getitem_1 = None
1093*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1094*da0073e9SAndroid Build Coastguard Worker    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
1095*da0073e9SAndroid Build Coastguard Worker    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
1096*da0073e9SAndroid Build Coastguard Worker    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
1097*da0073e9SAndroid Build Coastguard Worker    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
1098*da0073e9SAndroid Build Coastguard Worker    unbind_1 = torch.ops.aten.unbind.int(transpose_3);  transpose_3 = None
1099*da0073e9SAndroid Build Coastguard Worker    getitem_2 = unbind_1[0];  getitem_2 = None
1100*da0073e9SAndroid Build Coastguard Worker    getitem_3 = unbind_1[1];  unbind_1 = getitem_3 = None
1101*da0073e9SAndroid Build Coastguard Worker    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
1102*da0073e9SAndroid Build Coastguard Worker    return transpose_4
1103*da0073e9SAndroid Build Coastguard Worker    """,
1104*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker    def test_optional_tensor_list(self):
1107*da0073e9SAndroid Build Coastguard Worker        def f(x):
1108*da0073e9SAndroid Build Coastguard Worker            # test: an operator that takes in a List[Optional[Tensor]] argument
1109*da0073e9SAndroid Build Coastguard Worker            # (index_put)
1110*da0073e9SAndroid Build Coastguard Worker            y = x.view(8)
1111*da0073e9SAndroid Build Coastguard Worker            indices = torch.arange(4)
1112*da0073e9SAndroid Build Coastguard Worker            values = torch.arange(4, dtype=y.dtype)
1113*da0073e9SAndroid Build Coastguard Worker            y.index_put_((indices,), values, accumulate=False)
1114*da0073e9SAndroid Build Coastguard Worker            return y
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
1117*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1118*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1119*da0073e9SAndroid Build Coastguard Worker            logs,
1120*da0073e9SAndroid Build Coastguard Worker            """\
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1125*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [8])
1126*da0073e9SAndroid Build Coastguard Worker    arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
1127*da0073e9SAndroid Build Coastguard Worker    arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
1128*da0073e9SAndroid Build Coastguard Worker    index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1);  view_copy = arange = arange_1 = None
1129*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]);  index_put = None
1130*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
1131*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
1132*da0073e9SAndroid Build Coastguard Worker    return view_copy_2
1133*da0073e9SAndroid Build Coastguard Worker    """,
1134*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    def test_scalars(self):
1137*da0073e9SAndroid Build Coastguard Worker        def f(x):
1138*da0073e9SAndroid Build Coastguard Worker            # test: the pass can handle scalar inputs properly
1139*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
1140*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
1141*da0073e9SAndroid Build Coastguard Worker            y.add_(1)
1142*da0073e9SAndroid Build Coastguard Worker            z = 2 * y
1143*da0073e9SAndroid Build Coastguard Worker            z.div_(1)
1144*da0073e9SAndroid Build Coastguard Worker            return z
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
1147*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1148*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1149*da0073e9SAndroid Build Coastguard Worker            logs,
1150*da0073e9SAndroid Build Coastguard Worker            """\
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1155*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False);  ones = None
1156*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
1157*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
1158*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
1159*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
1160*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(view_copy_2, 2);  view_copy_2 = None
1161*da0073e9SAndroid Build Coastguard Worker    div = torch.ops.aten.div.Tensor(mul, 1);  mul = None
1162*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = copy_ = None
1163*da0073e9SAndroid Build Coastguard Worker    return div
1164*da0073e9SAndroid Build Coastguard Worker    """,
1165*da0073e9SAndroid Build Coastguard Worker        )
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1168*da0073e9SAndroid Build Coastguard Worker    def test_metadata_change(self):
1169*da0073e9SAndroid Build Coastguard Worker        def f(x):
1170*da0073e9SAndroid Build Coastguard Worker            # ops like ge_() are allowed to change the dtype of the input.
1171*da0073e9SAndroid Build Coastguard Worker            # functionalization should pick up on that.
1172*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
1173*da0073e9SAndroid Build Coastguard Worker            out = y.ge_(0)
1174*da0073e9SAndroid Build Coastguard Worker            return out
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
1177*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1178*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1179*da0073e9SAndroid Build Coastguard Worker            logs,
1180*da0073e9SAndroid Build Coastguard Worker            """\
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1185*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1186*da0073e9SAndroid Build Coastguard Worker    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
1187*da0073e9SAndroid Build Coastguard Worker    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
1188*da0073e9SAndroid Build Coastguard Worker    return _to_copy
1189*da0073e9SAndroid Build Coastguard Worker    """,
1190*da0073e9SAndroid Build Coastguard Worker        )
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1193*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
1194*da0073e9SAndroid Build Coastguard Worker        )
1195*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1196*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1197*da0073e9SAndroid Build Coastguard Worker            """\
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1202*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1203*da0073e9SAndroid Build Coastguard Worker    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
1204*da0073e9SAndroid Build Coastguard Worker    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
1205*da0073e9SAndroid Build Coastguard Worker    return _to_copy
1206*da0073e9SAndroid Build Coastguard Worker    """,
1207*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1210*da0073e9SAndroid Build Coastguard Worker    def test_metadata_change_out_op(self):
1211*da0073e9SAndroid Build Coastguard Worker        def f(t, y):
1212*da0073e9SAndroid Build Coastguard Worker            out_1 = torch.ones(1)
1213*da0073e9SAndroid Build Coastguard Worker            return torch.add(t, y, out=out_1)
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
1216*da0073e9SAndroid Build Coastguard Worker        inpt1_func, inpt2_func = (
1217*da0073e9SAndroid Build Coastguard Worker            torch._to_functional_tensor(inpt1),
1218*da0073e9SAndroid Build Coastguard Worker            torch._to_functional_tensor(inpt2),
1219*da0073e9SAndroid Build Coastguard Worker        )
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        out_ref = f(inpt1, inpt2)
1222*da0073e9SAndroid Build Coastguard Worker        torch._enable_functionalization(reapply_views=True)
1223*da0073e9SAndroid Build Coastguard Worker        try:
1224*da0073e9SAndroid Build Coastguard Worker            out_functional = f(inpt1_func, inpt2_func)
1225*da0073e9SAndroid Build Coastguard Worker        finally:
1226*da0073e9SAndroid Build Coastguard Worker            torch._disable_functionalization()
1227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker    def test_only_one_view(self):
1230*da0073e9SAndroid Build Coastguard Worker        def f(x):
1231*da0073e9SAndroid Build Coastguard Worker            # This tests that we don't have any unnecessary views in the trace.
1232*da0073e9SAndroid Build Coastguard Worker            # If the input wasn't mutated, we don't need to regenerate it,
1233*da0073e9SAndroid Build Coastguard Worker            # so there should be a total of 1 op in the output trace.
1234*da0073e9SAndroid Build Coastguard Worker            return x.view(4, 2)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1237*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1238*da0073e9SAndroid Build Coastguard Worker            logs,
1239*da0073e9SAndroid Build Coastguard Worker            """\
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1244*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
1245*da0073e9SAndroid Build Coastguard Worker    return view_copy
1246*da0073e9SAndroid Build Coastguard Worker    """,
1247*da0073e9SAndroid Build Coastguard Worker        )
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker    def test_everything(self):
1250*da0073e9SAndroid Build Coastguard Worker        def f(x):
1251*da0073e9SAndroid Build Coastguard Worker            # test: everything
1252*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(2, 2)
1253*da0073e9SAndroid Build Coastguard Worker            x2 = x + x
1254*da0073e9SAndroid Build Coastguard Worker            y = x2.view(8)
1255*da0073e9SAndroid Build Coastguard Worker            z0 = y.reshape(2, 4)
1256*da0073e9SAndroid Build Coastguard Worker            z1 = z0.transpose(1, 0)
1257*da0073e9SAndroid Build Coastguard Worker            z1.unsqueeze_(0)
1258*da0073e9SAndroid Build Coastguard Worker            z1.squeeze_()
1259*da0073e9SAndroid Build Coastguard Worker            z2, z3 = z1.split(2)
1260*da0073e9SAndroid Build Coastguard Worker            z2.add_(tmp)
1261*da0073e9SAndroid Build Coastguard Worker            z4 = z0[0] + z2.reshape(4)
1262*da0073e9SAndroid Build Coastguard Worker            return z2
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2))
1265*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2))
1266*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1267*da0073e9SAndroid Build Coastguard Worker            logs,
1268*da0073e9SAndroid Build Coastguard Worker            """\
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1273*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1274*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1275*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(add, [8])
1276*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]);  view_copy = None
1277*da0073e9SAndroid Build Coastguard Worker    transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
1278*da0073e9SAndroid Build Coastguard Worker    unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0);  transpose_copy = None
1279*da0073e9SAndroid Build Coastguard Worker    squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy);  unsqueeze_copy = None
1280*da0073e9SAndroid Build Coastguard Worker    split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2);  squeeze_copy = None
1281*da0073e9SAndroid Build Coastguard Worker    getitem = split_copy[0]
1282*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split_copy[1];  split_copy = getitem_1 = None
1283*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1284*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(add, [8]);  add = None
1285*da0073e9SAndroid Build Coastguard Worker    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]);  view_copy_2 = None
1286*da0073e9SAndroid Build Coastguard Worker    transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0);  view_copy_3 = None
1287*da0073e9SAndroid Build Coastguard Worker    unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0);  transpose_copy_1 = None
1288*da0073e9SAndroid Build Coastguard Worker    squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1);  unsqueeze_copy_1 = None
1289*da0073e9SAndroid Build Coastguard Worker    slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2);  squeeze_copy_1 = add_1 = None
1290*da0073e9SAndroid Build Coastguard Worker    unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0);  slice_scatter = None
1291*da0073e9SAndroid Build Coastguard Worker    squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0);  unsqueeze_copy_2 = None
1292*da0073e9SAndroid Build Coastguard Worker    transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0);  squeeze_copy_2 = None
1293*da0073e9SAndroid Build Coastguard Worker    view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]);  transpose_copy_2 = None
1294*da0073e9SAndroid Build Coastguard Worker    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]);  view_copy_4 = None
1295*da0073e9SAndroid Build Coastguard Worker    view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1296*da0073e9SAndroid Build Coastguard Worker    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]);  view_copy_6 = None
1297*da0073e9SAndroid Build Coastguard Worker    transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0);  view_copy_7 = None
1298*da0073e9SAndroid Build Coastguard Worker    unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0);  transpose_copy_3 = None
1299*da0073e9SAndroid Build Coastguard Worker    squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3);  unsqueeze_copy_3 = None
1300*da0073e9SAndroid Build Coastguard Worker    split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2);  squeeze_copy_3 = None
1301*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_copy_1[0]
1302*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_copy_1[1];  split_copy_1 = getitem_3 = None
1303*da0073e9SAndroid Build Coastguard Worker    select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0);  view_copy_1 = select_copy = None
1304*da0073e9SAndroid Build Coastguard Worker    view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]);  view_copy_8 = None
1305*da0073e9SAndroid Build Coastguard Worker    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1306*da0073e9SAndroid Build Coastguard Worker    view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]);  view_copy_9 = None
1307*da0073e9SAndroid Build Coastguard Worker    select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0);  view_copy_10 = None
1308*da0073e9SAndroid Build Coastguard Worker    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]);  view_copy_5 = None
1309*da0073e9SAndroid Build Coastguard Worker    view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]);  view_copy_11 = None
1310*da0073e9SAndroid Build Coastguard Worker    transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0);  view_copy_12 = None
1311*da0073e9SAndroid Build Coastguard Worker    unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0);  transpose_copy_4 = None
1312*da0073e9SAndroid Build Coastguard Worker    squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4);  unsqueeze_copy_4 = None
1313*da0073e9SAndroid Build Coastguard Worker    split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2);  squeeze_copy_4 = None
1314*da0073e9SAndroid Build Coastguard Worker    getitem_4 = split_copy_2[0]
1315*da0073e9SAndroid Build Coastguard Worker    getitem_5 = split_copy_2[1];  split_copy_2 = getitem_5 = None
1316*da0073e9SAndroid Build Coastguard Worker    view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]);  getitem_4 = None
1317*da0073e9SAndroid Build Coastguard Worker    add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13);  select_copy_1 = view_copy_13 = add_2 = None
1318*da0073e9SAndroid Build Coastguard Worker    return getitem_2
1319*da0073e9SAndroid Build Coastguard Worker    """,
1320*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1323*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(4, 2), reapply_views=True, run_reinplace=True
1324*da0073e9SAndroid Build Coastguard Worker        )
1325*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1326*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1327*da0073e9SAndroid Build Coastguard Worker            """\
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1332*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1333*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1334*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(add, [8])
1335*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(view, [2, 4]);  view = None
1336*da0073e9SAndroid Build Coastguard Worker    transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
1337*da0073e9SAndroid Build Coastguard Worker    unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0);  transpose = None
1338*da0073e9SAndroid Build Coastguard Worker    squeeze = torch.ops.aten.squeeze.default(unsqueeze);  unsqueeze = None
1339*da0073e9SAndroid Build Coastguard Worker    split = torch.ops.aten.split.Tensor(squeeze, 2);  squeeze = None
1340*da0073e9SAndroid Build Coastguard Worker    getitem = split[0]
1341*da0073e9SAndroid Build Coastguard Worker    getitem_1 = split[1];  split = getitem_1 = None
1342*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add_.Tensor(getitem, ones);  getitem = ones = add_1 = None
1343*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(add, [8]);  add = None
1344*da0073e9SAndroid Build Coastguard Worker    view_3 = torch.ops.aten.view.default(view_2, [2, 4]);  view_2 = None
1345*da0073e9SAndroid Build Coastguard Worker    transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0);  view_3 = None
1346*da0073e9SAndroid Build Coastguard Worker    unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0);  transpose_1 = None
1347*da0073e9SAndroid Build Coastguard Worker    squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1);  unsqueeze_1 = None
1348*da0073e9SAndroid Build Coastguard Worker    unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0);  squeeze_1 = None
1349*da0073e9SAndroid Build Coastguard Worker    squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0);  unsqueeze_2 = None
1350*da0073e9SAndroid Build Coastguard Worker    transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0);  squeeze_2 = None
1351*da0073e9SAndroid Build Coastguard Worker    view_4 = torch.ops.aten.view.default(transpose_2, [8]);  transpose_2 = None
1352*da0073e9SAndroid Build Coastguard Worker    view_5 = torch.ops.aten.view.default(view_4, [4, 2]);  view_4 = None
1353*da0073e9SAndroid Build Coastguard Worker    view_6 = torch.ops.aten.view.default(view_5, [8])
1354*da0073e9SAndroid Build Coastguard Worker    view_7 = torch.ops.aten.view.default(view_6, [2, 4]);  view_6 = None
1355*da0073e9SAndroid Build Coastguard Worker    transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0);  view_7 = None
1356*da0073e9SAndroid Build Coastguard Worker    unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0);  transpose_3 = None
1357*da0073e9SAndroid Build Coastguard Worker    squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3);  unsqueeze_3 = None
1358*da0073e9SAndroid Build Coastguard Worker    split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2);  squeeze_3 = None
1359*da0073e9SAndroid Build Coastguard Worker    getitem_2 = split_1[0]
1360*da0073e9SAndroid Build Coastguard Worker    getitem_3 = split_1[1];  split_1 = getitem_3 = None
1361*da0073e9SAndroid Build Coastguard Worker    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = select = None
1362*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
1363*da0073e9SAndroid Build Coastguard Worker    _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]);  clone = None
1364*da0073e9SAndroid Build Coastguard Worker    view_8 = torch.ops.aten.view.default(view_5, [8]);  view_5 = None
1365*da0073e9SAndroid Build Coastguard Worker    view_9 = torch.ops.aten.view.default(view_8, [2, 4]);  view_8 = None
1366*da0073e9SAndroid Build Coastguard Worker    select_1 = torch.ops.aten.select.int(view_9, 0, 0);  view_9 = None
1367*da0073e9SAndroid Build Coastguard Worker    add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view);  select_1 = _unsafe_view = add_2 = None
1368*da0073e9SAndroid Build Coastguard Worker    return getitem_2
1369*da0073e9SAndroid Build Coastguard Worker    """,
1370*da0073e9SAndroid Build Coastguard Worker        )
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker    def test_reapply_views_simple(self):
1373*da0073e9SAndroid Build Coastguard Worker        def f(x):
1374*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
1375*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
1376*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
1377*da0073e9SAndroid Build Coastguard Worker            z = x * x
1378*da0073e9SAndroid Build Coastguard Worker            return y
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
1381*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
1382*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1383*da0073e9SAndroid Build Coastguard Worker            logs,
1384*da0073e9SAndroid Build Coastguard Worker            """\
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1389*da0073e9SAndroid Build Coastguard Worker    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
1390*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(arg0_1, [4, 2])
1391*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
1392*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
1393*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
1394*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(view_1, view_1);  mul = None
1395*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = copy_ = None
1396*da0073e9SAndroid Build Coastguard Worker    return view_2
1397*da0073e9SAndroid Build Coastguard Worker    """,
1398*da0073e9SAndroid Build Coastguard Worker        )
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker    def test_aliases_maintained_after_pass_when_reapplying_views(self):
1401*da0073e9SAndroid Build Coastguard Worker        def f(x):
1402*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(4, 2)
1403*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 2)
1404*da0073e9SAndroid Build Coastguard Worker            z = x.view(4, 2)
1405*da0073e9SAndroid Build Coastguard Worker            y.add_(tmp)
1406*da0073e9SAndroid Build Coastguard Worker            return y, z
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
1409*da0073e9SAndroid Build Coastguard Worker        torch._enable_functionalization(reapply_views=True)
1410*da0073e9SAndroid Build Coastguard Worker        try:
1411*da0073e9SAndroid Build Coastguard Worker            y, z = f(input_functional)
1412*da0073e9SAndroid Build Coastguard Worker            torch._sync(y)
1413*da0073e9SAndroid Build Coastguard Worker            torch._sync(z)
1414*da0073e9SAndroid Build Coastguard Worker        finally:
1415*da0073e9SAndroid Build Coastguard Worker            torch._disable_functionalization()
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
1418*da0073e9SAndroid Build Coastguard Worker        _y = torch._from_functional_tensor(y)
1419*da0073e9SAndroid Build Coastguard Worker        _z = torch._from_functional_tensor(z)
1420*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(are_aliased(_y, _z))
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker    # copy_() gets its own test, because it used to be special cased in functionalization.
1423*da0073e9SAndroid Build Coastguard Worker    # However, now it works pretty similar to other functional ops
1424*da0073e9SAndroid Build Coastguard Worker    def test_copy_(self):
1425*da0073e9SAndroid Build Coastguard Worker        def f(x):
1426*da0073e9SAndroid Build Coastguard Worker            tmp = torch.zeros(2, 2)
1427*da0073e9SAndroid Build Coastguard Worker            tmp_slice = tmp.diagonal()
1428*da0073e9SAndroid Build Coastguard Worker            y = tmp_slice.copy_(x)
1429*da0073e9SAndroid Build Coastguard Worker            z = y.add_(x)
1430*da0073e9SAndroid Build Coastguard Worker            return z
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        # Test 1: copy_() with same dtype and shape
1433*da0073e9SAndroid Build Coastguard Worker        # to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
1434*da0073e9SAndroid Build Coastguard Worker        # self.assert_functionalization(f, torch.ones(2))
1435*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2))
1436*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1437*da0073e9SAndroid Build Coastguard Worker            logs,
1438*da0073e9SAndroid Build Coastguard Worker            """\
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1443*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1444*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1445*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1446*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1447*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1448*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1449*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1450*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1451*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_2
1452*da0073e9SAndroid Build Coastguard Worker    """,
1453*da0073e9SAndroid Build Coastguard Worker        )
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1456*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2), reapply_views=True, run_reinplace=True
1457*da0073e9SAndroid Build Coastguard Worker        )
1458*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1459*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1460*da0073e9SAndroid Build Coastguard Worker            """\
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1465*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1466*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(zeros)
1467*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1468*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1469*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1470*da0073e9SAndroid Build Coastguard Worker    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1471*da0073e9SAndroid Build Coastguard Worker    return diagonal_2
1472*da0073e9SAndroid Build Coastguard Worker    """,
1473*da0073e9SAndroid Build Coastguard Worker        )
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        # Test 2: copy_() with same dtype, different shape
1476*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(1))
1477*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(1))
1478*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1479*da0073e9SAndroid Build Coastguard Worker            logs,
1480*da0073e9SAndroid Build Coastguard Worker            """\
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1485*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1486*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1487*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1488*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1489*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1490*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1491*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1492*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1493*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_2
1494*da0073e9SAndroid Build Coastguard Worker    """,
1495*da0073e9SAndroid Build Coastguard Worker        )
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1498*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(1), reapply_views=True, run_reinplace=True
1499*da0073e9SAndroid Build Coastguard Worker        )
1500*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1501*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1502*da0073e9SAndroid Build Coastguard Worker            """\
1503*da0073e9SAndroid Build Coastguard Worker
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1507*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1508*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(zeros)
1509*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1510*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1511*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1512*da0073e9SAndroid Build Coastguard Worker    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1513*da0073e9SAndroid Build Coastguard Worker    return diagonal_2
1514*da0073e9SAndroid Build Coastguard Worker    """,
1515*da0073e9SAndroid Build Coastguard Worker        )
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        # Test 3: copy_() with different dtype, same shape
1518*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
1519*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
1520*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1521*da0073e9SAndroid Build Coastguard Worker            logs,
1522*da0073e9SAndroid Build Coastguard Worker            """\
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker
1526*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1527*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1528*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1529*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1530*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1531*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1532*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1533*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1534*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1535*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_2
1536*da0073e9SAndroid Build Coastguard Worker    """,
1537*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1540*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True
1541*da0073e9SAndroid Build Coastguard Worker        )
1542*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1543*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1544*da0073e9SAndroid Build Coastguard Worker            """\
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1549*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1550*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(zeros)
1551*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1552*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1553*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1554*da0073e9SAndroid Build Coastguard Worker    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1555*da0073e9SAndroid Build Coastguard Worker    return diagonal_2
1556*da0073e9SAndroid Build Coastguard Worker    """,
1557*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker        # Test 4: copy_() with different dtype, different shape
1560*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
1561*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
1562*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1563*da0073e9SAndroid Build Coastguard Worker            logs,
1564*da0073e9SAndroid Build Coastguard Worker            """\
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1569*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1570*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1571*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1572*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1573*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1574*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1575*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1576*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1577*da0073e9SAndroid Build Coastguard Worker    return diagonal_copy_2
1578*da0073e9SAndroid Build Coastguard Worker    """,
1579*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1582*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True
1583*da0073e9SAndroid Build Coastguard Worker        )
1584*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1585*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1586*da0073e9SAndroid Build Coastguard Worker            """\
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1591*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1592*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(zeros)
1593*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = copy = None
1594*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1595*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = add = None
1596*da0073e9SAndroid Build Coastguard Worker    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1597*da0073e9SAndroid Build Coastguard Worker    return diagonal_2
1598*da0073e9SAndroid Build Coastguard Worker    """,
1599*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker    def test_expand_symint(self):
1602*da0073e9SAndroid Build Coastguard Worker        # Once some existing SymInt bugs are ironed out, we should update
1603*da0073e9SAndroid Build Coastguard Worker        # this test to plumb FakeSymbolicTensors through it
1604*da0073e9SAndroid Build Coastguard Worker        def f(x):
1605*da0073e9SAndroid Build Coastguard Worker            return x.expand(x.size(0), x.size(1))
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
1608*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
1609*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1610*da0073e9SAndroid Build Coastguard Worker            logs,
1611*da0073e9SAndroid Build Coastguard Worker            """\
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1616*da0073e9SAndroid Build Coastguard Worker    expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]);  arg0_1 = None
1617*da0073e9SAndroid Build Coastguard Worker    return expand_copy
1618*da0073e9SAndroid Build Coastguard Worker    """,
1619*da0073e9SAndroid Build Coastguard Worker        )
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker    def test_fill_(self):
1622*da0073e9SAndroid Build Coastguard Worker        def f(x):
1623*da0073e9SAndroid Build Coastguard Worker            y = x + x
1624*da0073e9SAndroid Build Coastguard Worker            z = y.diagonal()
1625*da0073e9SAndroid Build Coastguard Worker            z.fill_(0)
1626*da0073e9SAndroid Build Coastguard Worker            return y
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
1629*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2, 2))
1630*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1631*da0073e9SAndroid Build Coastguard Worker            logs,
1632*da0073e9SAndroid Build Coastguard Worker            """\
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1637*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1638*da0073e9SAndroid Build Coastguard Worker    diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
1639*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0);  diagonal_copy = None
1640*da0073e9SAndroid Build Coastguard Worker    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill);  add = fill = None
1641*da0073e9SAndroid Build Coastguard Worker    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_copy_1 = None
1642*da0073e9SAndroid Build Coastguard Worker    return diagonal_scatter
1643*da0073e9SAndroid Build Coastguard Worker    """,
1644*da0073e9SAndroid Build Coastguard Worker        )
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1647*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2, 2), reapply_views=True, run_reinplace=True
1648*da0073e9SAndroid Build Coastguard Worker        )
1649*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1650*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1651*da0073e9SAndroid Build Coastguard Worker            """\
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1656*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1657*da0073e9SAndroid Build Coastguard Worker    diagonal = torch.ops.aten.diagonal.default(add)
1658*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill_.Scalar(diagonal, 0);  diagonal = fill = None
1659*da0073e9SAndroid Build Coastguard Worker    diagonal_1 = torch.ops.aten.diagonal.default(add);  diagonal_1 = None
1660*da0073e9SAndroid Build Coastguard Worker    return add
1661*da0073e9SAndroid Build Coastguard Worker    """,
1662*da0073e9SAndroid Build Coastguard Worker        )
1663*da0073e9SAndroid Build Coastguard Worker
1664*da0073e9SAndroid Build Coastguard Worker    def test_resize_smaller(self):
1665*da0073e9SAndroid Build Coastguard Worker        def f(w):
1666*da0073e9SAndroid Build Coastguard Worker            # Resizing to a smaller size doesn't affect storage
1667*da0073e9SAndroid Build Coastguard Worker            x = w + 1
1668*da0073e9SAndroid Build Coastguard Worker            y = x.view(4, 4)
1669*da0073e9SAndroid Build Coastguard Worker            y.resize_(3, 3)
1670*da0073e9SAndroid Build Coastguard Worker            y2 = y.view(-1)
1671*da0073e9SAndroid Build Coastguard Worker            y2.add_(1)
1672*da0073e9SAndroid Build Coastguard Worker            z = y + 1
1673*da0073e9SAndroid Build Coastguard Worker            return z
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(8, 2))
1676*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(8, 2))
1677*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1678*da0073e9SAndroid Build Coastguard Worker            logs,
1679*da0073e9SAndroid Build Coastguard Worker            """\
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1684*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1685*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
1686*da0073e9SAndroid Build Coastguard Worker    resize = torch.ops.aten.resize.default(view_copy, [3, 3]);  resize = None
1687*da0073e9SAndroid Build Coastguard Worker    as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]);  view_copy = None
1688*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]);  as_strided_copy = None
1689*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1);  view_copy_1 = None
1690*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]);  add = None
1691*da0073e9SAndroid Build Coastguard Worker    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]);  as_strided_copy_1 = None
1692*da0073e9SAndroid Build Coastguard Worker    view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]);  add_1 = None
1693*da0073e9SAndroid Build Coastguard Worker    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]);  view_copy_2 = view_copy_3 = None
1694*da0073e9SAndroid Build Coastguard Worker    view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]);  as_strided_scatter = None
1695*da0073e9SAndroid Build Coastguard Worker    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
1696*da0073e9SAndroid Build Coastguard Worker    as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]);  view_copy_5 = None
1697*da0073e9SAndroid Build Coastguard Worker    view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]);  as_strided_copy_2 = view_copy_6 = None
1698*da0073e9SAndroid Build Coastguard Worker    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]);  view_copy_4 = None
1699*da0073e9SAndroid Build Coastguard Worker    as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]);  view_copy_7 = None
1700*da0073e9SAndroid Build Coastguard Worker    add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1);  as_strided_copy_3 = None
1701*da0073e9SAndroid Build Coastguard Worker    return add_2
1702*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
1703*da0073e9SAndroid Build Coastguard Worker        )
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1706*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(8, 2), reapply_views=True, run_reinplace=True
1707*da0073e9SAndroid Build Coastguard Worker        )
1708*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1709*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1710*da0073e9SAndroid Build Coastguard Worker            """\
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1715*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1716*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(add, [4, 4])
1717*da0073e9SAndroid Build Coastguard Worker    resize = torch.ops.aten.resize.default(view, [3, 3]);  resize = None
1718*da0073e9SAndroid Build Coastguard Worker    as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]);  view = None
1719*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(as_strided, [-1]);  as_strided = None
1720*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add_.Tensor(view_1, 1);  add_1 = None
1721*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(add, [4, 4]);  add = None
1722*da0073e9SAndroid Build Coastguard Worker    as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]);  as_strided_1 = None
1723*da0073e9SAndroid Build Coastguard Worker    view_3 = torch.ops.aten.view.default(view_1, [3, 3]);  view_1 = view_3 = None
1724*da0073e9SAndroid Build Coastguard Worker    view_4 = torch.ops.aten.view.default(view_2, [8, 2]);  view_2 = None
1725*da0073e9SAndroid Build Coastguard Worker    view_5 = torch.ops.aten.view.default(view_4, [4, 4])
1726*da0073e9SAndroid Build Coastguard Worker    as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]);  view_5 = None
1727*da0073e9SAndroid Build Coastguard Worker    view_6 = torch.ops.aten.view.default(as_strided_2, [-1]);  as_strided_2 = view_6 = None
1728*da0073e9SAndroid Build Coastguard Worker    view_7 = torch.ops.aten.view.default(view_4, [4, 4]);  view_4 = None
1729*da0073e9SAndroid Build Coastguard Worker    as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]);  view_7 = None
1730*da0073e9SAndroid Build Coastguard Worker    add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1);  add_2 = None
1731*da0073e9SAndroid Build Coastguard Worker    return as_strided_3
1732*da0073e9SAndroid Build Coastguard Worker    """,
1733*da0073e9SAndroid Build Coastguard Worker        )
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker    def test_resize_same_size_diff_rank(self):
1736*da0073e9SAndroid Build Coastguard Worker        def f(x):
1737*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
1738*da0073e9SAndroid Build Coastguard Worker            y.resize_(25, 5)
1739*da0073e9SAndroid Build Coastguard Worker            return y
1740*da0073e9SAndroid Build Coastguard Worker
1741*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(5, 5, 5))
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker    def test_resize_larger_valid(self):
1744*da0073e9SAndroid Build Coastguard Worker        def f(x):
1745*da0073e9SAndroid Build Coastguard Worker            y = x + 1
1746*da0073e9SAndroid Build Coastguard Worker            # resizing a tensor to a larger size is only currently allowed
1747*da0073e9SAndroid Build Coastguard Worker            # if the tensor-to-resize is not a view / has no outstanding views.
1748*da0073e9SAndroid Build Coastguard Worker            # See Note [resize_() in functionalization pass]
1749*da0073e9SAndroid Build Coastguard Worker            y.resize_(5, 5)
1750*da0073e9SAndroid Build Coastguard Worker            y2 = y.view(25)
1751*da0073e9SAndroid Build Coastguard Worker            # Do a mutation to ensure that aliases of the output of resize_()
1752*da0073e9SAndroid Build Coastguard Worker            # propagate mutations correctly.
1753*da0073e9SAndroid Build Coastguard Worker            # I'm using fill_ specifically because I want to guarantee that
1754*da0073e9SAndroid Build Coastguard Worker            # none of the output has uninitialized memory at the end
1755*da0073e9SAndroid Build Coastguard Worker            # (since these tests compare the data output against a reference impl)
1756*da0073e9SAndroid Build Coastguard Worker            y2.fill_(1)
1757*da0073e9SAndroid Build Coastguard Worker            out = y + 1
1758*da0073e9SAndroid Build Coastguard Worker            return y, out
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(8, 2))
1761*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(8, 2))
1762*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1763*da0073e9SAndroid Build Coastguard Worker            logs,
1764*da0073e9SAndroid Build Coastguard Worker            """\
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker
1768*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1769*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1770*da0073e9SAndroid Build Coastguard Worker    resize = torch.ops.aten.resize.default(add, [5, 5]);  add = None
1771*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(resize, [25]);  resize = None
1772*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill.Scalar(view_copy, 1);  view_copy = None
1773*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]);  fill = None
1774*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]);  view_copy_2 = None
1775*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
1776*da0073e9SAndroid Build Coastguard Worker    return (view_copy_1, add_1)
1777*da0073e9SAndroid Build Coastguard Worker    """,
1778*da0073e9SAndroid Build Coastguard Worker        )
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1781*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(8, 2), reapply_views=True, run_reinplace=True
1782*da0073e9SAndroid Build Coastguard Worker        )
1783*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1784*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1785*da0073e9SAndroid Build Coastguard Worker            """\
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker
1789*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1790*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1791*da0073e9SAndroid Build Coastguard Worker    resize = torch.ops.aten.resize_.default(add, [5, 5]);  resize = None
1792*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(add, [25]);  add = None
1793*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill_.Scalar(view, 1);  fill = None
1794*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(view, [5, 5]);  view = None
1795*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(view_1, [25]);  view_2 = None
1796*da0073e9SAndroid Build Coastguard Worker    add_1 = torch.ops.aten.add.Tensor(view_1, 1)
1797*da0073e9SAndroid Build Coastguard Worker    return (view_1, add_1)
1798*da0073e9SAndroid Build Coastguard Worker    """,
1799*da0073e9SAndroid Build Coastguard Worker        )
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker    def test_resize_larger_invalid(self):
1802*da0073e9SAndroid Build Coastguard Worker        def f(x):
1803*da0073e9SAndroid Build Coastguard Worker            y = x + 1
1804*da0073e9SAndroid Build Coastguard Worker            z = y.view(4, 4)
1805*da0073e9SAndroid Build Coastguard Worker            # resizing a tensor to a larger size is only currently allowed
1806*da0073e9SAndroid Build Coastguard Worker            # if the tensor-to-resize is not a view / has no outstanding views.
1807*da0073e9SAndroid Build Coastguard Worker            # See Note [resize_() in functionalization pass]
1808*da0073e9SAndroid Build Coastguard Worker            # This should fail
1809*da0073e9SAndroid Build Coastguard Worker            z.resize_(5, 5)
1810*da0073e9SAndroid Build Coastguard Worker            z2 = z.view(25)
1811*da0073e9SAndroid Build Coastguard Worker            z2.fill_(1)
1812*da0073e9SAndroid Build Coastguard Worker            out = z + 1
1813*da0073e9SAndroid Build Coastguard Worker            return y, out
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1816*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1817*da0073e9SAndroid Build Coastguard Worker            r"Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass",
1818*da0073e9SAndroid Build Coastguard Worker        ):
1819*da0073e9SAndroid Build Coastguard Worker            self.assert_functionalization(f, torch.ones(8, 2))
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker    def test_nested_functions_propagate_updates(self):
1822*da0073e9SAndroid Build Coastguard Worker        def g(x):
1823*da0073e9SAndroid Build Coastguard Worker            # Create a view of x
1824*da0073e9SAndroid Build Coastguard Worker            y = x[0]
1825*da0073e9SAndroid Build Coastguard Worker            y.add_(1)
1826*da0073e9SAndroid Build Coastguard Worker            # The view, y, gets deallocated at the end of this function
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker        def f(x):
1829*da0073e9SAndroid Build Coastguard Worker            # Calling g(x) should mutate x
1830*da0073e9SAndroid Build Coastguard Worker            g(x)
1831*da0073e9SAndroid Build Coastguard Worker            # We expect x to be synced here, even though the alias created in g() has been deallocated!
1832*da0073e9SAndroid Build Coastguard Worker            y = x + x
1833*da0073e9SAndroid Build Coastguard Worker            return y
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2, 2))
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker    def test_mixed_wrappers_valid(self):
1838*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1839*da0073e9SAndroid Build Coastguard Worker            z = x + y
1840*da0073e9SAndroid Build Coastguard Worker            z.add_(1)
1841*da0073e9SAndroid Build Coastguard Worker            return z
1842*da0073e9SAndroid Build Coastguard Worker
1843*da0073e9SAndroid Build Coastguard Worker        x1_not_functional = LoggingTensor(torch.ones(4))
1844*da0073e9SAndroid Build Coastguard Worker        x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
1847*da0073e9SAndroid Build Coastguard Worker            y = f(x1_not_functional, x2_functional)
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker        # Make sure that functionalization ran the "+" kernel
1850*da0073e9SAndroid Build Coastguard Worker        # with a functional + non-functional tensor, and wrapped the output appropriately.
1851*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1852*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
1853*da0073e9SAndroid Build Coastguard Worker            """\
1854*da0073e9SAndroid Build Coastguard Worker$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
1855*da0073e9SAndroid Build Coastguard Worker$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""",
1856*da0073e9SAndroid Build Coastguard Worker        )
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker    def test_mixed_wrappers_invalid(self):
1859*da0073e9SAndroid Build Coastguard Worker        x1_not_functional = torch.ones(4)
1860*da0073e9SAndroid Build Coastguard Worker        x2_functional = torch._to_functional_tensor(torch.ones(4))
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker        # When dealing with mixed functional + non functional tensors,
1863*da0073e9SAndroid Build Coastguard Worker        # normal_tensor.add_(functional_tensor) is not valid
1864*da0073e9SAndroid Build Coastguard Worker        # because normal_tensor would need to be "promoted" to a functional tensor.
1865*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1866*da0073e9SAndroid Build Coastguard Worker            x1_not_functional.add_(x2_functional)
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker    def test_index_mutation_on_non_input(self):
1869*da0073e9SAndroid Build Coastguard Worker        def f(x):
1870*da0073e9SAndroid Build Coastguard Worker            tmp = torch.zeros(10)
1871*da0073e9SAndroid Build Coastguard Worker            tmp[5].fill_(1)
1872*da0073e9SAndroid Build Coastguard Worker            return tmp
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(f, torch.ones(2))
1875*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(f, torch.ones(2))
1876*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1877*da0073e9SAndroid Build Coastguard Worker            logs,
1878*da0073e9SAndroid Build Coastguard Worker            """\
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1883*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1884*da0073e9SAndroid Build Coastguard Worker    select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
1885*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill.Scalar(select_copy, 1);  select_copy = None
1886*da0073e9SAndroid Build Coastguard Worker    select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5);  zeros = fill = None
1887*da0073e9SAndroid Build Coastguard Worker    select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5);  select_copy_1 = None
1888*da0073e9SAndroid Build Coastguard Worker    return select_scatter
1889*da0073e9SAndroid Build Coastguard Worker    """,
1890*da0073e9SAndroid Build Coastguard Worker        )  # noqa: B950
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
1893*da0073e9SAndroid Build Coastguard Worker            f, torch.ones(2), reapply_views=True, run_reinplace=True
1894*da0073e9SAndroid Build Coastguard Worker        )
1895*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1896*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
1897*da0073e9SAndroid Build Coastguard Worker            """\
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker
1901*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
1902*da0073e9SAndroid Build Coastguard Worker    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1903*da0073e9SAndroid Build Coastguard Worker    select = torch.ops.aten.select.int(zeros, 0, 5)
1904*da0073e9SAndroid Build Coastguard Worker    fill = torch.ops.aten.fill_.Scalar(select, 1);  select = fill = None
1905*da0073e9SAndroid Build Coastguard Worker    select_1 = torch.ops.aten.select.int(zeros, 0, 5);  select_1 = None
1906*da0073e9SAndroid Build Coastguard Worker    return zeros
1907*da0073e9SAndroid Build Coastguard Worker    """,
1908*da0073e9SAndroid Build Coastguard Worker        )
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker    def test_instance_norm(self):
1911*da0073e9SAndroid Build Coastguard Worker        size = 100
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        def f(x, running_mean, running_var):
1914*da0073e9SAndroid Build Coastguard Worker            with enable_python_dispatcher():
1915*da0073e9SAndroid Build Coastguard Worker                return torch.instance_norm(
1916*da0073e9SAndroid Build Coastguard Worker                    x,
1917*da0073e9SAndroid Build Coastguard Worker                    None,
1918*da0073e9SAndroid Build Coastguard Worker                    None,
1919*da0073e9SAndroid Build Coastguard Worker                    running_mean,
1920*da0073e9SAndroid Build Coastguard Worker                    running_var,
1921*da0073e9SAndroid Build Coastguard Worker                    use_input_stats=True,
1922*da0073e9SAndroid Build Coastguard Worker                    momentum=0.1,
1923*da0073e9SAndroid Build Coastguard Worker                    eps=1e-5,
1924*da0073e9SAndroid Build Coastguard Worker                    cudnn_enabled=False,
1925*da0073e9SAndroid Build Coastguard Worker                )
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(
1928*da0073e9SAndroid Build Coastguard Worker            f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)
1929*da0073e9SAndroid Build Coastguard Worker        )
1930*da0073e9SAndroid Build Coastguard Worker        # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
1931*da0073e9SAndroid Build Coastguard Worker        # whereas on other platforms, the alias_copy's are before the view_copy's.
1932*da0073e9SAndroid Build Coastguard Worker        # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
1933*da0073e9SAndroid Build Coastguard Worker        if not IS_WINDOWS:
1934*da0073e9SAndroid Build Coastguard Worker            logs = self.get_logs(
1935*da0073e9SAndroid Build Coastguard Worker                f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)
1936*da0073e9SAndroid Build Coastguard Worker            )
1937*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
1938*da0073e9SAndroid Build Coastguard Worker                logs,
1939*da0073e9SAndroid Build Coastguard Worker                """\
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1):
1944*da0073e9SAndroid Build Coastguard Worker    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1945*da0073e9SAndroid Build Coastguard Worker    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1946*da0073e9SAndroid Build Coastguard Worker    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1947*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
1948*da0073e9SAndroid Build Coastguard Worker    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view_copy = repeat = repeat_1 = None
1949*da0073e9SAndroid Build Coastguard Worker    getitem = _native_batch_norm_legit_functional[0]
1950*da0073e9SAndroid Build Coastguard Worker    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
1951*da0073e9SAndroid Build Coastguard Worker    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
1952*da0073e9SAndroid Build Coastguard Worker    getitem_3 = _native_batch_norm_legit_functional[3]
1953*da0073e9SAndroid Build Coastguard Worker    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1954*da0073e9SAndroid Build Coastguard Worker    alias_copy = torch.ops.aten.alias_copy.default(arg1_1)
1955*da0073e9SAndroid Build Coastguard Worker    view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]);  view_copy_1 = None
1956*da0073e9SAndroid Build Coastguard Worker    view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]);  getitem_3 = None
1957*da0073e9SAndroid Build Coastguard Worker    mean = torch.ops.aten.mean.dim(view_copy_2, [0]);  view_copy_2 = None
1958*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(alias_copy, mean);  alias_copy = mean = None
1959*da0073e9SAndroid Build Coastguard Worker    alias_copy_1 = torch.ops.aten.alias_copy.default(copy);  copy = None
1960*da0073e9SAndroid Build Coastguard Worker    alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1);  alias_copy_2 = None
1961*da0073e9SAndroid Build Coastguard Worker    alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
1962*da0073e9SAndroid Build Coastguard Worker    view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]);  view_copy_3 = None
1963*da0073e9SAndroid Build Coastguard Worker    view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]);  getitem_4 = None
1964*da0073e9SAndroid Build Coastguard Worker    mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]);  view_copy_4 = None
1965*da0073e9SAndroid Build Coastguard Worker    copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1);  alias_copy_3 = mean_1 = None
1966*da0073e9SAndroid Build Coastguard Worker    alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1);  copy_1 = None
1967*da0073e9SAndroid Build Coastguard Worker    alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4);  alias_copy_5 = None
1968*da0073e9SAndroid Build Coastguard Worker    view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]);  getitem = None
1969*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1);  arg1_1 = alias_copy_1 = copy_ = None
1970*da0073e9SAndroid Build Coastguard Worker    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4);  arg2_1 = alias_copy_4 = copy__1 = None
1971*da0073e9SAndroid Build Coastguard Worker    return view_copy_5
1972*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
1973*da0073e9SAndroid Build Coastguard Worker            )
1974*da0073e9SAndroid Build Coastguard Worker
1975*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs = self.get_logs(
1976*da0073e9SAndroid Build Coastguard Worker                f,
1977*da0073e9SAndroid Build Coastguard Worker                torch.randn(20, size, 35, 45),
1978*da0073e9SAndroid Build Coastguard Worker                torch.zeros(size),
1979*da0073e9SAndroid Build Coastguard Worker                torch.ones(size),
1980*da0073e9SAndroid Build Coastguard Worker                reapply_views=True,
1981*da0073e9SAndroid Build Coastguard Worker                run_reinplace=True,
1982*da0073e9SAndroid Build Coastguard Worker            )
1983*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
1984*da0073e9SAndroid Build Coastguard Worker                reinplaced_logs,
1985*da0073e9SAndroid Build Coastguard Worker                """\
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1):
1990*da0073e9SAndroid Build Coastguard Worker    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1991*da0073e9SAndroid Build Coastguard Worker    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1992*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1993*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
1994*da0073e9SAndroid Build Coastguard Worker    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view = repeat = repeat_1 = None
1995*da0073e9SAndroid Build Coastguard Worker    getitem = _native_batch_norm_legit_functional[0]
1996*da0073e9SAndroid Build Coastguard Worker    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
1997*da0073e9SAndroid Build Coastguard Worker    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
1998*da0073e9SAndroid Build Coastguard Worker    getitem_3 = _native_batch_norm_legit_functional[3]
1999*da0073e9SAndroid Build Coastguard Worker    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2000*da0073e9SAndroid Build Coastguard Worker    alias = torch.ops.aten.alias.default(arg1_1)
2001*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]);  view_1 = None
2002*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]);  getitem_3 = None
2003*da0073e9SAndroid Build Coastguard Worker    mean = torch.ops.aten.mean.dim(view_2, [0]);  view_2 = None
2004*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(alias, mean);  alias = mean = None
2005*da0073e9SAndroid Build Coastguard Worker    alias_1 = torch.ops.aten.alias.default(copy);  copy = None
2006*da0073e9SAndroid Build Coastguard Worker    alias_2 = torch.ops.aten.alias.default(alias_1);  alias_2 = None
2007*da0073e9SAndroid Build Coastguard Worker    alias_3 = torch.ops.aten.alias.default(arg2_1)
2008*da0073e9SAndroid Build Coastguard Worker    view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]);  view_3 = None
2009*da0073e9SAndroid Build Coastguard Worker    view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]);  getitem_4 = None
2010*da0073e9SAndroid Build Coastguard Worker    mean_1 = torch.ops.aten.mean.dim(view_4, [0]);  view_4 = None
2011*da0073e9SAndroid Build Coastguard Worker    copy_1 = torch.ops.aten.copy.default(alias_3, mean_1);  alias_3 = mean_1 = None
2012*da0073e9SAndroid Build Coastguard Worker    alias_4 = torch.ops.aten.alias.default(copy_1);  copy_1 = None
2013*da0073e9SAndroid Build Coastguard Worker    alias_5 = torch.ops.aten.alias.default(alias_4);  alias_5 = None
2014*da0073e9SAndroid Build Coastguard Worker    view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]);  getitem = None
2015*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1);  arg1_1 = alias_1 = copy_ = None
2016*da0073e9SAndroid Build Coastguard Worker    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4);  arg2_1 = alias_4 = copy__1 = None
2017*da0073e9SAndroid Build Coastguard Worker    return view_5
2018*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
2019*da0073e9SAndroid Build Coastguard Worker            )
2020*da0073e9SAndroid Build Coastguard Worker
2021*da0073e9SAndroid Build Coastguard Worker    def test_mutation_overlapping_mem(self):
2022*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2023*da0073e9SAndroid Build Coastguard Worker            # x: (1, 5)
2024*da0073e9SAndroid Build Coastguard Worker            t1 = torch.add(x, x)
2025*da0073e9SAndroid Build Coastguard Worker            t2 = t1.unfold(1, 3, 2)
2026*da0073e9SAndroid Build Coastguard Worker            t3 = t2.abs_()
2027*da0073e9SAndroid Build Coastguard Worker            return t3
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2030*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2031*da0073e9SAndroid Build Coastguard Worker            r"encountered a tensor being mutated that has internal overlap",
2032*da0073e9SAndroid Build Coastguard Worker        ):
2033*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, 5)
2034*da0073e9SAndroid Build Coastguard Worker            out = _functionalize(fn, reapply_views=True, crossref=False)(x)
2035*da0073e9SAndroid Build Coastguard Worker
2036*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm(self):
2037*da0073e9SAndroid Build Coastguard Worker        def f(x, running_mean, running_var):
2038*da0073e9SAndroid Build Coastguard Worker            with enable_python_dispatcher():
2039*da0073e9SAndroid Build Coastguard Worker                return torch.batch_norm(
2040*da0073e9SAndroid Build Coastguard Worker                    x, None, None, running_mean, running_var, True, 0.1, 1e-5, False
2041*da0073e9SAndroid Build Coastguard Worker                )
2042*da0073e9SAndroid Build Coastguard Worker
2043*da0073e9SAndroid Build Coastguard Worker        self.assert_functionalization(
2044*da0073e9SAndroid Build Coastguard Worker            f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)
2045*da0073e9SAndroid Build Coastguard Worker        )
2046*da0073e9SAndroid Build Coastguard Worker        logs = self.get_logs(
2047*da0073e9SAndroid Build Coastguard Worker            f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)
2048*da0073e9SAndroid Build Coastguard Worker        )
2049*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2050*da0073e9SAndroid Build Coastguard Worker            logs,
2051*da0073e9SAndroid Build Coastguard Worker            """\
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker
2054*da0073e9SAndroid Build Coastguard Worker
2055*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1):
2056*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
2057*da0073e9SAndroid Build Coastguard Worker    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
2058*da0073e9SAndroid Build Coastguard Worker    getitem = _native_batch_norm_legit_functional[0]
2059*da0073e9SAndroid Build Coastguard Worker    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
2060*da0073e9SAndroid Build Coastguard Worker    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
2061*da0073e9SAndroid Build Coastguard Worker    getitem_3 = _native_batch_norm_legit_functional[3]
2062*da0073e9SAndroid Build Coastguard Worker    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2063*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = copy_ = None
2064*da0073e9SAndroid Build Coastguard Worker    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = copy__1 = None
2065*da0073e9SAndroid Build Coastguard Worker    return getitem
2066*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
2067*da0073e9SAndroid Build Coastguard Worker        )
2068*da0073e9SAndroid Build Coastguard Worker
2069*da0073e9SAndroid Build Coastguard Worker        reinplaced_logs = self.get_logs(
2070*da0073e9SAndroid Build Coastguard Worker            f,
2071*da0073e9SAndroid Build Coastguard Worker            torch.randn(20, 100, 35, 45),
2072*da0073e9SAndroid Build Coastguard Worker            torch.zeros(100),
2073*da0073e9SAndroid Build Coastguard Worker            torch.ones(100),
2074*da0073e9SAndroid Build Coastguard Worker            reapply_views=True,
2075*da0073e9SAndroid Build Coastguard Worker            run_reinplace=True,
2076*da0073e9SAndroid Build Coastguard Worker        )
2077*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2078*da0073e9SAndroid Build Coastguard Worker            reinplaced_logs,
2079*da0073e9SAndroid Build Coastguard Worker            """\
2080*da0073e9SAndroid Build Coastguard Worker
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker
2083*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1):
2084*da0073e9SAndroid Build Coastguard Worker    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'));  empty = None
2085*da0073e9SAndroid Build Coastguard Worker    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
2086*da0073e9SAndroid Build Coastguard Worker    getitem = _native_batch_norm_legit_functional[0]
2087*da0073e9SAndroid Build Coastguard Worker    getitem_1 = _native_batch_norm_legit_functional[1];  getitem_1 = None
2088*da0073e9SAndroid Build Coastguard Worker    getitem_2 = _native_batch_norm_legit_functional[2];  getitem_2 = None
2089*da0073e9SAndroid Build Coastguard Worker    getitem_3 = _native_batch_norm_legit_functional[3]
2090*da0073e9SAndroid Build Coastguard Worker    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
2091*da0073e9SAndroid Build Coastguard Worker    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = copy_ = None
2092*da0073e9SAndroid Build Coastguard Worker    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = copy__1 = None
2093*da0073e9SAndroid Build Coastguard Worker    return getitem
2094*da0073e9SAndroid Build Coastguard Worker    """,  # noqa: B950
2095*da0073e9SAndroid Build Coastguard Worker        )
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker    # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
2098*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization(self):
2099*da0073e9SAndroid Build Coastguard Worker        def f(x):
2100*da0073e9SAndroid Build Coastguard Worker            x_view = x.view(-1)
2101*da0073e9SAndroid Build Coastguard Worker            x.mul_(2)
2102*da0073e9SAndroid Build Coastguard Worker            return x_view + 1
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker        def f_functionalized(x):
2105*da0073e9SAndroid Build Coastguard Worker            # Note [Disabling Functionalize TLS Above Python Functionalization]
2106*da0073e9SAndroid Build Coastguard Worker            # This UX is pretty annoying (although python functionalization's main customer is AOTAutograd,
2107*da0073e9SAndroid Build Coastguard Worker            # and is not really advertised as a user API).
2108*da0073e9SAndroid Build Coastguard Worker            # We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode.
2109*da0073e9SAndroid Build Coastguard Worker            # Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper.
2110*da0073e9SAndroid Build Coastguard Worker            # Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default,
2111*da0073e9SAndroid Build Coastguard Worker            # our FunctionalTensor will inherit the same keyset.
2112*da0073e9SAndroid Build Coastguard Worker            # We don't have an easy way of directly mutating a tensor's keyset from python,
2113*da0073e9SAndroid Build Coastguard Worker            # so globally disabling functionalization here is easier.
2114*da0073e9SAndroid Build Coastguard Worker            maybe_disable = torch._C._ExcludeDispatchKeyGuard(
2115*da0073e9SAndroid Build Coastguard Worker                torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
2116*da0073e9SAndroid Build Coastguard Worker            )
2117*da0073e9SAndroid Build Coastguard Worker            with maybe_disable, FunctionalTensorMode():
2118*da0073e9SAndroid Build Coastguard Worker                x_wrapped = FunctionalTensor.to_functional(x)
2119*da0073e9SAndroid Build Coastguard Worker                out_wrapped = f(x_wrapped)
2120*da0073e9SAndroid Build Coastguard Worker            out_unwrapped = out_wrapped.elem
2121*da0073e9SAndroid Build Coastguard Worker            torch._sync(out_unwrapped)
2122*da0073e9SAndroid Build Coastguard Worker            return torch._from_functional_tensor(out_unwrapped)
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        # Make a non-leaf
2125*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, requires_grad=True) + 1
2126*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(f_functionalized)(x)
2127*da0073e9SAndroid Build Coastguard Worker        # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
2128*da0073e9SAndroid Build Coastguard Worker        # DCE pass that will remove nodes like this later on.
2129*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2130*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
2131*da0073e9SAndroid Build Coastguard Worker            """\
2132*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1):
2133*da0073e9SAndroid Build Coastguard Worker    view = torch.ops.aten.view.default(x_1, [-1]);  view = None
2134*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(x_1, 2);  x_1 = None
2135*da0073e9SAndroid Build Coastguard Worker    view_1 = torch.ops.aten.view.default(mul, [-1]);  view_1 = None
2136*da0073e9SAndroid Build Coastguard Worker    view_2 = torch.ops.aten.view.default(mul, [-1]);  mul = None
2137*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(view_2, 1);  view_2 = None
2138*da0073e9SAndroid Build Coastguard Worker    return add""",
2139*da0073e9SAndroid Build Coastguard Worker        )
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_zero_tensor(self):
2142*da0073e9SAndroid Build Coastguard Worker        def f(x):
2143*da0073e9SAndroid Build Coastguard Worker            y = torch.ops.aten._efficientzerotensor([4])
2144*da0073e9SAndroid Build Coastguard Worker            out = x + y
2145*da0073e9SAndroid Build Coastguard Worker            out.mul_(2)
2146*da0073e9SAndroid Build Coastguard Worker            return out
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2149*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2150*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2151*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(
2152*da0073e9SAndroid Build Coastguard Worker            f, reapply_views=True, crossref=False, skip_input_mutations=True
2153*da0073e9SAndroid Build Coastguard Worker        )(x)
2154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test_cpp)
2156*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(dispatch_functionalize(f))(x)
2157*da0073e9SAndroid Build Coastguard Worker        fx_g_cpp = make_fx(
2158*da0073e9SAndroid Build Coastguard Worker            _functionalize(
2159*da0073e9SAndroid Build Coastguard Worker                f, reapply_views=True, crossref=False, skip_input_mutations=True
2160*da0073e9SAndroid Build Coastguard Worker            )
2161*da0073e9SAndroid Build Coastguard Worker        )(x)
2162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_is_conj(self):
2165*da0073e9SAndroid Build Coastguard Worker        def f(x):
2166*da0073e9SAndroid Build Coastguard Worker            out = x.conj()
2167*da0073e9SAndroid Build Coastguard Worker            return out, out.is_conj()
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.complex64)
2170*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2171*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2172*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
2173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[0], out_test[0])
2174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[1], out_test[1])
2175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[0], out_test_cpp[0])
2176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[1], out_test_cpp[1])
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_is_neg(self):
2179*da0073e9SAndroid Build Coastguard Worker        def f(x):
2180*da0073e9SAndroid Build Coastguard Worker            out = x.neg()
2181*da0073e9SAndroid Build Coastguard Worker            return out, out.is_neg()
2182*da0073e9SAndroid Build Coastguard Worker
2183*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.complex64)
2184*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2185*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2186*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
2187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[0], out_test[0])
2188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[1], out_test[1])
2189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[0], out_test_cpp[0])
2190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref[1], out_test_cpp[1])
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_conj(self):
2193*da0073e9SAndroid Build Coastguard Worker        def f(x):
2194*da0073e9SAndroid Build Coastguard Worker            y = x.clone().conj()
2195*da0073e9SAndroid Build Coastguard Worker            y.mul_(2)
2196*da0073e9SAndroid Build Coastguard Worker            return torch.view_as_real(y.resolve_conj())
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.complex64)
2199*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2200*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2201*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(
2202*da0073e9SAndroid Build Coastguard Worker            f, reapply_views=True, crossref=False, skip_input_mutations=True
2203*da0073e9SAndroid Build Coastguard Worker        )(x)
2204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_test, out_test_cpp)
2206*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(dispatch_functionalize(f))(x)
2207*da0073e9SAndroid Build Coastguard Worker        fx_g_cpp = make_fx(
2208*da0073e9SAndroid Build Coastguard Worker            _functionalize(
2209*da0073e9SAndroid Build Coastguard Worker                f, reapply_views=True, crossref=False, skip_input_mutations=True
2210*da0073e9SAndroid Build Coastguard Worker            )
2211*da0073e9SAndroid Build Coastguard Worker        )(x)
2212*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2213*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
2214*da0073e9SAndroid Build Coastguard Worker            """\
2215*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
2216*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
2217*da0073e9SAndroid Build Coastguard Worker    _conj = torch.ops.aten._conj.default(clone);  clone = None
2218*da0073e9SAndroid Build Coastguard Worker    clone_1 = torch.ops.aten.clone.default(_conj)
2219*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(clone_1, 2);  clone_1 = None
2220*da0073e9SAndroid Build Coastguard Worker    clone_2 = torch.ops.aten.clone.default(_conj);  _conj = None
2221*da0073e9SAndroid Build Coastguard Worker    copy = torch.ops.aten.copy.default(clone_2, mul);  clone_2 = mul = None
2222*da0073e9SAndroid Build Coastguard Worker    _conj_1 = torch.ops.aten._conj.default(copy);  copy = None
2223*da0073e9SAndroid Build Coastguard Worker    _conj_2 = torch.ops.aten._conj.default(_conj_1);  _conj_1 = None
2224*da0073e9SAndroid Build Coastguard Worker    clone_3 = torch.ops.aten.clone.default(_conj_2);  _conj_2 = None
2225*da0073e9SAndroid Build Coastguard Worker    view_as_real = torch.ops.aten.view_as_real.default(clone_3);  clone_3 = None
2226*da0073e9SAndroid Build Coastguard Worker    return view_as_real""",
2227*da0073e9SAndroid Build Coastguard Worker        )
2228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_neg(self):
2231*da0073e9SAndroid Build Coastguard Worker        def f(x):
2232*da0073e9SAndroid Build Coastguard Worker            y = x._neg_view()
2233*da0073e9SAndroid Build Coastguard Worker            z = y.resolve_neg()
2234*da0073e9SAndroid Build Coastguard Worker            return z + 1
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2237*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2238*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2239*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(
2240*da0073e9SAndroid Build Coastguard Worker            f, reapply_views=True, crossref=False, skip_input_mutations=True
2241*da0073e9SAndroid Build Coastguard Worker        )(x)
2242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test_cpp)
2244*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(dispatch_functionalize(f))(x)
2245*da0073e9SAndroid Build Coastguard Worker        fx_g_cpp = make_fx(
2246*da0073e9SAndroid Build Coastguard Worker            _functionalize(
2247*da0073e9SAndroid Build Coastguard Worker                f, reapply_views=True, crossref=False, skip_input_mutations=True
2248*da0073e9SAndroid Build Coastguard Worker            )
2249*da0073e9SAndroid Build Coastguard Worker        )(x)
2250*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2251*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
2252*da0073e9SAndroid Build Coastguard Worker            """\
2253*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
2254*da0073e9SAndroid Build Coastguard Worker    _neg_view = torch.ops.aten._neg_view.default(arg0_1);  arg0_1 = None
2255*da0073e9SAndroid Build Coastguard Worker    clone = torch.ops.aten.clone.default(_neg_view);  _neg_view = None
2256*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
2257*da0073e9SAndroid Build Coastguard Worker    return add""",
2258*da0073e9SAndroid Build Coastguard Worker        )
2259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2260*da0073e9SAndroid Build Coastguard Worker
2261*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_lift_fresh_storage(self):
2262*da0073e9SAndroid Build Coastguard Worker        unlifted = torch.tensor([0.0])
2263*da0073e9SAndroid Build Coastguard Worker
2264*da0073e9SAndroid Build Coastguard Worker        maybe_disable = torch._C._ExcludeDispatchKeyGuard(
2265*da0073e9SAndroid Build Coastguard Worker            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
2266*da0073e9SAndroid Build Coastguard Worker        )
2267*da0073e9SAndroid Build Coastguard Worker        with maybe_disable, FunctionalTensorMode():
2268*da0073e9SAndroid Build Coastguard Worker            lifted = torch.ops.aten.lift_fresh.default(unlifted)
2269*da0073e9SAndroid Build Coastguard Worker
2270*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage())
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker    def test_python_functionalization_lift_fresh(self):
2273*da0073e9SAndroid Build Coastguard Worker        def f(x):
2274*da0073e9SAndroid Build Coastguard Worker            tmp = torch.tensor([0.0])
2275*da0073e9SAndroid Build Coastguard Worker            return tmp + x
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2278*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
2279*da0073e9SAndroid Build Coastguard Worker        out_test = dispatch_functionalize(f)(x)
2280*da0073e9SAndroid Build Coastguard Worker        out_test_cpp = _functionalize(
2281*da0073e9SAndroid Build Coastguard Worker            f, reapply_views=True, crossref=False, skip_input_mutations=True
2282*da0073e9SAndroid Build Coastguard Worker        )(x)
2283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test_cpp)
2285*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(dispatch_functionalize(f))(x)
2286*da0073e9SAndroid Build Coastguard Worker        fx_g_cpp = make_fx(
2287*da0073e9SAndroid Build Coastguard Worker            _functionalize(
2288*da0073e9SAndroid Build Coastguard Worker                f, reapply_views=True, crossref=False, skip_input_mutations=True
2289*da0073e9SAndroid Build Coastguard Worker            )
2290*da0073e9SAndroid Build Coastguard Worker        )(x)
2291*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2292*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
2293*da0073e9SAndroid Build Coastguard Worker            """\
2294*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1):
2295*da0073e9SAndroid Build Coastguard Worker    _tensor_constant0 = self._tensor_constant0
2296*da0073e9SAndroid Build Coastguard Worker    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
2297*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1);  lift_fresh_copy = arg0_1 = None
2298*da0073e9SAndroid Build Coastguard Worker    return add""",
2299*da0073e9SAndroid Build Coastguard Worker        )
2300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker
2303*da0073e9SAndroid Build Coastguard Worker@xfail_inherited_tests(
2304*da0073e9SAndroid Build Coastguard Worker    [
2305*da0073e9SAndroid Build Coastguard Worker        "test_as_strided",
2306*da0073e9SAndroid Build Coastguard Worker        "test_copy_",
2307*da0073e9SAndroid Build Coastguard Worker        "test_diagonal",
2308*da0073e9SAndroid Build Coastguard Worker        "test_diagonal_mutated_input",
2309*da0073e9SAndroid Build Coastguard Worker        "test_everything",
2310*da0073e9SAndroid Build Coastguard Worker        "test_fill_",
2311*da0073e9SAndroid Build Coastguard Worker        "test_slice",
2312*da0073e9SAndroid Build Coastguard Worker        "test_split",
2313*da0073e9SAndroid Build Coastguard Worker        "test_split_with_sizes",
2314*da0073e9SAndroid Build Coastguard Worker        "test_unbind",
2315*da0073e9SAndroid Build Coastguard Worker        "test_view_clone_view_inplace",
2316*da0073e9SAndroid Build Coastguard Worker        "test_view_inplace",
2317*da0073e9SAndroid Build Coastguard Worker    ]
2318*da0073e9SAndroid Build Coastguard Worker)
2319*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
2320*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well"
2321*da0073e9SAndroid Build Coastguard Worker)
2322*da0073e9SAndroid Build Coastguard Workerclass TestCrossRefFunctionalization(TestFunctionalization):
2323*da0073e9SAndroid Build Coastguard Worker    crossref = True
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2327*da0073e9SAndroid Build Coastguard Worker    run_tests()
2328