1# Owner(s): ["module: functorch"] 2import contextlib 3import functools 4import unittest 5 6import torch 7import torch.utils._pytree as pytree 8from functorch.experimental import control_flow 9from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException 10from torch._higher_order_ops.associative_scan import associative_scan 11from torch._higher_order_ops.while_loop import while_loop 12from torch._subclasses.functional_tensor import ( 13 CppFunctionalizeAPI, 14 FunctionalTensor, 15 FunctionalTensorMode, 16 PythonFunctionalizeAPI, 17) 18from torch.fx.experimental.proxy_tensor import make_fx 19from torch.testing._internal.common_cuda import SM70OrLater 20from torch.testing._internal.common_quantization import skipIfNoDynamoSupport 21from torch.testing._internal.common_utils import ( 22 decorateIf, 23 instantiate_parametrized_tests, 24 IS_WINDOWS, 25 parametrize, 26 run_tests, 27 skipIfRocm, 28 skipIfTorchDynamo, 29 TEST_WITH_TORCHDYNAMO, 30 TestCase, 31 xfailIfTorchDynamo, 32) 33 34 35# TODO: pull these helpers from AOTAutograd later 36def to_fun(t): 37 if isinstance(t, torch.Tensor): 38 return FunctionalTensor.to_functional(t) 39 return t 40 41 42def from_fun(t): 43 if not isinstance(t, FunctionalTensor): 44 # quick sanity assert 45 if isinstance(t, torch.Tensor): 46 assert not torch._is_functional_tensor(t) 47 return t 48 torch._sync(t) 49 return torch._from_functional_tensor(t.elem) 50 51 52def to_fun_old(t): 53 if isinstance(t, torch.Tensor) and not torch._is_functional_tensor(t): 54 out = torch._to_functional_tensor(t) 55 torch._mirror_autograd_meta_to(t, out) 56 return out 57 return t 58 59 60def from_fun_old(t): 61 # quick sanity assert 62 if isinstance(t, torch.Tensor): 63 assert torch._is_functional_tensor(t) 64 torch._sync(t) 65 return torch._from_functional_tensor(t) 66 return t 67 68 69def _fake_map(f, x, *args): 70 from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree 71 72 x_pytrees = _unstack_pytree(x) 73 zs = [] 74 for xp in x_pytrees: 75 zs.append(f(xp, *args)) 76 return _stack_pytree(zs) 77 78 79def _fake_while_loop(cond_fn, body_fn, operands): 80 while cond_fn(*operands): 81 operands = body_fn(*operands) 82 return operands 83 84 85def _fake_associative_scan(combine_fn, input, dim, reverse=False): 86 inp_leaves, spec = pytree.tree_flatten(input) 87 result_flat = [] 88 num_leaves = len(inp_leaves) 89 op = reversed if reverse else lambda x: x 90 91 for ind in op(range(inp_leaves[0].size(dim))): 92 r = [ 93 inp_leaves[leave_ind][(slice(None),) * dim + (ind,)] 94 for leave_ind in range(num_leaves) 95 ] 96 if (ind > 0 and not reverse) or ( 97 ind < (inp_leaves[0].size(dim) - 1) and reverse 98 ): 99 r = combine_fn( 100 pytree.tree_unflatten(result_flat[-1], spec), 101 pytree.tree_unflatten(r, spec), 102 ) 103 r_flat, _ = pytree.tree_flatten(r) 104 result_flat.append(r_flat) 105 106 results = [ 107 torch.stack([e[leave_ind] for e in op(result_flat)], dim) 108 for leave_ind in range(num_leaves) 109 ] 110 return pytree.tree_unflatten(results, spec) 111 112 113def _while_loop_tests(): 114 def simple(x): 115 def cond_fn(x): 116 return x.sum() < 10 117 118 def body_fn(x): 119 return (x + 1,) 120 121 return while_loop(cond_fn, body_fn, (x,)) 122 123 def simple_with_mutation(x): 124 def cond_fn(x): 125 y = x.clone().add_(1).add_(-1) 126 return y.sum() < 10 127 128 def body_fn(x): 129 y = x.clone().add_(1).add_(-1) 130 return (y + 1,) 131 132 return while_loop(cond_fn, body_fn, (x,)) 133 134 def nested(out_iter, it, y): 135 def cond_fn(out_iter, it, y): 136 return it.sum() < 10 137 138 def body_fn(out_iter, it, y): 139 return (out_iter.clone(), it + y, y + 1) 140 141 def outer_cond_fn(out_iter, it, y): 142 return out_iter.sum() < 2 143 144 def outer_body_fn(out_iter, it, y): 145 out_iter, it, y = while_loop(cond_fn, body_fn, (out_iter, it, y)) 146 return (out_iter + 1, it, y) 147 148 return while_loop(outer_cond_fn, outer_body_fn, (out_iter, it, y)) 149 150 class Nested(torch.nn.Module): 151 def forward(self, ci, cj, a, b): 152 def cond_fn(i1, j1, x1, y1): 153 return i1 > 0 154 155 def body_fn(i1, j1, x1, y1): 156 def cond_fn_nested(i2, j2, x2, y2): 157 return j2 > 0 158 159 def body_fn_nested(i2, j2, x2, y2): 160 return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71 161 162 i1, j1, x1, y1 = while_loop( 163 cond_fn_nested, body_fn_nested, [i1, j1, x1, y1] 164 ) 165 return i1 - 1, j1.clone(), x1 * 2, y1 / 2 166 167 return while_loop(cond_fn, body_fn, (ci, cj, a, b)) 168 169 class SimpleWithLinear(torch.nn.Module): 170 def __init__(self) -> None: 171 super().__init__() 172 self.linear = torch.nn.Linear(2, 2) 173 self.dec = torch.nn.Buffer(torch.tensor(1)) 174 175 def forward(self, iter, x): 176 def cond_fn(it, x): 177 return it - self.dec > 0 178 179 def body_fn(it, x): 180 return it - 1, self.linear(x) 181 182 return while_loop(cond_fn, body_fn, (iter, x)) 183 184 class NestedWithLinear(torch.nn.Module): 185 def __init__(self) -> None: 186 super().__init__() 187 self.mod = SimpleWithLinear() 188 self.outer_linear = torch.nn.Linear(2, 2) 189 self.dec = torch.nn.Buffer(torch.tensor(1)) 190 191 def forward(self, iter, x): 192 def cond_fn(it, x): 193 return it - self.dec > 0 194 195 def body_fn(it, x): 196 return it - 1, self.outer_linear(self.mod(it, x)[1]) 197 198 return while_loop(cond_fn, body_fn, (iter, x)) 199 200 nested2 = Nested() 201 simple_with_linear = SimpleWithLinear() 202 nested_with_linear = NestedWithLinear() 203 204 x = torch.zeros(1) 205 y = torch.zeros(1) 206 z = torch.zeros(1) 207 return { 208 "simple": (simple, (x,)), 209 "nested": (nested, (x, y, z)), 210 "nested2": ( 211 nested2, 212 (torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2)), 213 ), 214 "simple_with_mutation": (simple_with_mutation, (x,)), 215 "simple_with_linear": ( 216 simple_with_linear, 217 (torch.tensor(3), torch.randn(2, 2)), 218 ), 219 "nested_with_linear": ( 220 nested_with_linear, 221 (torch.tensor(3), torch.randn(2, 2)), 222 ), 223 } 224 225 226WHILE_LOOP_TESTS = _while_loop_tests() 227 228 229def collect_meta_for_filtered_nodes( 230 gm: torch.fx.GraphModule, node_names, meta_field_name 231): 232 ret = [] 233 for mod in gm.modules(): 234 for node in mod.graph.nodes: 235 if node.name in node_names: 236 for field_name in meta_field_name: 237 ret.append(node.meta.get(field_name)) 238 return ret 239 240 241def reduce_func(*operands): 242 acc = 0 243 for operand in operands: 244 acc += operand 245 return acc 246 247 248class ReduceObj: 249 def __call__(self, *operands): 250 return reduce_func(*operands) 251 252 253class ReduceMod(torch.nn.Module): 254 def _reduce(self, *operands): 255 return reduce_func(*operands) 256 257 def forward(self, *operands): 258 return self._reduce(*operands) 259 260 261@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 262@skipIfNoDynamoSupport 263class TestControlFlow(TestCase): 264 def setUp(self): 265 torch._dynamo.reset() 266 super().setUp() 267 268 def test_cond_no_trace(self): 269 def true_fn(x): 270 return x.sin() 271 272 def false_fn(x): 273 return x.cos() 274 275 x = torch.randn(4) 276 result = cond(False, true_fn, false_fn, [x]) 277 self.assertEqual(result, torch.cos(x)) 278 279 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 280 def test_cond_gpu(self): 281 def true_fn(x): 282 return x.sin() 283 284 def false_fn(x): 285 return x.cos() 286 287 x = torch.randn(4, device="cuda") 288 pred = torch.tensor(False, device="cuda") 289 result = cond(pred, true_fn, false_fn, [x]) 290 self.assertEqual(result, torch.cos(x)) 291 292 def test_cond_autograd_simple(self): 293 def true_fn(x): 294 return x.sin() 295 296 def false_fn(x): 297 return x.cos() 298 299 for pred, fn in zip( 300 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 301 ): 302 x = torch.randn(4, requires_grad=True) 303 result = cond(pred, true_fn, false_fn, (x,)) 304 self.assertEqual(result, fn(x)) 305 306 grad_out = torch.ones_like(result) 307 grads = torch.autograd.grad(result, (x,), grad_out) 308 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 309 self.assertEqual(expected_grads, grads) 310 311 def f(pred, x): 312 result = cond(pred, true_fn, false_fn, (x,)) 313 grad_out = torch.ones_like(result) 314 return torch.autograd.grad(result, (x,), grad_out) 315 316 gm = make_fx(f, tracing_mode="symbolic")(pred, x) 317 318 self.assertExpectedInline( 319 gm.code.strip(), 320 """\ 321def forward(self, pred_1, x_1): 322 true_graph_0 = self.true_graph_0 323 false_graph_0 = self.false_graph_0 324 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None 325 getitem = cond[0]; cond = None 326 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 327 true_graph_1 = self.true_graph_1 328 false_graph_1 = self.false_graph_1 329 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None 330 getitem_1 = cond_1[0]; cond_1 = None 331 return (getitem_1,)""", # noqa: B950 332 ) 333 334 def test_cond_autograd_complex(self): 335 def true_fn(x): 336 return torch.abs((x**2).sin()) 337 338 def false_fn(x): 339 return (x + 42).cos() 340 341 for pred, fn in zip( 342 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 343 ): 344 x = torch.randn(4, requires_grad=True) 345 result = cond(pred, true_fn, false_fn, (x,)) 346 self.assertEqual(result, fn(x)) 347 348 grad_out = torch.ones_like(result) 349 grads = torch.autograd.grad(result, (x,), grad_out) 350 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 351 self.assertEqual(expected_grads, grads) 352 353 def f(pred, x): 354 result = cond(pred, true_fn, false_fn, (x,)) 355 grad_out = torch.ones_like(result) 356 return torch.autograd.grad(result, (x,), grad_out) 357 358 gm = make_fx(f, tracing_mode="symbolic")(pred, x) 359 self.assertExpectedInline( 360 gm.code.strip(), 361 """\ 362def forward(self, pred_1, x_1): 363 true_graph_0 = self.true_graph_0 364 false_graph_0 = self.false_graph_0 365 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None 366 getitem = cond[0]; cond = None 367 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 368 true_graph_1 = self.true_graph_1 369 false_graph_1 = self.false_graph_1 370 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None 371 getitem_1 = cond_1[0]; cond_1 = None 372 return (getitem_1,)""", # noqa: B950 373 ) 374 375 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 376 def test_cond_autograd_nested(self): 377 class Nested(torch.nn.Module): 378 def forward(self, p0, p1, p2, a, b, c): 379 def true_fn(x0, y0, z0): 380 def true_true_fn(x1, y1, z1): 381 return (x1 - y1 * z1) * 3.14 382 383 def true_false_fn(x1, y1, z1): 384 def true_false_true_fn(x2, y2, z2): 385 return (x2 * y2 * z2) / 2.71 386 387 def true_false_false_fn(x2, y2, z2): 388 return (x2 + y2 + z2) * 1.23 389 390 return torch.cond( 391 p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1] 392 ) 393 394 return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0]) 395 396 def false_fn(x0, y0, z0): 397 def false_true_fn(x1, y1, z1): 398 def false_true_true_fn(x2, y2, z2): 399 return (x2 - y2 - z2) + 1.23 400 401 def false_true_false_fn(x2, y2, z2): 402 return (x2 / y2 / z2) - 3.14 403 404 return torch.cond( 405 p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1] 406 ) 407 408 def false_false_fn(x1, y1, z1): 409 return (x1 - y1 * z1) / 2.71 410 411 return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0]) 412 413 return torch.cond(p0, true_fn, false_fn, [a, b, c]) 414 415 nn_module = Nested() 416 417 def true_fn(x): 418 return nn_module( 419 torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x 420 ) 421 422 def false_fn(x): 423 return nn_module( 424 torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x 425 ) 426 427 x = torch.randn(4, requires_grad=True) 428 429 for pred, fn in zip( 430 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 431 ): 432 result = cond(pred, true_fn, false_fn, (x,)) 433 self.assertEqual(result, fn(x)) 434 435 grad_out = torch.ones_like(result) 436 grads = torch.autograd.grad(result, (x,), grad_out) 437 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 438 self.assertEqual(expected_grads, grads) 439 440 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 441 def test_cond_autograd_mixed_require_grad(self): 442 def true_fn(x, y, z): 443 return x * y * z 444 445 def false_fn(x, y, z): 446 return x + y + z 447 448 x = torch.randn(4, requires_grad=True) 449 y = torch.randn(4, requires_grad=False) 450 451 for pred, fn in zip( 452 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 453 ): 454 result = cond(pred, true_fn, false_fn, (x, y, x)) 455 self.assertEqual(result, fn(x, y, x)) 456 457 grad_out = torch.ones_like(result) 458 grads = torch.autograd.grad(result, (x,), grad_out) 459 expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out) 460 self.assertEqual(expected_grads, grads) 461 462 def f(pred, x, y, z): 463 result = cond(pred, true_fn, false_fn, (x, y, z)) 464 grad_out = torch.ones_like(result) 465 return torch.autograd.grad(result, (x,), grad_out) 466 467 gm = make_fx(f, tracing_mode="symbolic")(pred, x, y, x) 468 self.assertExpectedInline( 469 gm.code.strip(), 470 """\ 471def forward(self, pred_1, x_1, y_1, z_1): 472 true_graph_0 = self.true_graph_0 473 false_graph_0 = self.false_graph_0 474 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None 475 getitem = cond[0]; cond = None 476 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 477 true_graph_1 = self.true_graph_1 478 false_graph_1 = self.false_graph_1 479 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None 480 getitem_1 = cond_1[0] 481 getitem_2 = cond_1[1]; cond_1 = getitem_2 = None 482 return (getitem_1,)""", # noqa: B950 483 ) 484 485 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 486 def test_cond_autograd_grad_through_cond(self): 487 nn_module = torch.nn.Linear(4, 4) 488 489 def true_fn(x): 490 return nn_module(x) 491 492 def false_fn(X): 493 return x * nn_module(x) 494 495 x = torch.randn(4, requires_grad=True) 496 497 for pred, fn in zip( 498 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 499 ): 500 result = cond(pred, true_fn, false_fn, (x,)) 501 self.assertEqual(result, fn(x)) 502 503 grad_out = torch.ones_like(result) 504 grads = torch.autograd.grad(result, (nn_module.weight,), grad_out) 505 expected_grads = torch.autograd.grad( 506 fn( 507 x, 508 ), 509 (nn_module.weight,), 510 grad_out, 511 ) 512 self.assertEqual(expected_grads, grads) 513 514 def f(pred, x): 515 result = cond(pred, true_fn, false_fn, (x,)) 516 grad_out = torch.ones_like(result) 517 return torch.autograd.grad(result, (nn_module.weight,), grad_out) 518 519 # need to set _allow_non_fake_inputs = True because model parameters don't 520 # get fakified. 521 gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x) 522 self.assertExpectedInline( 523 gm.code.strip(), 524 """\ 525def forward(self, pred_1, x_1): 526 true_graph_0 = self.true_graph_0 527 false_graph_0 = self.false_graph_0 528 _param_constant0 = self._param_constant0 529 _param_constant1 = self._param_constant1 530 _tensor_constant0 = self._tensor_constant0 531 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None 532 getitem = cond[0]; cond = None 533 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 534 true_graph_1 = self.true_graph_1 535 false_graph_1 = self.false_graph_1 536 _param_constant0_1 = self._param_constant0 537 _param_constant1_1 = self._param_constant1 538 _tensor_constant0_1 = self._tensor_constant0 539 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None 540 getitem_1 = cond_1[0]; getitem_1 = None 541 getitem_2 = cond_1[1] 542 getitem_3 = cond_1[2]; getitem_3 = None 543 getitem_4 = cond_1[3]; cond_1 = getitem_4 = None 544 return (getitem_2,)""", # noqa: B950 545 ) 546 547 def test_cond_in_forloop(self): 548 def for_loop_fake(x): 549 for i in range(3): 550 x = x * x + 1 551 return x 552 553 def for_loop_test(x): 554 for i in range(3): 555 pred = i < 3 556 557 def true_fn(x): 558 return x * x + 1 559 560 def false_fn(x): 561 return x 562 563 x = cond(pred, true_fn, false_fn, (x,)) 564 565 return x 566 567 x = torch.ones(4, requires_grad=True) 568 x_new = for_loop_test(x) 569 x_exp = for_loop_fake(x) 570 571 self.assertEqual(x_new, x_exp) 572 573 grad_out = torch.ones_like(x_new) 574 grads = torch.autograd.grad(x_new, (x,), grad_out) 575 expected_grads = torch.autograd.grad(x_exp, (x,), grad_out) 576 self.assertEqual(expected_grads, grads) 577 578 def f(x): 579 x_new = for_loop_test(x) 580 grad_out = torch.ones_like(x_new) 581 return torch.autograd.grad(x_new, (x,), grad_out) 582 583 gm = make_fx(f, tracing_mode="symbolic")(x) 584 self.assertExpectedInline( 585 gm.code.strip(), 586 """\ 587def forward(self, x_1): 588 mul = torch.ops.aten.mul.Tensor(x_1, x_1) 589 add = torch.ops.aten.add.Tensor(mul, 1); mul = None 590 mul_1 = torch.ops.aten.mul.Tensor(add, add) 591 add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None 592 mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1) 593 add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None 594 ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None 595 mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1) 596 mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None 597 add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None 598 mul_5 = torch.ops.aten.mul.Tensor(add_3, add) 599 mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None 600 add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None 601 mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1) 602 mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None 603 add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None 604 return (add_5,)""", # noqa: B950 605 ) 606 607 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 608 def test_cond_autograd_pytree_not_all_inputs_used(self): 609 def true_fn(x): 610 return x["t"][0] + x["t"][1]["b"] 611 612 def false_fn(x): 613 return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"]) 614 615 a = torch.randn(4, requires_grad=True) 616 b = torch.randn(4, requires_grad=True) 617 c = torch.randn(4, requires_grad=True) 618 619 for pred, fn in zip( 620 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 621 ): 622 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 623 self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]})) 624 625 grad_out = torch.ones_like(result) 626 if pred: 627 with self.assertRaisesRegex(Exception, r"."): 628 grads = torch.autograd.grad(result, (a, b, c), grad_out) 629 expected_grads = torch.autograd.grad( 630 fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out 631 ) 632 self.assertEqual(expected_grads, grads) 633 634 def f(pred, a, b, c): 635 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 636 grad_out = torch.ones_like(result) 637 return torch.autograd.grad(result, (a, b), grad_out) 638 639 gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)( 640 pred, a, b, c 641 ) 642 self.assertExpectedInline( 643 gm.code.strip(), 644 """\ 645def forward(self, pred_1, a_1, b_1, c_1): 646 true_graph_0 = self.true_graph_0 647 false_graph_0 = self.false_graph_0 648 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None 649 getitem = cond[0]; cond = None 650 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 651 true_graph_1 = self.true_graph_1 652 false_graph_1 = self.false_graph_1 653 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None 654 getitem_1 = cond_1[0] 655 getitem_2 = cond_1[1] 656 getitem_3 = cond_1[2]; cond_1 = getitem_3 = None 657 return (getitem_1, getitem_2)""", # noqa: B950 658 ) 659 # Forward 660 self.assertExpectedInline( 661 gm.true_graph_0.code.strip(), 662 """\ 663def forward(self, arg0_1, arg1_1, arg2_1): 664 add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 665 return (add,)""", 666 ) 667 # Backward 668 self.assertExpectedInline( 669 gm.true_graph_1.code.strip(), 670 """\ 671def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 672 add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None 673 clone = torch.ops.aten.clone.default(arg0_1) 674 clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 675 return [clone, clone_1, None]""", 676 ) 677 678 def test_cond_autograd_pytree_input(self): 679 def true_fn(x): 680 return x["t"][0] + x["t"][1]["b"] * x["t"][2][0] 681 682 def false_fn(x): 683 return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"]) 684 685 a = torch.randn(4, requires_grad=True) 686 b = torch.randn(4, requires_grad=True) 687 c = torch.randn(4, requires_grad=True) 688 689 for pred, fn in zip( 690 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 691 ): 692 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 693 self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]})) 694 695 grad_out = torch.ones_like(result) 696 grads = torch.autograd.grad(result, (a, b), grad_out) 697 expected_grads = torch.autograd.grad( 698 fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out 699 ) 700 self.assertEqual(expected_grads, grads) 701 702 def f(pred): 703 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 704 grad_out = torch.ones_like(result) 705 return torch.autograd.grad(result, (a, b), grad_out) 706 707 # need to set _allow_non_fake_inputs = True because model parameters don't 708 # get fakified. 709 gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred) 710 self.assertExpectedInline( 711 gm.code.strip(), 712 """\ 713def forward(self, pred_1): 714 true_graph_0 = self.true_graph_0 715 false_graph_0 = self.false_graph_0 716 _tensor_constant0 = self._tensor_constant0 717 _tensor_constant1 = self._tensor_constant1 718 _tensor_constant2 = self._tensor_constant2 719 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None 720 getitem = cond[0]; cond = None 721 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 722 true_graph_1 = self.true_graph_1 723 false_graph_1 = self.false_graph_1 724 _tensor_constant0_1 = self._tensor_constant0 725 _tensor_constant1_1 = self._tensor_constant1 726 _tensor_constant2_1 = self._tensor_constant2 727 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None 728 getitem_1 = cond_1[0] 729 getitem_2 = cond_1[1] 730 getitem_3 = cond_1[2]; cond_1 = getitem_3 = None 731 return (getitem_1, getitem_2)""", # noqa: B950 732 ) 733 734 def test_cond_autograd_different_pytree_output(self): 735 def true_fn(x): 736 return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]] 737 738 def false_fn(x): 739 return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]} 740 741 a = torch.randn(4, requires_grad=True) 742 b = torch.randn(4, requires_grad=True) 743 c = torch.randn(4, requires_grad=True) 744 745 for pred, fn in zip( 746 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 747 ): 748 with self.assertRaisesRegex( 749 torch._dynamo.exc.UncapturedHigherOrderOpError, 750 "Cond doesn't work unless it is captured completely with torch.compile", 751 ): 752 cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 753 754 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 755 def test_cond_autograd_same_pytree_output(self): 756 def true_fn(x): 757 return {"res": [x["t"][0], (x["t"][2][0],)]} 758 759 def false_fn(x): 760 return {"res": [x["t"][1]["b"], (x["t"][2][0],)]} 761 762 a = torch.randn(4, requires_grad=True) 763 b = torch.randn(4, requires_grad=True) 764 c = torch.randn(4, requires_grad=True) 765 766 for pred, fn in zip( 767 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 768 ): 769 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 770 result_exp = fn({"t": [a, {"b": b}, (c,)]}) 771 self.assertEqual(result, result_exp) 772 773 result_flat, _ = pytree.tree_flatten(result) 774 result_exp_flat, _ = pytree.tree_flatten(result_exp) 775 776 grad_out = [torch.ones_like(g) for g in result_flat] 777 expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out) 778 grads = torch.autograd.grad(result_flat, (c,), grad_out) 779 self.assertEqual(expected_grads, grads) 780 781 def f(pred): 782 result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) 783 return result 784 785 gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred) 786 self.assertExpectedInline( 787 gm.code.strip(), 788 """\ 789def forward(self, pred_1): 790 true_graph_0 = self.true_graph_0 791 false_graph_0 = self.false_graph_0 792 _tensor_constant0 = self._tensor_constant0 793 _tensor_constant1 = self._tensor_constant1 794 _tensor_constant2 = self._tensor_constant2 795 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None 796 getitem = cond[0] 797 getitem_1 = cond[1]; cond = None 798 view = torch.ops.aten.view.default(getitem, [4]); getitem = None 799 view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None 800 return {'res': [view, (view_1,)]}""", # noqa: B950 801 ) 802 803 @skipIfTorchDynamo("Skip due to graph break when run with dynamo") 804 def test_cond_autograd_torch_nn_module(self): 805 nn_module_true = torch.nn.Linear(4, 4) 806 807 def true_fn(x): 808 return nn_module_true(torch.abs((x**2).sin())) 809 810 nn_module_false = torch.nn.GRUCell(4, 4) 811 812 def false_fn(x): 813 return nn_module_false((x + 42).cos()) 814 815 for pred, fn in zip( 816 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 817 ): 818 x = torch.randn(4, requires_grad=True) 819 result = cond(pred, true_fn, false_fn, (x,)) 820 self.assertEqual(result, fn(x)) 821 822 grad_out = torch.ones_like(result) 823 grads = torch.autograd.grad(result, (x,), grad_out) 824 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 825 self.assertEqual(expected_grads, grads) 826 827 def f(pred, x): 828 result = cond(pred, true_fn, false_fn, (x,)) 829 grad_out = torch.ones_like(result) 830 return torch.autograd.grad(result, (x,), grad_out) 831 832 gm = make_fx(f)(pred, x) 833 self.assertExpectedInline( 834 gm.code.strip(), 835 """\ 836def forward(self, pred_1, x_1): 837 true_graph_0 = self.true_graph_0 838 false_graph_0 = self.false_graph_0 839 _param_constant0 = self._param_constant0 840 _param_constant1 = self._param_constant1 841 _param_constant2 = self._param_constant2 842 _param_constant3 = self._param_constant3 843 _param_constant4 = self._param_constant4 844 _param_constant5 = self._param_constant5 845 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None 846 getitem = cond[0]; cond = None 847 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 848 true_graph_1 = self.true_graph_1 849 false_graph_1 = self.false_graph_1 850 _param_constant0_1 = self._param_constant0 851 _param_constant1_1 = self._param_constant1 852 _param_constant2_1 = self._param_constant2 853 _param_constant3_1 = self._param_constant3 854 _param_constant4_1 = self._param_constant4 855 _param_constant5_1 = self._param_constant5 856 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None 857 getitem_1 = cond_1[0] 858 getitem_2 = cond_1[1]; getitem_2 = None 859 getitem_3 = cond_1[2]; getitem_3 = None 860 getitem_4 = cond_1[3]; getitem_4 = None 861 getitem_5 = cond_1[4]; getitem_5 = None 862 getitem_6 = cond_1[5]; getitem_6 = None 863 getitem_7 = cond_1[6]; cond_1 = getitem_7 = None 864 return (getitem_1,)""", # noqa: B950 865 ) 866 867 def test_cond_autograd_user_nn_module(self): 868 class User_nn_module(torch.nn.Module): 869 def __init__(self) -> None: 870 super().__init__() 871 872 def forward(self, input): 873 return input * input 874 875 nn_module_true = User_nn_module() 876 877 def true_fn(x): 878 return nn_module_true(torch.abs((x**2).sin())) 879 880 nn_module_false = torch.nn.ReLU(inplace=False) 881 882 def false_fn(x): 883 return nn_module_false((x + 42).cos()) 884 885 for pred, fn in zip( 886 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 887 ): 888 x = torch.randn(4, requires_grad=True) 889 result = cond(pred, true_fn, false_fn, (x,)) 890 self.assertEqual(result, fn(x)) 891 892 grad_out = torch.ones_like(result) 893 grads = torch.autograd.grad(result, (x,), grad_out) 894 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 895 self.assertEqual(expected_grads, grads) 896 897 def f(pred, x): 898 result = cond(pred, true_fn, false_fn, (x,)) 899 grad_out = torch.ones_like(result) 900 return torch.autograd.grad(result, (x,), grad_out) 901 902 gm = make_fx(f)(pred, x) 903 self.assertExpectedInline( 904 gm.code.strip(), 905 """\ 906def forward(self, pred_1, x_1): 907 true_graph_0 = self.true_graph_0 908 false_graph_0 = self.false_graph_0 909 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None 910 getitem = cond[0]; cond = None 911 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 912 true_graph_1 = self.true_graph_1 913 false_graph_1 = self.false_graph_1 914 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None 915 getitem_1 = cond_1[0]; cond_1 = None 916 return (getitem_1,)""", # noqa: B950 917 ) 918 919 def test_cond_autograd_inner_fn(self): 920 def true_fn(x): 921 return torch.abs((x**2).sin()) 922 923 def false_fn(x): 924 def inner_fn(x): 925 return x**2 926 927 return torch.abs(inner_fn(x).sin()) 928 929 x = torch.randn(4, requires_grad=True) 930 pred = torch.tensor(False) 931 fn = false_fn 932 result_false = cond(pred, true_fn, false_fn, (x,)) 933 self.assertEqual(result_false, fn(x)) 934 935 grad_out = torch.ones_like(result_false) 936 grads_false = torch.autograd.grad(result_false, (x,), grad_out) 937 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 938 self.assertEqual(expected_grads, grads_false) 939 940 pred = torch.tensor(True) 941 fn = true_fn 942 result_true = cond(pred, true_fn, false_fn, (x,)) 943 self.assertEqual(result_true, fn(x)) 944 self.assertEqual(result_false, result_true) 945 946 grad_out = torch.ones_like(result_true) 947 grads_true = torch.autograd.grad(result_true, (x,), grad_out) 948 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 949 self.assertEqual(expected_grads, grads_true) 950 self.assertEqual(grads_false, grads_true) 951 952 def f(pred, x): 953 result = cond(pred, true_fn, false_fn, (x,)) 954 grad_out = torch.ones_like(result) 955 return torch.autograd.grad(result, (x,), grad_out) 956 957 gm = make_fx(f)(pred, x) 958 self.assertExpectedInline( 959 gm.code.strip(), 960 """\ 961def forward(self, pred_1, x_1): 962 true_graph_0 = self.true_graph_0 963 false_graph_0 = self.false_graph_0 964 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None 965 getitem = cond[0]; cond = None 966 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 967 true_graph_1 = self.true_graph_1 968 false_graph_1 = self.false_graph_1 969 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None 970 getitem_1 = cond_1[0]; cond_1 = None 971 return (getitem_1,)""", # noqa: B950 972 ) 973 974 def test_cond_autograd_inner_tensor(self): 975 def true_fn(x): 976 return torch.abs((x**2).sin()) 977 978 def false_fn(x): 979 y = torch.ones(4, requires_grad=False) * 42 980 return (x * y).cos() 981 982 for pred, fn in zip( 983 [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] 984 ): 985 x = torch.randn(4, requires_grad=True) 986 result = cond(pred, true_fn, false_fn, (x,)) 987 self.assertEqual(result, fn(x)) 988 989 grad_out = torch.ones_like(result) 990 grads = torch.autograd.grad(result, (x,), grad_out) 991 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 992 self.assertEqual(expected_grads, grads) 993 994 def f(pred, x): 995 result = cond(pred, true_fn, false_fn, (x,)) 996 grad_out = torch.ones_like(result) 997 return torch.autograd.grad(result, (x,), grad_out) 998 999 gm = make_fx(f, tracing_mode="symbolic")(pred, x) 1000 self.assertExpectedInline( 1001 gm.code.strip(), 1002 """\ 1003def forward(self, pred_1, x_1): 1004 true_graph_0 = self.true_graph_0 1005 false_graph_0 = self.false_graph_0 1006 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None 1007 getitem = cond[0]; cond = None 1008 ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None 1009 true_graph_1 = self.true_graph_1 1010 false_graph_1 = self.false_graph_1 1011 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None 1012 getitem_1 = cond_1[0]; cond_1 = None 1013 return (getitem_1,)""", # noqa: B950 1014 ) 1015 1016 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1017 def test_cond_autograd_gpu(self): 1018 def true_fn(x): 1019 return x.sin() 1020 1021 def false_fn(x): 1022 return x.cos() 1023 1024 for pred, fn in zip( 1025 [torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")], 1026 [false_fn, true_fn], 1027 ): 1028 x = torch.randn(4, requires_grad=True, device="cuda") 1029 result = cond(pred, true_fn, false_fn, (x,)) 1030 self.assertEqual(result, fn(x)) 1031 1032 grad_out = torch.ones_like(result) 1033 grads = torch.autograd.grad(result, (x,), grad_out) 1034 expected_grads = torch.autograd.grad(fn(x), (x,), grad_out) 1035 self.assertEqual(expected_grads, grads) 1036 1037 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1038 def test_map_gpu(self): 1039 def f(x, y): 1040 return x + y 1041 1042 xs = torch.ones(3, 2, 2, device="cuda") 1043 y = torch.ones(2, device="cuda") 1044 res = control_flow.map(f, xs, y) 1045 expected = _fake_map(f, xs, y) 1046 self.assertEqual(expected, res) 1047 1048 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1049 def test_while_loop_gpu(self): 1050 def cond_fn(x): 1051 return x.sum() < 10 1052 1053 def body_fn(x): 1054 return (x + 1,) 1055 1056 x = torch.zeros(1, device="cuda") 1057 res = while_loop(cond_fn, body_fn, (x,)) 1058 expected = _fake_while_loop(cond_fn, body_fn, (x,)) 1059 self.assertEqual(expected, res) 1060 1061 def test_map_illegal_inputs(self): 1062 def f(x, y): 1063 return x[0] + x[1] + y 1064 1065 with self.assertRaisesRegex( 1066 RuntimeError, 1067 r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\.", 1068 ): 1069 _ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2)) 1070 1071 with self.assertRaisesRegex( 1072 RuntimeError, r"Leading dimensions of mapped xs cannot be 0\." 1073 ): 1074 _ = control_flow.map( 1075 f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2) 1076 ) 1077 1078 with self.assertRaisesRegex( 1079 RuntimeError, 1080 r"Leading dimensions of mapped xs must be consistent\. " 1081 r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\.", 1082 ): 1083 _ = control_flow.map( 1084 f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5) 1085 ) 1086 1087 def test_map_illegal_outputs(self): 1088 def f(x, y): 1089 return x.item() 1090 1091 def f1(x, y): 1092 return y.size() 1093 1094 def f2(x, y): 1095 return None 1096 1097 x = torch.ones([3]) 1098 y = torch.ones([1, 2, 3]) 1099 with self.assertRaisesRegex( 1100 RuntimeError, r"Expect outputs of map only contains tensors or None\." 1101 ): 1102 _ = control_flow.map(f, x, y) 1103 1104 with self.assertRaisesRegex( 1105 RuntimeError, r"Expect outputs of map only contains tensors or None\." 1106 ): 1107 out = control_flow.map(f1, x, y) 1108 1109 # return None is OK 1110 _ = control_flow.map(f2, x, y) 1111 1112 def test_map_list_in_out(self): 1113 def f(x, y): 1114 return [[x[0][0] + y]] 1115 1116 xs = [[torch.ones(3, 2, 2)]] 1117 y = torch.ones(2) 1118 res = control_flow.map(f, xs, y) 1119 expected = _fake_map(f, xs, y) 1120 self.assertEqual(len(res), 1) 1121 self.assertEqual(len(res[0]), 1) 1122 self.assertEqual(expected, res) 1123 1124 def test_map_dict_in_out(self): 1125 def f(x, y): 1126 return {"c": x["a"]["b"] + y} 1127 1128 xs = {"a": {"b": torch.ones(3, 2, 2)}} 1129 y = torch.ones(2) 1130 res = control_flow.map(f, xs, y) 1131 expected = _fake_map(f, xs, y) 1132 self.assertEqual(len(res), 1) 1133 self.assertTrue("c" in res) 1134 self.assertEqual(expected, res) 1135 1136 def test_map_autograd_simple(self): 1137 def f(x, y): 1138 return x.sin().cos() * y.cos().sin() 1139 1140 xs = torch.ones(3, 2, 2, requires_grad=True) 1141 y = torch.ones(2, requires_grad=True) 1142 res = control_flow.map(f, xs, y) 1143 expected_res = _fake_map(f, xs, y) 1144 grad_out = torch.ones_like(res) 1145 grads = torch.autograd.grad(res, (xs, y), grad_out) 1146 expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out) 1147 self.assertEqual(expected_res, res) 1148 self.assertEqual(expected_grads, grads) 1149 1150 def test_map_autograd_simple_partial_grad(self): 1151 def f(x, y): 1152 return x.sin().cos() * y.cos().sin() 1153 1154 xs = torch.ones(3, 2, 2, requires_grad=True) 1155 # Disable the gradient computation for y 1156 y = torch.ones(2, requires_grad=False) 1157 res = control_flow.map(f, xs, y) 1158 expected_res = _fake_map(f, xs, y) 1159 grad_out = torch.ones_like(res) 1160 grads = torch.autograd.grad(res, (xs,), grad_out) 1161 expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out) 1162 self.assertEqual(expected_res, res) 1163 self.assertEqual(expected_grads, grads) 1164 1165 def test_map_autograd_no_grad_output(self): 1166 def f(x, y): 1167 return x[0].sin().cos() + y, y.cos().sin() 1168 1169 xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)] 1170 # Disable the gradient computation for y 1171 y = torch.ones(2, requires_grad=False) 1172 res = control_flow.map(f, xs, y) 1173 expected_res = _fake_map(f, xs, y) 1174 grad_out = torch.ones_like(res[0]) 1175 grads = torch.autograd.grad(res[0], (xs[0],), grad_out) 1176 expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out) 1177 self.assertEqual(expected_res, res) 1178 self.assertEqual(expected_grads, grads) 1179 1180 def test_map_autograd_nested_list(self): 1181 import torch.utils._pytree as pytree 1182 1183 def f(x, y): 1184 a, b = x 1185 c, d = a 1186 return [[b.sin() * c.cos()], d.sin() * y.cos()] 1187 1188 def fwbw(map_op, f, x, y): 1189 z = map_op(f, x, y) 1190 flat_x = pytree.tree_leaves(x) 1191 flat_z = pytree.tree_leaves(z) 1192 grads = torch.autograd.grad( 1193 flat_z, flat_x, [torch.ones_like(z) for z in flat_z] 1194 ) 1195 return z, grads 1196 1197 x = [ 1198 [ 1199 torch.randn(3, 2, 2, requires_grad=True), 1200 torch.randn(3, 2, 1, requires_grad=True), 1201 ], 1202 torch.ones(3, 1, 2, requires_grad=True), 1203 ] 1204 y = torch.ones(1, requires_grad=True) 1205 true_outs = fwbw(control_flow.map, f, x, y) 1206 fake_outs = fwbw(_fake_map, f, x, y) 1207 self.assertEqual(true_outs, fake_outs) 1208 1209 @unittest.skipIf(not SM70OrLater, "triton") 1210 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1211 @parametrize("reverse", [False, True]) 1212 @parametrize("combine_mode", ["pointwise", "generic"]) 1213 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1214 # Skipping the combination of combine_mode=pointwise and device=cpu 1215 # as the current implementation of pointwise does only support CUDA device 1216 @decorateIf( 1217 unittest.skip, 1218 lambda params: ( 1219 params["combine_mode"] == "pointwise" 1220 and (params["device"] == torch.device("cpu") or torch.version.hip) 1221 ), 1222 ) 1223 def test_pointwise_associative_scan_simple(self, reverse, combine_mode, device): 1224 def add(x: torch.Tensor, y: torch.Tensor): 1225 return x + y 1226 1227 def mul(x: torch.Tensor, y: torch.Tensor): 1228 return x * y 1229 1230 x = torch.randn(3, 10, 2, device=device) 1231 1232 for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: 1233 result = associative_scan( 1234 op, x, 0, reverse=reverse, combine_mode=combine_mode 1235 ) 1236 result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) 1237 self.assertEqual(result, result_exp) 1238 1239 # Jax Examples 1240 x = torch.arange(0, 4, device=device) 1241 cumsum1 = associative_scan( 1242 add, x, 0, reverse=reverse, combine_mode=combine_mode 1243 ) 1244 cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) 1245 if not reverse: 1246 self.assertEqual( 1247 cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) 1248 ) 1249 else: 1250 self.assertEqual( 1251 cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) 1252 ) 1253 self.assertEqual(cumsum1, cumsum_exp) 1254 1255 @unittest.skipIf(not SM70OrLater, "triton") 1256 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1257 @parametrize("reverse", [False, True]) 1258 @parametrize("combine_mode", ["pointwise", "generic"]) 1259 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1260 # Skipping the combination of combine_mode=pointwise and device=cpu 1261 # as the current implementation of pointwise does only support CUDA device 1262 @decorateIf( 1263 unittest.skip, 1264 lambda params: ( 1265 params["combine_mode"] == "pointwise" 1266 and (params["device"] == torch.device("cpu") or torch.version.hip) 1267 ), 1268 ) 1269 def test_pointwise_associative_scan_dim(self, reverse, combine_mode, device): 1270 import random 1271 1272 def add(x: torch.Tensor, y: torch.Tensor): 1273 return x + y 1274 1275 def mul(x: torch.Tensor, y: torch.Tensor): 1276 return x * y 1277 1278 num_dims = [random.randint(2, 5) for _ in range(10)] 1279 for num_dim in num_dims: 1280 shapes = [random.randint(1, 10) for _ in range(num_dim)] 1281 rnd_scan_dim = random.randint(0, num_dim - 1) 1282 x = torch.randn(*shapes, device=device) 1283 1284 for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: 1285 result = associative_scan( 1286 op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode 1287 ) 1288 result_exp = _fake_associative_scan( 1289 op, x, rnd_scan_dim, reverse=reverse 1290 ) 1291 self.assertEqual(result, result_exp) 1292 if not reverse: 1293 result_exp_PT = op_pt(x, rnd_scan_dim) 1294 self.assertEqual(result, result_exp_PT) 1295 1296 @unittest.skipIf(not SM70OrLater, "triton") 1297 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1298 @parametrize("reverse", [False, True]) 1299 @parametrize("combine_mode", ["pointwise", "generic"]) 1300 @parametrize("compile_mode", ["compile", "compile_dynamic_shape"]) 1301 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1302 # Skipping the combination of combine_mode=pointwise and device=cpu 1303 # as the current implementation of pointwise does only support CUDA device 1304 @decorateIf( 1305 unittest.skip, 1306 lambda params: ( 1307 params["combine_mode"] == "pointwise" 1308 and (params["device"] == torch.device("cpu") or torch.version.hip) 1309 ), 1310 ) 1311 def test_pointwise_associative_scan_compile( 1312 self, reverse, combine_mode, compile_mode, device 1313 ): 1314 def add(x: torch.Tensor, y: torch.Tensor): 1315 return x + y 1316 1317 def mul(x: torch.Tensor, y: torch.Tensor): 1318 return x * y 1319 1320 x = torch.randn(3, 10, 2, device=device) 1321 torch.compiler.reset() 1322 if compile_mode == "compile": 1323 associative_scan_fct = torch.compile( 1324 associative_scan, fullgraph=True, dynamic=False 1325 ) 1326 else: 1327 associative_scan_fct = torch.compile( 1328 associative_scan, fullgraph=True, dynamic=True 1329 ) 1330 1331 for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: 1332 result = associative_scan_fct( 1333 op, x, 0, reverse=reverse, combine_mode=combine_mode 1334 ) 1335 result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) 1336 self.assertEqual(result, result_exp) 1337 if not reverse: 1338 result_exp_PT = op_pt(x, 0) 1339 self.assertEqual(result, result_exp_PT) 1340 1341 # Jax Examples 1342 x = torch.arange(0, 4, device=device) 1343 cumsum1 = associative_scan( 1344 add, x, 0, reverse=reverse, combine_mode=combine_mode 1345 ) 1346 cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) 1347 if not reverse: 1348 self.assertEqual( 1349 cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) 1350 ) 1351 else: 1352 self.assertEqual( 1353 cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) 1354 ) 1355 self.assertEqual(cumsum1, cumsum_exp) 1356 1357 @skipIfRocm(msg="Unsupported on ROCM yet") 1358 @unittest.skipIf(not SM70OrLater, "triton") 1359 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1360 @parametrize("reverse", [False, True]) 1361 @parametrize("combine_mode", ["pointwise", "generic"]) 1362 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1363 # Skipping the combination of combine_mode=pointwise and device=cpu 1364 # as the current implementation of pointwise does only support CUDA device 1365 @decorateIf( 1366 unittest.skip, 1367 lambda params: ( 1368 params["combine_mode"] == "pointwise" 1369 and (params["device"] == torch.device("cpu") or torch.version.hip) 1370 ), 1371 ) 1372 def test_pointwise_associative_scan_binary_operator( 1373 self, reverse, combine_mode, device 1374 ): 1375 def fct(x, y): 1376 A_i, Bu_i = x 1377 A_j, Bu_j = y 1378 return A_j * A_i, A_j * Bu_i + Bu_j 1379 1380 torch.compiler.reset() 1381 associative_scan1 = torch.compile(associative_scan, fullgraph=True) 1382 associative_scan2 = associative_scan 1383 1384 state_dim = 20 1385 timesteps = 10 1386 projected_inputs = torch.randn( 1387 timesteps, state_dim, requires_grad=True, device=device 1388 ) 1389 A = torch.randn(state_dim, requires_grad=True, device=device) 1390 elements = (A.repeat((timesteps, 1)), projected_inputs) 1391 1392 result1 = associative_scan1( 1393 fct, elements, 0, combine_mode=combine_mode, reverse=reverse 1394 ) 1395 result2 = associative_scan2( 1396 fct, elements, 0, combine_mode=combine_mode, reverse=reverse 1397 ) 1398 expected_result = _fake_associative_scan(fct, elements, 0, reverse=reverse) 1399 self.assertEqual( 1400 result1, 1401 expected_result, 1402 ) 1403 self.assertEqual([r.device.type for r in result1], [device.type] * len(result1)) 1404 self.assertEqual( 1405 result2, 1406 expected_result, 1407 ) 1408 self.assertEqual([r.device.type for r in result2], [device.type] * len(result2)) 1409 1410 @skipIfRocm(msg="Unsupported on ROCM yet") 1411 @unittest.skipIf(not SM70OrLater, "triton") 1412 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1413 @parametrize("reverse", [False, True]) 1414 @parametrize("combine_mode", ["pointwise", "generic"]) 1415 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1416 # Skipping the combination of combine_mode=pointwise and device=cpu 1417 # as the current implementation of pointwise does only support CUDA device 1418 @decorateIf( 1419 unittest.skip, 1420 lambda params: ( 1421 params["combine_mode"] == "pointwise" 1422 and (params["device"] == torch.device("cpu") or torch.version.hip) 1423 ), 1424 ) 1425 def test_pointwise_associative_scan_tuple(self, reverse, combine_mode, device): 1426 def fct(x, y): 1427 return (x[0] + y[0], x[1] * y[1]) 1428 1429 x = torch.randn(3, 2, 2, device=device, requires_grad=True) 1430 y = torch.randn(3, 2, 2, device=device, requires_grad=True) 1431 inp = (x, y) 1432 1433 result1 = associative_scan( 1434 fct, inp, 0, reverse=reverse, combine_mode=combine_mode 1435 ) 1436 expected_result = _fake_associative_scan(fct, inp, 0, reverse=reverse) 1437 self.assertEqual(result1, expected_result) 1438 1439 @skipIfRocm(msg="Unsupported on ROCM yet") 1440 @unittest.skipIf(not SM70OrLater, "triton") 1441 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1442 @parametrize("reverse", [False, True]) 1443 @parametrize("combine_mode", ["pointwise", "generic"]) 1444 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1445 # Skipping the combination of combine_mode=pointwise and device=cpu 1446 # as the current implementation of pointwise does only support CUDA device 1447 @decorateIf( 1448 unittest.skip, 1449 lambda params: ( 1450 params["combine_mode"] == "pointwise" 1451 and (params["device"] == torch.device("cpu") or torch.version.hip) 1452 ), 1453 ) 1454 def test_pointwise_associative_scan_complex_pytree( 1455 self, reverse, combine_mode, device 1456 ): 1457 def fct_wrong_pytree(x, y): 1458 return { 1459 "i": x["i"] * y["j"][0][0], 1460 "k": 0.0, 1461 "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), 1462 } 1463 1464 def fct_pointwise(x, y): 1465 return { 1466 "i": x["i"] * y["i"], 1467 "j": ( 1468 [x["j"][0][0] * y["j"][0][0]], 1469 [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], 1470 ), 1471 } 1472 1473 x = torch.randn(3, 2, 2, device=device, requires_grad=True) 1474 y = torch.randn(3, 2, 2, device=device, requires_grad=True) 1475 z = torch.randn(3, 2, 2, device=device, requires_grad=True) 1476 inp = {"i": x, "j": ([y], [{"o": z}])} 1477 1478 with self.assertRaisesRegex(Exception, r"."): 1479 result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic") 1480 1481 torch.compiler.reset() 1482 associative_scan1 = torch.compile(associative_scan, fullgraph=True) 1483 associative_scan2 = associative_scan 1484 1485 result1 = associative_scan1( 1486 fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse 1487 ) 1488 result2 = associative_scan2( 1489 fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse 1490 ) 1491 expected_result = _fake_associative_scan(fct_pointwise, inp, 0, reverse=reverse) 1492 self.assertEqual(result1, expected_result) 1493 self.assertEqual(result2, expected_result) 1494 1495 @unittest.skipIf(not SM70OrLater, "triton") 1496 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") 1497 @parametrize("reverse", [False, True]) 1498 @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) 1499 def test_generic_associative_scan_generic_simple(self, reverse, device): 1500 def non_pointwise(x: torch.Tensor, y: torch.Tensor): 1501 W = torch.diag(torch.ones(2, device=device)) 1502 return x @ W + y @ W 1503 1504 x = torch.randn(3, 10, 2, device=device) 1505 with self.assertRaisesRegex(Exception, ".*"): 1506 out = associative_scan( 1507 non_pointwise, x, 0, reverse=reverse, combine_mode="pointwise" 1508 ) 1509 1510 result1 = associative_scan( 1511 non_pointwise, x, 0, reverse=reverse, combine_mode="generic" 1512 ) 1513 result_expected = _fake_associative_scan(non_pointwise, x, 0, reverse=reverse) 1514 self.assertEqual(result1, result_expected) 1515 1516 1517@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 1518@skipIfNoDynamoSupport 1519class TestControlFlowTraced(TestCase): 1520 def setUp(self): 1521 torch._dynamo.reset() 1522 super().setUp() 1523 1524 def _check_tracing(self, fn, args, allow_non_fake_inputs=False): 1525 graphs = {} 1526 eager_res = fn(*args) 1527 for tracing_mode in ["symbolic", "real", "fake"]: 1528 graph = make_fx( 1529 fn, 1530 tracing_mode=tracing_mode, 1531 _allow_non_fake_inputs=allow_non_fake_inputs, 1532 )(*args) 1533 graphs[tracing_mode] = graph 1534 self.assertEqual(graph(*args), eager_res) 1535 return graphs 1536 1537 def _check_compile(self, fn, args, *, backend="eager"): 1538 eager_res = fn(*args) 1539 compiled_fn = torch.compile(fn, backend=backend) 1540 self.assertEqual(compiled_fn(*args), eager_res) 1541 1542 def test_cond_traced_not_nested(self): 1543 def true_fn(x): 1544 return x.sin() 1545 1546 def false_fn(x): 1547 return x.cos() 1548 1549 def f(x, y): 1550 return cond(y, true_fn, false_fn, [x]) 1551 1552 x = torch.randn(4) 1553 graph = make_fx(f)(x, torch.tensor(False)) 1554 result_true = graph.forward(x, torch.tensor(True)) 1555 result_false = graph.forward(x, torch.tensor(False)) 1556 self.assertFalse(torch.allclose(result_true, result_false)) 1557 self.assertEqual(result_true, torch.sin(x)) 1558 self.assertEqual(result_false, torch.cos(x)) 1559 1560 graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False)) 1561 self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) 1562 1563 @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") 1564 def test_cond_simple_with_linear_compile_check_graph(self): 1565 from torch._dynamo.testing import EagerAndRecordGraphs 1566 1567 def true_fn(x): 1568 return x.sin() 1569 1570 def false_fn(x): 1571 return x.cos() 1572 1573 x = torch.randn(4, requires_grad=True) 1574 1575 def f(pred, x): 1576 result = cond(pred, true_fn, false_fn, (x,)) 1577 grad_out = torch.ones_like(result) 1578 return torch.autograd.grad(result, (x,), grad_out) 1579 1580 backend = EagerAndRecordGraphs() 1581 torch.compile(f, backend=backend)(torch.tensor(False), x) 1582 self.assertEqual(len(backend.graphs), 2) 1583 gm = backend.graphs[0] 1584 1585 self.assertExpectedInline( 1586 gm.code.strip(), 1587 """\ 1588def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor): 1589 l_pred_ = L_pred_ 1590 l_x_ = L_x_ 1591 cond_true_0 = self.cond_true_0 1592 cond_false_0 = self.cond_false_0 1593 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_x_]); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None 1594 result = cond[0]; cond = None 1595 grad_out = torch.ones_like(result) 1596 return (result, grad_out)""", # noqa: B950 1597 ) 1598 1599 self.assertExpectedInline( 1600 gm.cond_true_0.code.strip(), 1601 """\ 1602def forward(self, l_x_): 1603 l_x__1 = l_x_ 1604 sin = l_x__1.sin(); l_x__1 = None 1605 return (sin,)""", # noqa: B950 1606 ) 1607 self.assertExpectedInline( 1608 gm.cond_false_0.code.strip(), 1609 """\ 1610def forward(self, l_x_): 1611 l_x__1 = l_x_ 1612 cos = l_x__1.cos(); l_x__1 = None 1613 return (cos,)""", # noqa: B950 1614 ) 1615 1616 backward_gm = backend.graphs[1] 1617 self.assertExpectedInline( 1618 backward_gm.code.strip(), 1619 """\ 1620def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor): 1621 l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_ 1622 l_ctx_pred = L_ctx_pred 1623 l_flat_grads_0_ = L_flat_grads_0_ 1624 cond_true_0 = self.cond_true_0 1625 cond_false_0 = self.cond_false_0 1626 cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None 1627 getitem = cond[0]; cond = None 1628 return (getitem,)""", # noqa: B950 1629 ) 1630 1631 def test_while_loop_nested_traced(self): 1632 fn, inp = WHILE_LOOP_TESTS["nested"] 1633 graphs = self._check_tracing(fn, inp) 1634 self.assertExpectedInline( 1635 graphs["symbolic"].code.strip("\n"), 1636 """\ 1637def forward(self, out_iter_1, it_1, y_1): 1638 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1639 while_loop_body_graph_0 = self.while_loop_body_graph_0 1640 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (out_iter_1, it_1, y_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = out_iter_1 = it_1 = y_1 = None 1641 getitem = while_loop[0] 1642 getitem_1 = while_loop[1] 1643 getitem_2 = while_loop[2]; while_loop = None 1644 return (getitem, getitem_1, getitem_2) 1645 """, # noqa: B950 1646 ) 1647 self.assertExpectedInline( 1648 graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), 1649 """\ 1650def forward(self, arg0_1, arg1_1, arg2_1): 1651 sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None 1652 lt = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None 1653 return lt 1654 """, 1655 ) 1656 self.assertExpectedInline( 1657 graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"), 1658 """\ 1659def forward(self, arg0_1, arg1_1, arg2_1): 1660 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1661 while_loop_body_graph_0 = self.while_loop_body_graph_0 1662 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = None 1663 getitem = while_loop[0] 1664 getitem_1 = while_loop[1] 1665 getitem_2 = while_loop[2]; while_loop = None 1666 add = torch.ops.aten.add.Tensor(getitem, 1); getitem = None 1667 return (add, getitem_1, getitem_2) 1668 """, # noqa: B950 1669 ) 1670 1671 def _wrap_with_functionalize(self, fn, func_type): 1672 mode = None 1673 if func_type == "cpp": 1674 fn = CppFunctionalizeAPI().functionalize(fn) 1675 elif func_type == "python": 1676 fn = PythonFunctionalizeAPI().functionalize(fn) 1677 mode = FunctionalTensorMode() 1678 elif func_type == "functorch": 1679 fn = torch.func.functionalize(fn) 1680 else: 1681 assert func_type == "no" 1682 return fn, mode 1683 1684 @parametrize("func_type", ["no", "cpp", "python", "functorch"]) 1685 def test_while_loop_simple_functionalize_check_graph(self, func_type): 1686 fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"] 1687 fn, mode = self._wrap_with_functionalize(fn, func_type) 1688 mode = mode if mode is not None else contextlib.nullcontext() 1689 with mode: 1690 graphs = self._check_tracing(fn, inp) 1691 if func_type == "no": 1692 self.assertExpectedInline( 1693 graphs["symbolic"].code.strip("\n"), 1694 """\ 1695def forward(self, x_1): 1696 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1697 while_loop_body_graph_0 = self.while_loop_body_graph_0 1698 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None 1699 getitem = while_loop[0]; while_loop = None 1700 return (getitem,) 1701 """, # noqa: B950 1702 ) 1703 self.assertExpectedInline( 1704 graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), 1705 """\ 1706def forward(self, arg0_1): 1707 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1708 add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None 1709 add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None 1710 sum_1 = torch.ops.aten.sum.default(add__1); add__1 = None 1711 lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None 1712 return lt 1713 """, 1714 ) 1715 self.assertExpectedInline( 1716 graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"), 1717 """\ 1718def forward(self, arg0_1): 1719 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1720 add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None 1721 add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None 1722 add = torch.ops.aten.add.Tensor(add__1, 1); add__1 = None 1723 return (add,) 1724 """, 1725 ) 1726 elif func_type == "python": 1727 self.assertExpectedInline( 1728 graphs["symbolic"].code.strip("\n"), 1729 """\ 1730def forward(self, arg0_1): 1731 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1732 while_loop_body_graph_0 = self.while_loop_body_graph_0 1733 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None 1734 getitem = while_loop[0]; while_loop = None 1735 return (getitem,) 1736 """, # noqa: B950 1737 ) 1738 self.assertExpectedInline( 1739 graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), 1740 """\ 1741def forward(self, arg0_1): 1742 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1743 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1744 add_1 = torch.ops.aten.add.Tensor(add, -1); add = None 1745 sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None 1746 lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None 1747 return lt 1748 """, 1749 ) 1750 self.assertExpectedInline( 1751 graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"), 1752 """\ 1753def forward(self, arg0_1): 1754 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1755 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1756 add_1 = torch.ops.aten.add.Tensor(add, -1); add = None 1757 add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None 1758 return (add_2,) 1759 """, 1760 ) 1761 else: 1762 self.assertExpectedInline( 1763 graphs["symbolic"].code.strip("\n"), 1764 """\ 1765def forward(self, x_1): 1766 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1767 while_loop_body_graph_0 = self.while_loop_body_graph_0 1768 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None 1769 getitem = while_loop[0]; while_loop = None 1770 return (getitem,) 1771 """, # noqa: B950 1772 ) 1773 self.assertExpectedInline( 1774 graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), 1775 """\ 1776def forward(self, arg0_1): 1777 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1778 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1779 add_1 = torch.ops.aten.add.Tensor(add, -1); add = None 1780 sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None 1781 lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None 1782 return lt 1783 """, 1784 ) 1785 self.assertExpectedInline( 1786 graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"), 1787 """\ 1788def forward(self, arg0_1): 1789 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1790 add = torch.ops.aten.add.Tensor(clone, 1); clone = None 1791 add_1 = torch.ops.aten.add.Tensor(add, -1); add = None 1792 add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None 1793 return (add_2,) 1794 """, 1795 ) 1796 1797 @parametrize("func_type", ["no", "cpp", "python", "functorch"]) 1798 @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys())) 1799 def test_while_loop_functionalize(self, func_type, while_loop_test): 1800 # simple_with_linear doesn't work becaue parameters and buffers 1801 # are not inputs so they're not wrapped by functionalization and tracing. 1802 if while_loop_test not in ("simple_with_linear", "nested_with_linear"): 1803 fn, inp = WHILE_LOOP_TESTS[while_loop_test] 1804 fn, mode = self._wrap_with_functionalize(fn, func_type) 1805 mode = mode if mode is not None else contextlib.nullcontext() 1806 with mode: 1807 self._check_tracing(fn, inp) 1808 1809 @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys())) 1810 def test_while_loop_tracing(self, while_loop_test): 1811 fn, inp = WHILE_LOOP_TESTS[while_loop_test] 1812 allow_non_fake_inputs = ( 1813 False 1814 if while_loop_test not in ("simple_with_linear", "nested_with_linear") 1815 else True 1816 ) 1817 self._check_tracing(fn, inp, allow_non_fake_inputs) 1818 1819 @parametrize("backend", ["eager", "aot_eager"]) 1820 @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys())) 1821 def test_while_loop_compile(self, backend, while_loop_test): 1822 fn, inp = WHILE_LOOP_TESTS[while_loop_test] 1823 self._check_compile(fn, inp, backend=backend) 1824 1825 @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") 1826 def test_while_loop_simple_with_linear_compile_check_graph(self): 1827 fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] 1828 from torch._dynamo.testing import EagerAndRecordGraphs 1829 1830 backend = EagerAndRecordGraphs() 1831 torch.compile(fn, backend=backend)(*inp) 1832 self.assertEqual(len(backend.graphs), 1) 1833 gm = backend.graphs[0] 1834 if torch._dynamo.config.inline_inbuilt_nn_modules: 1835 self.assertExpectedInline( 1836 gm.code.strip(), 1837 """\ 1838def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter): 1839 l_iter_ = L_iter_ 1840 l_x_ = L_x_ 1841 l_self_buffers_dec_ = L_self_buffers_dec_ 1842 l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ 1843 l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ 1844 cond_fn_0 = self.cond_fn_0 1845 body_fn_0 = self.body_fn_0 1846 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None 1847 getitem = while_loop[0] 1848 getitem_1 = while_loop[1]; while_loop = None 1849 return (getitem, getitem_1)""", # noqa: B950 1850 ) 1851 self.assertExpectedInline( 1852 gm.cond_fn_0.code.strip(), 1853 """\ 1854def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): 1855 sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None 1856 gt = sub > 0; sub = None 1857 return gt""", # noqa: B950 1858 ) 1859 self.assertExpectedInline( 1860 gm.body_fn_0.code.strip(), 1861 """\ 1862def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): 1863 child = l_iter_ - 1; l_iter_ = None 1864 child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None 1865 return (child, child_1)""", # noqa: B950 1866 ) 1867 else: 1868 self.assertExpectedInline( 1869 gm.code.strip(), 1870 """\ 1871def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): 1872 l_iter_ = L_iter_ 1873 l_x_ = L_x_ 1874 l__self___dec = self.L__self___dec 1875 l__self___linear_weight = self.L__self___linear_weight 1876 l__self___linear_bias = self.L__self___linear_bias 1877 cond_fn_0 = self.cond_fn_0 1878 body_fn_0 = self.body_fn_0 1879 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None 1880 getitem = while_loop[0] 1881 getitem_1 = while_loop[1]; while_loop = None 1882 return (getitem, getitem_1)""", # noqa: B950 1883 ) 1884 self.assertExpectedInline( 1885 gm.cond_fn_0.code.strip(), 1886 """\ 1887def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): 1888 sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None 1889 gt = sub > 0; sub = None 1890 return gt""", # noqa: B950 1891 ) 1892 self.assertExpectedInline( 1893 gm.body_fn_0.code.strip(), 1894 """\ 1895def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): 1896 child = l_iter_ - 1; l_iter_ = None 1897 child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None 1898 return (child, child_1)""", # noqa: B950 1899 ) 1900 1901 def test_while_loop_nested2_traced(self): 1902 fn, inp = WHILE_LOOP_TESTS["nested2"] 1903 graphs = self._check_tracing(fn, inp) 1904 gm = graphs["symbolic"] 1905 outer_body = gm.while_loop_body_graph_0 1906 outer_cond = gm.while_loop_cond_graph_0 1907 inner_body = outer_body.while_loop_body_graph_0 1908 inner_cond = outer_body.while_loop_cond_graph_0 1909 self.assertExpectedInline( 1910 gm.code.strip("\n"), 1911 """\ 1912def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 1913 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1914 while_loop_body_graph_0 = self.while_loop_body_graph_0 1915 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None 1916 getitem = while_loop[0] 1917 getitem_1 = while_loop[1] 1918 getitem_2 = while_loop[2] 1919 getitem_3 = while_loop[3]; while_loop = None 1920 return (getitem, getitem_1, getitem_2, getitem_3) 1921 """, # noqa: B950 1922 ) 1923 self.assertExpectedInline( 1924 outer_body.code.strip("\n"), 1925 """\ 1926def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 1927 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1928 while_loop_body_graph_0 = self.while_loop_body_graph_0 1929 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None 1930 getitem = while_loop[0] 1931 getitem_1 = while_loop[1] 1932 getitem_2 = while_loop[2] 1933 getitem_3 = while_loop[3]; while_loop = None 1934 sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None 1935 clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None 1936 mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None 1937 div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None 1938 return (sub, clone, mul, div) 1939 """, # noqa: B950 1940 ) 1941 self.assertExpectedInline( 1942 outer_body.code.strip("\n"), 1943 """\ 1944def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 1945 while_loop_cond_graph_0 = self.while_loop_cond_graph_0 1946 while_loop_body_graph_0 = self.while_loop_body_graph_0 1947 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None 1948 getitem = while_loop[0] 1949 getitem_1 = while_loop[1] 1950 getitem_2 = while_loop[2] 1951 getitem_3 = while_loop[3]; while_loop = None 1952 sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None 1953 clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None 1954 mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None 1955 div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None 1956 return (sub, clone, mul, div) 1957 """, # noqa: B950 1958 ) 1959 self.assertExpectedInline( 1960 inner_body.code.strip("\n"), 1961 """\ 1962def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 1963 clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None 1964 sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None 1965 add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None 1966 sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71); arg3_1 = None 1967 return (clone, sub, add, sub_1) 1968 """, 1969 ) 1970 self.assertExpectedInline( 1971 inner_cond.code.strip("\n"), 1972 """\ 1973def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): 1974 gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None 1975 return gt 1976 """, 1977 ) 1978 1979 def test_cond_nested_traced(self): 1980 def true_nested(y): 1981 return y * y 1982 1983 def false_nested(y): 1984 return y + y 1985 1986 def true_fn(x, pred2): 1987 z = cond(pred2, true_nested, false_nested, [x]) 1988 return x + z 1989 1990 def false_fn(x, _): 1991 return x.cos() 1992 1993 def f(x, pred, pred2): 1994 return cond(pred, true_fn, false_fn, [x, pred2]) 1995 1996 x = torch.randn(4) 1997 graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) 1998 1999 result_true_true = graph.forward( 2000 x, torch.tensor(True), torch.tensor(True) 2001 ) # True + True -> x * x 2002 result_true_false = graph.forward( 2003 x, torch.tensor(True), torch.tensor(False) 2004 ) # True + True -> x + x 2005 result_false_true = graph.forward( 2006 x, torch.tensor(False), torch.tensor(True) 2007 ) # False + either -> cos 2008 result_false_false = graph.forward( 2009 x, torch.tensor(False), torch.tensor(False) 2010 ) # False + either -> cos 2011 2012 self.assertNotEqual(result_true_true, result_true_false) 2013 self.assertFalse(torch.allclose(result_false_true, result_true_true)) 2014 2015 self.assertEqual(result_false_true, result_false_false) 2016 2017 self.assertEqual(result_true_true, (x * x) + x) 2018 self.assertEqual(result_true_false, x + x + x) 2019 2020 self.assertEqual(result_false_true, torch.cos(x)) 2021 2022 graph = make_fx(f, tracing_mode="symbolic")( 2023 x, torch.tensor(False), torch.tensor(False) 2024 ) 2025 self.assertEqual( 2026 graph(x, torch.tensor(True), torch.tensor(True)), 2027 f(x, torch.tensor(True), torch.tensor(True)), 2028 ) 2029 2030 def test_cond_functionalized(self): 2031 def true_fn(x): 2032 y = x.sin() 2033 y.add_(4) 2034 return x.sin().max() + y.sum() 2035 2036 def false_fn(x): 2037 return x.cos().min() 2038 2039 def f(x): 2040 pred = x.shape[0] == 1 2041 return cond(pred, true_fn, false_fn, [x]) 2042 2043 example_inputs = (torch.ones(4, 5),) 2044 functional_f = torch.func.functionalize(f) 2045 self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) 2046 2047 graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2048 *example_inputs 2049 ) 2050 self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) 2051 2052 all_ops_in_true_branch = [] 2053 for node in graph_module.true_graph_0.graph.nodes: 2054 if node.op == "call_function": 2055 all_ops_in_true_branch.append(node.target) 2056 2057 self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch)) 2058 2059 self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) 2060 2061 def test_cond_accepts_torch_function_as_inputs(self): 2062 a = torch.randn(3, 4) 2063 b = torch.randn(3, 4) 2064 2065 def f(a, b): 2066 return cond(a.sum() > 0, torch.add, torch.mul, (a, b)) 2067 2068 gm = self._check_tracing(f, (a, b))["symbolic"] 2069 self.assertExpectedInline( 2070 gm.code.strip(), 2071 """\ 2072def forward(self, a_1, b_1): 2073 sum_1 = torch.ops.aten.sum.default(a_1) 2074 gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None 2075 true_graph_0 = self.true_graph_0 2076 false_graph_0 = self.false_graph_0 2077 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None 2078 getitem = cond[0]; cond = None 2079 return getitem""", # noqa: B950 2080 ) 2081 self.assertExpectedInline( 2082 gm.true_graph_0.code.strip(), 2083 """\ 2084def forward(self, arg0_1, arg1_1): 2085 add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 2086 return (add,)""", 2087 ) 2088 self.assertExpectedInline( 2089 gm.false_graph_0.code.strip(), 2090 """\ 2091def forward(self, arg0_1, arg1_1): 2092 mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 2093 return (mul,)""", 2094 ) 2095 2096 def test_cond_retrace_functionalized(self): 2097 def true_fn(x): 2098 return x.sin() 2099 2100 def false_fn(x): 2101 return x.cos() 2102 2103 def f(x): 2104 return cond(x.all(), true_fn, false_fn, (x,)) 2105 2106 inp = torch.ones(1, 2) 2107 gm_non_functional = make_fx(f, tracing_mode="real")(inp) 2108 gm_functional = make_fx( 2109 torch.func.functionalize(gm_non_functional), tracing_mode="real" 2110 )(inp) 2111 self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2))) 2112 2113 def test_cond_subgraph_same_shape_env_as_parent(self): 2114 def true_fn(x): 2115 return x.sin() + 10 2116 2117 def false_fn(x): 2118 return x.cos() - 20 2119 2120 def f(x, pred): 2121 y = cond(pred, true_fn, false_fn, [x]) 2122 z = torch.add(y, y) 2123 return z 2124 2125 symbolic_traced_graph = self._check_tracing( 2126 f, (torch.ones(4), torch.Tensor([True])) 2127 )["symbolic"] 2128 graph_shape_env = symbolic_traced_graph.shape_env 2129 2130 def _node_shape_env_iter(gm): 2131 for node in symbolic_traced_graph.graph.nodes: 2132 if node.op == "call_function": 2133 val = node.meta.get("val") 2134 if isinstance(val, tuple): 2135 for v in val: 2136 yield v.fake_mode.shape_env 2137 else: 2138 yield val.fake_mode.shape_env 2139 2140 for shape_env in _node_shape_env_iter(symbolic_traced_graph): 2141 self.assertTrue(shape_env is graph_shape_env) 2142 2143 for shape_env in _node_shape_env_iter(symbolic_traced_graph.true_graph_0): 2144 self.assertTrue(shape_env is graph_shape_env) 2145 2146 for shape_env in _node_shape_env_iter(symbolic_traced_graph.false_graph_0): 2147 self.assertTrue(shape_env is graph_shape_env) 2148 2149 def test_cond_functionalized_nested(self): 2150 def true_true_fn(x): 2151 y = x.cos() 2152 y.add_(4) 2153 return x.sin().max() + y.sin().max() 2154 2155 def true_false_fn(x): 2156 return x.cos().min() 2157 2158 def true_fn(x): 2159 pred = x.shape[0] == 1 2160 return cond(pred, true_true_fn, true_false_fn, [x]) 2161 2162 def false_fn(x): 2163 return x.sum() 2164 2165 def f(x): 2166 pred = x.shape[0] == 1 2167 return cond(pred, true_fn, false_fn, [x]) 2168 2169 example_inputs = (torch.ones(4, 5),) 2170 functional_f = torch.func.functionalize(f) 2171 self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) 2172 2173 graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2174 *example_inputs 2175 ) 2176 self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) 2177 2178 gm_true_true_branch = graph_module.true_graph_0.true_graph_0 2179 2180 self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) 2181 2182 all_ops = [] 2183 for node in gm_true_true_branch.graph.nodes: 2184 if node.op == "call_function": 2185 all_ops.append(node.target) 2186 2187 self.assertFalse(any(op._schema.is_mutable for op in all_ops)) 2188 2189 def test_cond_functionalized_data_dependent_pred(self): 2190 def true_fn(x): 2191 return x.sin().sum() 2192 2193 def false_fn(x): 2194 return x.cos().sum() 2195 2196 def f(x): 2197 pred = x.nonzero().shape[0] == 1 2198 return cond(pred, true_fn, false_fn, [x]) 2199 2200 example_inputs = (torch.ones(4, 5),) 2201 functional_f = torch.func.functionalize(f) 2202 self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) 2203 2204 graph_module = make_fx(torch.func.functionalize(f))(*example_inputs) 2205 self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) 2206 2207 # https://github.com/pytorch/pytorch/issues/126988 2208 def test_cond_functionalized_input_mutation_on_true_brancte(self): 2209 def true_fn(x): 2210 view_x = x.view(x.shape) 2211 view_x.add_(1) 2212 return view_x.sin().sum() 2213 2214 def false_fn(x): 2215 return x.cos().sum() 2216 2217 def f(x): 2218 pred = x.shape[0] == 4 2219 return cond(pred, true_fn, false_fn, [x]) 2220 2221 example_inputs = (torch.ones(4, 5),) 2222 # torch.cond inlines into one of the branches because the predicate 2223 # is a constant. 2224 gm = make_fx(torch.func.functionalize(f))(*example_inputs) 2225 self.assertExpectedInline( 2226 gm.code.strip(), 2227 """\ 2228def forward(self, x_1): 2229 view = torch.ops.aten.view.default(x_1, [4, 5]) 2230 add = torch.ops.aten.add.Tensor(view, 1); view = None 2231 view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None 2232 view_2 = torch.ops.aten.view.default(view_1, [4, 5]) 2233 sin = torch.ops.aten.sin.default(view_2); view_2 = None 2234 sum_1 = torch.ops.aten.sum.default(sin); sin = None 2235 copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None 2236 return sum_1""", 2237 ) 2238 2239 # torch.cond triggers the check of the branches because the predicate 2240 # is a SymBool. 2241 with self.assertRaisesRegex( 2242 UnsupportedAliasMutationException, "One of torch.cond branch" 2243 ): 2244 make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2245 *example_inputs 2246 ) 2247 2248 # https://github.com/pytorch/pytorch/issues/126988 2249 def test_cond_functionalized_input_mutation_on_false_branch(self): 2250 def true_fn(x): 2251 return x.sin().sum() 2252 2253 def false_fn(x): 2254 view_x = x.view(x.shape) 2255 view_x.add_(1) 2256 return view_x.cos().sum() 2257 2258 def f(x): 2259 pred = x.shape[0] == 4 2260 return cond(pred, true_fn, false_fn, [x]) 2261 2262 example_inputs = (torch.ones(5, 5),) 2263 gm = make_fx(torch.func.functionalize(f))(*example_inputs) 2264 # torch.cond inlines into one of the branches because the predicate 2265 # is a constant. 2266 self.assertExpectedInline( 2267 gm.code.strip(), 2268 """\ 2269def forward(self, x_1): 2270 view = torch.ops.aten.view.default(x_1, [5, 5]) 2271 add = torch.ops.aten.add.Tensor(view, 1); view = None 2272 view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None 2273 view_2 = torch.ops.aten.view.default(view_1, [5, 5]) 2274 cos = torch.ops.aten.cos.default(view_2); view_2 = None 2275 sum_1 = torch.ops.aten.sum.default(cos); cos = None 2276 copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None 2277 return sum_1""", 2278 ) 2279 2280 # torch.cond triggers the check of the branches because the predicate 2281 # is a SymBool. 2282 with self.assertRaisesRegex( 2283 UnsupportedAliasMutationException, "One of torch.cond branch" 2284 ): 2285 make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2286 *example_inputs 2287 ) 2288 2289 # https://github.com/pytorch/pytorch/issues/126988 2290 def test_cond_functionalized_output_alias_input(self): 2291 def true_fn(x): 2292 return x 2293 2294 def false_fn(x): 2295 view_x = x.view(x.shape) 2296 return view_x 2297 2298 def f(x): 2299 pred = x.shape[0] == 4 2300 return cond(pred, true_fn, false_fn, [x]) 2301 2302 example_inputs = (torch.ones(5, 5),) 2303 gm = make_fx(torch.func.functionalize(f))(*example_inputs) 2304 # torch.cond inlines into one of the branches because the predicate 2305 # is a constant. 2306 self.assertExpectedInline( 2307 gm.code.strip(), 2308 """\ 2309def forward(self, x_1): 2310 view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None 2311 return view""", 2312 ) 2313 2314 # torch.cond triggers the check of the branches because the predicate 2315 # is a SymBool. 2316 with self.assertRaisesRegex( 2317 UnsupportedAliasMutationException, "One of torch.cond branch" 2318 ): 2319 make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2320 *example_inputs 2321 ) 2322 2323 # https://github.com/pytorch/pytorch/issues/126988 2324 def test_cond_functionalized_nested_input_mutation(self): 2325 def true_true_fn(x): 2326 x.add_(4) 2327 return x.sin().max() 2328 2329 def true_false_fn(x): 2330 return x.cos().min() 2331 2332 def true_fn(x): 2333 pred = x.shape[0] == 1 2334 return cond(pred, true_true_fn, true_false_fn, [x]) 2335 2336 def false_fn(x): 2337 return x.sum() 2338 2339 def f(x): 2340 pred = x.shape[0] == 1 2341 return cond(pred, true_fn, false_fn, [x]) 2342 2343 example_inputs = (torch.ones(4, 5),) 2344 with self.assertRaisesRegex( 2345 UnsupportedAliasMutationException, "One of torch.cond branch" 2346 ): 2347 make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2348 *example_inputs 2349 ) 2350 2351 # https://github.com/pytorch/pytorch/issues/126988 2352 def test_cond_functionalized_nested_input_mutation_with_aot_func(self): 2353 def true_true_fn(x): 2354 x.add_(4) 2355 return x.sin().max() 2356 2357 def true_false_fn(x): 2358 return x.cos().min() 2359 2360 def true_fn(x): 2361 pred = x.shape[0] == 1 2362 return cond(pred, true_true_fn, true_false_fn, [x]) 2363 2364 def false_fn(x): 2365 return x.sum() 2366 2367 def f(x): 2368 pred = x.shape[0] == 1 2369 return cond(pred, true_fn, false_fn, [x]) 2370 2371 example_input = torch.ones(4, 5) 2372 try: 2373 example_input_func = to_fun_old(example_input) 2374 torch._enable_functionalization(reapply_views=False) 2375 f(example_input_func) 2376 2377 with self.assertRaisesRegex( 2378 UnsupportedAliasMutationException, "One of torch.cond branch" 2379 ): 2380 make_fx(f, tracing_mode="symbolic")(example_input_func) 2381 finally: 2382 torch._disable_functionalization() 2383 2384 def f_wrapper(func): 2385 @functools.wraps(func) 2386 def wrapper(*args, **kwargs): 2387 torch._enable_functionalization(reapply_views=False) 2388 try: 2389 return func(*args, **kwargs) 2390 finally: 2391 torch._disable_functionalization() 2392 2393 return wrapper 2394 2395 with self.assertRaisesRegex( 2396 UnsupportedAliasMutationException, "One of torch.cond branch" 2397 ): 2398 make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func) 2399 2400 # https://github.com/pytorch/pytorch/issues/126988 2401 @xfailIfTorchDynamo 2402 def test_cond_functionalized_input_aliasing_with_aot_func(self): 2403 def true_fn(x): 2404 return x 2405 2406 def false_fn(x): 2407 view_x = x.view(x.shape) 2408 return view_x 2409 2410 def f(x): 2411 pred = x.sum() > 0 2412 return cond(pred, true_fn, false_fn, [x]) 2413 2414 example_input = torch.ones(5, 5) 2415 try: 2416 example_input_func = to_fun_old(example_input) 2417 torch._enable_functionalization(reapply_views=False) 2418 with self.assertRaisesRegex( 2419 UnsupportedAliasMutationException, 2420 "One of torch.cond branch might be aliasing", 2421 ): 2422 f(example_input_func) 2423 finally: 2424 torch._disable_functionalization() 2425 2426 def f_wrapper(func): 2427 @functools.wraps(func) 2428 def wrapper(*args, **kwargs): 2429 torch._enable_functionalization(reapply_views=False) 2430 try: 2431 func_args = pytree.tree_map( 2432 lambda x: torch._to_functional_tensor(x) 2433 if isinstance(x, torch.Tensor) 2434 else x, 2435 args, 2436 ) 2437 func_kwargs = pytree.tree_map( 2438 lambda x: torch._to_functional_tensor(x) 2439 if isinstance(x, torch.Tensor) 2440 else x, 2441 kwargs, 2442 ) 2443 return func(*func_args, **func_kwargs) 2444 finally: 2445 torch._disable_functionalization() 2446 2447 return wrapper 2448 2449 with self.assertRaisesRegex( 2450 UnsupportedAliasMutationException, 2451 "One of torch.cond branch might be aliasing", 2452 ): 2453 make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input) 2454 2455 def test_cond_functionalized_aot_func_check_functional(self): 2456 def true_fn(x): 2457 return x.cos() 2458 2459 def false_fn(x): 2460 y = x.sin() 2461 y.add_(5) 2462 return y 2463 2464 def f(x): 2465 pred = x.shape[0] == 4 2466 return cond(pred, true_fn, false_fn, [x]) 2467 2468 example_input = torch.ones(5, 5) 2469 2470 def f_wrapper(func): 2471 @functools.wraps(func) 2472 def wrapper(*args, **kwargs): 2473 torch._enable_functionalization(reapply_views=False) 2474 try: 2475 func_args = pytree.tree_map( 2476 lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x, 2477 args, 2478 ) 2479 func_kwargs = pytree.tree_map( 2480 lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x, 2481 kwargs, 2482 ) 2483 return pytree.tree_map( 2484 from_fun_old, func(*func_args, **func_kwargs) 2485 ) 2486 finally: 2487 torch._disable_functionalization() 2488 2489 return wrapper 2490 2491 result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input) 2492 for node in result_gm.true_graph_0.graph.nodes: 2493 if node.op == "call_function": 2494 self.assertTrue(not node.target._schema.is_mutable) 2495 2496 for node in result_gm.false_graph_0.graph.nodes: 2497 if node.op == "call_function": 2498 self.assertTrue(not node.target._schema.is_mutable) 2499 2500 self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5))) 2501 2502 def test_cond_nested_traced_other_inputs(self): 2503 def true_nested(y): 2504 return y * y 2505 2506 def false_nested(y): 2507 return y + y 2508 2509 def true_fn(k, pred2): 2510 z = cond(pred2, true_nested, false_nested, [k]) 2511 return torch.add(torch.tensor([0.25, 0.25]), z) 2512 2513 def false_fn(k, _): 2514 return k.cos() 2515 2516 def f(k, pred, pred2): 2517 return cond(pred, true_fn, false_fn, [k, pred2]) 2518 2519 x = torch.tensor([0.5, 0.5]) 2520 graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) 2521 2522 a = torch.tensor([1.0, 1.0]) 2523 result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True)) 2524 self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25])) 2525 2526 b = torch.tensor([2.0, 2.0]) 2527 result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True)) 2528 self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25])) 2529 2530 def test_cond_nested_traced_multi(self): 2531 def true_a(y): 2532 return y * y 2533 2534 def false_a(y): 2535 return y + y 2536 2537 def true_b(y, z): 2538 return y + z 2539 2540 def false_b(y, z): 2541 return y * z 2542 2543 def f(x, pred, pred2): 2544 a_out = cond(pred, true_a, false_a, [x]) 2545 b_out = cond(pred2, true_b, false_b, [x, x]) 2546 return a_out + b_out 2547 2548 x = torch.randn(4) 2549 graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) 2550 2551 self.assertExpectedInline( 2552 graph.code.strip(), 2553 """\ 2554def forward(self, x_1, pred_1, pred2_1): 2555 true_graph_0 = self.true_graph_0 2556 false_graph_0 = self.false_graph_0 2557 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None 2558 getitem = cond[0]; cond = None 2559 true_graph_1 = self.true_graph_1 2560 false_graph_1 = self.false_graph_1 2561 cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None 2562 getitem_1 = cond_1[0]; cond_1 = None 2563 add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None 2564 return add""", # noqa: B950 2565 ) 2566 self.assertExpectedInline( 2567 graph.true_graph_0.code.strip(), 2568 """\ 2569def forward(self, arg0_1): 2570 mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 2571 return (mul,)""", 2572 ) 2573 2574 def test_raise_error_on_mismatch_type_size(self): 2575 def true_fn(x): 2576 return x.sin() 2577 2578 def false_fn(x): 2579 return (x, x) 2580 2581 def f(x, y): 2582 return cond(y, true_fn, false_fn, [x]) 2583 2584 x = torch.randn(4) 2585 with self.assertRaisesRegex( 2586 torch._dynamo.exc.CondOpArgsMismatchError, 2587 "Expected to return same number of outputs but got:", 2588 ): 2589 make_fx(f)(x, torch.tensor(False)) 2590 2591 def test_raise_error_on_mismatch_tensor_size(self): 2592 def true_fn(x): 2593 return x.sin() 2594 2595 def false_fn(x): 2596 return torch.zeros([10, 10]) 2597 2598 def f(x, y): 2599 return cond(y, true_fn, false_fn, [x]) 2600 2601 x = torch.randn(4) 2602 with self.assertRaisesRegex( 2603 torch._dynamo.exc.UncapturedHigherOrderOpError, 2604 "Cond doesn't work unless it is captured completely with torch.compile", 2605 ): 2606 make_fx(f)(x, torch.tensor(False)) 2607 2608 def test_cond_traced_not_nested_fake_tensor(self): 2609 def true_fn(x): 2610 return x.sin() 2611 2612 def false_fn(x): 2613 return x.cos() 2614 2615 def f(x, y): 2616 return cond(y, true_fn, false_fn, [x]) 2617 2618 x = torch.randn(4) 2619 graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) 2620 result_true = graph.forward(x, torch.tensor(True)) 2621 result_false = graph.forward(x, torch.tensor(False)) 2622 self.assertFalse(torch.allclose(result_true, result_false)) 2623 self.assertEqual(result_true, torch.sin(x)) 2624 self.assertEqual(result_false, torch.cos(x)) 2625 2626 def test_cond_nested_traced_fake_tensor(self): 2627 def true_nested(y): 2628 return y * y 2629 2630 def false_nested(y): 2631 return y + y 2632 2633 def true_fn(x, pred2): 2634 z = cond(pred2, true_nested, false_nested, [x]) 2635 return x + z 2636 2637 def false_fn(x, _): 2638 return x.cos() 2639 2640 def f(x, pred, pred2): 2641 return cond(pred, true_fn, false_fn, [x, pred2]) 2642 2643 x = torch.randn(4) 2644 graph = make_fx(f, tracing_mode="fake")( 2645 x, torch.tensor(False), torch.tensor(False) 2646 ) 2647 2648 result_true_true = graph.forward( 2649 x, torch.tensor(True), torch.tensor(True) 2650 ) # True + True -> x * x 2651 result_true_false = graph.forward( 2652 x, torch.tensor(True), torch.tensor(False) 2653 ) # True + True -> x + x 2654 result_false_true = graph.forward( 2655 x, torch.tensor(False), torch.tensor(True) 2656 ) # False + either -> cos 2657 result_false_false = graph.forward( 2658 x, torch.tensor(False), torch.tensor(False) 2659 ) # False + either -> cos 2660 2661 self.assertNotEqual(result_true_true, result_true_false) 2662 self.assertFalse(torch.allclose(result_false_true, result_true_true)) 2663 2664 self.assertEqual(result_false_true, result_false_false) 2665 2666 self.assertEqual(result_true_true, (x * x) + x) 2667 self.assertEqual(result_true_false, x + x + x) 2668 2669 self.assertEqual(result_false_true, torch.cos(x)) 2670 2671 def test_cond_nested_traced_other_inputs_fake_tensor(self): 2672 def true_nested(y): 2673 return y * y 2674 2675 def false_nested(y): 2676 return y + y 2677 2678 def true_fn(k, pred2): 2679 z = cond(pred2, true_nested, false_nested, [k]) 2680 return torch.add(torch.tensor([0.25, 0.25]), z) 2681 2682 def false_fn(k, _): 2683 return k.cos() 2684 2685 def f(k, pred, pred2): 2686 return cond(pred, true_fn, false_fn, [k, pred2]) 2687 2688 x = torch.tensor([0.5, 0.5]) 2689 graph = make_fx(f, tracing_mode="fake")( 2690 x, torch.tensor(False), torch.tensor(False) 2691 ) 2692 2693 a = torch.tensor([1.0, 1.0]) 2694 result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True)) 2695 self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25])) 2696 2697 b = torch.tensor([2.0, 2.0]) 2698 result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True)) 2699 self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25])) 2700 2701 def test_cond_nested_traced_multi_fake_tensor(self): 2702 def true_a(y): 2703 return y * y 2704 2705 def false_a(y): 2706 return y + y 2707 2708 def true_b(y, z): 2709 return y + z 2710 2711 def false_b(y, z): 2712 return y * z 2713 2714 def f(x, pred, pred2): 2715 a_out = cond(pred, true_a, false_a, [x]) 2716 b_out = cond(pred2, true_b, false_b, [x, x]) 2717 return a_out + b_out 2718 2719 x = torch.randn(4) 2720 graph = make_fx(f, tracing_mode="fake")( 2721 x, torch.tensor(False), torch.tensor(False) 2722 ) 2723 2724 self.assertExpectedInline( 2725 graph.code.strip(), 2726 """\ 2727def forward(self, x_1, pred_1, pred2_1): 2728 true_graph_0 = self.true_graph_0 2729 false_graph_0 = self.false_graph_0 2730 cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None 2731 getitem = cond[0]; cond = None 2732 true_graph_1 = self.true_graph_1 2733 false_graph_1 = self.false_graph_1 2734 cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None 2735 getitem_1 = cond_1[0]; cond_1 = None 2736 add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None 2737 return add""", # noqa: B950 2738 ) 2739 self.assertExpectedInline( 2740 graph.true_graph_0.code.strip(), 2741 """\ 2742def forward(self, arg0_1): 2743 mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 2744 return (mul,)""", 2745 ) 2746 2747 def test_raise_error_on_mismatch_type_size_fake_tensor(self): 2748 def true_fn(x): 2749 return x.sin() 2750 2751 def false_fn(x): 2752 return (x, x) 2753 2754 def f(x, y): 2755 return cond(y, true_fn, false_fn, [x]) 2756 2757 x = torch.randn(4) 2758 with self.assertRaisesRegex( 2759 torch._dynamo.exc.CondOpArgsMismatchError, 2760 "Expected to return same number of outputs but got:", 2761 ): 2762 make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) 2763 2764 def test_raise_error_on_mismatch_tensor_size_fake_tensor(self): 2765 def true_fn(x): 2766 return x.sin() 2767 2768 def false_fn(x): 2769 return torch.zeros([10, 10]) 2770 2771 def f(x, y): 2772 return cond(y, true_fn, false_fn, [x]) 2773 2774 x = torch.randn(4) 2775 with self.assertRaisesRegex( 2776 torch._dynamo.exc.UncapturedHigherOrderOpError, 2777 "Cond doesn't work unless it is captured completely with torch.compile", 2778 ): 2779 make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) 2780 2781 def check_map_count(self, gm, op_count): 2782 i = 0 2783 for m in gm.modules(): 2784 for node in m.graph.nodes: 2785 if ( 2786 node.op == "call_function" 2787 and node.target == torch.ops.higher_order.map_impl 2788 ): 2789 i += 1 2790 self.assertEqual(i, op_count) 2791 2792 def test_tracing_map_real(self): 2793 def f(x, y): 2794 return x + y 2795 2796 def g(xs, y): 2797 return control_flow.map(f, xs, y) 2798 2799 gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2)) 2800 x = torch.randn(3, 2, 2) 2801 y = torch.randn(2) 2802 res = gm(x, y) 2803 self.assertEqual(res, g(x, y)) 2804 self.check_map_count(gm, 1) 2805 2806 def test_tracing_map_symbolic_simple(self): 2807 def f(x, y): 2808 return x + y 2809 2810 def g(xs, y): 2811 return control_flow.map(f, xs, y) 2812 2813 gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4)) 2814 x = torch.randn(3, 2, 2) 2815 y = torch.randn(2) 2816 res = gm(x, y) 2817 self.assertEqual(res, g(x, y)) 2818 self.check_map_count(gm, 1) 2819 2820 def test_tracing_map_symbolic_list(self): 2821 def f(x, y): 2822 return [x[0][0] + y, x[1] * y] 2823 2824 def g(xs, y, z): 2825 out = control_flow.map(f, xs, y) 2826 return out[0] + z, out[1] * z 2827 2828 example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)] 2829 gm = make_fx(g, tracing_mode="symbolic")( 2830 example_x, torch.ones(5), torch.ones(5) 2831 ) 2832 x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)] 2833 y = torch.randn(6) 2834 z = torch.ones(6) 2835 res = gm(x, y, z) 2836 self.assertEqual(res, g(x, y, z)) 2837 self.check_map_count(gm, 1) 2838 2839 def test_tracing_map_symbolic_dict(self): 2840 def f(x, y): 2841 return {"d": x["b"]["a"] + y, "e": x["c"] * y} 2842 2843 def g(xs, y, z): 2844 out = control_flow.map(f, xs, y) 2845 return {"f": out["d"] + z, "g": out["e"] * z} 2846 2847 example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)} 2848 gm = make_fx(g, tracing_mode="symbolic")( 2849 example_x, torch.ones(5), torch.ones(5) 2850 ) 2851 x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)} 2852 y = torch.randn(6) 2853 z = torch.ones(6) 2854 res = gm(x, y, z) 2855 self.assertEqual(res, g(x, y, z)) 2856 self.check_map_count(gm, 1) 2857 2858 def test_tracing_map_autograd_symbolic_simple(self): 2859 def f(x, y): 2860 return x + y 2861 2862 def g(xs, y): 2863 out = control_flow.map(f, xs, y) 2864 return torch.autograd.grad(out, (xs, y), torch.ones_like(out)) 2865 2866 gm = make_fx(g, tracing_mode="symbolic")( 2867 torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True) 2868 ) 2869 x = torch.randn(4, 5, 6, requires_grad=True) 2870 y = torch.randn(6, requires_grad=True) 2871 res = gm(x, y) 2872 self.assertEqual(res, g(x, y)) 2873 self.check_map_count(gm, 2) 2874 2875 def test_tracing_map_autograd_symbolic_list(self): 2876 import torch.utils._pytree as pytree 2877 2878 def f(x, y): 2879 return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] 2880 2881 def g(xs, y): 2882 out = control_flow.map(f, xs, y) 2883 flat_out = pytree.tree_leaves(out) 2884 flat_inp = pytree.tree_leaves((xs, y)) 2885 requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad] 2886 return torch.autograd.grad( 2887 flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out] 2888 ) 2889 2890 gm = make_fx(g, tracing_mode="symbolic")( 2891 [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)], 2892 torch.ones(5, requires_grad=True), 2893 ) 2894 x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)] 2895 y = torch.randn(6, requires_grad=True) 2896 res = gm(x, y) 2897 self.assertEqual(res, g(x, y)) 2898 self.check_map_count(gm, 2) 2899 2900 def test_tracing_map_autograd_symbolic_dict(self): 2901 def f(x, y): 2902 return [x["a"] + y, x["b"] * y] 2903 2904 def g(xs, y): 2905 out = control_flow.map(f, xs, y) 2906 flat_out = pytree.tree_leaves(out) 2907 flat_inp = pytree.tree_leaves((xs, y)) 2908 requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad] 2909 return torch.autograd.grad( 2910 flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out] 2911 ) 2912 2913 traced_x = { 2914 "a": torch.ones(3, 4, 5, requires_grad=True), 2915 "b": torch.ones(3, 4, 5, requires_grad=True), 2916 } 2917 gm = make_fx(g, tracing_mode="symbolic")( 2918 traced_x, torch.ones(5, requires_grad=True) 2919 ) 2920 x = { 2921 "a": torch.randn(4, 5, 6, requires_grad=True), 2922 "b": torch.ones(4, 5, 6, requires_grad=True), 2923 } 2924 y = torch.randn(6, requires_grad=True) 2925 res = gm(x, y) 2926 self.assertEqual(res, g(x, y)) 2927 self.check_map_count(gm, 2) 2928 2929 def test_tracing_map_autograd_aot_functionalized(self): 2930 def inner(x, y): 2931 z = x - 1 2932 z.add_(1) 2933 return z * y 2934 2935 def f(xs, y): 2936 res = control_flow.map(inner, xs, y) 2937 grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res)) 2938 return grads 2939 2940 def f_wrapper(func): 2941 @functools.wraps(func) 2942 def wrapper(*args, **kwargs): 2943 torch._enable_functionalization(reapply_views=False) 2944 try: 2945 return pytree.tree_map(from_fun_old, func(*args, **kwargs)) 2946 finally: 2947 torch._disable_functionalization() 2948 2949 return wrapper 2950 2951 example_inputs = ( 2952 torch.ones(3, 2, 4, requires_grad=True), 2953 torch.ones(2, 4, requires_grad=True), 2954 ) 2955 gm = make_fx(f, tracing_mode="symbolic")(*example_inputs) 2956 fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs) 2957 xs = torch.ones(3, 4, 5, requires_grad=True) 2958 y = torch.ones(4, 5, requires_grad=True) 2959 2960 self.assertEqual(gm(xs, y), f(xs, y)) 2961 2962 def count_mutable(gm): 2963 c = 0 2964 for node in gm.graph.nodes: 2965 if node.op == "call_function": 2966 if node.target == torch.ops.higher_order.map_impl: 2967 c += count_mutable(getattr(gm, str(node.args[0]))) 2968 elif schema := getattr(node.target, "_schema", None): 2969 c += int(schema.is_mutable) 2970 return c 2971 2972 self.assertEqual(count_mutable(fgm), 0) 2973 # One for forward, one for recomputation logic in backward 2974 self.assertEqual(count_mutable(gm), 2) 2975 2976 def test_map_functionalized(self): 2977 def map_fn(x, y): 2978 z = x + y 2979 z.add_(4) 2980 return z 2981 2982 def f(xs, y): 2983 return control_flow.map(map_fn, xs, y) 2984 2985 example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) 2986 functional_f = torch.func.functionalize(f) 2987 self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) 2988 2989 gm = make_fx(torch.func.functionalize(f))(*example_inputs) 2990 self.assertEqual(gm(*example_inputs), f(*example_inputs)) 2991 2992 gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( 2993 *example_inputs 2994 ) 2995 self.assertEqual(gm(*example_inputs), f(*example_inputs)) 2996 2997 for node in gm.body_graph_0.graph.nodes: 2998 if node.op == "call_function": 2999 self.assertTrue(not node.target._schema.is_mutable) 3000 self.check_map_count(gm, 1) 3001 3002 def test_map_functionalized_aot_func(self): 3003 def map_fn(x, y): 3004 z = x + y 3005 z.add_(4) 3006 return z 3007 3008 def f(xs, y): 3009 return control_flow.map(map_fn, xs, y) 3010 3011 def f_wrapper(func): 3012 @functools.wraps(func) 3013 def wrapper(*args, **kwargs): 3014 torch._enable_functionalization(reapply_views=False) 3015 try: 3016 return pytree.tree_map(from_fun_old, func(*args, **kwargs)) 3017 finally: 3018 torch._disable_functionalization() 3019 3020 return wrapper 3021 3022 example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) 3023 3024 gm = make_fx(f_wrapper(f))(*example_inputs) 3025 3026 for node in gm.body_graph_0.graph.nodes: 3027 if node.op == "call_function": 3028 self.assertTrue(not node.target._schema.is_mutable) 3029 3030 self.assertEqual(gm(*example_inputs), f(*example_inputs)) 3031 3032 # https://github.com/pytorch/pytorch/issues/126988 3033 @xfailIfTorchDynamo 3034 def test_map_functionalized_arg_mutation(self): 3035 def map_fn(x, y): 3036 y.add_(4) 3037 return x + y 3038 3039 def f(xs, y): 3040 return control_flow.map(map_fn, xs, y) 3041 3042 example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) 3043 functional_f = torch.func.functionalize(f) 3044 with self.assertRaisesRegex( 3045 UnsupportedAliasMutationException, "torch.map is mutating the input!" 3046 ): 3047 functional_f(*example_inputs) 3048 3049 # https://github.com/pytorch/pytorch/issues/126988 3050 @xfailIfTorchDynamo 3051 def test_map_functionalized_elem_mutation(self): 3052 def map_fn(x, y): 3053 x.add_(4) 3054 return x + y 3055 3056 def f(xs, y): 3057 return control_flow.map(map_fn, xs, y) 3058 3059 example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) 3060 functional_f = torch.func.functionalize(f) 3061 with self.assertRaisesRegex( 3062 UnsupportedAliasMutationException, "torch.map is mutating the input!" 3063 ): 3064 functional_f(*example_inputs) 3065 3066 def test_cond_autograd_backward(self): 3067 def true_fn(x): 3068 return x.cos() 3069 3070 def false_fn(x): 3071 return x.sin() 3072 3073 def f(x, y): 3074 return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y]) 3075 3076 example_inputs = ( 3077 torch.ones(3, 2, 4, requires_grad=True), 3078 torch.ones(4, requires_grad=True), 3079 ) 3080 f(*example_inputs).sum().backward() 3081 3082 # Ensure no error is thrown when not running backward 3083 res = f(*example_inputs) 3084 3085 # Ensure no error is thrown when not running backward 3086 res_compiled = torch.compile(f)(*example_inputs) 3087 self.assertEqual(res, res_compiled) 3088 3089 # https://github.com/pytorch/pytorch/issues/126988 3090 @xfailIfTorchDynamo 3091 def test_map_functionalized_elem_alias(self): 3092 def map_fn(x): 3093 x.view(x.shape) 3094 return x 3095 3096 def f(xs): 3097 return control_flow.map(map_fn, xs) 3098 3099 example_inputs = (torch.ones(3, 2, 4),) 3100 functional_f = torch.func.functionalize(f) 3101 with self.assertRaisesRegex( 3102 UnsupportedAliasMutationException, "torch.map is aliasing the input!" 3103 ): 3104 functional_f(*example_inputs) 3105 3106 def test_nested_map_cond_real(self): 3107 def true_fn(x, y): 3108 return x * y 3109 3110 def false_fn(x, y): 3111 return x + y 3112 3113 def f(x, pred, y): 3114 return cond(pred, true_fn, false_fn, [x, y]) 3115 3116 def g(pred, xs, y): 3117 return control_flow.map(f, xs, pred, y) 3118 3119 gm = make_fx(g, tracing_mode="real")( 3120 torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) 3121 ) 3122 pred = torch.tensor(False) 3123 x = torch.randn(3, 2, 4) 3124 y = torch.randn(4) 3125 res = gm(pred, x, y) 3126 self.assertEqual(res, g(pred, x, y)) 3127 self.check_map_count(gm, 1) 3128 3129 def test_nested_map_cond_symbolic(self): 3130 def true_fn(x, y): 3131 return x * y 3132 3133 def false_fn(x, y): 3134 return x + y 3135 3136 def f(x, pred, y): 3137 return cond(pred, true_fn, false_fn, [x, y]) 3138 3139 def g(pred, xs, y): 3140 return control_flow.map(f, xs, pred, y) 3141 3142 gm = make_fx(g, tracing_mode="symbolic")( 3143 torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) 3144 ) 3145 pred = torch.tensor(False) 3146 x = torch.randn(3, 2, 2) 3147 y = torch.randn(2) 3148 res = gm(pred, x, y) 3149 self.assertEqual(res, g(pred, x, y)) 3150 self.check_map_count(gm, 1) 3151 3152 def test_nested_cond_map_cond_symbolic(self): 3153 def true_fn(x, y): 3154 return x * y 3155 3156 def false_fn(x, y): 3157 return x + y 3158 3159 def f(x, pred, y): 3160 return cond(pred, true_fn, false_fn, [x, y]) 3161 3162 def g(pred, xs, y): 3163 return control_flow.map(f, xs, pred, y) 3164 3165 def main_true_fn(pred, xs, y): 3166 return g(pred, xs, y) * 2 3167 3168 def main_false_fn(pred, xs, y): 3169 return g(pred, xs, y) + 1 3170 3171 def main(p, pred, xs, y): 3172 return cond(p, main_true_fn, main_false_fn, [pred, xs, y]) 3173 3174 gm = make_fx(main, tracing_mode="symbolic")( 3175 torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) 3176 ) 3177 p = torch.tensor(False) 3178 pred = torch.tensor(False) 3179 xs = torch.randn(3, 2, 2) 3180 y = torch.randn(2) 3181 res = gm(p, pred, xs, y) 3182 self.assertEqual(res, main(p, pred, xs, y)) 3183 self.check_map_count(gm, 2) 3184 3185 def test_cond_with_sym_pred(self): 3186 def true_fn(x): 3187 return x + x 3188 3189 def false_fn(x): 3190 return x * x 3191 3192 def foo(x): 3193 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3194 3195 gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1)) 3196 # The symbols in make_fx's shape_env should not be specialized. 3197 self.assertEqual(len(gm.shape_env.guards), 0) 3198 3199 self.assertExpectedInline( 3200 gm.code.strip(), 3201 """\ 3202def forward(self, x_1): 3203 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 3204 eq = sym_size_int == 4; sym_size_int = None 3205 true_graph_0 = self.true_graph_0 3206 false_graph_0 = self.false_graph_0 3207 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None 3208 getitem = cond[0]; cond = None 3209 return getitem""", # noqa: B950 3210 ) 3211 3212 # We expect the traced graph module to work even if input size changes. 3213 x = torch.ones(4, 3, 2) 3214 self.assertEqual(gm(x), true_fn(x)) 3215 self.assertEqual(foo(x), true_fn(x)) 3216 3217 def test_cond_with_unbacked_sym_pred(self): 3218 def foo(x): 3219 def true_fn(x): 3220 return x + x 3221 3222 def false_fn(x): 3223 return x * x 3224 3225 az = x.nonzero() 3226 return cond(az.shape[0] > 3, true_fn, false_fn, (x,)) 3227 3228 gm = make_fx(foo, tracing_mode="symbolic")(torch.randn(7)) 3229 self.assertExpectedInline( 3230 gm.code.strip(), 3231 """\ 3232def forward(self, x_1): 3233 nonzero = torch.ops.aten.nonzero.default(x_1) 3234 sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None 3235 gt = sym_size_int > 3; sym_size_int = None 3236 true_graph_0 = self.true_graph_0 3237 false_graph_0 = self.false_graph_0 3238 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]); gt = true_graph_0 = false_graph_0 = x_1 = None 3239 getitem = cond[0]; cond = None 3240 return getitem""", 3241 ) 3242 3243 def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num): 3244 assert isinstance(args, (tuple, list)) 3245 self.assertEqual(f(*args), exp_res) 3246 gm = make_fx(f)(*args) 3247 self.assertEqual(gm(*args), exp_res) 3248 3249 def cnt_placeholder(gm): 3250 return len([node for node in gm.graph.nodes if node.op == "placeholder"]) 3251 3252 placeholder_cnts = [cnt_placeholder(mod) for mod in gm.children()] 3253 self.assertTrue(all(cnt == exp_arg_num for cnt in placeholder_cnts)) 3254 3255 def _check_closure_correctly_lifted_with_mutation( 3256 self, f, closures_to_be_mutated, *, args, exp_arg_num 3257 ): 3258 exp_res = f(*args) 3259 self._check_closure_correctly_lifted( 3260 f, args=args, exp_res=exp_res, exp_arg_num=exp_arg_num 3261 ) 3262 3263 for closure in closures_to_be_mutated: 3264 closure.add(-1) 3265 new_exp_res = f(*args) 3266 3267 self._check_closure_correctly_lifted( 3268 f, args=args, exp_res=new_exp_res, exp_arg_num=exp_arg_num 3269 ) 3270 3271 def test_cond_with_tensor_closure(self): 3272 a = torch.ones(2, 3) 3273 b = torch.ones(2, 3) + 1 3274 3275 def true_fn(x): 3276 return x + a 3277 3278 def false_fn(x): 3279 return x + b 3280 3281 def foo(x): 3282 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3283 3284 # expected branches takes [x, a, b] as input 3285 inp = torch.randn(2, 3) 3286 self._check_closure_correctly_lifted_with_mutation( 3287 foo, (a, b), args=(inp,), exp_arg_num=3 3288 ) 3289 3290 def test_cond_with_tensor_closure_graph_module(self): 3291 a = torch.ones(2, 3) 3292 b = torch.ones(2, 3) + 1 3293 3294 def true_fn(x): 3295 return x + a 3296 3297 def false_fn(x): 3298 return x + b 3299 3300 def foo(x): 3301 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3302 3303 # expected branches takes [x, a, b] as input 3304 inp = torch.randn(2, 3) 3305 3306 gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) 3307 3308 self.assertExpectedInline( 3309 gm.code.strip(), 3310 """\ 3311def forward(self, x_1): 3312 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 3313 eq = sym_size_int == 4; sym_size_int = None 3314 true_graph_0 = self.true_graph_0 3315 false_graph_0 = self.false_graph_0 3316 _tensor_constant0 = self._tensor_constant0 3317 _tensor_constant1 = self._tensor_constant1 3318 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None 3319 getitem = cond[0]; cond = None 3320 return getitem""", # noqa: B950 3321 ) 3322 self.assertExpectedInline( 3323 gm.true_graph_0.code.strip(), 3324 """\ 3325def forward(self, arg0_1, arg1_1, arg2_1): 3326 add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 3327 return (add,)""", 3328 ) 3329 3330 def test_cond_with_module_param_closure(self): 3331 class Mod(torch.nn.Module): 3332 def __init__(self) -> None: 3333 super().__init__() 3334 self.register_parameter( 3335 "param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False) 3336 ) 3337 self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1) 3338 3339 my_mode = Mod() 3340 3341 def true_fn(x): 3342 return x + my_mode.param 3343 3344 def false_fn(x): 3345 return x + my_mode.buffer 3346 3347 def foo(x): 3348 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3349 3350 inp = torch.ones(2, 3) 3351 # expected both branches takes (x, param, buffer) 3352 self._check_closure_correctly_lifted_with_mutation( 3353 foo, (my_mode.param, my_mode.buffer), args=(inp,), exp_arg_num=3 3354 ) 3355 3356 def test_cond_with_module_python_scalar_closure(self): 3357 def foo(x): 3358 a = torch.ones(1, 1) 3359 b = 1 3360 3361 def true_fn(x): 3362 return x + a 3363 3364 def false_fn(x): 3365 return x + b 3366 3367 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3368 3369 inp = torch.ones(2, 3) 3370 res = inp + 1 3371 # python scalar b is not lifted as input, so both branches take (x, a) 3372 self._check_closure_correctly_lifted( 3373 foo, args=(inp,), exp_res=res, exp_arg_num=2 3374 ) 3375 3376 def test_cond_nested_with_closure(self): 3377 a = torch.ones(1, 1) 3378 b = torch.ones(1, 1) + 1 3379 3380 def inner_true_fn(x): 3381 return x + a 3382 3383 def inner_false_fn(x): 3384 return x + b 3385 3386 def foo(x): 3387 def true_fn(x): 3388 return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x]) 3389 3390 def false_fn(x): 3391 return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x]) 3392 3393 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3394 3395 inp = torch.ones(2, 3) 3396 # For top-level cond, it take 3 arguments (x, a, b). Dynamo should 3397 # realize that the nonlocal variables are same for the true and false 3398 # branches, so it should de-dupe them. 3399 # For second-level conds, it takes (x, a, b) 3400 self._check_closure_correctly_lifted_with_mutation( 3401 foo, (a, b), args=(inp,), exp_arg_num=3 3402 ) 3403 3404 def test_cond_nested_with_closure_graph_module(self): 3405 a = torch.ones(1, 1) 3406 b = torch.ones(1, 1) + 1 3407 3408 def inner_true_fn(x): 3409 return x + a 3410 3411 def inner_false_fn(x): 3412 return x + b 3413 3414 def foo(x): 3415 def true_fn(x): 3416 return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x]) 3417 3418 def false_fn(x): 3419 return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x]) 3420 3421 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3422 3423 def test_map_unfunc_boolean_tensor_for_nested_map_cond(self): 3424 def map_fn(pred, x): 3425 def fn(x, pred): 3426 return control_flow.cond(pred, lambda x: x * 2, lambda x: x / 2, (x,)) 3427 3428 return control_flow.map(fn, x, pred) 3429 3430 def f_wrapper(func): 3431 @functools.wraps(func) 3432 def wrapper(*args, **kwargs): 3433 torch._enable_functionalization(reapply_views=False) 3434 try: 3435 func_args = pytree.tree_map( 3436 lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x, 3437 args, 3438 ) 3439 func_kwargs = pytree.tree_map( 3440 lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x, 3441 kwargs, 3442 ) 3443 return pytree.tree_map( 3444 from_fun_old, func(*func_args, **func_kwargs) 3445 ) 3446 finally: 3447 torch._disable_functionalization() 3448 3449 return wrapper 3450 3451 gm = make_fx(f_wrapper(map_fn))( 3452 torch.tensor(True), torch.ones([2, 3], requires_grad=False) 3453 ) 3454 self.assertExpectedInline( 3455 gm.code.strip(), 3456 """\ 3457def forward(self, pred_1, x_1): 3458 body_graph_0 = self.body_graph_0 3459 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None 3460 getitem = map_impl[0]; map_impl = None 3461 return getitem""", 3462 ) 3463 self.assertExpectedInline( 3464 gm.body_graph_0.code.strip(), 3465 """\ 3466def forward(self, arg0_1, arg1_1): 3467 true_graph_0 = self.true_graph_0 3468 false_graph_0 = self.false_graph_0 3469 cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None 3470 getitem = cond[0]; cond = None 3471 return [getitem]""", # noqa: B950 3472 ) 3473 3474 def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self): 3475 def true_fn(x): 3476 return x + x.cos() 3477 3478 def false_fn(x): 3479 return x * x.sin() 3480 3481 def foo(x): 3482 return cond(x.shape[0] == 4, true_fn, false_fn, (x,)) 3483 3484 inp = torch.randn([4, 3]) 3485 gm, _ = torch._dynamo.export(foo)(inp) 3486 3487 def run_with_interpreter(*args): 3488 with torch.fx.traceback.preserve_node_meta(): 3489 return torch.fx.Interpreter(gm).run(*args) 3490 3491 new_gm = make_fx(run_with_interpreter)(inp) 3492 3493 checked_ops = {"add", "mul", "sin", "cos"} 3494 checked_meta = ["source_fn_stack", "stack_trace"] 3495 all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta) 3496 new_source_fns = collect_meta_for_filtered_nodes( 3497 new_gm, checked_ops, checked_meta 3498 ) 3499 self.assertEqual(all_source_fns, new_source_fns) 3500 3501 @unittest.skipIf( 3502 TEST_WITH_TORCHDYNAMO, 3503 "triggers cache limit for foo and changes unique_graphs count.", 3504 ) 3505 def test_cond_no_dynamo_cache_limit(self): 3506 torch._dynamo.reset() 3507 counters = torch._dynamo.utils.counters 3508 counters.clear() 3509 3510 def foo(x, true_fn, false_fn): 3511 return cond(x.sum() < 0, true_fn, false_fn, (x,)) 3512 3513 inp = torch.ones(3, 4) 3514 exp_out = inp.sin() 3515 iter_n = torch._dynamo.config.cache_size_limit + 1 3516 3517 # Need this because Dynamo checks lambda code ID not object itself. 3518 def make_dummy_fn(op): 3519 exec(f"temp = lambda x: x.{op}()") 3520 return locals()["temp"] 3521 3522 for _ in range(iter_n): 3523 # each lambda has a different object id thus fails the guard 3524 self.assertEqual( 3525 foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out 3526 ) 3527 3528 # each iteration captures a cond and a getitem from the tuple output 3529 self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2) 3530 self.assertEqual(counters["stats"]["unique_graphs"], iter_n) 3531 3532 def test_cond_with_consecutive_make_fx_symbolic(self): 3533 def true_fn(x): 3534 return x - x.cos() 3535 3536 def false_fn(x): 3537 return x + x.sin() 3538 3539 def foo(x): 3540 return cond(x.shape[0] == 4, true_fn, false_fn, [x]) 3541 3542 inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3)) 3543 for inp in inps: 3544 gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4)) 3545 self.assertExpectedInline( 3546 gm.code.strip(), 3547 """\ 3548def forward(self, x_1): 3549 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 3550 eq = sym_size_int == 4; sym_size_int = None 3551 true_graph_0 = self.true_graph_0 3552 false_graph_0 = self.false_graph_0 3553 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None 3554 getitem = cond[0]; cond = None 3555 return getitem""", # noqa: B950 3556 ) 3557 3558 self.assertExpectedInline( 3559 gm.true_graph_0.code.strip(), 3560 """\ 3561def forward(self, arg0_1): 3562 cos = torch.ops.aten.cos.default(arg0_1) 3563 sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None 3564 return (sub,)""", 3565 ) 3566 3567 self.assertExpectedInline( 3568 gm.false_graph_0.code.strip(), 3569 """\ 3570def forward(self, arg0_1): 3571 sin = torch.ops.aten.sin.default(arg0_1) 3572 add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None 3573 return (add,)""", 3574 ) 3575 3576 def _create_test_fns_for_cond( 3577 self, pred, inner_most_fn, operands, closure_list, nested_level 3578 ): 3579 if nested_level == 0: 3580 if len(closure_list) > 0: 3581 3582 def true_fn(*operands): 3583 return inner_most_fn(*operands) + inner_most_fn(*closure_list) 3584 3585 def false_fn(*operands): 3586 return inner_most_fn(*operands) - inner_most_fn(*closure_list) 3587 3588 else: 3589 3590 def true_fn(*operands): 3591 return inner_most_fn(*operands) 3592 3593 def false_fn(*operands): 3594 return inner_most_fn(*operands) 3595 3596 def fn(*operands): 3597 if len(operands) == 0 and len(closure_list) == 0: 3598 return torch.zeros(1) 3599 return cond(pred, true_fn, false_fn, operands) 3600 3601 return operands, fn 3602 else: 3603 args, inner_fn = self._create_test_fns_for_cond( 3604 pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1 3605 ) 3606 3607 def true_fn(*operands): 3608 return inner_most_fn(*operands) + inner_fn(*args) 3609 3610 def false_fn(*operands): 3611 return inner_most_fn(*operands) - inner_fn(*args) 3612 3613 def fn(*operands): 3614 if len(operands) == 0 and len(closure_list) == 0: 3615 return torch.ones(1) 3616 return cond(pred, true_fn, false_fn, operands) 3617 3618 return operands, fn 3619 3620 def _init_predicate(self, pred_type): 3621 if pred_type == "bool": 3622 return True 3623 elif pred_type == "intTensor": 3624 return torch.tensor(1) 3625 elif pred_type == "floatTensor": 3626 return torch.tensor(1.0) 3627 elif pred_type == "boolTensor": 3628 return torch.tensor(False) 3629 else: 3630 raise NotImplementedError 3631 3632 def _init_fn(self, inner_fn_type): 3633 if inner_fn_type == "function": 3634 return reduce_func 3635 elif inner_fn_type == "module": 3636 return ReduceMod() 3637 elif inner_fn_type == "object": 3638 return ReduceObj() 3639 else: 3640 raise NotImplementedError 3641 3642 @parametrize("predType", ["bool", "intTensor", "floatTensor", "boolTensor"]) 3643 @parametrize("innerFnType", ["function", "module", "object"]) 3644 @parametrize("nOperands", [0, 1]) 3645 @parametrize("nClosure", [0, 1]) 3646 @parametrize("nesting", [0, 2]) 3647 def test_cond_tracing_with_valid_inputs( 3648 self, predType, innerFnType, nOperands, nClosure, nesting 3649 ): 3650 pred = self._init_predicate(predType) 3651 inner_fn = self._init_fn(innerFnType) 3652 operands = [torch.ones(2, 3) + i for i in range(nOperands)] 3653 closure = [torch.ones(2, 3) - i for i in range(nClosure)] 3654 args, fn = self._create_test_fns_for_cond( 3655 pred, inner_fn, operands, closure, nesting 3656 ) 3657 eager_res = fn(*args) 3658 for tracing_mode in ["symbolic", "fake", "real"]: 3659 # set _allow_non_fake_inputs = True to allow fake prop through closures 3660 with self.subTest(tracing_mode=tracing_mode): 3661 gm = make_fx( 3662 fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True 3663 )(*args) 3664 self.assertEqual(gm(*args), eager_res) 3665 3666 @parametrize("predType", ["boolTensor"]) 3667 @parametrize("innerFnType", ["function", "module", "object"]) 3668 @parametrize("nOperands", [1, 2]) 3669 @parametrize("nClosure", [0, 1]) 3670 @parametrize("nesting", [0]) 3671 def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting): 3672 pred = self._init_predicate(predType) 3673 inner_fn = self._init_fn(innerFnType) 3674 operands = [torch.ones(2, 3) + i for i in range(nOperands)] 3675 closure = [torch.ones(2, 3) - i for i in range(nClosure)] 3676 args, fn = self._create_test_fns_for_cond( 3677 pred, inner_fn, operands, closure, nesting 3678 ) 3679 eager_res = fn(*args) 3680 out = torch.vmap(fn)(*args) 3681 if nClosure == 0: 3682 self.assertEqual(eager_res, out) 3683 else: 3684 self.assertEqual(eager_res, out[0]) 3685 self.assertEqual(eager_res, out[1]) 3686 3687 def test_cond_vmap_simple(self): 3688 def fn(x): 3689 return torch.cond( 3690 pred=torch.tensor([True]), 3691 true_fn=lambda x: x + 100, 3692 false_fn=lambda x: x, 3693 operands=(x,), 3694 ) 3695 3696 a = torch.arange(15).reshape((3, 5)) 3697 res = torch.vmap(fn, in_dims=(0,))(a) 3698 self.assertEqual(res.shape, (3, 5)) 3699 self.assertEqual(res, a + 100) 3700 3701 def test_cond_vmap_multiple_inputs(self): 3702 def fn(x, y): 3703 return torch.cond( 3704 pred=x.sum() < y.sum(), 3705 true_fn=lambda x, y: x + 100, 3706 false_fn=lambda x, y: y, 3707 operands=(x, y), 3708 ) 3709 3710 a = torch.arange(15).reshape(3, 5) 3711 b = torch.ones_like(a) + 3 3712 res = torch.vmap(fn, in_dims=(0, 0))(a, b) 3713 expected = torch.tensor( 3714 [[100, 101, 102, 103, 104], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]] 3715 ) 3716 self.assertEqual(res.shape, (3, 5)) 3717 self.assertEqual(expected, res) 3718 3719 def test_cond_vmap_single_input_with_closure(self): 3720 a = torch.ones((3, 5)) + 3 3721 c = torch.arange(5) 3722 3723 def fn(x): 3724 return torch.cond( 3725 pred=torch.tensor([True]), 3726 true_fn=lambda x: x + c, 3727 false_fn=lambda x: x - c, 3728 operands=(x,), 3729 ) 3730 3731 res = torch.vmap(fn, in_dims=(0,))( 3732 a, 3733 ) 3734 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 3735 res = torch.vmap(fn, in_dims=(0,))( 3736 a, 3737 ) 3738 self.assertEqual(a + c, res) 3739 3740 def test_cond_vmap_multiple_args_with_closure(self): 3741 a = torch.ones((3, 5), dtype=torch.int64) + 3 3742 b = torch.arange(15).reshape(3, 5) 3743 c = torch.arange(5) 3744 3745 def fn(x, y): 3746 return torch.cond( 3747 pred=torch.tensor([False]), 3748 true_fn=lambda x, y: x + c, 3749 false_fn=lambda x, y: y - c, 3750 operands=(x, y), 3751 ) 3752 3753 res = torch.vmap(fn)(a, b) 3754 self.assertEqual(b - c, res) 3755 3756 @parametrize("nClosure", [0, 1]) 3757 def test_cond_vmap_multiple_outputs(self, nClosure): 3758 if nClosure: 3759 c = torch.ones(5, dtype=torch.int64) + 5 3760 3761 def fn(x): 3762 return torch.cond( 3763 pred=torch.tensor([True]), 3764 true_fn=lambda x: (x + c, x - c), 3765 false_fn=lambda x: (x, x), 3766 operands=(x,), 3767 ) 3768 3769 else: 3770 3771 def fn(x): 3772 return torch.cond( 3773 pred=torch.tensor([True]), 3774 true_fn=lambda x: (x + 1, x - 1), 3775 false_fn=lambda x: (x, x), 3776 operands=(x,), 3777 ) 3778 3779 a = torch.arange(15).reshape(3, 5) 3780 res = torch.vmap(fn)( 3781 a, 3782 ) 3783 self.assertEqual(len(res), 2) 3784 if nClosure: 3785 self.assertEqual(res, (a + c, a - c)) 3786 else: 3787 self.assertEqual(res, (a + 1, a - 1)) 3788 3789 def test_vmap_vmap(self): 3790 def fn(x): 3791 return torch.cond( 3792 pred=torch.tensor([True]), 3793 true_fn=lambda x: x + 1, 3794 false_fn=lambda x: x - 1, 3795 operands=(x,), 3796 ) 3797 3798 def wrapper(x): 3799 return torch.vmap(fn)(x) 3800 3801 a = torch.ones((3, 4, 5)) 3802 res = torch.vmap(wrapper)(a) 3803 self.assertEqual(res, a + 1) 3804 3805 def test_cond_trace_set__and_mutate_input(self): 3806 def f(a, tmp): 3807 a_view = a.view(-1) 3808 with torch.no_grad(): 3809 a.set_(tmp) 3810 a_view.mul_(2) 3811 return a + tmp 3812 3813 inp = torch.ones(3, 3, requires_grad=True) 3814 tmp = torch.ones(3, 3, requires_grad=True) 3815 # graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {} 3816 # due to set_ 3817 with self.assertRaisesRegex( 3818 torch._dynamo.exc.UncapturedHigherOrderOpError, 3819 "Cond doesn't work unless it is captured completely with torch.compile", 3820 ): 3821 torch.cond(inp.sum() > 0, f, f, (inp, tmp)) 3822 3823 def test_cond_trace_set__and_mutate_intermediate(self): 3824 def f(a, tmp): 3825 a = a.clone() 3826 a_view = a.view(-1) 3827 tmp = tmp.clone() 3828 with torch.no_grad(): 3829 a.set_(tmp) 3830 a_view.mul_(2) 3831 return a + tmp 3832 3833 inp = torch.ones(3, 3, requires_grad=True) 3834 tmp = torch.ones(3, 3, requires_grad=True) 3835 3836 class Mod(torch.nn.Module): 3837 def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor: 3838 return torch.cond(inp.sum() > 0, f, f, (inp, tmp)) 3839 3840 with self.assertRaisesRegex( 3841 RuntimeError, "cannot mutate tensors with frozen storage" 3842 ): 3843 out = torch.compile(Mod(), backend="aot_eager")(inp, tmp) 3844 3845 with self.assertRaisesRegex( 3846 RuntimeError, "cannot mutate tensors with frozen storage" 3847 ): 3848 out = torch.compile(Mod(), backend="inductor")(inp, tmp) 3849 3850 from torch._dynamo.testing import EagerAndRecordGraphs 3851 3852 backend = EagerAndRecordGraphs() 3853 out = torch.compile(Mod(), backend=backend)(inp, tmp) 3854 self.assertExpectedInline( 3855 backend.graphs[0].cond_true_0.code.strip("\n"), 3856 """\ 3857def forward(self, l_inp_, l_tmp_): 3858 l_inp__1 = l_inp_ 3859 l_tmp__1 = l_tmp_ 3860 a = l_inp__1.clone(); l_inp__1 = None 3861 a_view = a.view(-1) 3862 tmp = l_tmp__1.clone(); l_tmp__1 = None 3863 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None 3864 set_ = a.set_(tmp); set_ = None 3865 mul_ = a_view.mul_(2); a_view = mul_ = None 3866 _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None 3867 add = a + tmp; a = tmp = None 3868 return (add,) 3869 """, 3870 ) 3871 self.assertEqual(out, f(inp, tmp)) 3872 3873 def test_two_hops_not_sharing_code_obj(self): 3874 pred, args = torch.tensor(True), (torch.ones(3, 3),) 3875 3876 def fn1(x): 3877 return x + 1 3878 3879 def fn2(x): 3880 return x - 1 3881 3882 from torch._dynamo.testing import CompileCounter 3883 3884 # Tests rely on automatic_dynamic = True 3885 with torch._dynamo.config.patch(automatic_dynamic_shapes=True): 3886 cnt = CompileCounter() 3887 torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args) 3888 self.assertEqual(cnt.frame_count, 1) 3889 3890 args = (torch.randn(3, 3),) 3891 # No recompilation 3892 torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args) 3893 self.assertEqual(cnt.frame_count, 1) 3894 3895 def cond_fn(x): 3896 return x.sum() > 0 3897 3898 args = (torch.randn(4, 4),) 3899 torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args) 3900 # recompilation 3901 self.assertEqual(cnt.frame_count, 2) 3902 3903 args = (torch.randn(4, 4),) 3904 torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args) 3905 self.assertEqual(cnt.frame_count, 2) 3906 3907 # With recompilation due to automatic dynamic 3908 # This also proves that while_loop doesn't share code obj with cond 3909 torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),)) 3910 self.assertEqual(cnt.frame_count, 3) 3911 3912 def test_hop_raises_if_not_overriding_call(self): 3913 class WrongHop(torch._ops.HigherOrderOperator): 3914 pass 3915 3916 with self.assertRaisesRegex(TypeError, "WrongHop"): 3917 wrong_hop = WrongHop("wrong_hop") 3918 3919 3920_hop_schema_test_schema_types = [ 3921 "bool", 3922 "int", 3923 "float", 3924 "str", 3925 "Tensor", 3926 "SymInt", 3927 "SymBool", 3928 "GraphModule", 3929 "ScriptObj", 3930] 3931 3932 3933@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 3934class TestHopSchema(TestCase): 3935 def _get_example_val(self, ty: str): 3936 from torch.fx.experimental.sym_node import SymNode 3937 from torch.fx.experimental.symbolic_shapes import ShapeEnv 3938 3939 def create_symtype(cls, pytype, shape_env, val): 3940 from torch._dynamo.source import ConstantSource 3941 3942 symbol = shape_env.create_symbol( 3943 val, 3944 source=ConstantSource( 3945 f"__testing_hop_schema{len(shape_env.var_to_val)}" 3946 ), 3947 ) 3948 return cls(SymNode(symbol, shape_env, pytype, hint=val)) 3949 3950 if ty == "bool": 3951 return True 3952 elif ty == "int": 3953 return 1 3954 elif ty == "float": 3955 return 1.0 3956 elif ty == "str": 3957 return "foo" 3958 elif ty == "Tensor": 3959 return torch.tensor(1) 3960 elif ty == "SymInt": 3961 shape_env = ShapeEnv() 3962 return create_symtype(torch.SymInt, int, shape_env, 1) 3963 elif ty == "SymBool": 3964 shape_env = ShapeEnv() 3965 return create_symtype(torch.SymBool, bool, shape_env, True) 3966 elif ty == "GraphModule": 3967 3968 def f(x): 3969 return x.sin() 3970 3971 return make_fx(f)(torch.ones(1)) 3972 elif ty == "ScriptObj": 3973 from torch.testing._internal.torchbind_impls import ( 3974 init_torchbind_implementations, 3975 ) 3976 3977 init_torchbind_implementations() 3978 foo = torch.classes._TorchScriptTesting._Foo(3, 4) 3979 return foo 3980 else: 3981 raise NotImplementedError(ty) 3982 3983 @parametrize("schema_type", _hop_schema_test_schema_types) 3984 def test_type_gen(self, schema_type): 3985 from torchgen.gen_schema_utils import TypeGen 3986 3987 example_val = self._get_example_val(schema_type) 3988 ty = TypeGen.from_example(example_val) 3989 # Test the generated type can be parsed 3990 self.assertEqual(ty.parse(str(ty)), ty) 3991 3992 @parametrize("schema_type", _hop_schema_test_schema_types) 3993 def test_list_gen(self, schema_type): 3994 from torchgen.gen_schema_utils import TypeGen 3995 3996 example_val = self._get_example_val(schema_type) 3997 li1 = [example_val] 3998 li2 = [example_val, example_val] 3999 ty1 = TypeGen.from_example(li1) 4000 ty2 = TypeGen.from_example(li1) 4001 self.assertEqual(ty1.parse(str(ty1)), ty1) 4002 self.assertEqual(ty2.parse(str(ty2)), ty2) 4003 4004 def test_function_schema_gen(self): 4005 from torchgen.gen_schema_utils import FunctionSchemaGen 4006 4007 inps = [ 4008 (schema_type + "_v", self._get_example_val(schema_type)) 4009 for schema_type in _hop_schema_test_schema_types 4010 ] 4011 op_name = "test_op" 4012 schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1)) 4013 schema2 = FunctionSchemaGen.from_example( 4014 "test_op2", 4015 inps, 4016 [ 4017 torch.ones(1), 4018 ], 4019 ) 4020 schema3 = FunctionSchemaGen.from_example( 4021 "test_op3", inps, [torch.ones(1), torch.ones(1)] 4022 ) 4023 self.assertExpectedInline( 4024 str(schema1), 4025 """test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950 4026 ) 4027 self.assertExpectedInline( 4028 str(schema2), 4029 """test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950 4030 ) 4031 self.assertExpectedInline( 4032 str(schema3), 4033 """test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950, 4034 ) 4035 self.assertEqual(schema1.parse(str(schema1)), schema1) 4036 self.assertEqual(schema2.parse(str(schema2)), schema2) 4037 self.assertEqual(schema3.parse(str(schema3)), schema3) 4038 4039 def test_while_loop_schema_gen(self): 4040 fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] 4041 graph = make_fx(fn)(*inp).graph 4042 while_loop_node = next( 4043 node 4044 for node in graph.nodes 4045 if node.op == "call_function" 4046 and node.target is torch.ops.higher_order.while_loop 4047 ) 4048 schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node) 4049 self.assertExpectedInline( 4050 str(schema), 4051 """while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950 4052 ) 4053 self.assertEqual(schema.parse(str(schema)), schema) 4054 4055 4056instantiate_parametrized_tests(TestHopSchema) 4057instantiate_parametrized_tests(TestControlFlowTraced) 4058 4059instantiate_parametrized_tests(TestControlFlow) 4060 4061if __name__ == "__main__": 4062 run_tests() 4063