1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: codegen"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import ( 8*da0073e9SAndroid Build Coastguard Worker enable_crossref_functionalize, 9*da0073e9SAndroid Build Coastguard Worker enable_python_dispatcher, 10*da0073e9SAndroid Build Coastguard Worker) 11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.functional_tensor import ( 12*da0073e9SAndroid Build Coastguard Worker dispatch_functionalize, 13*da0073e9SAndroid Build Coastguard Worker FunctionalTensor, 14*da0073e9SAndroid Build Coastguard Worker FunctionalTensorMode, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx 17*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.reinplace import reinplace 18*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing.reductions import StorageWeakRef 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 20*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 21*da0073e9SAndroid Build Coastguard Worker run_tests, 22*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 23*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 24*da0073e9SAndroid Build Coastguard Worker TestCase, 25*da0073e9SAndroid Build Coastguard Worker xfail_inherited_tests, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import capture_logs, LoggingTensor 28*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 29*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map_only 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdef are_aliased(x, y): 33*da0073e9SAndroid Build Coastguard Worker x_storage = StorageWeakRef(x.storage()) 34*da0073e9SAndroid Build Coastguard Worker y_storage = StorageWeakRef(y.storage()) 35*da0073e9SAndroid Build Coastguard Worker return x_storage == y_storage 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker# We can unify testing and use functionalize() here instead 39*da0073e9SAndroid Build Coastguard Worker# if/when functorch moves into core. 40*da0073e9SAndroid Build Coastguard Worker# This is basically a crappy version of `functionalize()`. 41*da0073e9SAndroid Build Coastguard Workerdef _functionalize( 42*da0073e9SAndroid Build Coastguard Worker f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False 43*da0073e9SAndroid Build Coastguard Worker): 44*da0073e9SAndroid Build Coastguard Worker def to_fun(t: torch.Tensor): 45*da0073e9SAndroid Build Coastguard Worker func_t = torch._to_functional_tensor(t) 46*da0073e9SAndroid Build Coastguard Worker func_t.requires_grad = t.requires_grad 47*da0073e9SAndroid Build Coastguard Worker return func_t 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def wrapped(*inputs): 50*da0073e9SAndroid Build Coastguard Worker ctx = nullcontext() 51*da0073e9SAndroid Build Coastguard Worker if crossref: 52*da0073e9SAndroid Build Coastguard Worker ctx = enable_crossref_functionalize() 53*da0073e9SAndroid Build Coastguard Worker with ctx: 54*da0073e9SAndroid Build Coastguard Worker inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs) 55*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=reapply_views) 56*da0073e9SAndroid Build Coastguard Worker try: 57*da0073e9SAndroid Build Coastguard Worker out = f(*inputs_functional) 58*da0073e9SAndroid Build Coastguard Worker finally: 59*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 60*da0073e9SAndroid Build Coastguard Worker flat_inputs = pytree.tree_leaves(inputs) 61*da0073e9SAndroid Build Coastguard Worker flat_inputs_functional = pytree.tree_leaves(inputs_functional) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker for inpt, input_functional in zip(flat_inputs, flat_inputs_functional): 64*da0073e9SAndroid Build Coastguard Worker torch._sync(input_functional) 65*da0073e9SAndroid Build Coastguard Worker inpt_new = torch._from_functional_tensor(input_functional) 66*da0073e9SAndroid Build Coastguard Worker if inpt_new is not inpt and not skip_input_mutations: 67*da0073e9SAndroid Build Coastguard Worker # Existing deficiency in functionalize(): 68*da0073e9SAndroid Build Coastguard Worker # we don't correctly mutate input metadata (yet?) 69*da0073e9SAndroid Build Coastguard Worker if inpt_new.shape == inpt.shape: 70*da0073e9SAndroid Build Coastguard Worker inpt.copy_(inpt_new) 71*da0073e9SAndroid Build Coastguard Worker tree_map_only(torch.Tensor, torch._sync, out) 72*da0073e9SAndroid Build Coastguard Worker out_unwrapped = tree_map_only( 73*da0073e9SAndroid Build Coastguard Worker torch.Tensor, torch._from_functional_tensor, out 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker return out_unwrapped 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker return wrapped 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 81*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457" 82*da0073e9SAndroid Build Coastguard Worker) 83*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalization(TestCase): 84*da0073e9SAndroid Build Coastguard Worker crossref = False 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False): 87*da0073e9SAndroid Build Coastguard Worker inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts) 88*da0073e9SAndroid Build Coastguard Worker traced_f = make_fx( 89*da0073e9SAndroid Build Coastguard Worker _functionalize(func, reapply_views=reapply_views, crossref=self.crossref) 90*da0073e9SAndroid Build Coastguard Worker )(*inpts) 91*da0073e9SAndroid Build Coastguard Worker if run_reinplace: 92*da0073e9SAndroid Build Coastguard Worker traced_f = reinplace(traced_f, *inpts_clone) 93*da0073e9SAndroid Build Coastguard Worker return traced_f.code 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def assert_functionalization( 96*da0073e9SAndroid Build Coastguard Worker self, func, *inpts, reapply_views=False, mutated_input_metadata=False 97*da0073e9SAndroid Build Coastguard Worker ): 98*da0073e9SAndroid Build Coastguard Worker clones1 = tree_map_only(torch.Tensor, torch.clone, inpts) 99*da0073e9SAndroid Build Coastguard Worker clones2 = tree_map_only(torch.Tensor, torch.clone, inpts) 100*da0073e9SAndroid Build Coastguard Worker clones3 = tree_map_only(torch.Tensor, torch.clone, inpts) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # Compare outputs (and mutated inputs), with and without functionalization. 103*da0073e9SAndroid Build Coastguard Worker out_ref = func(*inpts) 104*da0073e9SAndroid Build Coastguard Worker out_functional = _functionalize( 105*da0073e9SAndroid Build Coastguard Worker func, reapply_views=reapply_views, crossref=self.crossref 106*da0073e9SAndroid Build Coastguard Worker )(*clones1) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker # The reinplacing pass is only valid to run with reapply_views=True. 109*da0073e9SAndroid Build Coastguard Worker functional_func = make_fx( 110*da0073e9SAndroid Build Coastguard Worker _functionalize(func, reapply_views=True, crossref=self.crossref) 111*da0073e9SAndroid Build Coastguard Worker )(*clones2) 112*da0073e9SAndroid Build Coastguard Worker reinplace_func = reinplace(functional_func, *clones2) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker # NOTE: for now, need to pass in fresh inputs here, because make_fx 115*da0073e9SAndroid Build Coastguard Worker # will directly mutate the inputs that you trace with. 116*da0073e9SAndroid Build Coastguard Worker # Once this is fixed we can clean this up. 117*da0073e9SAndroid Build Coastguard Worker out_reinplace = reinplace_func(*clones3) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker # functionalize() deficiency: input metadata mutations aren't propagated properly, 120*da0073e9SAndroid Build Coastguard Worker # so we just need to skip checks here for the tests that exercise that. 121*da0073e9SAndroid Build Coastguard Worker if not mutated_input_metadata: 122*da0073e9SAndroid Build Coastguard Worker flat_inpts = pytree.tree_leaves(inpts) 123*da0073e9SAndroid Build Coastguard Worker flat_clones1 = pytree.tree_leaves(clones1) 124*da0073e9SAndroid Build Coastguard Worker flat_clones3 = pytree.tree_leaves(clones3) 125*da0073e9SAndroid Build Coastguard Worker for inpt, input_clone, input_clone3 in zip( 126*da0073e9SAndroid Build Coastguard Worker flat_inpts, flat_clones1, flat_clones3 127*da0073e9SAndroid Build Coastguard Worker ): 128*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 129*da0073e9SAndroid Build Coastguard Worker inpt, input_clone 130*da0073e9SAndroid Build Coastguard Worker ) # input mutations should still occur 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inpt, input_clone3) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker # Handle tests with multi-tensor outputs 134*da0073e9SAndroid Build Coastguard Worker if isinstance(out_ref, tuple): 135*da0073e9SAndroid Build Coastguard Worker out_refs, out_functionals, out_reinplaces = ( 136*da0073e9SAndroid Build Coastguard Worker list(out_ref), 137*da0073e9SAndroid Build Coastguard Worker list(out_functional), 138*da0073e9SAndroid Build Coastguard Worker list(out_reinplace), 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker else: 141*da0073e9SAndroid Build Coastguard Worker out_refs, out_functionals, out_reinplaces = ( 142*da0073e9SAndroid Build Coastguard Worker [out_ref], 143*da0073e9SAndroid Build Coastguard Worker [out_functional], 144*da0073e9SAndroid Build Coastguard Worker [out_reinplace], 145*da0073e9SAndroid Build Coastguard Worker ) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker for out_ref_, out_functional_, out_reinplace_ in zip( 148*da0073e9SAndroid Build Coastguard Worker out_refs, out_functionals, out_reinplaces 149*da0073e9SAndroid Build Coastguard Worker ): 150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref_, out_functional_) 151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref_, out_reinplace_) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def test_save_for_backwards_segfault(self): 154*da0073e9SAndroid Build Coastguard Worker inp = torch._to_functional_tensor( 155*da0073e9SAndroid Build Coastguard Worker LoggingTensor(torch.randn(2, 2)) 156*da0073e9SAndroid Build Coastguard Worker ).requires_grad_(True) 157*da0073e9SAndroid Build Coastguard Worker inp.exp() 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def test_multiple_views_of_same_base(self): 160*da0073e9SAndroid Build Coastguard Worker def f(x): 161*da0073e9SAndroid Build Coastguard Worker y = x.view(-1) 162*da0073e9SAndroid Build Coastguard Worker z = x.view(-1) 163*da0073e9SAndroid Build Coastguard Worker x.add_(1) 164*da0073e9SAndroid Build Coastguard Worker # y should have been updated. 165*da0073e9SAndroid Build Coastguard Worker y2 = y + 1 166*da0073e9SAndroid Build Coastguard Worker # z should have been updated too. 167*da0073e9SAndroid Build Coastguard Worker z2 = z + 1 168*da0073e9SAndroid Build Coastguard Worker return z2 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4)) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def test_freeze(self): 173*da0073e9SAndroid Build Coastguard Worker def f(x): 174*da0073e9SAndroid Build Coastguard Worker y = x.clone() 175*da0073e9SAndroid Build Coastguard Worker z = y[0] 176*da0073e9SAndroid Build Coastguard Worker torch._freeze_functional_tensor(y) 177*da0073e9SAndroid Build Coastguard Worker x.add_(1) 178*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: y.add_(1)) 179*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z.add_(1)) 180*da0073e9SAndroid Build Coastguard Worker return z 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3)) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker def test_copy_stride_mismatch(self): 185*da0073e9SAndroid Build Coastguard Worker def f(x): 186*da0073e9SAndroid Build Coastguard Worker y = torch.empty_strided((2, 2), (5, 1)) 187*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 188*da0073e9SAndroid Build Coastguard Worker return y 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker r = _functionalize(f, reapply_views=True, crossref=self.crossref)( 191*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2) 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.stride(), (5, 1)) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_set_(self): 196*da0073e9SAndroid Build Coastguard Worker def f(x): 197*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2) 198*da0073e9SAndroid Build Coastguard Worker y.set_(x.storage()) 199*da0073e9SAndroid Build Coastguard Worker return y 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker # We should probaby get the crossref test to work, 202*da0073e9SAndroid Build Coastguard Worker # but fixing it for Storage() objects is annoying. 203*da0073e9SAndroid Build Coastguard Worker r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2)) 204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(r.device), "cpu") 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing(self): 207*da0073e9SAndroid Build Coastguard Worker def f(): 208*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, 3) 209*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0]) 210*da0073e9SAndroid Build Coastguard Worker val = torch.ones(3, 1) 211*da0073e9SAndroid Build Coastguard Worker x[:, idx] = val 212*da0073e9SAndroid Build Coastguard Worker return x 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def test_view_clone_view_inplace(self): 217*da0073e9SAndroid Build Coastguard Worker def f(input): 218*da0073e9SAndroid Build Coastguard Worker shape = [1, 1024, 128, 128] 219*da0073e9SAndroid Build Coastguard Worker input_reshaped = input.view(shape) 220*da0073e9SAndroid Build Coastguard Worker out = input_reshaped.clone() 221*da0073e9SAndroid Build Coastguard Worker r = out.view(input.shape) 222*da0073e9SAndroid Build Coastguard Worker r.relu_() 223*da0073e9SAndroid Build Coastguard Worker return r 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker def g(x): 226*da0073e9SAndroid Build Coastguard Worker loss = f(x).sum() 227*da0073e9SAndroid Build Coastguard Worker import torch.fx.traceback as fx_traceback 228*da0073e9SAndroid Build Coastguard Worker from torch._functorch.aot_autograd import ( 229*da0073e9SAndroid Build Coastguard Worker setup_stacktrace_preservation_hooks, 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker setup_stacktrace_preservation_hooks([loss.grad_fn]) 233*da0073e9SAndroid Build Coastguard Worker with fx_traceback.preserve_node_meta(): 234*da0073e9SAndroid Build Coastguard Worker loss.backward() 235*da0073e9SAndroid Build Coastguard Worker return x.grad 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker with torch.autograd.detect_anomaly(check_nan=False): 238*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True)) 239*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 240*da0073e9SAndroid Build Coastguard Worker logs, 241*da0073e9SAndroid Build Coastguard Worker """\ 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 246*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]); arg0_1 = None 247*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(view_copy); view_copy = None 248*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]) 249*da0073e9SAndroid Build Coastguard Worker relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None 250*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None 251*da0073e9SAndroid Build Coastguard Worker view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None 252*da0073e9SAndroid Build Coastguard Worker view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = view_copy_4 = None 253*da0073e9SAndroid Build Coastguard Worker sum_1 = torch.ops.aten.sum.default(view_copy_3) 254*da0073e9SAndroid Build Coastguard Worker ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None 255*da0073e9SAndroid Build Coastguard Worker expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None 256*da0073e9SAndroid Build Coastguard Worker view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None 257*da0073e9SAndroid Build Coastguard Worker new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) 258*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None 259*da0073e9SAndroid Build Coastguard Worker view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); view_copy_6 = None 260*da0073e9SAndroid Build Coastguard Worker view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) 261*da0073e9SAndroid Build Coastguard Worker clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format) 262*da0073e9SAndroid Build Coastguard Worker threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None 263*da0073e9SAndroid Build Coastguard Worker copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None 264*da0073e9SAndroid Build Coastguard Worker view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None 265*da0073e9SAndroid Build Coastguard Worker view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_9 = None 266*da0073e9SAndroid Build Coastguard Worker view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None 267*da0073e9SAndroid Build Coastguard Worker detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = detach_copy = None 268*da0073e9SAndroid Build Coastguard Worker view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None 269*da0073e9SAndroid Build Coastguard Worker detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None 270*da0073e9SAndroid Build Coastguard Worker return detach_copy_1 271*da0073e9SAndroid Build Coastguard Worker """, 272*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def test_simple(self): 275*da0073e9SAndroid Build Coastguard Worker def f(x): 276*da0073e9SAndroid Build Coastguard Worker # simple test: 1 view op, 1 inplace op 277*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 278*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 279*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 280*da0073e9SAndroid Build Coastguard Worker z = x * x 281*da0073e9SAndroid Build Coastguard Worker return y 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 284*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 285*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 286*da0073e9SAndroid Build Coastguard Worker logs, 287*da0073e9SAndroid Build Coastguard Worker """\ 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 292*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 293*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) 294*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None 295*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 296*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) 297*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1); mul = None 298*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 299*da0073e9SAndroid Build Coastguard Worker return view_copy_2 300*da0073e9SAndroid Build Coastguard Worker """, 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 304*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 307*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 308*da0073e9SAndroid Build Coastguard Worker """\ 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 313*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 314*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(arg0_1, [4, 2]) 315*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 316*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 317*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(view_1, [4, 2]) 318*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None 319*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None 320*da0073e9SAndroid Build Coastguard Worker return view_2 321*da0073e9SAndroid Build Coastguard Worker """, 322*da0073e9SAndroid Build Coastguard Worker ) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def test_simple_out(self): 325*da0073e9SAndroid Build Coastguard Worker def f(x): 326*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 327*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 328*da0073e9SAndroid Build Coastguard Worker # the out= tensor will get resized, since it has size=0 to start. 329*da0073e9SAndroid Build Coastguard Worker z = torch.empty(()) 330*da0073e9SAndroid Build Coastguard Worker torch.add(y, tmp, out=z) 331*da0073e9SAndroid Build Coastguard Worker w = z * z 332*da0073e9SAndroid Build Coastguard Worker return w 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 335*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 336*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 337*da0073e9SAndroid Build Coastguard Worker logs, 338*da0073e9SAndroid Build Coastguard Worker """\ 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 343*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 344*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None 345*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None 346*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None 347*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(add, add); add = None 348*da0073e9SAndroid Build Coastguard Worker return mul 349*da0073e9SAndroid Build Coastguard Worker """, 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 353*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 354*da0073e9SAndroid Build Coastguard Worker ) 355*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 356*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 357*da0073e9SAndroid Build Coastguard Worker """\ 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 362*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 363*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None 364*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None 365*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 366*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(add, add); add = None 367*da0073e9SAndroid Build Coastguard Worker return mul 368*da0073e9SAndroid Build Coastguard Worker """, 369*da0073e9SAndroid Build Coastguard Worker ) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker def test_multi_out(self): 372*da0073e9SAndroid Build Coastguard Worker def f(x): 373*da0073e9SAndroid Build Coastguard Worker # aminmax.out returns a tuple of tensors. 374*da0073e9SAndroid Build Coastguard Worker # functionalization should properly handle the tuple. 375*da0073e9SAndroid Build Coastguard Worker out_min = torch.empty(4) 376*da0073e9SAndroid Build Coastguard Worker out_max = torch.empty(4) 377*da0073e9SAndroid Build Coastguard Worker torch.aminmax(x, dim=0, out=(out_max, out_min)) 378*da0073e9SAndroid Build Coastguard Worker return out_max 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.arange(8, dtype=torch.float32)) 381*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.arange(8, dtype=torch.float32)) 382*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 383*da0073e9SAndroid Build Coastguard Worker logs, 384*da0073e9SAndroid Build Coastguard Worker """\ 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 389*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None 390*da0073e9SAndroid Build Coastguard Worker empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None 391*da0073e9SAndroid Build Coastguard Worker aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None 392*da0073e9SAndroid Build Coastguard Worker getitem = aminmax[0] 393*da0073e9SAndroid Build Coastguard Worker getitem_1 = aminmax[1]; aminmax = getitem_1 = None 394*da0073e9SAndroid Build Coastguard Worker return getitem 395*da0073e9SAndroid Build Coastguard Worker """, 396*da0073e9SAndroid Build Coastguard Worker ) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 399*da0073e9SAndroid Build Coastguard Worker f, 400*da0073e9SAndroid Build Coastguard Worker torch.arange(8, dtype=torch.float32), 401*da0073e9SAndroid Build Coastguard Worker reapply_views=True, 402*da0073e9SAndroid Build Coastguard Worker run_reinplace=True, 403*da0073e9SAndroid Build Coastguard Worker ) 404*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 405*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 406*da0073e9SAndroid Build Coastguard Worker """\ 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 411*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None 412*da0073e9SAndroid Build Coastguard Worker empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None 413*da0073e9SAndroid Build Coastguard Worker aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None 414*da0073e9SAndroid Build Coastguard Worker getitem = aminmax[0] 415*da0073e9SAndroid Build Coastguard Worker getitem_1 = aminmax[1]; aminmax = getitem_1 = None 416*da0073e9SAndroid Build Coastguard Worker return getitem 417*da0073e9SAndroid Build Coastguard Worker """, 418*da0073e9SAndroid Build Coastguard Worker ) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def test_tensor_ctr(self): 421*da0073e9SAndroid Build Coastguard Worker def f(x): 422*da0073e9SAndroid Build Coastguard Worker y = torch.tensor((1, 2, 3)) 423*da0073e9SAndroid Build Coastguard Worker z = y.view(-1) 424*da0073e9SAndroid Build Coastguard Worker z.add_(1) 425*da0073e9SAndroid Build Coastguard Worker return y 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker inpt = torch.arange(3, dtype=torch.float32) 428*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, inpt) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, inpt) 431*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 432*da0073e9SAndroid Build Coastguard Worker logs, 433*da0073e9SAndroid Build Coastguard Worker """\ 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 438*da0073e9SAndroid Build Coastguard Worker _tensor_constant0 = self._tensor_constant0 439*da0073e9SAndroid Build Coastguard Worker lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 440*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None 441*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None 442*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None 443*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]); view_copy_2 = None 444*da0073e9SAndroid Build Coastguard Worker return view_copy_1 445*da0073e9SAndroid Build Coastguard Worker """, 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True) 449*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 450*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 451*da0073e9SAndroid Build Coastguard Worker """\ 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 456*da0073e9SAndroid Build Coastguard Worker _tensor_constant0 = self._tensor_constant0 457*da0073e9SAndroid Build Coastguard Worker lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 458*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None 459*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(view, 1); add = None 460*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(view, [3]); view = None 461*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(view_1, [-1]); view_2 = None 462*da0073e9SAndroid Build Coastguard Worker return view_1 463*da0073e9SAndroid Build Coastguard Worker """, 464*da0073e9SAndroid Build Coastguard Worker ) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_correct_strides(self): 467*da0073e9SAndroid Build Coastguard Worker def f(a): 468*da0073e9SAndroid Build Coastguard Worker # This test requires that *_scatter ops are able to return 469*da0073e9SAndroid Build Coastguard Worker # non-contiguous tensors. 470*da0073e9SAndroid Build Coastguard Worker b = a.clone()[:, 1] 471*da0073e9SAndroid Build Coastguard Worker c = torch.ones_like(b, dtype=torch.bool) 472*da0073e9SAndroid Build Coastguard Worker d = b.masked_fill_(c, 0) 473*da0073e9SAndroid Build Coastguard Worker return d 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker def test_tensor_list_mixed_functional_nonfunctional(self): 478*da0073e9SAndroid Build Coastguard Worker nonfunctional_tensor = torch.ones(2, dtype=torch.long) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def f(x): 481*da0073e9SAndroid Build Coastguard Worker # simple test: 1 view op, 1 inplace op 482*da0073e9SAndroid Build Coastguard Worker functional_tensor = torch.ones(2, dtype=torch.long) 483*da0073e9SAndroid Build Coastguard Worker out = x[functional_tensor, nonfunctional_tensor] 484*da0073e9SAndroid Build Coastguard Worker return out 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker out = f(torch.ones(2, 2)) 487*da0073e9SAndroid Build Coastguard Worker out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)( 488*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2) 489*da0073e9SAndroid Build Coastguard Worker ) 490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_functional) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_non_view(self): 493*da0073e9SAndroid Build Coastguard Worker def f(x): 494*da0073e9SAndroid Build Coastguard Worker # test for the case where we functionalize an inplace op on the other tensor - not a view. 495*da0073e9SAndroid Build Coastguard Worker # This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased. 496*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 497*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 498*da0073e9SAndroid Build Coastguard Worker x.add_(tmp) 499*da0073e9SAndroid Build Coastguard Worker return y 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 502*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 503*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 504*da0073e9SAndroid Build Coastguard Worker logs, 505*da0073e9SAndroid Build Coastguard Worker """\ 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 510*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 511*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); view_copy = None 512*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None 513*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None 514*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 515*da0073e9SAndroid Build Coastguard Worker return view_copy_1 516*da0073e9SAndroid Build Coastguard Worker """, 517*da0073e9SAndroid Build Coastguard Worker ) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 520*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 523*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 524*da0073e9SAndroid Build Coastguard Worker """\ 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 529*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 530*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(arg0_1, [4, 2]); view = None 531*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None 532*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None 533*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 534*da0073e9SAndroid Build Coastguard Worker return view_1 535*da0073e9SAndroid Build Coastguard Worker """, 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker # Some ops that are mutable are neither inplace nor out= ops. 539*da0073e9SAndroid Build Coastguard Worker # They also need special handling. 540*da0073e9SAndroid Build Coastguard Worker def test_mutable_op_not_inplace_or_other(self): 541*da0073e9SAndroid Build Coastguard Worker def f(x): 542*da0073e9SAndroid Build Coastguard Worker return torch._fused_moving_avg_obs_fq_helper( 543*da0073e9SAndroid Build Coastguard Worker x, x, x, x, x, x, x, 1.0, 0, 1, 0 544*da0073e9SAndroid Build Coastguard Worker ) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(1)) 547*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 548*da0073e9SAndroid Build Coastguard Worker logs, 549*da0073e9SAndroid Build Coastguard Worker """\ 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 554*da0073e9SAndroid Build Coastguard Worker _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) 555*da0073e9SAndroid Build Coastguard Worker getitem = _fused_moving_avg_obs_fq_helper_functional[0] 556*da0073e9SAndroid Build Coastguard Worker getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] 557*da0073e9SAndroid Build Coastguard Worker getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]; getitem_2 = None 558*da0073e9SAndroid Build Coastguard Worker getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]; getitem_3 = None 559*da0073e9SAndroid Build Coastguard Worker getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]; getitem_4 = None 560*da0073e9SAndroid Build Coastguard Worker getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None 561*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = copy_ = None 562*da0073e9SAndroid Build Coastguard Worker return (getitem, getitem_1) 563*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 564*da0073e9SAndroid Build Coastguard Worker ) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker def test_as_strided(self): 567*da0073e9SAndroid Build Coastguard Worker def f(x): 568*da0073e9SAndroid Build Coastguard Worker y = x.as_strided((2,), (2,), 1) 569*da0073e9SAndroid Build Coastguard Worker y.add_(1) 570*da0073e9SAndroid Build Coastguard Worker return x 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(9)) 573*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(9)) 574*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 575*da0073e9SAndroid Build Coastguard Worker logs, 576*da0073e9SAndroid Build Coastguard Worker """\ 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 581*da0073e9SAndroid Build Coastguard Worker as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1) 582*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None 583*da0073e9SAndroid Build Coastguard Worker as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None 584*da0073e9SAndroid Build Coastguard Worker as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1); as_strided_copy_1 = None 585*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None 586*da0073e9SAndroid Build Coastguard Worker return as_strided_scatter 587*da0073e9SAndroid Build Coastguard Worker """, 588*da0073e9SAndroid Build Coastguard Worker ) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 591*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 592*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=False 593*da0073e9SAndroid Build Coastguard Worker ) 594*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 595*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 596*da0073e9SAndroid Build Coastguard Worker """\ 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 601*da0073e9SAndroid Build Coastguard Worker as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1) 602*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 603*da0073e9SAndroid Build Coastguard Worker as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None 604*da0073e9SAndroid Build Coastguard Worker as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1); as_strided_1 = None 605*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None 606*da0073e9SAndroid Build Coastguard Worker return as_strided_scatter 607*da0073e9SAndroid Build Coastguard Worker """, 608*da0073e9SAndroid Build Coastguard Worker ) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker def test_tensor_list_composite(self): 611*da0073e9SAndroid Build Coastguard Worker def f(x): 612*da0073e9SAndroid Build Coastguard Worker # Test an op with TensorList input 613*da0073e9SAndroid Build Coastguard Worker y = torch.block_diag(x, x) 614*da0073e9SAndroid Build Coastguard Worker return y 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 617*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 618*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 619*da0073e9SAndroid Build Coastguard Worker logs, 620*da0073e9SAndroid Build Coastguard Worker """\ 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 625*da0073e9SAndroid Build Coastguard Worker block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None 626*da0073e9SAndroid Build Coastguard Worker return block_diag 627*da0073e9SAndroid Build Coastguard Worker """, 628*da0073e9SAndroid Build Coastguard Worker ) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 631*da0073e9SAndroid Build Coastguard Worker def f(x): 632*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0) 633*da0073e9SAndroid Build Coastguard Worker torch.cat((x,), out=out) 634*da0073e9SAndroid Build Coastguard Worker return out 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 637*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 638*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 639*da0073e9SAndroid Build Coastguard Worker logs, 640*da0073e9SAndroid Build Coastguard Worker """\ 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 645*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None 646*da0073e9SAndroid Build Coastguard Worker cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None 647*da0073e9SAndroid Build Coastguard Worker return cat 648*da0073e9SAndroid Build Coastguard Worker """, 649*da0073e9SAndroid Build Coastguard Worker ) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 652*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 653*da0073e9SAndroid Build Coastguard Worker ) 654*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 655*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 656*da0073e9SAndroid Build Coastguard Worker """\ 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 661*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None 662*da0073e9SAndroid Build Coastguard Worker cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None 663*da0073e9SAndroid Build Coastguard Worker return cat 664*da0073e9SAndroid Build Coastguard Worker """, 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker def test_diagonal(self): 668*da0073e9SAndroid Build Coastguard Worker def f(x): 669*da0073e9SAndroid Build Coastguard Worker # test: view ops that take a subset of the original tensor (select/diagonal) 670*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2) 671*da0073e9SAndroid Build Coastguard Worker y = x.clone().diagonal() 672*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 673*da0073e9SAndroid Build Coastguard Worker z = x * x 674*da0073e9SAndroid Build Coastguard Worker return z 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 677*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 678*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 679*da0073e9SAndroid Build Coastguard Worker logs, 680*da0073e9SAndroid Build Coastguard Worker """\ 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 685*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 686*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(arg0_1) 687*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(clone) 688*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 689*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None 690*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = diagonal_copy_1 = None 691*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 692*da0073e9SAndroid Build Coastguard Worker return mul 693*da0073e9SAndroid Build Coastguard Worker """, 694*da0073e9SAndroid Build Coastguard Worker ) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 697*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 698*da0073e9SAndroid Build Coastguard Worker ) 699*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 700*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 701*da0073e9SAndroid Build Coastguard Worker """\ 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 706*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 707*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(arg0_1) 708*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(clone) 709*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = add = None 710*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = diagonal_1 = None 711*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 712*da0073e9SAndroid Build Coastguard Worker return mul 713*da0073e9SAndroid Build Coastguard Worker """, 714*da0073e9SAndroid Build Coastguard Worker ) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker def test_diagonal_mutated_input(self): 717*da0073e9SAndroid Build Coastguard Worker def f(x): 718*da0073e9SAndroid Build Coastguard Worker # simple test: there are pending updates afterwards, which the test syncs manually 719*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2) 720*da0073e9SAndroid Build Coastguard Worker y = x.diagonal() 721*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 722*da0073e9SAndroid Build Coastguard Worker return x 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 725*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, x) 726*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 727*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 728*da0073e9SAndroid Build Coastguard Worker logs, 729*da0073e9SAndroid Build Coastguard Worker """\ 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 734*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 735*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1) 736*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 737*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None 738*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None 739*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None 740*da0073e9SAndroid Build Coastguard Worker return diagonal_scatter 741*da0073e9SAndroid Build Coastguard Worker """, 742*da0073e9SAndroid Build Coastguard Worker ) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 745*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 746*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=False 747*da0073e9SAndroid Build Coastguard Worker ) 748*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 749*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 750*da0073e9SAndroid Build Coastguard Worker """\ 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 755*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 756*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(arg0_1) 757*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 758*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None 759*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter); diagonal_1 = None 760*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None 761*da0073e9SAndroid Build Coastguard Worker return diagonal_scatter 762*da0073e9SAndroid Build Coastguard Worker """, 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def test_channels_last_contiguous(self): 766*da0073e9SAndroid Build Coastguard Worker def f(x): 767*da0073e9SAndroid Build Coastguard Worker return x.contiguous(memory_format=torch.channels_last) 768*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2) 769*da0073e9SAndroid Build Coastguard Worker y = x.diagonal() 770*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 771*da0073e9SAndroid Build Coastguard Worker return x 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2) 774*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, x) 775*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, x).strip() 776*da0073e9SAndroid Build Coastguard Worker # There should be no clone in the graph 777*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 778*da0073e9SAndroid Build Coastguard Worker logs, 779*da0073e9SAndroid Build Coastguard Worker """\ 780*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 781*da0073e9SAndroid Build Coastguard Worker return arg0_1""", 782*da0073e9SAndroid Build Coastguard Worker ) 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker def test_split(self): 785*da0073e9SAndroid Build Coastguard Worker def f(x): 786*da0073e9SAndroid Build Coastguard Worker # test: view ops that return multiple tensors (split) 787*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2) 788*da0073e9SAndroid Build Coastguard Worker y1, y2 = x.split(2) 789*da0073e9SAndroid Build Coastguard Worker y3 = y2.diagonal() 790*da0073e9SAndroid Build Coastguard Worker y3.add_(tmp) 791*da0073e9SAndroid Build Coastguard Worker z = x * x 792*da0073e9SAndroid Build Coastguard Worker return y3 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 795*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 796*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 797*da0073e9SAndroid Build Coastguard Worker logs, 798*da0073e9SAndroid Build Coastguard Worker """\ 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 803*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 804*da0073e9SAndroid Build Coastguard Worker split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2) 805*da0073e9SAndroid Build Coastguard Worker getitem = split_copy[0]; getitem = None 806*da0073e9SAndroid Build Coastguard Worker getitem_1 = split_copy[1]; split_copy = None 807*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None 808*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 809*da0073e9SAndroid Build Coastguard Worker split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2) 810*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_copy_1[0]; getitem_2 = None 811*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_copy_1[1]; split_copy_1 = None 812*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None 813*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None 814*da0073e9SAndroid Build Coastguard Worker split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2) 815*da0073e9SAndroid Build Coastguard Worker getitem_4 = split_copy_2[0]; getitem_4 = None 816*da0073e9SAndroid Build Coastguard Worker getitem_5 = split_copy_2[1]; split_copy_2 = None 817*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None 818*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 819*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 820*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_1 821*da0073e9SAndroid Build Coastguard Worker """, 822*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 825*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 826*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 827*da0073e9SAndroid Build Coastguard Worker ) 828*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 829*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 830*da0073e9SAndroid Build Coastguard Worker """\ 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 835*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 836*da0073e9SAndroid Build Coastguard Worker split = torch.ops.aten.split.Tensor(arg0_1, 2) 837*da0073e9SAndroid Build Coastguard Worker getitem = split[0]; getitem = None 838*da0073e9SAndroid Build Coastguard Worker getitem_1 = split[1]; split = None 839*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None 840*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 841*da0073e9SAndroid Build Coastguard Worker split_1 = torch.ops.aten.split.Tensor(arg0_1, 2) 842*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_1[0]; getitem_2 = None 843*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_1[1]; split_1 = None 844*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None 845*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None 846*da0073e9SAndroid Build Coastguard Worker split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2) 847*da0073e9SAndroid Build Coastguard Worker getitem_4 = split_2[0]; getitem_4 = None 848*da0073e9SAndroid Build Coastguard Worker getitem_5 = split_2[1]; split_2 = None 849*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None 850*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 851*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 852*da0073e9SAndroid Build Coastguard Worker return diagonal_1 853*da0073e9SAndroid Build Coastguard Worker """, 854*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def test_split_with_sizes(self): 857*da0073e9SAndroid Build Coastguard Worker def f(x): 858*da0073e9SAndroid Build Coastguard Worker # test: view ops that return multiple tensors (split_with_sizes) 859*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2) 860*da0073e9SAndroid Build Coastguard Worker y1, y2 = x.split_with_sizes([2, 2]) 861*da0073e9SAndroid Build Coastguard Worker y3 = y1.diagonal() 862*da0073e9SAndroid Build Coastguard Worker y3.add_(tmp) 863*da0073e9SAndroid Build Coastguard Worker z = x * x 864*da0073e9SAndroid Build Coastguard Worker return y3 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 867*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 868*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 869*da0073e9SAndroid Build Coastguard Worker logs, 870*da0073e9SAndroid Build Coastguard Worker """\ 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 875*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 876*da0073e9SAndroid Build Coastguard Worker split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) 877*da0073e9SAndroid Build Coastguard Worker getitem = split_with_sizes_copy[0] 878*da0073e9SAndroid Build Coastguard Worker getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = getitem_1 = None 879*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None 880*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 881*da0073e9SAndroid Build Coastguard Worker split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) 882*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_with_sizes_copy_1[0] 883*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = getitem_3 = None 884*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None 885*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None 886*da0073e9SAndroid Build Coastguard Worker split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2]) 887*da0073e9SAndroid Build Coastguard Worker getitem_4 = split_with_sizes_copy_2[0] 888*da0073e9SAndroid Build Coastguard Worker getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = getitem_5 = None 889*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None 890*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 891*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 892*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_1 893*da0073e9SAndroid Build Coastguard Worker """, 894*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 897*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 898*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 899*da0073e9SAndroid Build Coastguard Worker ) 900*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 901*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 902*da0073e9SAndroid Build Coastguard Worker """\ 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 907*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 908*da0073e9SAndroid Build Coastguard Worker split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) 909*da0073e9SAndroid Build Coastguard Worker getitem = split_with_sizes[0] 910*da0073e9SAndroid Build Coastguard Worker getitem_1 = split_with_sizes[1]; split_with_sizes = getitem_1 = None 911*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None 912*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 913*da0073e9SAndroid Build Coastguard Worker split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) 914*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_with_sizes_1[0] 915*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = getitem_3 = None 916*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None 917*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None 918*da0073e9SAndroid Build Coastguard Worker split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2]) 919*da0073e9SAndroid Build Coastguard Worker getitem_4 = split_with_sizes_2[0] 920*da0073e9SAndroid Build Coastguard Worker getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = getitem_5 = None 921*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None 922*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 923*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 924*da0073e9SAndroid Build Coastguard Worker return diagonal_1 925*da0073e9SAndroid Build Coastguard Worker """, 926*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 929*da0073e9SAndroid Build Coastguard Worker def f(x): 930*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4) 931*da0073e9SAndroid Build Coastguard Worker x.transpose_(1, 0) 932*da0073e9SAndroid Build Coastguard Worker y = x[0:2] 933*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 934*da0073e9SAndroid Build Coastguard Worker return x 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 937*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 938*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 939*da0073e9SAndroid Build Coastguard Worker logs, 940*da0073e9SAndroid Build Coastguard Worker """\ 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 945*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 946*da0073e9SAndroid Build Coastguard Worker transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 947*da0073e9SAndroid Build Coastguard Worker slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2); transpose_copy = None 948*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(slice_copy, ones); slice_copy = ones = None 949*da0073e9SAndroid Build Coastguard Worker transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 950*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None 951*da0073e9SAndroid Build Coastguard Worker transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None 952*da0073e9SAndroid Build Coastguard Worker transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 953*da0073e9SAndroid Build Coastguard Worker slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = slice_copy_1 = None 954*da0073e9SAndroid Build Coastguard Worker transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 955*da0073e9SAndroid Build Coastguard Worker return transpose_copy_4 956*da0073e9SAndroid Build Coastguard Worker """, 957*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 960*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 961*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 962*da0073e9SAndroid Build Coastguard Worker ) 963*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 964*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 965*da0073e9SAndroid Build Coastguard Worker """\ 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 970*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 971*da0073e9SAndroid Build Coastguard Worker transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 972*da0073e9SAndroid Build Coastguard Worker slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2); transpose = None 973*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(slice_1, ones); slice_1 = ones = None 974*da0073e9SAndroid Build Coastguard Worker transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 975*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None 976*da0073e9SAndroid Build Coastguard Worker transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None 977*da0073e9SAndroid Build Coastguard Worker transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 978*da0073e9SAndroid Build Coastguard Worker slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = slice_2 = None 979*da0073e9SAndroid Build Coastguard Worker transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 980*da0073e9SAndroid Build Coastguard Worker return transpose_4 981*da0073e9SAndroid Build Coastguard Worker """, 982*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker def test_view_inplace(self): 985*da0073e9SAndroid Build Coastguard Worker def f(x): 986*da0073e9SAndroid Build Coastguard Worker # test: view + inplace op (transpose_) 987*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4) 988*da0073e9SAndroid Build Coastguard Worker x.transpose_(1, 0) 989*da0073e9SAndroid Build Coastguard Worker y = x[0] 990*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 991*da0073e9SAndroid Build Coastguard Worker return x 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 994*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 995*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 996*da0073e9SAndroid Build Coastguard Worker logs, 997*da0073e9SAndroid Build Coastguard Worker """\ 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1002*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1003*da0073e9SAndroid Build Coastguard Worker transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 1004*da0073e9SAndroid Build Coastguard Worker select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None 1005*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None 1006*da0073e9SAndroid Build Coastguard Worker transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 1007*da0073e9SAndroid Build Coastguard Worker select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None 1008*da0073e9SAndroid Build Coastguard Worker transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None 1009*da0073e9SAndroid Build Coastguard Worker transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 1010*da0073e9SAndroid Build Coastguard Worker select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = select_copy_1 = None 1011*da0073e9SAndroid Build Coastguard Worker transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 1012*da0073e9SAndroid Build Coastguard Worker return transpose_copy_4 1013*da0073e9SAndroid Build Coastguard Worker """, 1014*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 1017*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1018*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 1019*da0073e9SAndroid Build Coastguard Worker ) 1020*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1021*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1022*da0073e9SAndroid Build Coastguard Worker """\ 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1027*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1028*da0073e9SAndroid Build Coastguard Worker transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 1029*da0073e9SAndroid Build Coastguard Worker select = torch.ops.aten.select.int(transpose, 0, 0); transpose = None 1030*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(select, ones); select = ones = None 1031*da0073e9SAndroid Build Coastguard Worker transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 1032*da0073e9SAndroid Build Coastguard Worker select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None 1033*da0073e9SAndroid Build Coastguard Worker transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None 1034*da0073e9SAndroid Build Coastguard Worker transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 1035*da0073e9SAndroid Build Coastguard Worker select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = select_1 = None 1036*da0073e9SAndroid Build Coastguard Worker transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 1037*da0073e9SAndroid Build Coastguard Worker return transpose_4 1038*da0073e9SAndroid Build Coastguard Worker """, 1039*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 1042*da0073e9SAndroid Build Coastguard Worker def f(x): 1043*da0073e9SAndroid Build Coastguard Worker # test: view + inplace op (transpose_) 1044*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4) 1045*da0073e9SAndroid Build Coastguard Worker x.transpose_(1, 0) 1046*da0073e9SAndroid Build Coastguard Worker y, _ = x.unbind(0) 1047*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 1048*da0073e9SAndroid Build Coastguard Worker return x 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 1051*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1052*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1053*da0073e9SAndroid Build Coastguard Worker logs, 1054*da0073e9SAndroid Build Coastguard Worker """\ 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1059*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1060*da0073e9SAndroid Build Coastguard Worker transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 1061*da0073e9SAndroid Build Coastguard Worker unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None 1062*da0073e9SAndroid Build Coastguard Worker getitem = unbind_copy[0] 1063*da0073e9SAndroid Build Coastguard Worker getitem_1 = unbind_copy[1]; unbind_copy = getitem_1 = None 1064*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1065*da0073e9SAndroid Build Coastguard Worker transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 1066*da0073e9SAndroid Build Coastguard Worker select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None 1067*da0073e9SAndroid Build Coastguard Worker transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None 1068*da0073e9SAndroid Build Coastguard Worker transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 1069*da0073e9SAndroid Build Coastguard Worker unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None 1070*da0073e9SAndroid Build Coastguard Worker getitem_2 = unbind_copy_1[0]; getitem_2 = None 1071*da0073e9SAndroid Build Coastguard Worker getitem_3 = unbind_copy_1[1]; unbind_copy_1 = getitem_3 = None 1072*da0073e9SAndroid Build Coastguard Worker transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 1073*da0073e9SAndroid Build Coastguard Worker return transpose_copy_4 1074*da0073e9SAndroid Build Coastguard Worker """, 1075*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1076*da0073e9SAndroid Build Coastguard Worker 1077*da0073e9SAndroid Build Coastguard Worker # NB: even with reapply_views=True, we expect to see scatter op 1078*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1079*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 1080*da0073e9SAndroid Build Coastguard Worker ) 1081*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1082*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1083*da0073e9SAndroid Build Coastguard Worker """\ 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1088*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1089*da0073e9SAndroid Build Coastguard Worker transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 1090*da0073e9SAndroid Build Coastguard Worker unbind = torch.ops.aten.unbind.int(transpose); transpose = None 1091*da0073e9SAndroid Build Coastguard Worker getitem = unbind[0] 1092*da0073e9SAndroid Build Coastguard Worker getitem_1 = unbind[1]; unbind = getitem_1 = None 1093*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1094*da0073e9SAndroid Build Coastguard Worker transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 1095*da0073e9SAndroid Build Coastguard Worker select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None 1096*da0073e9SAndroid Build Coastguard Worker transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None 1097*da0073e9SAndroid Build Coastguard Worker transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 1098*da0073e9SAndroid Build Coastguard Worker unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None 1099*da0073e9SAndroid Build Coastguard Worker getitem_2 = unbind_1[0]; getitem_2 = None 1100*da0073e9SAndroid Build Coastguard Worker getitem_3 = unbind_1[1]; unbind_1 = getitem_3 = None 1101*da0073e9SAndroid Build Coastguard Worker transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 1102*da0073e9SAndroid Build Coastguard Worker return transpose_4 1103*da0073e9SAndroid Build Coastguard Worker """, 1104*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker def test_optional_tensor_list(self): 1107*da0073e9SAndroid Build Coastguard Worker def f(x): 1108*da0073e9SAndroid Build Coastguard Worker # test: an operator that takes in a List[Optional[Tensor]] argument 1109*da0073e9SAndroid Build Coastguard Worker # (index_put) 1110*da0073e9SAndroid Build Coastguard Worker y = x.view(8) 1111*da0073e9SAndroid Build Coastguard Worker indices = torch.arange(4) 1112*da0073e9SAndroid Build Coastguard Worker values = torch.arange(4, dtype=y.dtype) 1113*da0073e9SAndroid Build Coastguard Worker y.index_put_((indices,), values, accumulate=False) 1114*da0073e9SAndroid Build Coastguard Worker return y 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 1117*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1118*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1119*da0073e9SAndroid Build Coastguard Worker logs, 1120*da0073e9SAndroid Build Coastguard Worker """\ 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1125*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [8]) 1126*da0073e9SAndroid Build Coastguard Worker arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False) 1127*da0073e9SAndroid Build Coastguard Worker arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 1128*da0073e9SAndroid Build Coastguard Worker index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None 1129*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None 1130*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8]) 1131*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 1132*da0073e9SAndroid Build Coastguard Worker return view_copy_2 1133*da0073e9SAndroid Build Coastguard Worker """, 1134*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker def test_scalars(self): 1137*da0073e9SAndroid Build Coastguard Worker def f(x): 1138*da0073e9SAndroid Build Coastguard Worker # test: the pass can handle scalar inputs properly 1139*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 1140*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 1141*da0073e9SAndroid Build Coastguard Worker y.add_(1) 1142*da0073e9SAndroid Build Coastguard Worker z = 2 * y 1143*da0073e9SAndroid Build Coastguard Worker z.div_(1) 1144*da0073e9SAndroid Build Coastguard Worker return z 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 1147*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1148*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1149*da0073e9SAndroid Build Coastguard Worker logs, 1150*da0073e9SAndroid Build Coastguard Worker """\ 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1155*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False); ones = None 1156*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) 1157*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None 1158*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 1159*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) 1160*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None 1161*da0073e9SAndroid Build Coastguard Worker div = torch.ops.aten.div.Tensor(mul, 1); mul = None 1162*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 1163*da0073e9SAndroid Build Coastguard Worker return div 1164*da0073e9SAndroid Build Coastguard Worker """, 1165*da0073e9SAndroid Build Coastguard Worker ) 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test does not work with TorchDynamo") 1168*da0073e9SAndroid Build Coastguard Worker def test_metadata_change(self): 1169*da0073e9SAndroid Build Coastguard Worker def f(x): 1170*da0073e9SAndroid Build Coastguard Worker # ops like ge_() are allowed to change the dtype of the input. 1171*da0073e9SAndroid Build Coastguard Worker # functionalization should pick up on that. 1172*da0073e9SAndroid Build Coastguard Worker y = x.clone() 1173*da0073e9SAndroid Build Coastguard Worker out = y.ge_(0) 1174*da0073e9SAndroid Build Coastguard Worker return out 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 1177*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1178*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1179*da0073e9SAndroid Build Coastguard Worker logs, 1180*da0073e9SAndroid Build Coastguard Worker """\ 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1185*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1186*da0073e9SAndroid Build Coastguard Worker ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None 1187*da0073e9SAndroid Build Coastguard Worker _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None 1188*da0073e9SAndroid Build Coastguard Worker return _to_copy 1189*da0073e9SAndroid Build Coastguard Worker """, 1190*da0073e9SAndroid Build Coastguard Worker ) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1193*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 1194*da0073e9SAndroid Build Coastguard Worker ) 1195*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1196*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1197*da0073e9SAndroid Build Coastguard Worker """\ 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1202*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1203*da0073e9SAndroid Build Coastguard Worker ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None 1204*da0073e9SAndroid Build Coastguard Worker _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None 1205*da0073e9SAndroid Build Coastguard Worker return _to_copy 1206*da0073e9SAndroid Build Coastguard Worker """, 1207*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test does not work with TorchDynamo") 1210*da0073e9SAndroid Build Coastguard Worker def test_metadata_change_out_op(self): 1211*da0073e9SAndroid Build Coastguard Worker def f(t, y): 1212*da0073e9SAndroid Build Coastguard Worker out_1 = torch.ones(1) 1213*da0073e9SAndroid Build Coastguard Worker return torch.add(t, y, out=out_1) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker inpt1, inpt2 = torch.tensor([1]), torch.tensor([1]) 1216*da0073e9SAndroid Build Coastguard Worker inpt1_func, inpt2_func = ( 1217*da0073e9SAndroid Build Coastguard Worker torch._to_functional_tensor(inpt1), 1218*da0073e9SAndroid Build Coastguard Worker torch._to_functional_tensor(inpt2), 1219*da0073e9SAndroid Build Coastguard Worker ) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker out_ref = f(inpt1, inpt2) 1222*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=True) 1223*da0073e9SAndroid Build Coastguard Worker try: 1224*da0073e9SAndroid Build Coastguard Worker out_functional = f(inpt1_func, inpt2_func) 1225*da0073e9SAndroid Build Coastguard Worker finally: 1226*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 1227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker def test_only_one_view(self): 1230*da0073e9SAndroid Build Coastguard Worker def f(x): 1231*da0073e9SAndroid Build Coastguard Worker # This tests that we don't have any unnecessary views in the trace. 1232*da0073e9SAndroid Build Coastguard Worker # If the input wasn't mutated, we don't need to regenerate it, 1233*da0073e9SAndroid Build Coastguard Worker # so there should be a total of 1 op in the output trace. 1234*da0073e9SAndroid Build Coastguard Worker return x.view(4, 2) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1237*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1238*da0073e9SAndroid Build Coastguard Worker logs, 1239*da0073e9SAndroid Build Coastguard Worker """\ 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1244*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None 1245*da0073e9SAndroid Build Coastguard Worker return view_copy 1246*da0073e9SAndroid Build Coastguard Worker """, 1247*da0073e9SAndroid Build Coastguard Worker ) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker def test_everything(self): 1250*da0073e9SAndroid Build Coastguard Worker def f(x): 1251*da0073e9SAndroid Build Coastguard Worker # test: everything 1252*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(2, 2) 1253*da0073e9SAndroid Build Coastguard Worker x2 = x + x 1254*da0073e9SAndroid Build Coastguard Worker y = x2.view(8) 1255*da0073e9SAndroid Build Coastguard Worker z0 = y.reshape(2, 4) 1256*da0073e9SAndroid Build Coastguard Worker z1 = z0.transpose(1, 0) 1257*da0073e9SAndroid Build Coastguard Worker z1.unsqueeze_(0) 1258*da0073e9SAndroid Build Coastguard Worker z1.squeeze_() 1259*da0073e9SAndroid Build Coastguard Worker z2, z3 = z1.split(2) 1260*da0073e9SAndroid Build Coastguard Worker z2.add_(tmp) 1261*da0073e9SAndroid Build Coastguard Worker z4 = z0[0] + z2.reshape(4) 1262*da0073e9SAndroid Build Coastguard Worker return z2 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2)) 1265*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2)) 1266*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1267*da0073e9SAndroid Build Coastguard Worker logs, 1268*da0073e9SAndroid Build Coastguard Worker """\ 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker 1272*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1273*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) 1274*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1275*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(add, [8]) 1276*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None 1277*da0073e9SAndroid Build Coastguard Worker transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0) 1278*da0073e9SAndroid Build Coastguard Worker unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None 1279*da0073e9SAndroid Build Coastguard Worker squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None 1280*da0073e9SAndroid Build Coastguard Worker split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None 1281*da0073e9SAndroid Build Coastguard Worker getitem = split_copy[0] 1282*da0073e9SAndroid Build Coastguard Worker getitem_1 = split_copy[1]; split_copy = getitem_1 = None 1283*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1284*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None 1285*da0073e9SAndroid Build Coastguard Worker view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None 1286*da0073e9SAndroid Build Coastguard Worker transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0); view_copy_3 = None 1287*da0073e9SAndroid Build Coastguard Worker unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None 1288*da0073e9SAndroid Build Coastguard Worker squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None 1289*da0073e9SAndroid Build Coastguard Worker slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = add_1 = None 1290*da0073e9SAndroid Build Coastguard Worker unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None 1291*da0073e9SAndroid Build Coastguard Worker squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None 1292*da0073e9SAndroid Build Coastguard Worker transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None 1293*da0073e9SAndroid Build Coastguard Worker view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None 1294*da0073e9SAndroid Build Coastguard Worker view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]); view_copy_4 = None 1295*da0073e9SAndroid Build Coastguard Worker view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8]) 1296*da0073e9SAndroid Build Coastguard Worker view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]); view_copy_6 = None 1297*da0073e9SAndroid Build Coastguard Worker transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0); view_copy_7 = None 1298*da0073e9SAndroid Build Coastguard Worker unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None 1299*da0073e9SAndroid Build Coastguard Worker squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None 1300*da0073e9SAndroid Build Coastguard Worker split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None 1301*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_copy_1[0] 1302*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_copy_1[1]; split_copy_1 = getitem_3 = None 1303*da0073e9SAndroid Build Coastguard Worker select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = select_copy = None 1304*da0073e9SAndroid Build Coastguard Worker view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]); view_copy_8 = None 1305*da0073e9SAndroid Build Coastguard Worker view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8]) 1306*da0073e9SAndroid Build Coastguard Worker view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None 1307*da0073e9SAndroid Build Coastguard Worker select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None 1308*da0073e9SAndroid Build Coastguard Worker view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]); view_copy_5 = None 1309*da0073e9SAndroid Build Coastguard Worker view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]); view_copy_11 = None 1310*da0073e9SAndroid Build Coastguard Worker transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0); view_copy_12 = None 1311*da0073e9SAndroid Build Coastguard Worker unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0); transpose_copy_4 = None 1312*da0073e9SAndroid Build Coastguard Worker squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None 1313*da0073e9SAndroid Build Coastguard Worker split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None 1314*da0073e9SAndroid Build Coastguard Worker getitem_4 = split_copy_2[0] 1315*da0073e9SAndroid Build Coastguard Worker getitem_5 = split_copy_2[1]; split_copy_2 = getitem_5 = None 1316*da0073e9SAndroid Build Coastguard Worker view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None 1317*da0073e9SAndroid Build Coastguard Worker add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = add_2 = None 1318*da0073e9SAndroid Build Coastguard Worker return getitem_2 1319*da0073e9SAndroid Build Coastguard Worker """, 1320*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1323*da0073e9SAndroid Build Coastguard Worker f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 1324*da0073e9SAndroid Build Coastguard Worker ) 1325*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1326*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1327*da0073e9SAndroid Build Coastguard Worker """\ 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1332*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) 1333*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1334*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(add, [8]) 1335*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None 1336*da0073e9SAndroid Build Coastguard Worker transpose = torch.ops.aten.transpose.int(view_1, 1, 0) 1337*da0073e9SAndroid Build Coastguard Worker unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None 1338*da0073e9SAndroid Build Coastguard Worker squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None 1339*da0073e9SAndroid Build Coastguard Worker split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None 1340*da0073e9SAndroid Build Coastguard Worker getitem = split[0] 1341*da0073e9SAndroid Build Coastguard Worker getitem_1 = split[1]; split = getitem_1 = None 1342*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = add_1 = None 1343*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(add, [8]); add = None 1344*da0073e9SAndroid Build Coastguard Worker view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None 1345*da0073e9SAndroid Build Coastguard Worker transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None 1346*da0073e9SAndroid Build Coastguard Worker unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None 1347*da0073e9SAndroid Build Coastguard Worker squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None 1348*da0073e9SAndroid Build Coastguard Worker unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None 1349*da0073e9SAndroid Build Coastguard Worker squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None 1350*da0073e9SAndroid Build Coastguard Worker transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None 1351*da0073e9SAndroid Build Coastguard Worker view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None 1352*da0073e9SAndroid Build Coastguard Worker view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None 1353*da0073e9SAndroid Build Coastguard Worker view_6 = torch.ops.aten.view.default(view_5, [8]) 1354*da0073e9SAndroid Build Coastguard Worker view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None 1355*da0073e9SAndroid Build Coastguard Worker transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None 1356*da0073e9SAndroid Build Coastguard Worker unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0); transpose_3 = None 1357*da0073e9SAndroid Build Coastguard Worker squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None 1358*da0073e9SAndroid Build Coastguard Worker split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None 1359*da0073e9SAndroid Build Coastguard Worker getitem_2 = split_1[0] 1360*da0073e9SAndroid Build Coastguard Worker getitem_3 = split_1[1]; split_1 = getitem_3 = None 1361*da0073e9SAndroid Build Coastguard Worker select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = select = None 1362*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format) 1363*da0073e9SAndroid Build Coastguard Worker _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None 1364*da0073e9SAndroid Build Coastguard Worker view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None 1365*da0073e9SAndroid Build Coastguard Worker view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None 1366*da0073e9SAndroid Build Coastguard Worker select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None 1367*da0073e9SAndroid Build Coastguard Worker add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = add_2 = None 1368*da0073e9SAndroid Build Coastguard Worker return getitem_2 1369*da0073e9SAndroid Build Coastguard Worker """, 1370*da0073e9SAndroid Build Coastguard Worker ) 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker def test_reapply_views_simple(self): 1373*da0073e9SAndroid Build Coastguard Worker def f(x): 1374*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 1375*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 1376*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 1377*da0073e9SAndroid Build Coastguard Worker z = x * x 1378*da0073e9SAndroid Build Coastguard Worker return y 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True) 1381*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True) 1382*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1383*da0073e9SAndroid Build Coastguard Worker logs, 1384*da0073e9SAndroid Build Coastguard Worker """\ 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1389*da0073e9SAndroid Build Coastguard Worker ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 1390*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(arg0_1, [4, 2]) 1391*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 1392*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 1393*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(view_1, [4, 2]) 1394*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None 1395*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None 1396*da0073e9SAndroid Build Coastguard Worker return view_2 1397*da0073e9SAndroid Build Coastguard Worker """, 1398*da0073e9SAndroid Build Coastguard Worker ) 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker def test_aliases_maintained_after_pass_when_reapplying_views(self): 1401*da0073e9SAndroid Build Coastguard Worker def f(x): 1402*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(4, 2) 1403*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 2) 1404*da0073e9SAndroid Build Coastguard Worker z = x.view(4, 2) 1405*da0073e9SAndroid Build Coastguard Worker y.add_(tmp) 1406*da0073e9SAndroid Build Coastguard Worker return y, z 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker input_functional = torch._to_functional_tensor(torch.ones(4, 2)) 1409*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=True) 1410*da0073e9SAndroid Build Coastguard Worker try: 1411*da0073e9SAndroid Build Coastguard Worker y, z = f(input_functional) 1412*da0073e9SAndroid Build Coastguard Worker torch._sync(y) 1413*da0073e9SAndroid Build Coastguard Worker torch._sync(z) 1414*da0073e9SAndroid Build Coastguard Worker finally: 1415*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker # y and z are aliases inside of the function, and that aliasing relationship should be maintained. 1418*da0073e9SAndroid Build Coastguard Worker _y = torch._from_functional_tensor(y) 1419*da0073e9SAndroid Build Coastguard Worker _z = torch._from_functional_tensor(z) 1420*da0073e9SAndroid Build Coastguard Worker self.assertTrue(are_aliased(_y, _z)) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker # copy_() gets its own test, because it used to be special cased in functionalization. 1423*da0073e9SAndroid Build Coastguard Worker # However, now it works pretty similar to other functional ops 1424*da0073e9SAndroid Build Coastguard Worker def test_copy_(self): 1425*da0073e9SAndroid Build Coastguard Worker def f(x): 1426*da0073e9SAndroid Build Coastguard Worker tmp = torch.zeros(2, 2) 1427*da0073e9SAndroid Build Coastguard Worker tmp_slice = tmp.diagonal() 1428*da0073e9SAndroid Build Coastguard Worker y = tmp_slice.copy_(x) 1429*da0073e9SAndroid Build Coastguard Worker z = y.add_(x) 1430*da0073e9SAndroid Build Coastguard Worker return z 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker # Test 1: copy_() with same dtype and shape 1433*da0073e9SAndroid Build Coastguard Worker # to() is a composite op that noops when the dtype/shape match, so nothing gets logged. 1434*da0073e9SAndroid Build Coastguard Worker # self.assert_functionalization(f, torch.ones(2)) 1435*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2)) 1436*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1437*da0073e9SAndroid Build Coastguard Worker logs, 1438*da0073e9SAndroid Build Coastguard Worker """\ 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1443*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1444*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1445*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1446*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1447*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1448*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1449*da0073e9SAndroid Build Coastguard Worker diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1450*da0073e9SAndroid Build Coastguard Worker diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1451*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_2 1452*da0073e9SAndroid Build Coastguard Worker """, 1453*da0073e9SAndroid Build Coastguard Worker ) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1456*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2), reapply_views=True, run_reinplace=True 1457*da0073e9SAndroid Build Coastguard Worker ) 1458*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1459*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1460*da0073e9SAndroid Build Coastguard Worker """\ 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1465*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1466*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(zeros) 1467*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1468*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1469*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1470*da0073e9SAndroid Build Coastguard Worker diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1471*da0073e9SAndroid Build Coastguard Worker return diagonal_2 1472*da0073e9SAndroid Build Coastguard Worker """, 1473*da0073e9SAndroid Build Coastguard Worker ) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker # Test 2: copy_() with same dtype, different shape 1476*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(1)) 1477*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(1)) 1478*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1479*da0073e9SAndroid Build Coastguard Worker logs, 1480*da0073e9SAndroid Build Coastguard Worker """\ 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1485*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1486*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1487*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1488*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1489*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1490*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1491*da0073e9SAndroid Build Coastguard Worker diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1492*da0073e9SAndroid Build Coastguard Worker diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1493*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_2 1494*da0073e9SAndroid Build Coastguard Worker """, 1495*da0073e9SAndroid Build Coastguard Worker ) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1498*da0073e9SAndroid Build Coastguard Worker f, torch.ones(1), reapply_views=True, run_reinplace=True 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1501*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1502*da0073e9SAndroid Build Coastguard Worker """\ 1503*da0073e9SAndroid Build Coastguard Worker 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1507*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1508*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(zeros) 1509*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1510*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1511*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1512*da0073e9SAndroid Build Coastguard Worker diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1513*da0073e9SAndroid Build Coastguard Worker return diagonal_2 1514*da0073e9SAndroid Build Coastguard Worker """, 1515*da0073e9SAndroid Build Coastguard Worker ) 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker # Test 3: copy_() with different dtype, same shape 1518*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, dtype=torch.long)) 1519*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, dtype=torch.long)) 1520*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1521*da0073e9SAndroid Build Coastguard Worker logs, 1522*da0073e9SAndroid Build Coastguard Worker """\ 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker 1526*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1527*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1528*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1529*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1530*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1531*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1532*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1533*da0073e9SAndroid Build Coastguard Worker diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1534*da0073e9SAndroid Build Coastguard Worker diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1535*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_2 1536*da0073e9SAndroid Build Coastguard Worker """, 1537*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1540*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True 1541*da0073e9SAndroid Build Coastguard Worker ) 1542*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1543*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1544*da0073e9SAndroid Build Coastguard Worker """\ 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1549*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1550*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(zeros) 1551*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1552*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1553*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1554*da0073e9SAndroid Build Coastguard Worker diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1555*da0073e9SAndroid Build Coastguard Worker return diagonal_2 1556*da0073e9SAndroid Build Coastguard Worker """, 1557*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker # Test 4: copy_() with different dtype, different shape 1560*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) 1561*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(1, dtype=torch.long)) 1562*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1563*da0073e9SAndroid Build Coastguard Worker logs, 1564*da0073e9SAndroid Build Coastguard Worker """\ 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1569*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1570*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1571*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1572*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1573*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1574*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1575*da0073e9SAndroid Build Coastguard Worker diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1576*da0073e9SAndroid Build Coastguard Worker diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1577*da0073e9SAndroid Build Coastguard Worker return diagonal_copy_2 1578*da0073e9SAndroid Build Coastguard Worker """, 1579*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1582*da0073e9SAndroid Build Coastguard Worker f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True 1583*da0073e9SAndroid Build Coastguard Worker ) 1584*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1585*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1586*da0073e9SAndroid Build Coastguard Worker """\ 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1591*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1592*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(zeros) 1593*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1594*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1595*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1596*da0073e9SAndroid Build Coastguard Worker diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1597*da0073e9SAndroid Build Coastguard Worker return diagonal_2 1598*da0073e9SAndroid Build Coastguard Worker """, 1599*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker def test_expand_symint(self): 1602*da0073e9SAndroid Build Coastguard Worker # Once some existing SymInt bugs are ironed out, we should update 1603*da0073e9SAndroid Build Coastguard Worker # this test to plumb FakeSymbolicTensors through it 1604*da0073e9SAndroid Build Coastguard Worker def f(x): 1605*da0073e9SAndroid Build Coastguard Worker return x.expand(x.size(0), x.size(1)) 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 1608*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 1609*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1610*da0073e9SAndroid Build Coastguard Worker logs, 1611*da0073e9SAndroid Build Coastguard Worker """\ 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1616*da0073e9SAndroid Build Coastguard Worker expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None 1617*da0073e9SAndroid Build Coastguard Worker return expand_copy 1618*da0073e9SAndroid Build Coastguard Worker """, 1619*da0073e9SAndroid Build Coastguard Worker ) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker def test_fill_(self): 1622*da0073e9SAndroid Build Coastguard Worker def f(x): 1623*da0073e9SAndroid Build Coastguard Worker y = x + x 1624*da0073e9SAndroid Build Coastguard Worker z = y.diagonal() 1625*da0073e9SAndroid Build Coastguard Worker z.fill_(0) 1626*da0073e9SAndroid Build Coastguard Worker return y 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 1629*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2, 2)) 1630*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1631*da0073e9SAndroid Build Coastguard Worker logs, 1632*da0073e9SAndroid Build Coastguard Worker """\ 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1637*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1638*da0073e9SAndroid Build Coastguard Worker diagonal_copy = torch.ops.aten.diagonal_copy.default(add) 1639*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None 1640*da0073e9SAndroid Build Coastguard Worker diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None 1641*da0073e9SAndroid Build Coastguard Worker diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None 1642*da0073e9SAndroid Build Coastguard Worker return diagonal_scatter 1643*da0073e9SAndroid Build Coastguard Worker """, 1644*da0073e9SAndroid Build Coastguard Worker ) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1647*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 1648*da0073e9SAndroid Build Coastguard Worker ) 1649*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1650*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1651*da0073e9SAndroid Build Coastguard Worker """\ 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1656*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1657*da0073e9SAndroid Build Coastguard Worker diagonal = torch.ops.aten.diagonal.default(add) 1658*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = fill = None 1659*da0073e9SAndroid Build Coastguard Worker diagonal_1 = torch.ops.aten.diagonal.default(add); diagonal_1 = None 1660*da0073e9SAndroid Build Coastguard Worker return add 1661*da0073e9SAndroid Build Coastguard Worker """, 1662*da0073e9SAndroid Build Coastguard Worker ) 1663*da0073e9SAndroid Build Coastguard Worker 1664*da0073e9SAndroid Build Coastguard Worker def test_resize_smaller(self): 1665*da0073e9SAndroid Build Coastguard Worker def f(w): 1666*da0073e9SAndroid Build Coastguard Worker # Resizing to a smaller size doesn't affect storage 1667*da0073e9SAndroid Build Coastguard Worker x = w + 1 1668*da0073e9SAndroid Build Coastguard Worker y = x.view(4, 4) 1669*da0073e9SAndroid Build Coastguard Worker y.resize_(3, 3) 1670*da0073e9SAndroid Build Coastguard Worker y2 = y.view(-1) 1671*da0073e9SAndroid Build Coastguard Worker y2.add_(1) 1672*da0073e9SAndroid Build Coastguard Worker z = y + 1 1673*da0073e9SAndroid Build Coastguard Worker return z 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(8, 2)) 1676*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(8, 2)) 1677*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1678*da0073e9SAndroid Build Coastguard Worker logs, 1679*da0073e9SAndroid Build Coastguard Worker """\ 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1684*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1685*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(add, [4, 4]) 1686*da0073e9SAndroid Build Coastguard Worker resize = torch.ops.aten.resize.default(view_copy, [3, 3]); resize = None 1687*da0073e9SAndroid Build Coastguard Worker as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None 1688*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None 1689*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None 1690*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None 1691*da0073e9SAndroid Build Coastguard Worker as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]); as_strided_copy_1 = None 1692*da0073e9SAndroid Build Coastguard Worker view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None 1693*da0073e9SAndroid Build Coastguard Worker as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None 1694*da0073e9SAndroid Build Coastguard Worker view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None 1695*da0073e9SAndroid Build Coastguard Worker view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) 1696*da0073e9SAndroid Build Coastguard Worker as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None 1697*da0073e9SAndroid Build Coastguard Worker view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = view_copy_6 = None 1698*da0073e9SAndroid Build Coastguard Worker view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None 1699*da0073e9SAndroid Build Coastguard Worker as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None 1700*da0073e9SAndroid Build Coastguard Worker add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None 1701*da0073e9SAndroid Build Coastguard Worker return add_2 1702*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 1703*da0073e9SAndroid Build Coastguard Worker ) 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1706*da0073e9SAndroid Build Coastguard Worker f, torch.ones(8, 2), reapply_views=True, run_reinplace=True 1707*da0073e9SAndroid Build Coastguard Worker ) 1708*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1709*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1710*da0073e9SAndroid Build Coastguard Worker """\ 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1715*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1716*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(add, [4, 4]) 1717*da0073e9SAndroid Build Coastguard Worker resize = torch.ops.aten.resize.default(view, [3, 3]); resize = None 1718*da0073e9SAndroid Build Coastguard Worker as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None 1719*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None 1720*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add_.Tensor(view_1, 1); add_1 = None 1721*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None 1722*da0073e9SAndroid Build Coastguard Worker as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]); as_strided_1 = None 1723*da0073e9SAndroid Build Coastguard Worker view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = view_3 = None 1724*da0073e9SAndroid Build Coastguard Worker view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None 1725*da0073e9SAndroid Build Coastguard Worker view_5 = torch.ops.aten.view.default(view_4, [4, 4]) 1726*da0073e9SAndroid Build Coastguard Worker as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None 1727*da0073e9SAndroid Build Coastguard Worker view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = view_6 = None 1728*da0073e9SAndroid Build Coastguard Worker view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None 1729*da0073e9SAndroid Build Coastguard Worker as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None 1730*da0073e9SAndroid Build Coastguard Worker add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1); add_2 = None 1731*da0073e9SAndroid Build Coastguard Worker return as_strided_3 1732*da0073e9SAndroid Build Coastguard Worker """, 1733*da0073e9SAndroid Build Coastguard Worker ) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker def test_resize_same_size_diff_rank(self): 1736*da0073e9SAndroid Build Coastguard Worker def f(x): 1737*da0073e9SAndroid Build Coastguard Worker y = x.clone() 1738*da0073e9SAndroid Build Coastguard Worker y.resize_(25, 5) 1739*da0073e9SAndroid Build Coastguard Worker return y 1740*da0073e9SAndroid Build Coastguard Worker 1741*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(5, 5, 5)) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker def test_resize_larger_valid(self): 1744*da0073e9SAndroid Build Coastguard Worker def f(x): 1745*da0073e9SAndroid Build Coastguard Worker y = x + 1 1746*da0073e9SAndroid Build Coastguard Worker # resizing a tensor to a larger size is only currently allowed 1747*da0073e9SAndroid Build Coastguard Worker # if the tensor-to-resize is not a view / has no outstanding views. 1748*da0073e9SAndroid Build Coastguard Worker # See Note [resize_() in functionalization pass] 1749*da0073e9SAndroid Build Coastguard Worker y.resize_(5, 5) 1750*da0073e9SAndroid Build Coastguard Worker y2 = y.view(25) 1751*da0073e9SAndroid Build Coastguard Worker # Do a mutation to ensure that aliases of the output of resize_() 1752*da0073e9SAndroid Build Coastguard Worker # propagate mutations correctly. 1753*da0073e9SAndroid Build Coastguard Worker # I'm using fill_ specifically because I want to guarantee that 1754*da0073e9SAndroid Build Coastguard Worker # none of the output has uninitialized memory at the end 1755*da0073e9SAndroid Build Coastguard Worker # (since these tests compare the data output against a reference impl) 1756*da0073e9SAndroid Build Coastguard Worker y2.fill_(1) 1757*da0073e9SAndroid Build Coastguard Worker out = y + 1 1758*da0073e9SAndroid Build Coastguard Worker return y, out 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(8, 2)) 1761*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(8, 2)) 1762*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1763*da0073e9SAndroid Build Coastguard Worker logs, 1764*da0073e9SAndroid Build Coastguard Worker """\ 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker 1768*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1769*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1770*da0073e9SAndroid Build Coastguard Worker resize = torch.ops.aten.resize.default(add, [5, 5]); add = None 1771*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None 1772*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None 1773*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None 1774*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]); view_copy_2 = None 1775*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1) 1776*da0073e9SAndroid Build Coastguard Worker return (view_copy_1, add_1) 1777*da0073e9SAndroid Build Coastguard Worker """, 1778*da0073e9SAndroid Build Coastguard Worker ) 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1781*da0073e9SAndroid Build Coastguard Worker f, torch.ones(8, 2), reapply_views=True, run_reinplace=True 1782*da0073e9SAndroid Build Coastguard Worker ) 1783*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1784*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1785*da0073e9SAndroid Build Coastguard Worker """\ 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker 1789*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1790*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1791*da0073e9SAndroid Build Coastguard Worker resize = torch.ops.aten.resize_.default(add, [5, 5]); resize = None 1792*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(add, [25]); add = None 1793*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill_.Scalar(view, 1); fill = None 1794*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None 1795*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(view_1, [25]); view_2 = None 1796*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add.Tensor(view_1, 1) 1797*da0073e9SAndroid Build Coastguard Worker return (view_1, add_1) 1798*da0073e9SAndroid Build Coastguard Worker """, 1799*da0073e9SAndroid Build Coastguard Worker ) 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker def test_resize_larger_invalid(self): 1802*da0073e9SAndroid Build Coastguard Worker def f(x): 1803*da0073e9SAndroid Build Coastguard Worker y = x + 1 1804*da0073e9SAndroid Build Coastguard Worker z = y.view(4, 4) 1805*da0073e9SAndroid Build Coastguard Worker # resizing a tensor to a larger size is only currently allowed 1806*da0073e9SAndroid Build Coastguard Worker # if the tensor-to-resize is not a view / has no outstanding views. 1807*da0073e9SAndroid Build Coastguard Worker # See Note [resize_() in functionalization pass] 1808*da0073e9SAndroid Build Coastguard Worker # This should fail 1809*da0073e9SAndroid Build Coastguard Worker z.resize_(5, 5) 1810*da0073e9SAndroid Build Coastguard Worker z2 = z.view(25) 1811*da0073e9SAndroid Build Coastguard Worker z2.fill_(1) 1812*da0073e9SAndroid Build Coastguard Worker out = z + 1 1813*da0073e9SAndroid Build Coastguard Worker return y, out 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1816*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1817*da0073e9SAndroid Build Coastguard Worker r"Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass", 1818*da0073e9SAndroid Build Coastguard Worker ): 1819*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(8, 2)) 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker def test_nested_functions_propagate_updates(self): 1822*da0073e9SAndroid Build Coastguard Worker def g(x): 1823*da0073e9SAndroid Build Coastguard Worker # Create a view of x 1824*da0073e9SAndroid Build Coastguard Worker y = x[0] 1825*da0073e9SAndroid Build Coastguard Worker y.add_(1) 1826*da0073e9SAndroid Build Coastguard Worker # The view, y, gets deallocated at the end of this function 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker def f(x): 1829*da0073e9SAndroid Build Coastguard Worker # Calling g(x) should mutate x 1830*da0073e9SAndroid Build Coastguard Worker g(x) 1831*da0073e9SAndroid Build Coastguard Worker # We expect x to be synced here, even though the alias created in g() has been deallocated! 1832*da0073e9SAndroid Build Coastguard Worker y = x + x 1833*da0073e9SAndroid Build Coastguard Worker return y 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2, 2)) 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker def test_mixed_wrappers_valid(self): 1838*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1839*da0073e9SAndroid Build Coastguard Worker z = x + y 1840*da0073e9SAndroid Build Coastguard Worker z.add_(1) 1841*da0073e9SAndroid Build Coastguard Worker return z 1842*da0073e9SAndroid Build Coastguard Worker 1843*da0073e9SAndroid Build Coastguard Worker x1_not_functional = LoggingTensor(torch.ones(4)) 1844*da0073e9SAndroid Build Coastguard Worker x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4))) 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 1847*da0073e9SAndroid Build Coastguard Worker y = f(x1_not_functional, x2_functional) 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker # Make sure that functionalization ran the "+" kernel 1850*da0073e9SAndroid Build Coastguard Worker # with a functional + non-functional tensor, and wrapped the output appropriately. 1851*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1852*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1853*da0073e9SAndroid Build Coastguard Worker """\ 1854*da0073e9SAndroid Build Coastguard Worker$2: f32[4] = torch._ops.aten.add.Tensor($0, $1) 1855*da0073e9SAndroid Build Coastguard Worker$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""", 1856*da0073e9SAndroid Build Coastguard Worker ) 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker def test_mixed_wrappers_invalid(self): 1859*da0073e9SAndroid Build Coastguard Worker x1_not_functional = torch.ones(4) 1860*da0073e9SAndroid Build Coastguard Worker x2_functional = torch._to_functional_tensor(torch.ones(4)) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker # When dealing with mixed functional + non functional tensors, 1863*da0073e9SAndroid Build Coastguard Worker # normal_tensor.add_(functional_tensor) is not valid 1864*da0073e9SAndroid Build Coastguard Worker # because normal_tensor would need to be "promoted" to a functional tensor. 1865*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1866*da0073e9SAndroid Build Coastguard Worker x1_not_functional.add_(x2_functional) 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker def test_index_mutation_on_non_input(self): 1869*da0073e9SAndroid Build Coastguard Worker def f(x): 1870*da0073e9SAndroid Build Coastguard Worker tmp = torch.zeros(10) 1871*da0073e9SAndroid Build Coastguard Worker tmp[5].fill_(1) 1872*da0073e9SAndroid Build Coastguard Worker return tmp 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization(f, torch.ones(2)) 1875*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs(f, torch.ones(2)) 1876*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1877*da0073e9SAndroid Build Coastguard Worker logs, 1878*da0073e9SAndroid Build Coastguard Worker """\ 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1883*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) 1884*da0073e9SAndroid Build Coastguard Worker select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5) 1885*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None 1886*da0073e9SAndroid Build Coastguard Worker select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None 1887*da0073e9SAndroid Build Coastguard Worker select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5); select_copy_1 = None 1888*da0073e9SAndroid Build Coastguard Worker return select_scatter 1889*da0073e9SAndroid Build Coastguard Worker """, 1890*da0073e9SAndroid Build Coastguard Worker ) # noqa: B950 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1893*da0073e9SAndroid Build Coastguard Worker f, torch.ones(2), reapply_views=True, run_reinplace=True 1894*da0073e9SAndroid Build Coastguard Worker ) 1895*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1896*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1897*da0073e9SAndroid Build Coastguard Worker """\ 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker 1901*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 1902*da0073e9SAndroid Build Coastguard Worker zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) 1903*da0073e9SAndroid Build Coastguard Worker select = torch.ops.aten.select.int(zeros, 0, 5) 1904*da0073e9SAndroid Build Coastguard Worker fill = torch.ops.aten.fill_.Scalar(select, 1); select = fill = None 1905*da0073e9SAndroid Build Coastguard Worker select_1 = torch.ops.aten.select.int(zeros, 0, 5); select_1 = None 1906*da0073e9SAndroid Build Coastguard Worker return zeros 1907*da0073e9SAndroid Build Coastguard Worker """, 1908*da0073e9SAndroid Build Coastguard Worker ) 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker def test_instance_norm(self): 1911*da0073e9SAndroid Build Coastguard Worker size = 100 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker def f(x, running_mean, running_var): 1914*da0073e9SAndroid Build Coastguard Worker with enable_python_dispatcher(): 1915*da0073e9SAndroid Build Coastguard Worker return torch.instance_norm( 1916*da0073e9SAndroid Build Coastguard Worker x, 1917*da0073e9SAndroid Build Coastguard Worker None, 1918*da0073e9SAndroid Build Coastguard Worker None, 1919*da0073e9SAndroid Build Coastguard Worker running_mean, 1920*da0073e9SAndroid Build Coastguard Worker running_var, 1921*da0073e9SAndroid Build Coastguard Worker use_input_stats=True, 1922*da0073e9SAndroid Build Coastguard Worker momentum=0.1, 1923*da0073e9SAndroid Build Coastguard Worker eps=1e-5, 1924*da0073e9SAndroid Build Coastguard Worker cudnn_enabled=False, 1925*da0073e9SAndroid Build Coastguard Worker ) 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization( 1928*da0073e9SAndroid Build Coastguard Worker f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) 1929*da0073e9SAndroid Build Coastguard Worker ) 1930*da0073e9SAndroid Build Coastguard Worker # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used 1931*da0073e9SAndroid Build Coastguard Worker # whereas on other platforms, the alias_copy's are before the view_copy's. 1932*da0073e9SAndroid Build Coastguard Worker # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment. 1933*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 1934*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs( 1935*da0073e9SAndroid Build Coastguard Worker f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) 1936*da0073e9SAndroid Build Coastguard Worker ) 1937*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1938*da0073e9SAndroid Build Coastguard Worker logs, 1939*da0073e9SAndroid Build Coastguard Worker """\ 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1): 1944*da0073e9SAndroid Build Coastguard Worker repeat = torch.ops.aten.repeat.default(arg1_1, [20]) 1945*da0073e9SAndroid Build Coastguard Worker repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) 1946*da0073e9SAndroid Build Coastguard Worker view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None 1947*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 1948*da0073e9SAndroid Build Coastguard Worker _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None 1949*da0073e9SAndroid Build Coastguard Worker getitem = _native_batch_norm_legit_functional[0] 1950*da0073e9SAndroid Build Coastguard Worker getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 1951*da0073e9SAndroid Build Coastguard Worker getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 1952*da0073e9SAndroid Build Coastguard Worker getitem_3 = _native_batch_norm_legit_functional[3] 1953*da0073e9SAndroid Build Coastguard Worker getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 1954*da0073e9SAndroid Build Coastguard Worker alias_copy = torch.ops.aten.alias_copy.default(arg1_1) 1955*da0073e9SAndroid Build Coastguard Worker view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); view_copy_1 = None 1956*da0073e9SAndroid Build Coastguard Worker view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None 1957*da0073e9SAndroid Build Coastguard Worker mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None 1958*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None 1959*da0073e9SAndroid Build Coastguard Worker alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None 1960*da0073e9SAndroid Build Coastguard Worker alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1); alias_copy_2 = None 1961*da0073e9SAndroid Build Coastguard Worker alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1) 1962*da0073e9SAndroid Build Coastguard Worker view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); view_copy_3 = None 1963*da0073e9SAndroid Build Coastguard Worker view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None 1964*da0073e9SAndroid Build Coastguard Worker mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None 1965*da0073e9SAndroid Build Coastguard Worker copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None 1966*da0073e9SAndroid Build Coastguard Worker alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None 1967*da0073e9SAndroid Build Coastguard Worker alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4); alias_copy_5 = None 1968*da0073e9SAndroid Build Coastguard Worker view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None 1969*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = copy_ = None 1970*da0073e9SAndroid Build Coastguard Worker copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = copy__1 = None 1971*da0073e9SAndroid Build Coastguard Worker return view_copy_5 1972*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 1973*da0073e9SAndroid Build Coastguard Worker ) 1974*da0073e9SAndroid Build Coastguard Worker 1975*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 1976*da0073e9SAndroid Build Coastguard Worker f, 1977*da0073e9SAndroid Build Coastguard Worker torch.randn(20, size, 35, 45), 1978*da0073e9SAndroid Build Coastguard Worker torch.zeros(size), 1979*da0073e9SAndroid Build Coastguard Worker torch.ones(size), 1980*da0073e9SAndroid Build Coastguard Worker reapply_views=True, 1981*da0073e9SAndroid Build Coastguard Worker run_reinplace=True, 1982*da0073e9SAndroid Build Coastguard Worker ) 1983*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1984*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 1985*da0073e9SAndroid Build Coastguard Worker """\ 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1): 1990*da0073e9SAndroid Build Coastguard Worker repeat = torch.ops.aten.repeat.default(arg1_1, [20]) 1991*da0073e9SAndroid Build Coastguard Worker repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) 1992*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None 1993*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 1994*da0073e9SAndroid Build Coastguard Worker _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None 1995*da0073e9SAndroid Build Coastguard Worker getitem = _native_batch_norm_legit_functional[0] 1996*da0073e9SAndroid Build Coastguard Worker getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 1997*da0073e9SAndroid Build Coastguard Worker getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 1998*da0073e9SAndroid Build Coastguard Worker getitem_3 = _native_batch_norm_legit_functional[3] 1999*da0073e9SAndroid Build Coastguard Worker getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2000*da0073e9SAndroid Build Coastguard Worker alias = torch.ops.aten.alias.default(arg1_1) 2001*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]); view_1 = None 2002*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None 2003*da0073e9SAndroid Build Coastguard Worker mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None 2004*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None 2005*da0073e9SAndroid Build Coastguard Worker alias_1 = torch.ops.aten.alias.default(copy); copy = None 2006*da0073e9SAndroid Build Coastguard Worker alias_2 = torch.ops.aten.alias.default(alias_1); alias_2 = None 2007*da0073e9SAndroid Build Coastguard Worker alias_3 = torch.ops.aten.alias.default(arg2_1) 2008*da0073e9SAndroid Build Coastguard Worker view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]); view_3 = None 2009*da0073e9SAndroid Build Coastguard Worker view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None 2010*da0073e9SAndroid Build Coastguard Worker mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None 2011*da0073e9SAndroid Build Coastguard Worker copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None 2012*da0073e9SAndroid Build Coastguard Worker alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None 2013*da0073e9SAndroid Build Coastguard Worker alias_5 = torch.ops.aten.alias.default(alias_4); alias_5 = None 2014*da0073e9SAndroid Build Coastguard Worker view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None 2015*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = copy_ = None 2016*da0073e9SAndroid Build Coastguard Worker copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = copy__1 = None 2017*da0073e9SAndroid Build Coastguard Worker return view_5 2018*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 2019*da0073e9SAndroid Build Coastguard Worker ) 2020*da0073e9SAndroid Build Coastguard Worker 2021*da0073e9SAndroid Build Coastguard Worker def test_mutation_overlapping_mem(self): 2022*da0073e9SAndroid Build Coastguard Worker def fn(x): 2023*da0073e9SAndroid Build Coastguard Worker # x: (1, 5) 2024*da0073e9SAndroid Build Coastguard Worker t1 = torch.add(x, x) 2025*da0073e9SAndroid Build Coastguard Worker t2 = t1.unfold(1, 3, 2) 2026*da0073e9SAndroid Build Coastguard Worker t3 = t2.abs_() 2027*da0073e9SAndroid Build Coastguard Worker return t3 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2030*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2031*da0073e9SAndroid Build Coastguard Worker r"encountered a tensor being mutated that has internal overlap", 2032*da0073e9SAndroid Build Coastguard Worker ): 2033*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, 5) 2034*da0073e9SAndroid Build Coastguard Worker out = _functionalize(fn, reapply_views=True, crossref=False)(x) 2035*da0073e9SAndroid Build Coastguard Worker 2036*da0073e9SAndroid Build Coastguard Worker def test_batch_norm(self): 2037*da0073e9SAndroid Build Coastguard Worker def f(x, running_mean, running_var): 2038*da0073e9SAndroid Build Coastguard Worker with enable_python_dispatcher(): 2039*da0073e9SAndroid Build Coastguard Worker return torch.batch_norm( 2040*da0073e9SAndroid Build Coastguard Worker x, None, None, running_mean, running_var, True, 0.1, 1e-5, False 2041*da0073e9SAndroid Build Coastguard Worker ) 2042*da0073e9SAndroid Build Coastguard Worker 2043*da0073e9SAndroid Build Coastguard Worker self.assert_functionalization( 2044*da0073e9SAndroid Build Coastguard Worker f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) 2045*da0073e9SAndroid Build Coastguard Worker ) 2046*da0073e9SAndroid Build Coastguard Worker logs = self.get_logs( 2047*da0073e9SAndroid Build Coastguard Worker f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) 2048*da0073e9SAndroid Build Coastguard Worker ) 2049*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2050*da0073e9SAndroid Build Coastguard Worker logs, 2051*da0073e9SAndroid Build Coastguard Worker """\ 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker 2054*da0073e9SAndroid Build Coastguard Worker 2055*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1): 2056*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 2057*da0073e9SAndroid Build Coastguard Worker _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None 2058*da0073e9SAndroid Build Coastguard Worker getitem = _native_batch_norm_legit_functional[0] 2059*da0073e9SAndroid Build Coastguard Worker getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 2060*da0073e9SAndroid Build Coastguard Worker getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 2061*da0073e9SAndroid Build Coastguard Worker getitem_3 = _native_batch_norm_legit_functional[3] 2062*da0073e9SAndroid Build Coastguard Worker getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2063*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None 2064*da0073e9SAndroid Build Coastguard Worker copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None 2065*da0073e9SAndroid Build Coastguard Worker return getitem 2066*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 2067*da0073e9SAndroid Build Coastguard Worker ) 2068*da0073e9SAndroid Build Coastguard Worker 2069*da0073e9SAndroid Build Coastguard Worker reinplaced_logs = self.get_logs( 2070*da0073e9SAndroid Build Coastguard Worker f, 2071*da0073e9SAndroid Build Coastguard Worker torch.randn(20, 100, 35, 45), 2072*da0073e9SAndroid Build Coastguard Worker torch.zeros(100), 2073*da0073e9SAndroid Build Coastguard Worker torch.ones(100), 2074*da0073e9SAndroid Build Coastguard Worker reapply_views=True, 2075*da0073e9SAndroid Build Coastguard Worker run_reinplace=True, 2076*da0073e9SAndroid Build Coastguard Worker ) 2077*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2078*da0073e9SAndroid Build Coastguard Worker reinplaced_logs, 2079*da0073e9SAndroid Build Coastguard Worker """\ 2080*da0073e9SAndroid Build Coastguard Worker 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker 2083*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1, arg1_1, arg2_1): 2084*da0073e9SAndroid Build Coastguard Worker empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 2085*da0073e9SAndroid Build Coastguard Worker _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None 2086*da0073e9SAndroid Build Coastguard Worker getitem = _native_batch_norm_legit_functional[0] 2087*da0073e9SAndroid Build Coastguard Worker getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 2088*da0073e9SAndroid Build Coastguard Worker getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 2089*da0073e9SAndroid Build Coastguard Worker getitem_3 = _native_batch_norm_legit_functional[3] 2090*da0073e9SAndroid Build Coastguard Worker getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2091*da0073e9SAndroid Build Coastguard Worker copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None 2092*da0073e9SAndroid Build Coastguard Worker copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None 2093*da0073e9SAndroid Build Coastguard Worker return getitem 2094*da0073e9SAndroid Build Coastguard Worker """, # noqa: B950 2095*da0073e9SAndroid Build Coastguard Worker ) 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode 2098*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization(self): 2099*da0073e9SAndroid Build Coastguard Worker def f(x): 2100*da0073e9SAndroid Build Coastguard Worker x_view = x.view(-1) 2101*da0073e9SAndroid Build Coastguard Worker x.mul_(2) 2102*da0073e9SAndroid Build Coastguard Worker return x_view + 1 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker def f_functionalized(x): 2105*da0073e9SAndroid Build Coastguard Worker # Note [Disabling Functionalize TLS Above Python Functionalization] 2106*da0073e9SAndroid Build Coastguard Worker # This UX is pretty annoying (although python functionalization's main customer is AOTAutograd, 2107*da0073e9SAndroid Build Coastguard Worker # and is not really advertised as a user API). 2108*da0073e9SAndroid Build Coastguard Worker # We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode. 2109*da0073e9SAndroid Build Coastguard Worker # Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper. 2110*da0073e9SAndroid Build Coastguard Worker # Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default, 2111*da0073e9SAndroid Build Coastguard Worker # our FunctionalTensor will inherit the same keyset. 2112*da0073e9SAndroid Build Coastguard Worker # We don't have an easy way of directly mutating a tensor's keyset from python, 2113*da0073e9SAndroid Build Coastguard Worker # so globally disabling functionalization here is easier. 2114*da0073e9SAndroid Build Coastguard Worker maybe_disable = torch._C._ExcludeDispatchKeyGuard( 2115*da0073e9SAndroid Build Coastguard Worker torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 2116*da0073e9SAndroid Build Coastguard Worker ) 2117*da0073e9SAndroid Build Coastguard Worker with maybe_disable, FunctionalTensorMode(): 2118*da0073e9SAndroid Build Coastguard Worker x_wrapped = FunctionalTensor.to_functional(x) 2119*da0073e9SAndroid Build Coastguard Worker out_wrapped = f(x_wrapped) 2120*da0073e9SAndroid Build Coastguard Worker out_unwrapped = out_wrapped.elem 2121*da0073e9SAndroid Build Coastguard Worker torch._sync(out_unwrapped) 2122*da0073e9SAndroid Build Coastguard Worker return torch._from_functional_tensor(out_unwrapped) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf 2125*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) + 1 2126*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(f_functionalized)(x) 2127*da0073e9SAndroid Build Coastguard Worker # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a 2128*da0073e9SAndroid Build Coastguard Worker # DCE pass that will remove nodes like this later on. 2129*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2130*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 2131*da0073e9SAndroid Build Coastguard Worker """\ 2132*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1): 2133*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(x_1, [-1]); view = None 2134*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None 2135*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(mul, [-1]); view_1 = None 2136*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None 2137*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None 2138*da0073e9SAndroid Build Coastguard Worker return add""", 2139*da0073e9SAndroid Build Coastguard Worker ) 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_zero_tensor(self): 2142*da0073e9SAndroid Build Coastguard Worker def f(x): 2143*da0073e9SAndroid Build Coastguard Worker y = torch.ops.aten._efficientzerotensor([4]) 2144*da0073e9SAndroid Build Coastguard Worker out = x + y 2145*da0073e9SAndroid Build Coastguard Worker out.mul_(2) 2146*da0073e9SAndroid Build Coastguard Worker return out 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2149*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2150*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2151*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize( 2152*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2153*da0073e9SAndroid Build Coastguard Worker )(x) 2154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test_cpp) 2156*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(dispatch_functionalize(f))(x) 2157*da0073e9SAndroid Build Coastguard Worker fx_g_cpp = make_fx( 2158*da0073e9SAndroid Build Coastguard Worker _functionalize( 2159*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2160*da0073e9SAndroid Build Coastguard Worker ) 2161*da0073e9SAndroid Build Coastguard Worker )(x) 2162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_is_conj(self): 2165*da0073e9SAndroid Build Coastguard Worker def f(x): 2166*da0073e9SAndroid Build Coastguard Worker out = x.conj() 2167*da0073e9SAndroid Build Coastguard Worker return out, out.is_conj() 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64) 2170*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2171*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2172*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x) 2173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[0], out_test[0]) 2174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[1], out_test[1]) 2175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[0], out_test_cpp[0]) 2176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[1], out_test_cpp[1]) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_is_neg(self): 2179*da0073e9SAndroid Build Coastguard Worker def f(x): 2180*da0073e9SAndroid Build Coastguard Worker out = x.neg() 2181*da0073e9SAndroid Build Coastguard Worker return out, out.is_neg() 2182*da0073e9SAndroid Build Coastguard Worker 2183*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64) 2184*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2185*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2186*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x) 2187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[0], out_test[0]) 2188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[1], out_test[1]) 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[0], out_test_cpp[0]) 2190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref[1], out_test_cpp[1]) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_conj(self): 2193*da0073e9SAndroid Build Coastguard Worker def f(x): 2194*da0073e9SAndroid Build Coastguard Worker y = x.clone().conj() 2195*da0073e9SAndroid Build Coastguard Worker y.mul_(2) 2196*da0073e9SAndroid Build Coastguard Worker return torch.view_as_real(y.resolve_conj()) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64) 2199*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2200*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2201*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize( 2202*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2203*da0073e9SAndroid Build Coastguard Worker )(x) 2204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_test, out_test_cpp) 2206*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(dispatch_functionalize(f))(x) 2207*da0073e9SAndroid Build Coastguard Worker fx_g_cpp = make_fx( 2208*da0073e9SAndroid Build Coastguard Worker _functionalize( 2209*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2210*da0073e9SAndroid Build Coastguard Worker ) 2211*da0073e9SAndroid Build Coastguard Worker )(x) 2212*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2213*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 2214*da0073e9SAndroid Build Coastguard Worker """\ 2215*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 2216*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 2217*da0073e9SAndroid Build Coastguard Worker _conj = torch.ops.aten._conj.default(clone); clone = None 2218*da0073e9SAndroid Build Coastguard Worker clone_1 = torch.ops.aten.clone.default(_conj) 2219*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None 2220*da0073e9SAndroid Build Coastguard Worker clone_2 = torch.ops.aten.clone.default(_conj); _conj = None 2221*da0073e9SAndroid Build Coastguard Worker copy = torch.ops.aten.copy.default(clone_2, mul); clone_2 = mul = None 2222*da0073e9SAndroid Build Coastguard Worker _conj_1 = torch.ops.aten._conj.default(copy); copy = None 2223*da0073e9SAndroid Build Coastguard Worker _conj_2 = torch.ops.aten._conj.default(_conj_1); _conj_1 = None 2224*da0073e9SAndroid Build Coastguard Worker clone_3 = torch.ops.aten.clone.default(_conj_2); _conj_2 = None 2225*da0073e9SAndroid Build Coastguard Worker view_as_real = torch.ops.aten.view_as_real.default(clone_3); clone_3 = None 2226*da0073e9SAndroid Build Coastguard Worker return view_as_real""", 2227*da0073e9SAndroid Build Coastguard Worker ) 2228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2229*da0073e9SAndroid Build Coastguard Worker 2230*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_neg(self): 2231*da0073e9SAndroid Build Coastguard Worker def f(x): 2232*da0073e9SAndroid Build Coastguard Worker y = x._neg_view() 2233*da0073e9SAndroid Build Coastguard Worker z = y.resolve_neg() 2234*da0073e9SAndroid Build Coastguard Worker return z + 1 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2237*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2238*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2239*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize( 2240*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2241*da0073e9SAndroid Build Coastguard Worker )(x) 2242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test_cpp) 2244*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(dispatch_functionalize(f))(x) 2245*da0073e9SAndroid Build Coastguard Worker fx_g_cpp = make_fx( 2246*da0073e9SAndroid Build Coastguard Worker _functionalize( 2247*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2248*da0073e9SAndroid Build Coastguard Worker ) 2249*da0073e9SAndroid Build Coastguard Worker )(x) 2250*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2251*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 2252*da0073e9SAndroid Build Coastguard Worker """\ 2253*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 2254*da0073e9SAndroid Build Coastguard Worker _neg_view = torch.ops.aten._neg_view.default(arg0_1); arg0_1 = None 2255*da0073e9SAndroid Build Coastguard Worker clone = torch.ops.aten.clone.default(_neg_view); _neg_view = None 2256*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(clone, 1); clone = None 2257*da0073e9SAndroid Build Coastguard Worker return add""", 2258*da0073e9SAndroid Build Coastguard Worker ) 2259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2260*da0073e9SAndroid Build Coastguard Worker 2261*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_lift_fresh_storage(self): 2262*da0073e9SAndroid Build Coastguard Worker unlifted = torch.tensor([0.0]) 2263*da0073e9SAndroid Build Coastguard Worker 2264*da0073e9SAndroid Build Coastguard Worker maybe_disable = torch._C._ExcludeDispatchKeyGuard( 2265*da0073e9SAndroid Build Coastguard Worker torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 2266*da0073e9SAndroid Build Coastguard Worker ) 2267*da0073e9SAndroid Build Coastguard Worker with maybe_disable, FunctionalTensorMode(): 2268*da0073e9SAndroid Build Coastguard Worker lifted = torch.ops.aten.lift_fresh.default(unlifted) 2269*da0073e9SAndroid Build Coastguard Worker 2270*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage()) 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker def test_python_functionalization_lift_fresh(self): 2273*da0073e9SAndroid Build Coastguard Worker def f(x): 2274*da0073e9SAndroid Build Coastguard Worker tmp = torch.tensor([0.0]) 2275*da0073e9SAndroid Build Coastguard Worker return tmp + x 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2278*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 2279*da0073e9SAndroid Build Coastguard Worker out_test = dispatch_functionalize(f)(x) 2280*da0073e9SAndroid Build Coastguard Worker out_test_cpp = _functionalize( 2281*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2282*da0073e9SAndroid Build Coastguard Worker )(x) 2283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test_cpp) 2285*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(dispatch_functionalize(f))(x) 2286*da0073e9SAndroid Build Coastguard Worker fx_g_cpp = make_fx( 2287*da0073e9SAndroid Build Coastguard Worker _functionalize( 2288*da0073e9SAndroid Build Coastguard Worker f, reapply_views=True, crossref=False, skip_input_mutations=True 2289*da0073e9SAndroid Build Coastguard Worker ) 2290*da0073e9SAndroid Build Coastguard Worker )(x) 2291*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2292*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 2293*da0073e9SAndroid Build Coastguard Worker """\ 2294*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1): 2295*da0073e9SAndroid Build Coastguard Worker _tensor_constant0 = self._tensor_constant0 2296*da0073e9SAndroid Build Coastguard Worker lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 2297*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1); lift_fresh_copy = arg0_1 = None 2298*da0073e9SAndroid Build Coastguard Worker return add""", 2299*da0073e9SAndroid Build Coastguard Worker ) 2300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker 2303*da0073e9SAndroid Build Coastguard Worker@xfail_inherited_tests( 2304*da0073e9SAndroid Build Coastguard Worker [ 2305*da0073e9SAndroid Build Coastguard Worker "test_as_strided", 2306*da0073e9SAndroid Build Coastguard Worker "test_copy_", 2307*da0073e9SAndroid Build Coastguard Worker "test_diagonal", 2308*da0073e9SAndroid Build Coastguard Worker "test_diagonal_mutated_input", 2309*da0073e9SAndroid Build Coastguard Worker "test_everything", 2310*da0073e9SAndroid Build Coastguard Worker "test_fill_", 2311*da0073e9SAndroid Build Coastguard Worker "test_slice", 2312*da0073e9SAndroid Build Coastguard Worker "test_split", 2313*da0073e9SAndroid Build Coastguard Worker "test_split_with_sizes", 2314*da0073e9SAndroid Build Coastguard Worker "test_unbind", 2315*da0073e9SAndroid Build Coastguard Worker "test_view_clone_view_inplace", 2316*da0073e9SAndroid Build Coastguard Worker "test_view_inplace", 2317*da0073e9SAndroid Build Coastguard Worker ] 2318*da0073e9SAndroid Build Coastguard Worker) 2319*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 2320*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well" 2321*da0073e9SAndroid Build Coastguard Worker) 2322*da0073e9SAndroid Build Coastguard Workerclass TestCrossRefFunctionalization(TestFunctionalization): 2323*da0073e9SAndroid Build Coastguard Worker crossref = True 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2327*da0073e9SAndroid Build Coastguard Worker run_tests() 2328