1# Owner(s): ["module: codegen"] 2 3import unittest 4from contextlib import nullcontext 5 6import torch 7from torch._dispatch.python import ( 8 enable_crossref_functionalize, 9 enable_python_dispatcher, 10) 11from torch._subclasses.functional_tensor import ( 12 dispatch_functionalize, 13 FunctionalTensor, 14 FunctionalTensorMode, 15) 16from torch.fx.experimental.proxy_tensor import make_fx 17from torch.fx.passes.reinplace import reinplace 18from torch.multiprocessing.reductions import StorageWeakRef 19from torch.testing._internal.common_utils import ( 20 IS_WINDOWS, 21 run_tests, 22 skipIfTorchDynamo, 23 TEST_WITH_TORCHDYNAMO, 24 TestCase, 25 xfail_inherited_tests, 26) 27from torch.testing._internal.logging_tensor import capture_logs, LoggingTensor 28from torch.utils import _pytree as pytree 29from torch.utils._pytree import tree_map_only 30 31 32def are_aliased(x, y): 33 x_storage = StorageWeakRef(x.storage()) 34 y_storage = StorageWeakRef(y.storage()) 35 return x_storage == y_storage 36 37 38# We can unify testing and use functionalize() here instead 39# if/when functorch moves into core. 40# This is basically a crappy version of `functionalize()`. 41def _functionalize( 42 f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False 43): 44 def to_fun(t: torch.Tensor): 45 func_t = torch._to_functional_tensor(t) 46 func_t.requires_grad = t.requires_grad 47 return func_t 48 49 def wrapped(*inputs): 50 ctx = nullcontext() 51 if crossref: 52 ctx = enable_crossref_functionalize() 53 with ctx: 54 inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs) 55 torch._enable_functionalization(reapply_views=reapply_views) 56 try: 57 out = f(*inputs_functional) 58 finally: 59 torch._disable_functionalization() 60 flat_inputs = pytree.tree_leaves(inputs) 61 flat_inputs_functional = pytree.tree_leaves(inputs_functional) 62 63 for inpt, input_functional in zip(flat_inputs, flat_inputs_functional): 64 torch._sync(input_functional) 65 inpt_new = torch._from_functional_tensor(input_functional) 66 if inpt_new is not inpt and not skip_input_mutations: 67 # Existing deficiency in functionalize(): 68 # we don't correctly mutate input metadata (yet?) 69 if inpt_new.shape == inpt.shape: 70 inpt.copy_(inpt_new) 71 tree_map_only(torch.Tensor, torch._sync, out) 72 out_unwrapped = tree_map_only( 73 torch.Tensor, torch._from_functional_tensor, out 74 ) 75 return out_unwrapped 76 77 return wrapped 78 79 80@unittest.skipIf( 81 TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457" 82) 83class TestFunctionalization(TestCase): 84 crossref = False 85 86 def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False): 87 inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts) 88 traced_f = make_fx( 89 _functionalize(func, reapply_views=reapply_views, crossref=self.crossref) 90 )(*inpts) 91 if run_reinplace: 92 traced_f = reinplace(traced_f, *inpts_clone) 93 return traced_f.code 94 95 def assert_functionalization( 96 self, func, *inpts, reapply_views=False, mutated_input_metadata=False 97 ): 98 clones1 = tree_map_only(torch.Tensor, torch.clone, inpts) 99 clones2 = tree_map_only(torch.Tensor, torch.clone, inpts) 100 clones3 = tree_map_only(torch.Tensor, torch.clone, inpts) 101 102 # Compare outputs (and mutated inputs), with and without functionalization. 103 out_ref = func(*inpts) 104 out_functional = _functionalize( 105 func, reapply_views=reapply_views, crossref=self.crossref 106 )(*clones1) 107 108 # The reinplacing pass is only valid to run with reapply_views=True. 109 functional_func = make_fx( 110 _functionalize(func, reapply_views=True, crossref=self.crossref) 111 )(*clones2) 112 reinplace_func = reinplace(functional_func, *clones2) 113 114 # NOTE: for now, need to pass in fresh inputs here, because make_fx 115 # will directly mutate the inputs that you trace with. 116 # Once this is fixed we can clean this up. 117 out_reinplace = reinplace_func(*clones3) 118 119 # functionalize() deficiency: input metadata mutations aren't propagated properly, 120 # so we just need to skip checks here for the tests that exercise that. 121 if not mutated_input_metadata: 122 flat_inpts = pytree.tree_leaves(inpts) 123 flat_clones1 = pytree.tree_leaves(clones1) 124 flat_clones3 = pytree.tree_leaves(clones3) 125 for inpt, input_clone, input_clone3 in zip( 126 flat_inpts, flat_clones1, flat_clones3 127 ): 128 self.assertEqual( 129 inpt, input_clone 130 ) # input mutations should still occur 131 self.assertEqual(inpt, input_clone3) 132 133 # Handle tests with multi-tensor outputs 134 if isinstance(out_ref, tuple): 135 out_refs, out_functionals, out_reinplaces = ( 136 list(out_ref), 137 list(out_functional), 138 list(out_reinplace), 139 ) 140 else: 141 out_refs, out_functionals, out_reinplaces = ( 142 [out_ref], 143 [out_functional], 144 [out_reinplace], 145 ) 146 147 for out_ref_, out_functional_, out_reinplace_ in zip( 148 out_refs, out_functionals, out_reinplaces 149 ): 150 self.assertEqual(out_ref_, out_functional_) 151 self.assertEqual(out_ref_, out_reinplace_) 152 153 def test_save_for_backwards_segfault(self): 154 inp = torch._to_functional_tensor( 155 LoggingTensor(torch.randn(2, 2)) 156 ).requires_grad_(True) 157 inp.exp() 158 159 def test_multiple_views_of_same_base(self): 160 def f(x): 161 y = x.view(-1) 162 z = x.view(-1) 163 x.add_(1) 164 # y should have been updated. 165 y2 = y + 1 166 # z should have been updated too. 167 z2 = z + 1 168 return z2 169 170 self.assert_functionalization(f, torch.ones(4)) 171 172 def test_freeze(self): 173 def f(x): 174 y = x.clone() 175 z = y[0] 176 torch._freeze_functional_tensor(y) 177 x.add_(1) 178 self.assertRaises(RuntimeError, lambda: y.add_(1)) 179 self.assertRaises(RuntimeError, lambda: z.add_(1)) 180 return z 181 182 _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3)) 183 184 def test_copy_stride_mismatch(self): 185 def f(x): 186 y = torch.empty_strided((2, 2), (5, 1)) 187 y.copy_(x) 188 return y 189 190 r = _functionalize(f, reapply_views=True, crossref=self.crossref)( 191 torch.ones(2, 2) 192 ) 193 self.assertEqual(r.stride(), (5, 1)) 194 195 def test_set_(self): 196 def f(x): 197 y = torch.ones(2) 198 y.set_(x.storage()) 199 return y 200 201 # We should probaby get the crossref test to work, 202 # but fixing it for Storage() objects is annoying. 203 r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2)) 204 self.assertEqual(str(r.device), "cpu") 205 206 def test_advanced_indexing(self): 207 def f(): 208 x = torch.zeros(3, 3) 209 idx = torch.tensor([0]) 210 val = torch.ones(3, 1) 211 x[:, idx] = val 212 return x 213 214 self.assert_functionalization(f) 215 216 def test_view_clone_view_inplace(self): 217 def f(input): 218 shape = [1, 1024, 128, 128] 219 input_reshaped = input.view(shape) 220 out = input_reshaped.clone() 221 r = out.view(input.shape) 222 r.relu_() 223 return r 224 225 def g(x): 226 loss = f(x).sum() 227 import torch.fx.traceback as fx_traceback 228 from torch._functorch.aot_autograd import ( 229 setup_stacktrace_preservation_hooks, 230 ) 231 232 setup_stacktrace_preservation_hooks([loss.grad_fn]) 233 with fx_traceback.preserve_node_meta(): 234 loss.backward() 235 return x.grad 236 237 with torch.autograd.detect_anomaly(check_nan=False): 238 logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True)) 239 self.assertExpectedInline( 240 logs, 241 """\ 242 243 244 245def forward(self, arg0_1): 246 view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]); arg0_1 = None 247 clone = torch.ops.aten.clone.default(view_copy); view_copy = None 248 view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]) 249 relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None 250 view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None 251 view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None 252 view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = view_copy_4 = None 253 sum_1 = torch.ops.aten.sum.default(view_copy_3) 254 ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None 255 expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None 256 view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None 257 new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) 258 copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None 259 view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); view_copy_6 = None 260 view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) 261 clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format) 262 threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None 263 copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None 264 view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None 265 view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_9 = None 266 view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None 267 detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = detach_copy = None 268 view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None 269 detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None 270 return detach_copy_1 271 """, 272 ) # noqa: B950 273 274 def test_simple(self): 275 def f(x): 276 # simple test: 1 view op, 1 inplace op 277 tmp = torch.ones(4, 2) 278 y = x.view(4, 2) 279 y.add_(tmp) 280 z = x * x 281 return y 282 283 self.assert_functionalization(f, torch.ones(4, 2)) 284 logs = self.get_logs(f, torch.ones(4, 2)) 285 self.assertExpectedInline( 286 logs, 287 """\ 288 289 290 291def forward(self, arg0_1): 292 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 293 view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) 294 add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None 295 view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 296 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) 297 mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1); mul = None 298 copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 299 return view_copy_2 300 """, 301 ) 302 303 reinplaced_logs = self.get_logs( 304 f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 305 ) 306 self.assertExpectedInline( 307 reinplaced_logs, 308 """\ 309 310 311 312def forward(self, arg0_1): 313 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 314 view = torch.ops.aten.view.default(arg0_1, [4, 2]) 315 add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 316 view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 317 view_2 = torch.ops.aten.view.default(view_1, [4, 2]) 318 mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None 319 copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None 320 return view_2 321 """, 322 ) 323 324 def test_simple_out(self): 325 def f(x): 326 tmp = torch.ones(4, 2) 327 y = x.view(4, 2) 328 # the out= tensor will get resized, since it has size=0 to start. 329 z = torch.empty(()) 330 torch.add(y, tmp, out=z) 331 w = z * z 332 return w 333 334 self.assert_functionalization(f, torch.ones(4, 2)) 335 logs = self.get_logs(f, torch.ones(4, 2)) 336 self.assertExpectedInline( 337 logs, 338 """\ 339 340 341 342def forward(self, arg0_1): 343 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 344 view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None 345 empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None 346 add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None 347 mul = torch.ops.aten.mul.Tensor(add, add); add = None 348 return mul 349 """, 350 ) 351 352 reinplaced_logs = self.get_logs( 353 f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 354 ) 355 self.assertExpectedInline( 356 reinplaced_logs, 357 """\ 358 359 360 361def forward(self, arg0_1): 362 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 363 view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None 364 empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None 365 add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 366 mul = torch.ops.aten.mul.Tensor(add, add); add = None 367 return mul 368 """, 369 ) 370 371 def test_multi_out(self): 372 def f(x): 373 # aminmax.out returns a tuple of tensors. 374 # functionalization should properly handle the tuple. 375 out_min = torch.empty(4) 376 out_max = torch.empty(4) 377 torch.aminmax(x, dim=0, out=(out_max, out_min)) 378 return out_max 379 380 self.assert_functionalization(f, torch.arange(8, dtype=torch.float32)) 381 logs = self.get_logs(f, torch.arange(8, dtype=torch.float32)) 382 self.assertExpectedInline( 383 logs, 384 """\ 385 386 387 388def forward(self, arg0_1): 389 empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None 390 empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None 391 aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None 392 getitem = aminmax[0] 393 getitem_1 = aminmax[1]; aminmax = getitem_1 = None 394 return getitem 395 """, 396 ) 397 398 reinplaced_logs = self.get_logs( 399 f, 400 torch.arange(8, dtype=torch.float32), 401 reapply_views=True, 402 run_reinplace=True, 403 ) 404 self.assertExpectedInline( 405 reinplaced_logs, 406 """\ 407 408 409 410def forward(self, arg0_1): 411 empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None 412 empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None 413 aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None 414 getitem = aminmax[0] 415 getitem_1 = aminmax[1]; aminmax = getitem_1 = None 416 return getitem 417 """, 418 ) 419 420 def test_tensor_ctr(self): 421 def f(x): 422 y = torch.tensor((1, 2, 3)) 423 z = y.view(-1) 424 z.add_(1) 425 return y 426 427 inpt = torch.arange(3, dtype=torch.float32) 428 self.assert_functionalization(f, inpt) 429 430 logs = self.get_logs(f, inpt) 431 self.assertExpectedInline( 432 logs, 433 """\ 434 435 436 437def forward(self, arg0_1): 438 _tensor_constant0 = self._tensor_constant0 439 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 440 view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None 441 add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None 442 view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None 443 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]); view_copy_2 = None 444 return view_copy_1 445 """, 446 ) 447 448 reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True) 449 self.assertExpectedInline( 450 reinplaced_logs, 451 """\ 452 453 454 455def forward(self, arg0_1): 456 _tensor_constant0 = self._tensor_constant0 457 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 458 view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None 459 add = torch.ops.aten.add_.Tensor(view, 1); add = None 460 view_1 = torch.ops.aten.view.default(view, [3]); view = None 461 view_2 = torch.ops.aten.view.default(view_1, [-1]); view_2 = None 462 return view_1 463 """, 464 ) 465 466 def test_advanced_indexing_correct_strides(self): 467 def f(a): 468 # This test requires that *_scatter ops are able to return 469 # non-contiguous tensors. 470 b = a.clone()[:, 1] 471 c = torch.ones_like(b, dtype=torch.bool) 472 d = b.masked_fill_(c, 0) 473 return d 474 475 self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True) 476 477 def test_tensor_list_mixed_functional_nonfunctional(self): 478 nonfunctional_tensor = torch.ones(2, dtype=torch.long) 479 480 def f(x): 481 # simple test: 1 view op, 1 inplace op 482 functional_tensor = torch.ones(2, dtype=torch.long) 483 out = x[functional_tensor, nonfunctional_tensor] 484 return out 485 486 out = f(torch.ones(2, 2)) 487 out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)( 488 torch.ones(2, 2) 489 ) 490 self.assertEqual(out, out_functional) 491 492 def test_inplace_on_non_view(self): 493 def f(x): 494 # test for the case where we functionalize an inplace op on the other tensor - not a view. 495 # This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased. 496 tmp = torch.ones(4, 2) 497 y = x.view(4, 2) 498 x.add_(tmp) 499 return y 500 501 self.assert_functionalization(f, torch.ones(4, 2)) 502 logs = self.get_logs(f, torch.ones(4, 2)) 503 self.assertExpectedInline( 504 logs, 505 """\ 506 507 508 509def forward(self, arg0_1): 510 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 511 view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); view_copy = None 512 add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None 513 copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None 514 view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 515 return view_copy_1 516 """, 517 ) 518 519 reinplaced_logs = self.get_logs( 520 f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 521 ) 522 self.assertExpectedInline( 523 reinplaced_logs, 524 """\ 525 526 527 528def forward(self, arg0_1): 529 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 530 view = torch.ops.aten.view.default(arg0_1, [4, 2]); view = None 531 add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None 532 copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None 533 view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 534 return view_1 535 """, 536 ) 537 538 # Some ops that are mutable are neither inplace nor out= ops. 539 # They also need special handling. 540 def test_mutable_op_not_inplace_or_other(self): 541 def f(x): 542 return torch._fused_moving_avg_obs_fq_helper( 543 x, x, x, x, x, x, x, 1.0, 0, 1, 0 544 ) 545 546 logs = self.get_logs(f, torch.ones(1)) 547 self.assertExpectedInline( 548 logs, 549 """\ 550 551 552 553def forward(self, arg0_1): 554 _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) 555 getitem = _fused_moving_avg_obs_fq_helper_functional[0] 556 getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] 557 getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]; getitem_2 = None 558 getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]; getitem_3 = None 559 getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]; getitem_4 = None 560 getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None 561 copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = copy_ = None 562 return (getitem, getitem_1) 563 """, # noqa: B950 564 ) 565 566 def test_as_strided(self): 567 def f(x): 568 y = x.as_strided((2,), (2,), 1) 569 y.add_(1) 570 return x 571 572 self.assert_functionalization(f, torch.ones(9)) 573 logs = self.get_logs(f, torch.ones(9)) 574 self.assertExpectedInline( 575 logs, 576 """\ 577 578 579 580def forward(self, arg0_1): 581 as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1) 582 add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None 583 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None 584 as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1); as_strided_copy_1 = None 585 copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None 586 return as_strided_scatter 587 """, 588 ) 589 590 # NB: even with reapply_views=True, we expect to see scatter op 591 reinplaced_logs = self.get_logs( 592 f, torch.ones(2, 2), reapply_views=True, run_reinplace=False 593 ) 594 self.assertExpectedInline( 595 reinplaced_logs, 596 """\ 597 598 599 600def forward(self, arg0_1): 601 as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1) 602 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None 603 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None 604 as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1); as_strided_1 = None 605 copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None 606 return as_strided_scatter 607 """, 608 ) 609 610 def test_tensor_list_composite(self): 611 def f(x): 612 # Test an op with TensorList input 613 y = torch.block_diag(x, x) 614 return y 615 616 self.assert_functionalization(f, torch.ones(2, 2)) 617 logs = self.get_logs(f, torch.ones(2, 2)) 618 self.assertExpectedInline( 619 logs, 620 """\ 621 622 623 624def forward(self, arg0_1): 625 block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None 626 return block_diag 627 """, 628 ) 629 630 def test_cat(self): 631 def f(x): 632 out = torch.empty(0) 633 torch.cat((x,), out=out) 634 return out 635 636 self.assert_functionalization(f, torch.ones(2, 2)) 637 logs = self.get_logs(f, torch.ones(2, 2)) 638 self.assertExpectedInline( 639 logs, 640 """\ 641 642 643 644def forward(self, arg0_1): 645 empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None 646 cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None 647 return cat 648 """, 649 ) 650 651 reinplaced_logs = self.get_logs( 652 f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 653 ) 654 self.assertExpectedInline( 655 reinplaced_logs, 656 """\ 657 658 659 660def forward(self, arg0_1): 661 empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None 662 cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None 663 return cat 664 """, 665 ) 666 667 def test_diagonal(self): 668 def f(x): 669 # test: view ops that take a subset of the original tensor (select/diagonal) 670 tmp = torch.ones(2) 671 y = x.clone().diagonal() 672 y.add_(tmp) 673 z = x * x 674 return z 675 676 self.assert_functionalization(f, torch.ones(2, 2)) 677 logs = self.get_logs(f, torch.ones(2, 2)) 678 self.assertExpectedInline( 679 logs, 680 """\ 681 682 683 684def forward(self, arg0_1): 685 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 686 clone = torch.ops.aten.clone.default(arg0_1) 687 diagonal_copy = torch.ops.aten.diagonal_copy.default(clone) 688 add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 689 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None 690 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = diagonal_copy_1 = None 691 mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 692 return mul 693 """, 694 ) 695 696 reinplaced_logs = self.get_logs( 697 f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 698 ) 699 self.assertExpectedInline( 700 reinplaced_logs, 701 """\ 702 703 704 705def forward(self, arg0_1): 706 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 707 clone = torch.ops.aten.clone.default(arg0_1) 708 diagonal = torch.ops.aten.diagonal.default(clone) 709 add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = add = None 710 diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = diagonal_1 = None 711 mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 712 return mul 713 """, 714 ) 715 716 def test_diagonal_mutated_input(self): 717 def f(x): 718 # simple test: there are pending updates afterwards, which the test syncs manually 719 tmp = torch.ones(2) 720 y = x.diagonal() 721 y.add_(tmp) 722 return x 723 724 x = torch.ones(2, 2) 725 self.assert_functionalization(f, x) 726 logs = self.get_logs(f, torch.ones(2, 2)) 727 self.assertExpectedInline( 728 logs, 729 """\ 730 731 732 733def forward(self, arg0_1): 734 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 735 diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1) 736 add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 737 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None 738 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None 739 copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None 740 return diagonal_scatter 741 """, 742 ) 743 744 # NB: even with reapply_views=True, we expect to see scatter op 745 reinplaced_logs = self.get_logs( 746 f, torch.ones(2, 2), reapply_views=True, run_reinplace=False 747 ) 748 self.assertExpectedInline( 749 reinplaced_logs, 750 """\ 751 752 753 754def forward(self, arg0_1): 755 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 756 diagonal = torch.ops.aten.diagonal.default(arg0_1) 757 add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 758 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None 759 diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter); diagonal_1 = None 760 copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None 761 return diagonal_scatter 762 """, 763 ) 764 765 def test_channels_last_contiguous(self): 766 def f(x): 767 return x.contiguous(memory_format=torch.channels_last) 768 tmp = torch.ones(2) 769 y = x.diagonal() 770 y.add_(tmp) 771 return x 772 773 x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2) 774 self.assert_functionalization(f, x) 775 logs = self.get_logs(f, x).strip() 776 # There should be no clone in the graph 777 self.assertExpectedInline( 778 logs, 779 """\ 780def forward(self, arg0_1): 781 return arg0_1""", 782 ) 783 784 def test_split(self): 785 def f(x): 786 # test: view ops that return multiple tensors (split) 787 tmp = torch.ones(2) 788 y1, y2 = x.split(2) 789 y3 = y2.diagonal() 790 y3.add_(tmp) 791 z = x * x 792 return y3 793 794 self.assert_functionalization(f, torch.ones(4, 2)) 795 logs = self.get_logs(f, torch.ones(4, 2)) 796 self.assertExpectedInline( 797 logs, 798 """\ 799 800 801 802def forward(self, arg0_1): 803 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 804 split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2) 805 getitem = split_copy[0]; getitem = None 806 getitem_1 = split_copy[1]; split_copy = None 807 diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None 808 add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 809 split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2) 810 getitem_2 = split_copy_1[0]; getitem_2 = None 811 getitem_3 = split_copy_1[1]; split_copy_1 = None 812 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None 813 slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None 814 split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2) 815 getitem_4 = split_copy_2[0]; getitem_4 = None 816 getitem_5 = split_copy_2[1]; split_copy_2 = None 817 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None 818 mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 819 copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 820 return diagonal_copy_1 821 """, 822 ) # noqa: B950 823 824 # NB: even with reapply_views=True, we expect to see scatter op 825 reinplaced_logs = self.get_logs( 826 f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 827 ) 828 self.assertExpectedInline( 829 reinplaced_logs, 830 """\ 831 832 833 834def forward(self, arg0_1): 835 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 836 split = torch.ops.aten.split.Tensor(arg0_1, 2) 837 getitem = split[0]; getitem = None 838 getitem_1 = split[1]; split = None 839 diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None 840 add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 841 split_1 = torch.ops.aten.split.Tensor(arg0_1, 2) 842 getitem_2 = split_1[0]; getitem_2 = None 843 getitem_3 = split_1[1]; split_1 = None 844 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None 845 slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None 846 split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2) 847 getitem_4 = split_2[0]; getitem_4 = None 848 getitem_5 = split_2[1]; split_2 = None 849 diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None 850 mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 851 copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 852 return diagonal_1 853 """, 854 ) # noqa: B950 855 856 def test_split_with_sizes(self): 857 def f(x): 858 # test: view ops that return multiple tensors (split_with_sizes) 859 tmp = torch.ones(2) 860 y1, y2 = x.split_with_sizes([2, 2]) 861 y3 = y1.diagonal() 862 y3.add_(tmp) 863 z = x * x 864 return y3 865 866 self.assert_functionalization(f, torch.ones(4, 2)) 867 logs = self.get_logs(f, torch.ones(4, 2)) 868 self.assertExpectedInline( 869 logs, 870 """\ 871 872 873 874def forward(self, arg0_1): 875 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 876 split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) 877 getitem = split_with_sizes_copy[0] 878 getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = getitem_1 = None 879 diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None 880 add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None 881 split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) 882 getitem_2 = split_with_sizes_copy_1[0] 883 getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = getitem_3 = None 884 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None 885 slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None 886 split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2]) 887 getitem_4 = split_with_sizes_copy_2[0] 888 getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = getitem_5 = None 889 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None 890 mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 891 copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 892 return diagonal_copy_1 893 """, 894 ) # noqa: B950 895 896 # NB: even with reapply_views=True, we expect to see scatter op 897 reinplaced_logs = self.get_logs( 898 f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 899 ) 900 self.assertExpectedInline( 901 reinplaced_logs, 902 """\ 903 904 905 906def forward(self, arg0_1): 907 ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) 908 split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) 909 getitem = split_with_sizes[0] 910 getitem_1 = split_with_sizes[1]; split_with_sizes = getitem_1 = None 911 diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None 912 add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None 913 split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) 914 getitem_2 = split_with_sizes_1[0] 915 getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = getitem_3 = None 916 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None 917 slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None 918 split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2]) 919 getitem_4 = split_with_sizes_2[0] 920 getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = getitem_5 = None 921 diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None 922 mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None 923 copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None 924 return diagonal_1 925 """, 926 ) # noqa: B950 927 928 def test_slice(self): 929 def f(x): 930 tmp = torch.ones(4) 931 x.transpose_(1, 0) 932 y = x[0:2] 933 y.add_(tmp) 934 return x 935 936 self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 937 logs = self.get_logs(f, torch.ones(4, 2)) 938 self.assertExpectedInline( 939 logs, 940 """\ 941 942 943 944def forward(self, arg0_1): 945 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 946 transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 947 slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2); transpose_copy = None 948 add = torch.ops.aten.add.Tensor(slice_copy, ones); slice_copy = ones = None 949 transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 950 slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None 951 transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None 952 transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 953 slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = slice_copy_1 = None 954 transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 955 return transpose_copy_4 956 """, 957 ) # noqa: B950 958 959 # NB: even with reapply_views=True, we expect to see scatter op 960 reinplaced_logs = self.get_logs( 961 f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 962 ) 963 self.assertExpectedInline( 964 reinplaced_logs, 965 """\ 966 967 968 969def forward(self, arg0_1): 970 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 971 transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 972 slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2); transpose = None 973 add = torch.ops.aten.add.Tensor(slice_1, ones); slice_1 = ones = None 974 transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 975 slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None 976 transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None 977 transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 978 slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = slice_2 = None 979 transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 980 return transpose_4 981 """, 982 ) # noqa: B950 983 984 def test_view_inplace(self): 985 def f(x): 986 # test: view + inplace op (transpose_) 987 tmp = torch.ones(4) 988 x.transpose_(1, 0) 989 y = x[0] 990 y.add_(tmp) 991 return x 992 993 self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 994 logs = self.get_logs(f, torch.ones(4, 2)) 995 self.assertExpectedInline( 996 logs, 997 """\ 998 999 1000 1001def forward(self, arg0_1): 1002 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1003 transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 1004 select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None 1005 add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None 1006 transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 1007 select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None 1008 transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None 1009 transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 1010 select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = select_copy_1 = None 1011 transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 1012 return transpose_copy_4 1013 """, 1014 ) # noqa: B950 1015 1016 # NB: even with reapply_views=True, we expect to see scatter op 1017 reinplaced_logs = self.get_logs( 1018 f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 1019 ) 1020 self.assertExpectedInline( 1021 reinplaced_logs, 1022 """\ 1023 1024 1025 1026def forward(self, arg0_1): 1027 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1028 transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 1029 select = torch.ops.aten.select.int(transpose, 0, 0); transpose = None 1030 add = torch.ops.aten.add.Tensor(select, ones); select = ones = None 1031 transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 1032 select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None 1033 transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None 1034 transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 1035 select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = select_1 = None 1036 transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 1037 return transpose_4 1038 """, 1039 ) # noqa: B950 1040 1041 def test_unbind(self): 1042 def f(x): 1043 # test: view + inplace op (transpose_) 1044 tmp = torch.ones(4) 1045 x.transpose_(1, 0) 1046 y, _ = x.unbind(0) 1047 y.add_(tmp) 1048 return x 1049 1050 self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) 1051 logs = self.get_logs(f, torch.ones(4, 2)) 1052 self.assertExpectedInline( 1053 logs, 1054 """\ 1055 1056 1057 1058def forward(self, arg0_1): 1059 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1060 transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) 1061 unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None 1062 getitem = unbind_copy[0] 1063 getitem_1 = unbind_copy[1]; unbind_copy = getitem_1 = None 1064 add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1065 transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None 1066 select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None 1067 transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None 1068 transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) 1069 unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None 1070 getitem_2 = unbind_copy_1[0]; getitem_2 = None 1071 getitem_3 = unbind_copy_1[1]; unbind_copy_1 = getitem_3 = None 1072 transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None 1073 return transpose_copy_4 1074 """, 1075 ) # noqa: B950 1076 1077 # NB: even with reapply_views=True, we expect to see scatter op 1078 reinplaced_logs = self.get_logs( 1079 f, torch.ones(4, 2), reapply_views=True, run_reinplace=False 1080 ) 1081 self.assertExpectedInline( 1082 reinplaced_logs, 1083 """\ 1084 1085 1086 1087def forward(self, arg0_1): 1088 ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) 1089 transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) 1090 unbind = torch.ops.aten.unbind.int(transpose); transpose = None 1091 getitem = unbind[0] 1092 getitem_1 = unbind[1]; unbind = getitem_1 = None 1093 add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1094 transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None 1095 select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None 1096 transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None 1097 transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) 1098 unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None 1099 getitem_2 = unbind_1[0]; getitem_2 = None 1100 getitem_3 = unbind_1[1]; unbind_1 = getitem_3 = None 1101 transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None 1102 return transpose_4 1103 """, 1104 ) # noqa: B950 1105 1106 def test_optional_tensor_list(self): 1107 def f(x): 1108 # test: an operator that takes in a List[Optional[Tensor]] argument 1109 # (index_put) 1110 y = x.view(8) 1111 indices = torch.arange(4) 1112 values = torch.arange(4, dtype=y.dtype) 1113 y.index_put_((indices,), values, accumulate=False) 1114 return y 1115 1116 self.assert_functionalization(f, torch.ones(4, 2)) 1117 logs = self.get_logs(f, torch.ones(4, 2)) 1118 self.assertExpectedInline( 1119 logs, 1120 """\ 1121 1122 1123 1124def forward(self, arg0_1): 1125 view_copy = torch.ops.aten.view_copy.default(arg0_1, [8]) 1126 arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False) 1127 arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 1128 index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None 1129 view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None 1130 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8]) 1131 copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 1132 return view_copy_2 1133 """, 1134 ) # noqa: B950 1135 1136 def test_scalars(self): 1137 def f(x): 1138 # test: the pass can handle scalar inputs properly 1139 tmp = torch.ones(4, 2) 1140 y = x.view(4, 2) 1141 y.add_(1) 1142 z = 2 * y 1143 z.div_(1) 1144 return z 1145 1146 self.assert_functionalization(f, torch.ones(4, 2)) 1147 logs = self.get_logs(f, torch.ones(4, 2)) 1148 self.assertExpectedInline( 1149 logs, 1150 """\ 1151 1152 1153 1154def forward(self, arg0_1): 1155 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False); ones = None 1156 view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) 1157 add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None 1158 view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 1159 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) 1160 mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None 1161 div = torch.ops.aten.div.Tensor(mul, 1); mul = None 1162 copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None 1163 return div 1164 """, 1165 ) 1166 1167 @skipIfTorchDynamo("Test does not work with TorchDynamo") 1168 def test_metadata_change(self): 1169 def f(x): 1170 # ops like ge_() are allowed to change the dtype of the input. 1171 # functionalization should pick up on that. 1172 y = x.clone() 1173 out = y.ge_(0) 1174 return out 1175 1176 self.assert_functionalization(f, torch.ones(4, 2)) 1177 logs = self.get_logs(f, torch.ones(4, 2)) 1178 self.assertExpectedInline( 1179 logs, 1180 """\ 1181 1182 1183 1184def forward(self, arg0_1): 1185 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1186 ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None 1187 _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None 1188 return _to_copy 1189 """, 1190 ) 1191 1192 reinplaced_logs = self.get_logs( 1193 f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 1194 ) 1195 self.assertExpectedInline( 1196 reinplaced_logs, 1197 """\ 1198 1199 1200 1201def forward(self, arg0_1): 1202 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1203 ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None 1204 _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None 1205 return _to_copy 1206 """, 1207 ) # noqa: B950 1208 1209 @skipIfTorchDynamo("Test does not work with TorchDynamo") 1210 def test_metadata_change_out_op(self): 1211 def f(t, y): 1212 out_1 = torch.ones(1) 1213 return torch.add(t, y, out=out_1) 1214 1215 inpt1, inpt2 = torch.tensor([1]), torch.tensor([1]) 1216 inpt1_func, inpt2_func = ( 1217 torch._to_functional_tensor(inpt1), 1218 torch._to_functional_tensor(inpt2), 1219 ) 1220 1221 out_ref = f(inpt1, inpt2) 1222 torch._enable_functionalization(reapply_views=True) 1223 try: 1224 out_functional = f(inpt1_func, inpt2_func) 1225 finally: 1226 torch._disable_functionalization() 1227 self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) 1228 1229 def test_only_one_view(self): 1230 def f(x): 1231 # This tests that we don't have any unnecessary views in the trace. 1232 # If the input wasn't mutated, we don't need to regenerate it, 1233 # so there should be a total of 1 op in the output trace. 1234 return x.view(4, 2) 1235 1236 logs = self.get_logs(f, torch.ones(4, 2)) 1237 self.assertExpectedInline( 1238 logs, 1239 """\ 1240 1241 1242 1243def forward(self, arg0_1): 1244 view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None 1245 return view_copy 1246 """, 1247 ) 1248 1249 def test_everything(self): 1250 def f(x): 1251 # test: everything 1252 tmp = torch.ones(2, 2) 1253 x2 = x + x 1254 y = x2.view(8) 1255 z0 = y.reshape(2, 4) 1256 z1 = z0.transpose(1, 0) 1257 z1.unsqueeze_(0) 1258 z1.squeeze_() 1259 z2, z3 = z1.split(2) 1260 z2.add_(tmp) 1261 z4 = z0[0] + z2.reshape(4) 1262 return z2 1263 1264 self.assert_functionalization(f, torch.ones(4, 2)) 1265 logs = self.get_logs(f, torch.ones(4, 2)) 1266 self.assertExpectedInline( 1267 logs, 1268 """\ 1269 1270 1271 1272def forward(self, arg0_1): 1273 ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) 1274 add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1275 view_copy = torch.ops.aten.view_copy.default(add, [8]) 1276 view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None 1277 transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0) 1278 unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None 1279 squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None 1280 split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None 1281 getitem = split_copy[0] 1282 getitem_1 = split_copy[1]; split_copy = getitem_1 = None 1283 add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None 1284 view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None 1285 view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None 1286 transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0); view_copy_3 = None 1287 unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None 1288 squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None 1289 slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = add_1 = None 1290 unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None 1291 squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None 1292 transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None 1293 view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None 1294 view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]); view_copy_4 = None 1295 view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8]) 1296 view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]); view_copy_6 = None 1297 transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0); view_copy_7 = None 1298 unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None 1299 squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None 1300 split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None 1301 getitem_2 = split_copy_1[0] 1302 getitem_3 = split_copy_1[1]; split_copy_1 = getitem_3 = None 1303 select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = select_copy = None 1304 view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]); view_copy_8 = None 1305 view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8]) 1306 view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None 1307 select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None 1308 view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]); view_copy_5 = None 1309 view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]); view_copy_11 = None 1310 transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0); view_copy_12 = None 1311 unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0); transpose_copy_4 = None 1312 squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None 1313 split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None 1314 getitem_4 = split_copy_2[0] 1315 getitem_5 = split_copy_2[1]; split_copy_2 = getitem_5 = None 1316 view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None 1317 add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = add_2 = None 1318 return getitem_2 1319 """, 1320 ) # noqa: B950 1321 1322 reinplaced_logs = self.get_logs( 1323 f, torch.ones(4, 2), reapply_views=True, run_reinplace=True 1324 ) 1325 self.assertExpectedInline( 1326 reinplaced_logs, 1327 """\ 1328 1329 1330 1331def forward(self, arg0_1): 1332 ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) 1333 add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1334 view = torch.ops.aten.view.default(add, [8]) 1335 view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None 1336 transpose = torch.ops.aten.transpose.int(view_1, 1, 0) 1337 unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None 1338 squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None 1339 split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None 1340 getitem = split[0] 1341 getitem_1 = split[1]; split = getitem_1 = None 1342 add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = add_1 = None 1343 view_2 = torch.ops.aten.view.default(add, [8]); add = None 1344 view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None 1345 transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None 1346 unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None 1347 squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None 1348 unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None 1349 squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None 1350 transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None 1351 view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None 1352 view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None 1353 view_6 = torch.ops.aten.view.default(view_5, [8]) 1354 view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None 1355 transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None 1356 unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0); transpose_3 = None 1357 squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None 1358 split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None 1359 getitem_2 = split_1[0] 1360 getitem_3 = split_1[1]; split_1 = getitem_3 = None 1361 select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = select = None 1362 clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format) 1363 _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None 1364 view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None 1365 view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None 1366 select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None 1367 add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = add_2 = None 1368 return getitem_2 1369 """, 1370 ) 1371 1372 def test_reapply_views_simple(self): 1373 def f(x): 1374 tmp = torch.ones(4, 2) 1375 y = x.view(4, 2) 1376 y.add_(tmp) 1377 z = x * x 1378 return y 1379 1380 self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True) 1381 logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True) 1382 self.assertExpectedInline( 1383 logs, 1384 """\ 1385 1386 1387 1388def forward(self, arg0_1): 1389 ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) 1390 view = torch.ops.aten.view.default(arg0_1, [4, 2]) 1391 add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 1392 view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 1393 view_2 = torch.ops.aten.view.default(view_1, [4, 2]) 1394 mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None 1395 copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None 1396 return view_2 1397 """, 1398 ) 1399 1400 def test_aliases_maintained_after_pass_when_reapplying_views(self): 1401 def f(x): 1402 tmp = torch.ones(4, 2) 1403 y = x.view(4, 2) 1404 z = x.view(4, 2) 1405 y.add_(tmp) 1406 return y, z 1407 1408 input_functional = torch._to_functional_tensor(torch.ones(4, 2)) 1409 torch._enable_functionalization(reapply_views=True) 1410 try: 1411 y, z = f(input_functional) 1412 torch._sync(y) 1413 torch._sync(z) 1414 finally: 1415 torch._disable_functionalization() 1416 1417 # y and z are aliases inside of the function, and that aliasing relationship should be maintained. 1418 _y = torch._from_functional_tensor(y) 1419 _z = torch._from_functional_tensor(z) 1420 self.assertTrue(are_aliased(_y, _z)) 1421 1422 # copy_() gets its own test, because it used to be special cased in functionalization. 1423 # However, now it works pretty similar to other functional ops 1424 def test_copy_(self): 1425 def f(x): 1426 tmp = torch.zeros(2, 2) 1427 tmp_slice = tmp.diagonal() 1428 y = tmp_slice.copy_(x) 1429 z = y.add_(x) 1430 return z 1431 1432 # Test 1: copy_() with same dtype and shape 1433 # to() is a composite op that noops when the dtype/shape match, so nothing gets logged. 1434 # self.assert_functionalization(f, torch.ones(2)) 1435 logs = self.get_logs(f, torch.ones(2)) 1436 self.assertExpectedInline( 1437 logs, 1438 """\ 1439 1440 1441 1442def forward(self, arg0_1): 1443 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1444 diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1445 copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1446 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1447 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1448 add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1449 diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1450 diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1451 return diagonal_copy_2 1452 """, 1453 ) 1454 1455 reinplaced_logs = self.get_logs( 1456 f, torch.ones(2), reapply_views=True, run_reinplace=True 1457 ) 1458 self.assertExpectedInline( 1459 reinplaced_logs, 1460 """\ 1461 1462 1463 1464def forward(self, arg0_1): 1465 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1466 diagonal = torch.ops.aten.diagonal.default(zeros) 1467 copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1468 diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1469 add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1470 diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1471 return diagonal_2 1472 """, 1473 ) 1474 1475 # Test 2: copy_() with same dtype, different shape 1476 self.assert_functionalization(f, torch.ones(1)) 1477 logs = self.get_logs(f, torch.ones(1)) 1478 self.assertExpectedInline( 1479 logs, 1480 """\ 1481 1482 1483 1484def forward(self, arg0_1): 1485 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1486 diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1487 copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1488 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1489 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1490 add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1491 diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1492 diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1493 return diagonal_copy_2 1494 """, 1495 ) 1496 1497 reinplaced_logs = self.get_logs( 1498 f, torch.ones(1), reapply_views=True, run_reinplace=True 1499 ) 1500 self.assertExpectedInline( 1501 reinplaced_logs, 1502 """\ 1503 1504 1505 1506def forward(self, arg0_1): 1507 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1508 diagonal = torch.ops.aten.diagonal.default(zeros) 1509 copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1510 diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1511 add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1512 diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1513 return diagonal_2 1514 """, 1515 ) 1516 1517 # Test 3: copy_() with different dtype, same shape 1518 self.assert_functionalization(f, torch.ones(2, dtype=torch.long)) 1519 logs = self.get_logs(f, torch.ones(2, dtype=torch.long)) 1520 self.assertExpectedInline( 1521 logs, 1522 """\ 1523 1524 1525 1526def forward(self, arg0_1): 1527 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1528 diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1529 copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1530 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1531 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1532 add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1533 diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1534 diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1535 return diagonal_copy_2 1536 """, 1537 ) # noqa: B950 1538 1539 reinplaced_logs = self.get_logs( 1540 f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True 1541 ) 1542 self.assertExpectedInline( 1543 reinplaced_logs, 1544 """\ 1545 1546 1547 1548def forward(self, arg0_1): 1549 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1550 diagonal = torch.ops.aten.diagonal.default(zeros) 1551 copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1552 diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1553 add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1554 diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1555 return diagonal_2 1556 """, 1557 ) # noqa: B950 1558 1559 # Test 4: copy_() with different dtype, different shape 1560 self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) 1561 logs = self.get_logs(f, torch.ones(1, dtype=torch.long)) 1562 self.assertExpectedInline( 1563 logs, 1564 """\ 1565 1566 1567 1568def forward(self, arg0_1): 1569 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1570 diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros) 1571 copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None 1572 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None 1573 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) 1574 add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None 1575 diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None 1576 diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None 1577 return diagonal_copy_2 1578 """, 1579 ) # noqa: B950 1580 1581 reinplaced_logs = self.get_logs( 1582 f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True 1583 ) 1584 self.assertExpectedInline( 1585 reinplaced_logs, 1586 """\ 1587 1588 1589 1590def forward(self, arg0_1): 1591 zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) 1592 diagonal = torch.ops.aten.diagonal.default(zeros) 1593 copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None 1594 diagonal_1 = torch.ops.aten.diagonal.default(zeros) 1595 add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None 1596 diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None 1597 return diagonal_2 1598 """, 1599 ) # noqa: B950 1600 1601 def test_expand_symint(self): 1602 # Once some existing SymInt bugs are ironed out, we should update 1603 # this test to plumb FakeSymbolicTensors through it 1604 def f(x): 1605 return x.expand(x.size(0), x.size(1)) 1606 1607 self.assert_functionalization(f, torch.ones(2, 2)) 1608 logs = self.get_logs(f, torch.ones(2, 2)) 1609 self.assertExpectedInline( 1610 logs, 1611 """\ 1612 1613 1614 1615def forward(self, arg0_1): 1616 expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None 1617 return expand_copy 1618 """, 1619 ) 1620 1621 def test_fill_(self): 1622 def f(x): 1623 y = x + x 1624 z = y.diagonal() 1625 z.fill_(0) 1626 return y 1627 1628 self.assert_functionalization(f, torch.ones(2, 2)) 1629 logs = self.get_logs(f, torch.ones(2, 2)) 1630 self.assertExpectedInline( 1631 logs, 1632 """\ 1633 1634 1635 1636def forward(self, arg0_1): 1637 add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1638 diagonal_copy = torch.ops.aten.diagonal_copy.default(add) 1639 fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None 1640 diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None 1641 diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None 1642 return diagonal_scatter 1643 """, 1644 ) 1645 1646 reinplaced_logs = self.get_logs( 1647 f, torch.ones(2, 2), reapply_views=True, run_reinplace=True 1648 ) 1649 self.assertExpectedInline( 1650 reinplaced_logs, 1651 """\ 1652 1653 1654 1655def forward(self, arg0_1): 1656 add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None 1657 diagonal = torch.ops.aten.diagonal.default(add) 1658 fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = fill = None 1659 diagonal_1 = torch.ops.aten.diagonal.default(add); diagonal_1 = None 1660 return add 1661 """, 1662 ) 1663 1664 def test_resize_smaller(self): 1665 def f(w): 1666 # Resizing to a smaller size doesn't affect storage 1667 x = w + 1 1668 y = x.view(4, 4) 1669 y.resize_(3, 3) 1670 y2 = y.view(-1) 1671 y2.add_(1) 1672 z = y + 1 1673 return z 1674 1675 self.assert_functionalization(f, torch.ones(8, 2)) 1676 logs = self.get_logs(f, torch.ones(8, 2)) 1677 self.assertExpectedInline( 1678 logs, 1679 """\ 1680 1681 1682 1683def forward(self, arg0_1): 1684 add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1685 view_copy = torch.ops.aten.view_copy.default(add, [4, 4]) 1686 resize = torch.ops.aten.resize.default(view_copy, [3, 3]); resize = None 1687 as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None 1688 view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None 1689 add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None 1690 view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None 1691 as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]); as_strided_copy_1 = None 1692 view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None 1693 as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None 1694 view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None 1695 view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) 1696 as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None 1697 view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = view_copy_6 = None 1698 view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None 1699 as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None 1700 add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None 1701 return add_2 1702 """, # noqa: B950 1703 ) 1704 1705 reinplaced_logs = self.get_logs( 1706 f, torch.ones(8, 2), reapply_views=True, run_reinplace=True 1707 ) 1708 self.assertExpectedInline( 1709 reinplaced_logs, 1710 """\ 1711 1712 1713 1714def forward(self, arg0_1): 1715 add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1716 view = torch.ops.aten.view.default(add, [4, 4]) 1717 resize = torch.ops.aten.resize.default(view, [3, 3]); resize = None 1718 as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None 1719 view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None 1720 add_1 = torch.ops.aten.add_.Tensor(view_1, 1); add_1 = None 1721 view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None 1722 as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]); as_strided_1 = None 1723 view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = view_3 = None 1724 view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None 1725 view_5 = torch.ops.aten.view.default(view_4, [4, 4]) 1726 as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None 1727 view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = view_6 = None 1728 view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None 1729 as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None 1730 add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1); add_2 = None 1731 return as_strided_3 1732 """, 1733 ) 1734 1735 def test_resize_same_size_diff_rank(self): 1736 def f(x): 1737 y = x.clone() 1738 y.resize_(25, 5) 1739 return y 1740 1741 self.assert_functionalization(f, torch.ones(5, 5, 5)) 1742 1743 def test_resize_larger_valid(self): 1744 def f(x): 1745 y = x + 1 1746 # resizing a tensor to a larger size is only currently allowed 1747 # if the tensor-to-resize is not a view / has no outstanding views. 1748 # See Note [resize_() in functionalization pass] 1749 y.resize_(5, 5) 1750 y2 = y.view(25) 1751 # Do a mutation to ensure that aliases of the output of resize_() 1752 # propagate mutations correctly. 1753 # I'm using fill_ specifically because I want to guarantee that 1754 # none of the output has uninitialized memory at the end 1755 # (since these tests compare the data output against a reference impl) 1756 y2.fill_(1) 1757 out = y + 1 1758 return y, out 1759 1760 self.assert_functionalization(f, torch.ones(8, 2)) 1761 logs = self.get_logs(f, torch.ones(8, 2)) 1762 self.assertExpectedInline( 1763 logs, 1764 """\ 1765 1766 1767 1768def forward(self, arg0_1): 1769 add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1770 resize = torch.ops.aten.resize.default(add, [5, 5]); add = None 1771 view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None 1772 fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None 1773 view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None 1774 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]); view_copy_2 = None 1775 add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1) 1776 return (view_copy_1, add_1) 1777 """, 1778 ) 1779 1780 reinplaced_logs = self.get_logs( 1781 f, torch.ones(8, 2), reapply_views=True, run_reinplace=True 1782 ) 1783 self.assertExpectedInline( 1784 reinplaced_logs, 1785 """\ 1786 1787 1788 1789def forward(self, arg0_1): 1790 add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None 1791 resize = torch.ops.aten.resize_.default(add, [5, 5]); resize = None 1792 view = torch.ops.aten.view.default(add, [25]); add = None 1793 fill = torch.ops.aten.fill_.Scalar(view, 1); fill = None 1794 view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None 1795 view_2 = torch.ops.aten.view.default(view_1, [25]); view_2 = None 1796 add_1 = torch.ops.aten.add.Tensor(view_1, 1) 1797 return (view_1, add_1) 1798 """, 1799 ) 1800 1801 def test_resize_larger_invalid(self): 1802 def f(x): 1803 y = x + 1 1804 z = y.view(4, 4) 1805 # resizing a tensor to a larger size is only currently allowed 1806 # if the tensor-to-resize is not a view / has no outstanding views. 1807 # See Note [resize_() in functionalization pass] 1808 # This should fail 1809 z.resize_(5, 5) 1810 z2 = z.view(25) 1811 z2.fill_(1) 1812 out = z + 1 1813 return y, out 1814 1815 with self.assertRaisesRegex( 1816 RuntimeError, 1817 r"Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass", 1818 ): 1819 self.assert_functionalization(f, torch.ones(8, 2)) 1820 1821 def test_nested_functions_propagate_updates(self): 1822 def g(x): 1823 # Create a view of x 1824 y = x[0] 1825 y.add_(1) 1826 # The view, y, gets deallocated at the end of this function 1827 1828 def f(x): 1829 # Calling g(x) should mutate x 1830 g(x) 1831 # We expect x to be synced here, even though the alias created in g() has been deallocated! 1832 y = x + x 1833 return y 1834 1835 self.assert_functionalization(f, torch.ones(2, 2)) 1836 1837 def test_mixed_wrappers_valid(self): 1838 def f(x, y): 1839 z = x + y 1840 z.add_(1) 1841 return z 1842 1843 x1_not_functional = LoggingTensor(torch.ones(4)) 1844 x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4))) 1845 1846 with capture_logs() as logs: 1847 y = f(x1_not_functional, x2_functional) 1848 1849 # Make sure that functionalization ran the "+" kernel 1850 # with a functional + non-functional tensor, and wrapped the output appropriately. 1851 self.assertExpectedInline( 1852 "\n".join(logs), 1853 """\ 1854$2: f32[4] = torch._ops.aten.add.Tensor($0, $1) 1855$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""", 1856 ) 1857 1858 def test_mixed_wrappers_invalid(self): 1859 x1_not_functional = torch.ones(4) 1860 x2_functional = torch._to_functional_tensor(torch.ones(4)) 1861 1862 # When dealing with mixed functional + non functional tensors, 1863 # normal_tensor.add_(functional_tensor) is not valid 1864 # because normal_tensor would need to be "promoted" to a functional tensor. 1865 with self.assertRaises(RuntimeError): 1866 x1_not_functional.add_(x2_functional) 1867 1868 def test_index_mutation_on_non_input(self): 1869 def f(x): 1870 tmp = torch.zeros(10) 1871 tmp[5].fill_(1) 1872 return tmp 1873 1874 self.assert_functionalization(f, torch.ones(2)) 1875 logs = self.get_logs(f, torch.ones(2)) 1876 self.assertExpectedInline( 1877 logs, 1878 """\ 1879 1880 1881 1882def forward(self, arg0_1): 1883 zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) 1884 select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5) 1885 fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None 1886 select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None 1887 select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5); select_copy_1 = None 1888 return select_scatter 1889 """, 1890 ) # noqa: B950 1891 1892 reinplaced_logs = self.get_logs( 1893 f, torch.ones(2), reapply_views=True, run_reinplace=True 1894 ) 1895 self.assertExpectedInline( 1896 reinplaced_logs, 1897 """\ 1898 1899 1900 1901def forward(self, arg0_1): 1902 zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) 1903 select = torch.ops.aten.select.int(zeros, 0, 5) 1904 fill = torch.ops.aten.fill_.Scalar(select, 1); select = fill = None 1905 select_1 = torch.ops.aten.select.int(zeros, 0, 5); select_1 = None 1906 return zeros 1907 """, 1908 ) 1909 1910 def test_instance_norm(self): 1911 size = 100 1912 1913 def f(x, running_mean, running_var): 1914 with enable_python_dispatcher(): 1915 return torch.instance_norm( 1916 x, 1917 None, 1918 None, 1919 running_mean, 1920 running_var, 1921 use_input_stats=True, 1922 momentum=0.1, 1923 eps=1e-5, 1924 cudnn_enabled=False, 1925 ) 1926 1927 self.assert_functionalization( 1928 f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) 1929 ) 1930 # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used 1931 # whereas on other platforms, the alias_copy's are before the view_copy's. 1932 # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment. 1933 if not IS_WINDOWS: 1934 logs = self.get_logs( 1935 f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) 1936 ) 1937 self.assertExpectedInline( 1938 logs, 1939 """\ 1940 1941 1942 1943def forward(self, arg0_1, arg1_1, arg2_1): 1944 repeat = torch.ops.aten.repeat.default(arg1_1, [20]) 1945 repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) 1946 view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None 1947 empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 1948 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None 1949 getitem = _native_batch_norm_legit_functional[0] 1950 getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 1951 getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 1952 getitem_3 = _native_batch_norm_legit_functional[3] 1953 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 1954 alias_copy = torch.ops.aten.alias_copy.default(arg1_1) 1955 view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); view_copy_1 = None 1956 view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None 1957 mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None 1958 copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None 1959 alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None 1960 alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1); alias_copy_2 = None 1961 alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1) 1962 view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); view_copy_3 = None 1963 view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None 1964 mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None 1965 copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None 1966 alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None 1967 alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4); alias_copy_5 = None 1968 view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None 1969 copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = copy_ = None 1970 copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = copy__1 = None 1971 return view_copy_5 1972 """, # noqa: B950 1973 ) 1974 1975 reinplaced_logs = self.get_logs( 1976 f, 1977 torch.randn(20, size, 35, 45), 1978 torch.zeros(size), 1979 torch.ones(size), 1980 reapply_views=True, 1981 run_reinplace=True, 1982 ) 1983 self.assertExpectedInline( 1984 reinplaced_logs, 1985 """\ 1986 1987 1988 1989def forward(self, arg0_1, arg1_1, arg2_1): 1990 repeat = torch.ops.aten.repeat.default(arg1_1, [20]) 1991 repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) 1992 view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None 1993 empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 1994 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None 1995 getitem = _native_batch_norm_legit_functional[0] 1996 getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 1997 getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 1998 getitem_3 = _native_batch_norm_legit_functional[3] 1999 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2000 alias = torch.ops.aten.alias.default(arg1_1) 2001 view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]); view_1 = None 2002 view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None 2003 mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None 2004 copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None 2005 alias_1 = torch.ops.aten.alias.default(copy); copy = None 2006 alias_2 = torch.ops.aten.alias.default(alias_1); alias_2 = None 2007 alias_3 = torch.ops.aten.alias.default(arg2_1) 2008 view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]); view_3 = None 2009 view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None 2010 mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None 2011 copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None 2012 alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None 2013 alias_5 = torch.ops.aten.alias.default(alias_4); alias_5 = None 2014 view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None 2015 copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = copy_ = None 2016 copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = copy__1 = None 2017 return view_5 2018 """, # noqa: B950 2019 ) 2020 2021 def test_mutation_overlapping_mem(self): 2022 def fn(x): 2023 # x: (1, 5) 2024 t1 = torch.add(x, x) 2025 t2 = t1.unfold(1, 3, 2) 2026 t3 = t2.abs_() 2027 return t3 2028 2029 with self.assertRaisesRegex( 2030 RuntimeError, 2031 r"encountered a tensor being mutated that has internal overlap", 2032 ): 2033 x = torch.ones(1, 5) 2034 out = _functionalize(fn, reapply_views=True, crossref=False)(x) 2035 2036 def test_batch_norm(self): 2037 def f(x, running_mean, running_var): 2038 with enable_python_dispatcher(): 2039 return torch.batch_norm( 2040 x, None, None, running_mean, running_var, True, 0.1, 1e-5, False 2041 ) 2042 2043 self.assert_functionalization( 2044 f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) 2045 ) 2046 logs = self.get_logs( 2047 f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) 2048 ) 2049 self.assertExpectedInline( 2050 logs, 2051 """\ 2052 2053 2054 2055def forward(self, arg0_1, arg1_1, arg2_1): 2056 empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 2057 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None 2058 getitem = _native_batch_norm_legit_functional[0] 2059 getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 2060 getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 2061 getitem_3 = _native_batch_norm_legit_functional[3] 2062 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2063 copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None 2064 copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None 2065 return getitem 2066 """, # noqa: B950 2067 ) 2068 2069 reinplaced_logs = self.get_logs( 2070 f, 2071 torch.randn(20, 100, 35, 45), 2072 torch.zeros(100), 2073 torch.ones(100), 2074 reapply_views=True, 2075 run_reinplace=True, 2076 ) 2077 self.assertExpectedInline( 2078 reinplaced_logs, 2079 """\ 2080 2081 2082 2083def forward(self, arg0_1, arg1_1, arg2_1): 2084 empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None 2085 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None 2086 getitem = _native_batch_norm_legit_functional[0] 2087 getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None 2088 getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None 2089 getitem_3 = _native_batch_norm_legit_functional[3] 2090 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 2091 copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None 2092 copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None 2093 return getitem 2094 """, # noqa: B950 2095 ) 2096 2097 # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode 2098 def test_python_functionalization(self): 2099 def f(x): 2100 x_view = x.view(-1) 2101 x.mul_(2) 2102 return x_view + 1 2103 2104 def f_functionalized(x): 2105 # Note [Disabling Functionalize TLS Above Python Functionalization] 2106 # This UX is pretty annoying (although python functionalization's main customer is AOTAutograd, 2107 # and is not really advertised as a user API). 2108 # We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode. 2109 # Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper. 2110 # Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default, 2111 # our FunctionalTensor will inherit the same keyset. 2112 # We don't have an easy way of directly mutating a tensor's keyset from python, 2113 # so globally disabling functionalization here is easier. 2114 maybe_disable = torch._C._ExcludeDispatchKeyGuard( 2115 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 2116 ) 2117 with maybe_disable, FunctionalTensorMode(): 2118 x_wrapped = FunctionalTensor.to_functional(x) 2119 out_wrapped = f(x_wrapped) 2120 out_unwrapped = out_wrapped.elem 2121 torch._sync(out_unwrapped) 2122 return torch._from_functional_tensor(out_unwrapped) 2123 2124 # Make a non-leaf 2125 x = torch.randn(2, requires_grad=True) + 1 2126 fx_g = make_fx(f_functionalized)(x) 2127 # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a 2128 # DCE pass that will remove nodes like this later on. 2129 self.assertExpectedInline( 2130 fx_g.code.strip(), 2131 """\ 2132def forward(self, x_1): 2133 view = torch.ops.aten.view.default(x_1, [-1]); view = None 2134 mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None 2135 view_1 = torch.ops.aten.view.default(mul, [-1]); view_1 = None 2136 view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None 2137 add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None 2138 return add""", 2139 ) 2140 2141 def test_python_functionalization_zero_tensor(self): 2142 def f(x): 2143 y = torch.ops.aten._efficientzerotensor([4]) 2144 out = x + y 2145 out.mul_(2) 2146 return out 2147 2148 x = torch.randn(4) 2149 out_ref = f(x) 2150 out_test = dispatch_functionalize(f)(x) 2151 out_test_cpp = _functionalize( 2152 f, reapply_views=True, crossref=False, skip_input_mutations=True 2153 )(x) 2154 self.assertEqual(out_ref, out_test) 2155 self.assertEqual(out_ref, out_test_cpp) 2156 fx_g = make_fx(dispatch_functionalize(f))(x) 2157 fx_g_cpp = make_fx( 2158 _functionalize( 2159 f, reapply_views=True, crossref=False, skip_input_mutations=True 2160 ) 2161 )(x) 2162 self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2163 2164 def test_python_functionalization_is_conj(self): 2165 def f(x): 2166 out = x.conj() 2167 return out, out.is_conj() 2168 2169 x = torch.randn(4, dtype=torch.complex64) 2170 out_ref = f(x) 2171 out_test = dispatch_functionalize(f)(x) 2172 out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x) 2173 self.assertEqual(out_ref[0], out_test[0]) 2174 self.assertEqual(out_ref[1], out_test[1]) 2175 self.assertEqual(out_ref[0], out_test_cpp[0]) 2176 self.assertEqual(out_ref[1], out_test_cpp[1]) 2177 2178 def test_python_functionalization_is_neg(self): 2179 def f(x): 2180 out = x.neg() 2181 return out, out.is_neg() 2182 2183 x = torch.randn(4, dtype=torch.complex64) 2184 out_ref = f(x) 2185 out_test = dispatch_functionalize(f)(x) 2186 out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x) 2187 self.assertEqual(out_ref[0], out_test[0]) 2188 self.assertEqual(out_ref[1], out_test[1]) 2189 self.assertEqual(out_ref[0], out_test_cpp[0]) 2190 self.assertEqual(out_ref[1], out_test_cpp[1]) 2191 2192 def test_python_functionalization_conj(self): 2193 def f(x): 2194 y = x.clone().conj() 2195 y.mul_(2) 2196 return torch.view_as_real(y.resolve_conj()) 2197 2198 x = torch.randn(4, dtype=torch.complex64) 2199 out_ref = f(x) 2200 out_test = dispatch_functionalize(f)(x) 2201 out_test_cpp = _functionalize( 2202 f, reapply_views=True, crossref=False, skip_input_mutations=True 2203 )(x) 2204 self.assertEqual(out_ref, out_test) 2205 self.assertEqual(out_test, out_test_cpp) 2206 fx_g = make_fx(dispatch_functionalize(f))(x) 2207 fx_g_cpp = make_fx( 2208 _functionalize( 2209 f, reapply_views=True, crossref=False, skip_input_mutations=True 2210 ) 2211 )(x) 2212 self.assertExpectedInline( 2213 fx_g.code.strip(), 2214 """\ 2215def forward(self, arg0_1): 2216 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 2217 _conj = torch.ops.aten._conj.default(clone); clone = None 2218 clone_1 = torch.ops.aten.clone.default(_conj) 2219 mul = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None 2220 clone_2 = torch.ops.aten.clone.default(_conj); _conj = None 2221 copy = torch.ops.aten.copy.default(clone_2, mul); clone_2 = mul = None 2222 _conj_1 = torch.ops.aten._conj.default(copy); copy = None 2223 _conj_2 = torch.ops.aten._conj.default(_conj_1); _conj_1 = None 2224 clone_3 = torch.ops.aten.clone.default(_conj_2); _conj_2 = None 2225 view_as_real = torch.ops.aten.view_as_real.default(clone_3); clone_3 = None 2226 return view_as_real""", 2227 ) 2228 self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2229 2230 def test_python_functionalization_neg(self): 2231 def f(x): 2232 y = x._neg_view() 2233 z = y.resolve_neg() 2234 return z + 1 2235 2236 x = torch.randn(4) 2237 out_ref = f(x) 2238 out_test = dispatch_functionalize(f)(x) 2239 out_test_cpp = _functionalize( 2240 f, reapply_views=True, crossref=False, skip_input_mutations=True 2241 )(x) 2242 self.assertEqual(out_ref, out_test) 2243 self.assertEqual(out_ref, out_test_cpp) 2244 fx_g = make_fx(dispatch_functionalize(f))(x) 2245 fx_g_cpp = make_fx( 2246 _functionalize( 2247 f, reapply_views=True, crossref=False, skip_input_mutations=True 2248 ) 2249 )(x) 2250 self.assertExpectedInline( 2251 fx_g.code.strip(), 2252 """\ 2253def forward(self, arg0_1): 2254 _neg_view = torch.ops.aten._neg_view.default(arg0_1); arg0_1 = None 2255 clone = torch.ops.aten.clone.default(_neg_view); _neg_view = None 2256 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 2257 return add""", 2258 ) 2259 self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2260 2261 def test_python_functionalization_lift_fresh_storage(self): 2262 unlifted = torch.tensor([0.0]) 2263 2264 maybe_disable = torch._C._ExcludeDispatchKeyGuard( 2265 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 2266 ) 2267 with maybe_disable, FunctionalTensorMode(): 2268 lifted = torch.ops.aten.lift_fresh.default(unlifted) 2269 2270 self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage()) 2271 2272 def test_python_functionalization_lift_fresh(self): 2273 def f(x): 2274 tmp = torch.tensor([0.0]) 2275 return tmp + x 2276 2277 x = torch.randn(4) 2278 out_ref = f(x) 2279 out_test = dispatch_functionalize(f)(x) 2280 out_test_cpp = _functionalize( 2281 f, reapply_views=True, crossref=False, skip_input_mutations=True 2282 )(x) 2283 self.assertEqual(out_ref, out_test) 2284 self.assertEqual(out_ref, out_test_cpp) 2285 fx_g = make_fx(dispatch_functionalize(f))(x) 2286 fx_g_cpp = make_fx( 2287 _functionalize( 2288 f, reapply_views=True, crossref=False, skip_input_mutations=True 2289 ) 2290 )(x) 2291 self.assertExpectedInline( 2292 fx_g.code.strip(), 2293 """\ 2294def forward(self, arg0_1): 2295 _tensor_constant0 = self._tensor_constant0 2296 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 2297 add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1); lift_fresh_copy = arg0_1 = None 2298 return add""", 2299 ) 2300 self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) 2301 2302 2303@xfail_inherited_tests( 2304 [ 2305 "test_as_strided", 2306 "test_copy_", 2307 "test_diagonal", 2308 "test_diagonal_mutated_input", 2309 "test_everything", 2310 "test_fill_", 2311 "test_slice", 2312 "test_split", 2313 "test_split_with_sizes", 2314 "test_unbind", 2315 "test_view_clone_view_inplace", 2316 "test_view_inplace", 2317 ] 2318) 2319@unittest.skipIf( 2320 TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well" 2321) 2322class TestCrossRefFunctionalization(TestFunctionalization): 2323 crossref = True 2324 2325 2326if __name__ == "__main__": 2327 run_tests() 2328