1# Owner(s): ["module: inductor"] 2 3from typing import List 4 5import torch 6import torch._inductor.config as inductor_config 7from functorch import make_fx 8from torch import Tensor 9from torch._dynamo.utils import counters 10from torch._higher_order_ops.auto_functionalize import ( 11 auto_functionalized, 12 auto_functionalized_v2, 13) 14from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core 15from torch._inductor.test_case import run_tests, TestCase as InductorTestCase 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 IS_LINUX, 19 parametrize, 20 subtest, 21) 22from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 23from torch.testing._internal.logging_utils import logs_to_string 24 25 26aten = torch.ops.aten 27 28 29const = torch.tensor(0.0) 30device = GPU_TYPE 31 32 33def num_reinplacing_failures(): 34 return counters["inductor"]["possibly_missed_reinplacing_opportunities"] 35 36 37@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) 38def sin(x: torch.Tensor, result: torch.Tensor) -> None: 39 result.copy_(x.sin()) 40 41 42@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) 43def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None: 44 out_sin.copy_(x.sin()) 45 out_cos.copy_(x.cos()) 46 47 48if HAS_GPU: 49 import triton 50 import triton.language as tl 51 52 @triton.jit 53 def sin_kernel( 54 in_ptr0, 55 out_ptr, 56 n_elements, 57 BLOCK_SIZE: "tl.constexpr", 58 ): 59 pid = tl.program_id(axis=0) 60 block_start = pid * BLOCK_SIZE 61 offsets = block_start + tl.arange(0, BLOCK_SIZE) 62 mask = offsets < n_elements 63 x = tl.load(in_ptr0 + offsets, mask=mask) 64 output = tl.sin(x) 65 tl.store(out_ptr + offsets, output, mask=mask) 66 67 def sin_triton(x, out): 68 n_elements = x.numel() 69 sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) 70 71else: 72 73 def sin_triton(x, out): 74 return 75 76 77@torch.library.custom_op("test_view::boo", mutates_args={"x"}) 78def boo(x: torch.Tensor) -> None: 79 x.sin_() 80 81 82class TestReinplacingPassCorrectness(InductorTestCase): 83 def setUp(self): 84 counters.clear() 85 return super().setUp() 86 87 def _test(self, f): 88 nf = torch.compile(f) 89 inp = ( 90 torch.randn(4, device=device), 91 torch.ones(2, device=device, dtype=torch.int), 92 ) 93 inp2 = (inp[0].clone(), inp[1].clone()) 94 self.assertEqual(f(*inp), nf(*inp2)) 95 self.assertEqual(inp, inp2) 96 97 def test_dont_modify_live(self): 98 def f(x, y): 99 x = x.cos() 100 x2 = x.index_put((y,), const) 101 return x2, x 102 103 self._test(f) 104 105 def test_dont_modify_view_of_live(self): 106 def f(x, y): 107 x = x.cos() 108 x2 = aten.alias(x) 109 x2 = x2.index_put((y,), const) 110 y = x2 + x.cos() 111 return y 112 113 self._test(f) 114 115 def test_dont_modify_input(self): 116 def f(x, y): 117 return x.index_put((y,), const) 118 119 self._test(f) 120 121 def test_should_modify_inner(self): 122 def f(x, y): 123 x = x.cos() 124 x = x.index_put((y,), const) 125 return x 126 127 self._test(f) 128 129 def test_should_modify_input(self): 130 def f(x, y): 131 x = x.index_put_((y,), const) 132 return x 133 134 self._test(f) 135 136 def test_counters_functionalize_old(self): 137 counters.clear() 138 139 def f(x): 140 out = torch.empty_like(x) 141 _, new_out = auto_functionalized(sin._opoverload, x=x, result=out) 142 y = out * new_out 143 return new_out, y 144 145 x = torch.randn(3, device=device) 146 gm = make_fx(f, tracing_mode="fake")(x) 147 reinplace_inplaceable_ops_core(gm.graph) 148 149 # We shouldn't have been able to reinplace `out` because it was used after 150 # auto_functionalized. Note that this usually doesn't happen in practice; 151 # we're artificially creating this example to test the counter. 152 # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE 153 self.assertEqual(num_reinplacing_failures(), 1) 154 155 def test_counters_functionalize_v2(self): 156 counters.clear() 157 158 def f(x): 159 out = torch.empty_like(x) 160 _, new_out = auto_functionalized_v2( 161 sin._opoverload, 162 x=x, 163 _result_base_index=0, 164 _result_size=(3,), 165 _result_stride=(1,), 166 _result_storage_offset=0, 167 _all_bases=[out], 168 ) 169 y = out * new_out 170 return new_out, y 171 172 x = torch.randn(3, device=device) 173 gm = make_fx(f, tracing_mode="fake")(x) 174 reinplace_inplaceable_ops_core(gm.graph) 175 176 # We shouldn't have been able to reinplace `out` because it was used after 177 # auto_functionalized. Note that this usually doesn't happen in practice; 178 # we're artificially creating this example to test the counter. 179 # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE 180 self.assertEqual(num_reinplacing_failures(), 1) 181 182 def get_not_inplaced_count(self, graph): 183 counter = 0 184 auto_functionalized_found = False 185 for node in graph.nodes: 186 if (node.target == torch.ops.higher_order.auto_functionalized) or ( 187 node.target == torch.ops.higher_order.auto_functionalized_v2 188 ): 189 auto_functionalized_found = True 190 counter += len(node.meta["only_clone_these_tensors"]) 191 assert auto_functionalized_found 192 return counter 193 194 def test_view_inplaced_functionalize_v2(self): 195 def f(arg0_1): 196 select = torch.ops.aten.select.int(arg0_1, 0, 0) 197 auto_functionalized = auto_functionalized_v2( 198 torch.ops.test_view.boo.default, 199 _x_base_index=0, 200 _x_size=(3,), 201 _x_stride=(1,), 202 _x_storage_offset=0, 203 _all_bases=[arg0_1], 204 ) 205 getitem_1 = auto_functionalized[1] 206 copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) 207 return () 208 209 x1 = torch.randn(3, device=device) 210 gm = make_fx(f, tracing_mode="fake")(x1) 211 reinplace_inplaceable_ops_core(gm.graph) 212 213 self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) 214 215 # introduce a view another_view that is used `after` the copy 216 def test_view_inplaced2_functionalize_v2(self): 217 def f(arg0_1): 218 select = torch.ops.aten.select.int(arg0_1, 0, 0) 219 another_view = arg0_1[2] 220 auto_functionalized = auto_functionalized_v2( 221 torch.ops.test_view.boo.default, 222 _x_base_index=0, 223 _x_size=(3,), 224 _x_stride=(1,), 225 _x_storage_offset=0, 226 _all_bases=[arg0_1], 227 ) 228 getitem_1 = auto_functionalized[1] 229 copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) 230 return another_view 231 232 x1 = torch.randn(3, device=device) 233 gm = make_fx(f, tracing_mode="fake")(x1) 234 reinplace_inplaceable_ops_core(gm.graph) 235 236 self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) 237 238 # introduce a view another_view that is used `before` the copy 239 def test_views_not_inplaced_functionalize_v2(self): 240 def f(arg0_1): 241 select = torch.ops.aten.select.int(arg0_1, 0, 0) 242 another_view = arg0_1[2] 243 auto_functionalized = auto_functionalized_v2( 244 torch.ops.test_view.boo.default, 245 _x_base_index=0, 246 _x_size=(3,), 247 _x_stride=(1,), 248 _x_storage_offset=0, 249 _all_bases=[arg0_1], 250 ) 251 getitem_1 = auto_functionalized[1] 252 use_another_view = another_view * 10 253 copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) 254 return use_another_view 255 256 x1 = torch.randn(3, device=device) 257 gm = make_fx(f, tracing_mode="fake")(x1) 258 reinplace_inplaceable_ops_core(gm.graph) 259 260 self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) 261 262 # a view over input without copy node, inplace not allowed 263 def test_views_not_inplaced2_functionalize_v2(self): 264 def f(arg0_1): 265 select = torch.ops.aten.select.int(arg0_1, 0, 0) 266 another_view = arg0_1[2] 267 auto_functionalized = auto_functionalized_v2( 268 torch.ops.test_view.boo.default, 269 _x_base_index=0, 270 _x_size=(3,), 271 _x_stride=(1,), 272 _x_storage_offset=0, 273 _all_bases=[arg0_1], 274 ) 275 getitem_1 = auto_functionalized[1] 276 return 277 278 x1 = torch.randn(3, device=device) 279 gm = make_fx(f, tracing_mode="fake")(x1) 280 reinplace_inplaceable_ops_core(gm.graph) 281 282 self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) 283 284 # no copy nodes, view over local, with a use for another view 285 def test_views_not_inplaced3_functionalize_v2(self): 286 def f(arg0_1): 287 a = torch.ones(10) 288 another_view = a[2] 289 auto_functionalized = auto_functionalized_v2( 290 torch.ops.test_view.boo.default, 291 _x_base_index=0, 292 _x_size=(), 293 _x_stride=(), 294 _x_storage_offset=0, 295 _all_bases=[a], 296 ) 297 getitem_1 = auto_functionalized[1] 298 return another_view 299 300 x1 = torch.randn(3, device=device) 301 gm = make_fx(f, tracing_mode="fake")(x1) 302 reinplace_inplaceable_ops_core(gm.graph) 303 304 self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) 305 306 def test_multi_output_intermediate(self): 307 for requires_grad in [False, True]: 308 for enable_v2 in [False, True]: 309 with inductor_config.patch( 310 {"enable_auto_functionalized_v2": enable_v2} 311 ): 312 counters.clear() 313 314 def f(x): 315 out1 = torch.empty_like(x) 316 out2 = torch.empty_like(x) 317 sin_cos(x, out1, out2) 318 return out1, out2, x**2 319 320 x = torch.randn(3, device=device, requires_grad=requires_grad) 321 res1, res2, _ = torch.compile(f)(x) 322 self.assertEqual(res1, x.sin()) 323 self.assertEqual(res2, x.cos()) 324 self.assertEqual(num_reinplacing_failures(), 0) 325 326 def test_multiple_mutations(self): 327 counters.clear() 328 329 def f(x, out): 330 sin(x, out) 331 sin(out, out) 332 sin(out, out) 333 return out 334 335 x = torch.randn(3, device=device) 336 out = torch.randn(3, device=device) 337 result = torch.compile(f)(x, out) 338 self.assertEqual(result, x.sin().sin().sin()) 339 self.assertEqual(result, out) 340 self.assertEqual(num_reinplacing_failures(), 0) 341 342 def test_multiple_intermediate(self): 343 counters.clear() 344 345 def f(x): 346 out = torch.empty_like(x) 347 sin(x, out) 348 sin(out, out) 349 sin(out, out) 350 return out 351 352 x = torch.randn(3, device=device) 353 result = torch.compile(f)(x) 354 self.assertEqual(result, x.sin().sin().sin()) 355 self.assertEqual(num_reinplacing_failures(), 0) 356 357 def test_lists_functionalize_v2(self): 358 with inductor_config.patch({"enable_auto_functionalized_v2": True}): 359 360 @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) 361 def mutate_op(y: List[Tensor]) -> None: 362 y[0].add_(2) 363 y[1].add_(3) 364 365 @torch.compile(fullgraph=True, dynamic=False, backend="inductor") 366 def f(b): 367 mutate_op([b[0], b[1]]) 368 369 x1 = torch.tensor([0.3, 0.4], device=device) 370 log_stream, ctx = logs_to_string( 371 "torch._inductor.compile_fx", "post_grad_graphs" 372 ) 373 with ctx(): 374 torch.compile(f, backend="inductor", fullgraph=True)(x1) 375 post_grad_graphs = "\n".join( 376 log_stream.getvalue().strip().split("\n")[3:] 377 ).strip() 378 379 # We can inplace the base y. no clones emitted. 380 self.assertEqual(num_reinplacing_failures(), 0) 381 self.assertEqual(post_grad_graphs.count("aten.clone"), 0) 382 383 def test_lists_old_functionalize(self): 384 with inductor_config.patch({"enable_auto_functionalized_v2": False}): 385 386 @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) 387 def mutate_op(y: List[Tensor]) -> None: 388 y[0].add_(2) 389 y[1].add_(3) 390 391 @torch.compile(fullgraph=True, dynamic=False, backend="inductor") 392 def f(b): 393 mutate_op([b[0], b[1]]) 394 395 x1 = torch.tensor([0.3, 0.4], device=device) 396 log_stream, ctx = logs_to_string( 397 "torch._inductor.compile_fx", "post_grad_graphs" 398 ) 399 with ctx(): 400 torch.compile(f, backend="inductor", fullgraph=True)(x1) 401 post_grad_graphs = "\n".join( 402 log_stream.getvalue().strip().split("\n")[3:] 403 ).strip() 404 405 # Can't reinplace on views yet (1 for the "entire list" failing to reinplace) 406 self.assertEqual(num_reinplacing_failures(), 1) 407 408 # Both list inputs failed to reinplace. So we should have emitted clones for them. 409 self.assertEqual(post_grad_graphs.count("aten.clone"), 2) 410 411 @parametrize( 412 "factory_op", 413 [ 414 subtest(torch.ones_like, name="ones_like"), 415 subtest(torch.empty_like, name="empty_like"), 416 ], 417 ) 418 @parametrize( 419 "sin_op", 420 [ 421 subtest(sin, name="sin_op"), 422 subtest(sin_triton, name="sin_triton"), 423 ], 424 ) 425 def test_partitioner_recomputes_factory(self, factory_op, sin_op): 426 class MySin(torch.autograd.Function): 427 @staticmethod 428 def forward(ctx, x): 429 out = factory_op(x) 430 sin_op(x, out) 431 ctx.save_for_backward(out) 432 return out 433 434 @staticmethod 435 def backward(ctx, grad): 436 (saved,) = ctx.saved_tensors 437 out = factory_op(grad) 438 sin_op(saved, out) 439 return out 440 441 @torch.compile(backend="inductor") 442 def f(x): 443 return MySin.apply(x) 444 445 x = torch.randn(3, requires_grad=True, device=device) 446 y = f(x) 447 self.assertEqual(num_reinplacing_failures(), 0) 448 449 450instantiate_parametrized_tests(TestReinplacingPassCorrectness) 451 452 453if __name__ == "__main__": 454 if IS_LINUX and HAS_GPU: 455 run_tests(needs="filelock") 456