1# Owner(s): ["module: fx"] 2 3import builtins 4import contextlib 5import copy 6import functools 7import inspect 8import math 9import numbers 10import io 11import operator 12import os 13import pickle 14import sys 15import torch 16import traceback 17import typing 18import types 19import warnings 20import unittest 21from math import sqrt 22from functorch.experimental import control_flow 23from torch.multiprocessing import Process 24from torch.testing import FileCheck 25from torch.testing._internal.common_methods_invocations import op_db 26from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests 27import torch.utils._pytree as pytree 28import torch.fx._pytree as fx_pytree 29from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen 30from torch.fx.node import Target, Argument, _format_arg 31from torch.fx.passes import shape_prop 32from torch.fx.immutable_collections import immutable_dict, immutable_list 33from torch.fx.experimental.rewriter import RewritingTracer 34from torch.fx.operator_schemas import get_signature_for_torch_op 35from copy import deepcopy 36from collections import namedtuple 37 38from torch.fx.proxy import TraceError 39from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY 40from torch.fx._symbolic_trace import PHBase, PHWithMeta 41from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 42from fx.test_dce_pass import TestDCE # noqa: F401 43from fx.test_fx_const_fold import TestConstFold # noqa: F401 44from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401 45from fx.test_pass_infra import TestPassManager # noqa: F401 46from fx.test_common_passes import TestCommonPass # noqa: F401 47from fx.test_cse_pass import TestCSEPass # noqa: F401 48from fx.test_matcher_utils import TestMatcher # noqa: F401 49from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401 50 51from fx.test_gradual_type import AnnotationsTest # noqa: F401 52from fx.test_gradual_type import TypeCheckerTest # noqa: F401 53from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union 54from torch.testing._internal.common_utils import ( 55 IS_FBCODE, 56 IS_MACOS, 57 IS_WINDOWS, 58 find_library_location, 59 run_tests, 60 skipIfTorchDynamo, 61) 62from torch.testing._internal.jit_utils import JitTestCase 63 64from fx.named_tup import MyNamedTup 65 66try: 67 from torchvision import models as torchvision_models 68 HAS_TORCHVISION = True 69except ImportError: 70 HAS_TORCHVISION = False 71skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 72from torch.testing._internal.common_quantization import skipIfNoDynamoSupport 73 74class SimpleTest(torch.nn.Module): 75 def forward(self, x): 76 return torch.relu(x + 3.0) 77 78def a_non_torch_leaf(a, b): 79 return a + b 80 81# Used for test_autowrap_function. Autowrapped functions need to be global 82def fx_int(x: float) -> int: 83 return int(x) 84 85def fx_int_x2(x: float) -> int: 86 return int(x) * 2 87 88# used in test_pytree. It's all the way out here because pickling a GraphModule 89# that uses Point errors out if Point is local to the function 90Point = namedtuple('Point', ['x', 'y']) 91 92# Test wrap() passing both a function name as well as a function 93# directly 94def a_lifted_leaf(a, b): 95 return a[0] + a[1] + b 96 97wrap('a_lifted_leaf') 98# Test wrapping twice doesn't break anything 99wrap('a_lifted_leaf') 100 101def a_lifted_leaf2(a, b): 102 return a[0] + a[1] + b 103 104wrap(a_lifted_leaf2) 105 106wrap('len') 107 108wrap('getattr') 109 110def wrapped_named_tup(p1, *, p2): 111 return p1.x + p2.y 112 113wrap(wrapped_named_tup) 114 115@wrap 116def wrapped_via_decorator(a): 117 return a + 1 118 119wrap('wrapped_with_submodule') 120 121def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): 122 return batchnorm1d(x) 123 124def my_decorator(f): 125 @functools.wraps(f) 126 def wrapper_inside_decorator(*args, **kwargs): 127 return f(*args, **kwargs) 128 return wrapper_inside_decorator 129 130@wrap 131@my_decorator 132def wrapped_decorated_fn(x): 133 return x 134 135real_wrapped_via_decorator = wrapped_via_decorator 136real_a_lifed_leaf = a_lifted_leaf 137real_a_lifed_leaf2 = a_lifted_leaf2 138_sqrt = sqrt 139 140wrap('wrapper_fn') 141 142def wrapper_fn(x): 143 return torch.foo(x) 144 145class Pair(NamedTuple): 146 x : torch.Tensor 147 y : torch.Tensor 148 149 def _custom_fx_repr_fn(self) -> str: 150 return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" 151 152# for testing pytrees 153class Foo: # noqa: B209 154 def __init__(self, a, b): 155 self.a = a 156 self.b = b 157 158class Add(torch.nn.Module): 159 def forward(self, x): 160 return x + x 161 162@torch.fx.has_side_effect 163@torch.fx.wrap 164def side_effect_func(x: torch.Tensor): 165 print(x) 166 167class TestFX(JitTestCase): 168 def setUp(self): 169 super().setUp() 170 # Checking for mutable operations whil tracing is feature flagged 171 # Enable it in testing but not by default 172 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 173 torch.fx.proxy.TracerBase.check_mutable_operations = True 174 175 if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): 176 lib_file_path = find_library_location('libtorchbind_test.so') 177 torch.ops.load_library(str(lib_file_path)) 178 179 def tearDown(self): 180 super().tearDown() 181 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 182 183 def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): 184 """Check that an nn.Module's results match the GraphModule version 185 for a given set of args/kwargs. 186 """ 187 kwargs = kwargs if kwargs else {} 188 ref_outs = m(*args, **kwargs) 189 gm = symbolic_trace(m) 190 gm.graph.lint() 191 test_outs = gm(*args, **kwargs) 192 self.assertEqual(ref_outs, test_outs) 193 194 def test_graph_module(self): 195 class MySub(torch.nn.Module): 196 def __init__(self) -> None: 197 super().__init__() 198 self.w = torch.nn.Parameter(torch.rand(4, 3)) 199 200 def forward(self, x): 201 return self.w + x 202 203 class MyModule(torch.nn.Module): 204 def __init__(self) -> None: 205 super().__init__() 206 self.lin = torch.nn.Linear(4, 3) 207 self.sub_mod = MySub() 208 self.w = torch.nn.Parameter(torch.rand(3)) 209 210 def forward(self, A, B, c): 211 t = torch.sigmoid(A) + self.lin(c) 212 return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) 213 214 m = MyModule() 215 gm = symbolic_trace(m) 216 217 ms = torch.jit.script(gm) 218 219 class M2(torch.nn.Module): 220 def forward(self, A): 221 m, idx = torch.max(A, 0) 222 return m + 1, idx + 1 223 224 m2 = M2() 225 gm2 = symbolic_trace(m2) 226 227 class T(torch.nn.Module): 228 229 def forward(self, A, b=4, *args, c=5, **kwargs): 230 x = A + 1 + args[0] + kwargs['3'] 231 return x 232 233 t = T() 234 symbolic_trace(t) 235 236 # test for issue described at https://github.com/pytorch/pytorch/issues/63883 237 class M3(torch.nn.Module): 238 def forward(self, x): 239 return torch.relu(x) 240 241 m3 = M3() 242 gm3 = symbolic_trace(m3) 243 new_instance = gm3.__new__(type(gm3)) 244 new_instance.__init__(gm3, gm3.graph) 245 246 x = torch.randn(5, 3) 247 torch.testing.assert_close(new_instance(x), torch.relu(x)) 248 249 def test_informative_co_filename(self): 250 class MyModule(torch.nn.Module): 251 def forward(self, a): 252 return a * 2 253 254 gm = symbolic_trace(MyModule()) 255 self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename) 256 257 def test_custom_import(self): 258 graph = torch.fx.Graph() 259 a = graph.placeholder('x') 260 b = graph.placeholder('y') 261 c = graph.call_function(a_non_torch_leaf, (a, b)) 262 d = graph.call_function(torch.sin, (c,)) 263 graph.output(d) 264 gm = GraphModule(torch.nn.Module(), graph) 265 x, y = torch.rand(1), torch.rand(1) 266 self.assertEqual(torch.sin(x + y), gm(x, y)) 267 268 def test_args_kwargs(self): 269 class T(torch.nn.Module): 270 def forward(self, *args, **kwargs): 271 x = args[0] + kwargs['foo'] 272 return x 273 274 t = T() 275 self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) 276 277 def test_varargs_concrete(self): 278 class T(torch.nn.Module): 279 def forward(self, *args, **kwargs): 280 x = args[0] + args[1] 281 return x 282 283 args = (torch.rand(1), torch.rand(1)) 284 285 t = T() 286 ref_outs = t(*args) 287 gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH)) 288 gm.graph.lint() 289 test_outs = gm(*args) 290 self.assertEqual(ref_outs, test_outs) 291 292 def test_args_kwargs_no_self(self): 293 class T(torch.nn.Module): 294 def forward(*args, **kwargs): # noqa: B902 295 self = args[0] 296 return torch.relu(args[1]) 297 298 t = T() 299 with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): 300 self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) 301 302 def test_fx_shifts(self): 303 class MyModule(torch.nn.Module): 304 def forward(self, x): 305 return x << 3, x >> 3 306 307 input = torch.LongTensor(10).random_(0, 1024) 308 309 m = MyModule() 310 self.checkGraphModule(m, (input,)) 311 312 def test_fx_and_or(self): 313 class MyModule(torch.nn.Module): 314 def forward(self, x): 315 return x & x, x | x 316 317 input = torch.LongTensor(10).random_(0, 1024) 318 319 m = MyModule() 320 self.checkGraphModule(m, (input,)) 321 322 def test_dict(self): 323 class MyDictMod(torch.nn.Module): 324 def forward(self, d): 325 return d['3'].relu(), {'4' : d['3'].neg()} 326 327 input_dict = {'3': torch.rand(3, 4)} 328 m = MyDictMod() 329 330 self.checkGraphModule(m, (input_dict,)) 331 332 def test_matmul_tracing(self): 333 const = torch.randn(3) 334 335 def matmul_f(x): 336 return x @ const 337 338 mod = symbolic_trace(matmul_f) 339 inp = torch.randn(3) 340 self.assertEqual(mod(inp), matmul_f(inp)) 341 342 def rmatmul_f(x): 343 return const @ x 344 345 mod = symbolic_trace(rmatmul_f) 346 inp = torch.randn(3) 347 self.assertEqual(mod(inp), rmatmul_f(inp)) 348 349 @skipIfNoDynamoSupport 350 def test_control_flow_tracing(self): 351 def true(x, y): 352 return x + y 353 354 def false(x, y): 355 return x - y 356 357 def f(x, y): 358 x = control_flow.cond(x[0] == 0, true, false, [x, y]) 359 360 with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"): 361 _ = symbolic_trace(f) 362 363 def test_disallow_override(self): 364 # Custom delegate to disallow in-place tensor operations 365 class NoMutableCallTracer(Tracer): 366 def create_node(self, kind : str, target : Union[str, Callable], 367 args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 368 type_expr : Optional[Any] = None) -> Node: 369 name = target if isinstance(target, str) else torch.typename(target) 370 if name[-1] == '_': 371 raise RuntimeError('In-place operations are not supported') 372 return super().create_node(kind, target, args, kwargs, name) 373 374 # Test method 375 class MyInplaceMod(torch.nn.Module): 376 def forward(self, x): 377 x.add_(3.0) 378 return x 379 380 m = MyInplaceMod() 381 382 with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 383 NoMutableCallTracer().trace(m) 384 385 # Test free function 386 class MyInplaceMod2(torch.nn.Module): 387 def forward(self, x): 388 torch.log_(x) 389 return x 390 m2 = MyInplaceMod2() 391 with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 392 NoMutableCallTracer().trace(m2) 393 394 # Test symbolic node as an arg 395 class MyInplaceMod3(torch.nn.Module): 396 def forward(self, x): 397 y = torch.ones(3, 4) 398 y.add_(x) 399 return x 400 m3 = MyInplaceMod3() 401 with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 402 NoMutableCallTracer().trace(m3) 403 404 def test_leaf_module(self): 405 # Custom delegate to make it so that there are no leaf modules, everything 406 # should get traced through 407 class NoLeafModulesTracer(Tracer): 408 def is_leaf_module(self, m, qualname): 409 return False 410 411 class MyReluMod(torch.nn.Module): 412 def __init__(self) -> None: 413 super().__init__() 414 self.relu = torch.nn.ReLU() 415 416 def forward(self, x): 417 return self.relu(x) 418 419 mrm = MyReluMod() 420 sym = NoLeafModulesTracer().trace(mrm) 421 for node in sym.nodes: 422 self.assertNotEqual(node.op, 'call_module') 423 sym.lint() 424 425 def test_wrap(self): 426 self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) 427 428 def to_trace(y): 429 return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) 430 431 m = symbolic_trace(to_trace) 432 self.assertIn('a_lifted_leaf', m.code) 433 self.assertEqual(27, m(2)) 434 self.assertIs(a_lifted_leaf, real_a_lifed_leaf) 435 436 def test_wrap_fn_directly(self): 437 self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) 438 439 def to_trace(y): 440 return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) 441 442 m = symbolic_trace(to_trace) 443 self.assertIn('a_lifted_leaf2', m.code) 444 self.assertEqual(27, m(2)) 445 self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) 446 447 def test_wrapped_via_decorator(self): 448 self.assertEqual(wrapped_via_decorator(0), 1) 449 450 def to_trace(y): 451 return wrapped_via_decorator(y) 452 453 m = symbolic_trace(to_trace) 454 self.assertIn('wrapped_via_decorator', m.code) 455 self.assertEqual(m(0), 1) 456 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 457 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 458 459 def test_wrapped_via_decorator_and_transformed(self): 460 self.assertEqual(wrapped_via_decorator(0), 1) 461 462 def to_trace(y): 463 return wrapped_via_decorator(y) 464 465 m = symbolic_trace(to_trace) 466 self.assertIn('wrapped_via_decorator', m.code) 467 self.assertEqual(m(0), 1) 468 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 469 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 470 471 transformed = torch.fx.Transformer(m).transform() 472 self.assertIn('wrapped_via_decorator', transformed.code) 473 self.assertEqual(transformed(0), 1) 474 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 475 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 476 477 def test_wrap_with_submodule(self): 478 479 class M(torch.nn.Module): 480 def __init__(self) -> None: 481 super().__init__() 482 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 483 484 def forward(self, x: torch.Tensor): 485 return wrapped_with_submodule(x, self.batchnorm1d) 486 487 m = symbolic_trace(M()) 488 489 self.assertIn("wrapped_with_submodule", m.code) 490 491 input = torch.rand(3, 2) 492 ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 493 self.assertEqual(ref_batchnorm1d(input), m(input)) 494 495 def test_wrapped_retrace(self): 496 def to_trace(y): 497 return wrapped_via_decorator(y) 498 499 m = symbolic_trace(to_trace) 500 self.assertIn('wrapped_via_decorator', m.code) 501 self.assertEqual(m(0), 1) 502 503 retraced = symbolic_trace(m) 504 self.assertIn('wrapped_via_decorator', retraced.code) 505 self.assertEqual(retraced(0), 1) 506 507 def test_wrap_decorated_function(self): 508 def to_trace(y): 509 return wrapped_decorated_fn(y) 510 511 m = symbolic_trace(to_trace) 512 self.assertIn('wrapped_decorated_fn', m.code) 513 self.assertEqual(m(1), 1) 514 515 def test_graph_edit_with_proxy(self): 516 class M(torch.nn.Module): 517 def forward(self, a, b): 518 return a + b 519 m = M() 520 g = symbolic_trace(m).graph 521 new_g = torch.fx.Graph() 522 val_map : Dict[Node, Node] = {} 523 output_val = new_g.graph_copy(g, val_map) 524 t = Proxy(output_val) 525 # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. 526 new_g.output((t + t).node) 527 gm = GraphModule(m, new_g) 528 gm.graph.lint() 529 self.assertEqual(gm(3, 4), 14) 530 531 def test_proxy_deepcopy_without_tracer(self): 532 class MyModule(torch.nn.Module): 533 def __init__(self): 534 super().__init__() 535 536 def forward(self, x): 537 return 2 * x 538 539 module = MyModule() 540 traced = symbolic_trace(module) 541 node = list(traced.graph.nodes)[-2] 542 p = torch.fx.Proxy(node, None) 543 node.proxy = p 544 p2 = copy.deepcopy(p) 545 self.assertTrue(isinstance(p2, torch.fx.Proxy)) 546 self.assertEqual(p2.node.name, node.name) 547 self.assertEqual(p2.node.target, node.target) 548 self.assertNotEqual(id(p2.node), id(node)) 549 550 def test_proxy_deepcopy_with_tracer(self): 551 class TestTracer(Tracer): 552 def __init__(self, name): 553 super().__init__() 554 self.name = name 555 556 def is_leaf_module(self, module, name): 557 return True 558 559 class MyModule(torch.nn.Module): 560 def __init__(self): 561 super().__init__() 562 563 def forward(self, x): 564 return 2 * x 565 566 module = MyModule() 567 tracer = TestTracer("mytracer") 568 traced = symbolic_trace(module) 569 node = list(traced.graph.nodes)[-2] 570 p = torch.fx.Proxy(node, tracer) 571 node.proxy = p 572 p2 = copy.deepcopy(p) 573 self.assertTrue(isinstance(p2, torch.fx.Proxy)) 574 self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer)) 575 self.assertEqual(p2.tracer.name, "mytracer") 576 self.assertEqual(p2.node.name, node.name) 577 self.assertEqual(p2.node.target, node.target) 578 self.assertNotEqual(id(p2.node), id(node)) 579 self.assertNotEqual(id(p2.tracer), id(tracer)) 580 581 def test_concrete_arg_none_assert(self): 582 class Foo(torch.nn.Module): 583 def forward(self, x, val=None): 584 return x if val is None else x + val 585 586 f = Foo() 587 traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None}) 588 with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'): 589 traced(torch.randn(5), torch.randn(5)) 590 591 x = torch.randn(5) 592 torch.testing.assert_close(traced(x), f(x)) 593 594 def test_trace_multiple_funcs(self): 595 class Foo(torch.nn.Module): 596 def forward(self, x, y): 597 return x + y 598 599 def minus_forward(self, x, y): 600 return x - y 601 602 def multiply_forward(self, x, y): 603 return x * y 604 605 f = Foo() 606 x, y = torch.randn(5), torch.randn(5) 607 608 print(torch.__version__) 609 610 tracer = Tracer() 611 torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y)) 612 613 tracer.traced_func_name = "minus_forward" 614 torch.testing.assert_close( 615 GraphModule(f, tracer.trace(f))(x, y), 616 f.minus_forward(x, y), 617 ) 618 619 tracer.traced_func_name = "multiply_forward" 620 torch.testing.assert_close( 621 GraphModule(f, tracer.trace(f))(x, y), 622 f.multiply_forward(x, y), 623 ) 624 625 tracer.traced_func_name = "add_forward" 626 with self.assertRaisesRegex(AssertionError, "doesn't exist in"): 627 tracer.trace(f) 628 629 def test_graph_unique_names(self): 630 class M(torch.nn.Module): 631 def forward(self, a, b): 632 return a + b 633 m = M() 634 g = symbolic_trace(m).graph 635 new_g = torch.fx.Graph() 636 val_map : Dict[Node, Node] = {} 637 output_val = new_g.graph_copy(g, val_map) 638 t = Proxy(output_val) 639 # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. 640 new_g.output((t + t).node) 641 gm = GraphModule(m, new_g) 642 seen_names : Set[str] = set() 643 for node in gm.graph.nodes: 644 assert node.name not in seen_names 645 seen_names.add(node.name) 646 647 def test_stack_traces(self): 648 class M(torch.nn.Module): 649 def forward(self, a, b): 650 return a + b 651 652 tracer = torch.fx.Tracer() 653 tracer.record_stack_traces = True 654 655 graph = tracer.trace(M()) 656 # saving the original list because we will insert new nodes as a part of a test 657 orig_graph_nodes = list(graph.nodes) 658 for node in orig_graph_nodes: 659 if node.op == 'output': 660 continue 661 self.assertTrue(node.stack_trace is not None) 662 assert 'test_fx.py' in node.stack_trace 663 664 # verify that copying the node does not lose the stack trace 665 new_node = graph.node_copy(node) 666 self.assertTrue(new_node.stack_trace is not None) 667 assert 'test_fx.py' in new_node.stack_trace 668 669 def test_stack_traces_with_transformer(self): 670 class M(torch.nn.Module): 671 def forward(self, a, b): 672 return a + b 673 674 tracer = torch.fx.Tracer() 675 tracer.record_stack_traces = True 676 677 graph = tracer.trace(M()) 678 gm = GraphModule(tracer.root, graph) 679 new_gm = Transformer(gm).transform() 680 681 # nodes after Transformer should still preserve the original node's stack trace 682 for node in new_gm.graph.nodes: 683 if node.op in {'placeholder', 'output'}: 684 continue 685 self.assertTrue(node.stack_trace is not None) 686 assert 'test_fx.py' in node.stack_trace 687 688 def test_lineno_map(self): 689 class M(torch.nn.Module): 690 def forward(self, a, b): 691 a = torch.sin(a) 692 b = torch.cos(b) 693 return a + b 694 695 tracer = torch.fx.Tracer() 696 graph = tracer.trace(M()) 697 gm = GraphModule(tracer.root, graph) 698 expected = {1: 2, 2: 3, 3: 4, 4: 5} 699 self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) 700 701 # test custom codegen 702 def transform_code(code): 703 return ["print('hello!')\n", *code] 704 gm.graph.on_generate_code(lambda _: transform_code) 705 gm.recompile() 706 expected = {2: 2, 3: 3, 4: 4, 5: 5} 707 self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) 708 709 def test_graph_unique_names_manual(self): 710 graph : torch.fx.Graph = torch.fx.Graph() 711 a : torch.fx.Node = graph.create_node('placeholder', 'x') 712 b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') 713 c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') 714 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 715 graph.output(d) 716 graph2 = torch.fx.Graph() 717 val_map : Dict[Node, Node] = {} 718 graph2.graph_copy(graph, val_map) 719 seen_names : Set[str] = set() 720 for node in graph2.nodes: 721 assert node.name not in seen_names 722 seen_names.add(node.name) 723 724 def test_unpack(self): 725 class M(torch.nn.Module): 726 def forward(self, a, b): 727 c, d = a 728 return c + d + b 729 730 a = (torch.rand(1), torch.rand(1)) 731 b = torch.rand(1) 732 m = M() 733 self.checkGraphModule(m, (a, b)) 734 735 def test_native_callable(self): 736 if IS_FBCODE or IS_WINDOWS or IS_MACOS: 737 raise unittest.SkipTest("non-portable load_library call used in test") 738 # This test exercises the case where we use FX to translate from Python 739 # code to some native callable object 740 # 741 # For the purposes of testing, we use ElementwiseInterpreter defined 742 # in test_custom_class.cpp. 743 # 744 # We test that we can 745 # 1) Construct a native callable from FX IR 746 # 2) Construct a drop-in replacement module that delegates to the 747 # native callable rather than the original code 748 # 3) Run both the original code and native callable wrapper with 749 # equivalent results 750 # 4) TorchScript compile the native callable wrapper and confirm 751 # equivalent results with the reference 752 # 5) TorchScript serialize and deserialize the native callable 753 # and confirm equivalent results with the reference 754 755 # We use this simple Module as a reference computation 756 class MySimpleMod(torch.nn.Module): 757 def forward(self, x): 758 return 3.0 * x + x 759 760 msm = MySimpleMod() 761 762 # This is what a lowering pass might look like: a function that takes 763 # a valid nn.Module, symbolically traces it, lowers the Module to some 764 # representation, and wraps that representation up into another 765 # nn.Module instance that handles dispatch to the compiled/lowered code. 766 def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: 767 # ===== Stage 1: Symbolic trace the module ===== 768 mod = symbolic_trace(orig_mod) 769 770 # ===== Stage 2: Lower GraphModule representation to the C++ 771 # interpreter's instruction format ====== 772 instructions = [] 773 constant_idx = 0 774 constants = {} 775 fn_input_names = [] 776 777 target_to_name = { 778 operator.add : "add", 779 operator.mul : "mul" 780 } 781 782 output_node : Optional[Node] = None 783 # For each instruction, create a triple 784 # (instruction_name : str, inputs : List[str], output : str) 785 # to feed into the C++ interpreter 786 for n in mod.graph.nodes: 787 target, args, out_name = n.target, n.args, n.name 788 assert len(n.kwargs) == 0, "kwargs currently not supported" 789 790 if n.op == 'placeholder': 791 # Placeholders specify function argument names. Save these 792 # for later when we generate the wrapper GraphModule 793 fn_input_names.append(target) 794 elif n.op == 'call_function': 795 assert target in target_to_name, "Unsupported call target " + target 796 arg_names = [] 797 for arg in args: 798 if not isinstance(arg, Node): 799 # Pull out constants. These constants will later be 800 # fed to the interpreter C++ object via add_constant() 801 arg_name = f'constant_{constant_idx}' 802 constants[arg_name] = torch.tensor( 803 [arg] if isinstance(arg, numbers.Number) else arg) 804 arg_names.append(arg_name) 805 constant_idx += 1 806 else: 807 arg_names.append(arg.name) 808 instructions.append((target_to_name[target], arg_names, out_name)) 809 elif n.op == 'output': 810 if output_node is not None: 811 raise RuntimeError('Multiple output nodes!') 812 output_node = n 813 else: 814 raise RuntimeError('Unsupported opcode ' + n.op) 815 816 interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() 817 # Load constants 818 for k, v in constants.items(): 819 interpreter.add_constant(k, v) 820 # Specify names for positional input arguments 821 interpreter.set_input_names(fn_input_names) 822 # Load instructions 823 interpreter.set_instructions(instructions) 824 # Specify name for single output 825 assert isinstance(output_node.args[0], torch.fx.Node) 826 interpreter.set_output_name(output_node.args[0].name) 827 828 # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== 829 class WrapperModule(torch.nn.Module): 830 def __init__(self, interpreter): 831 super().__init__() 832 self.interpreter = interpreter 833 834 wrapper = WrapperModule(interpreter) 835 836 # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter 837 # 3) Returns the speficied return value 838 839 # FIXME: The following code could be greatly simplified by symbolic_trace'ing 840 # the wrapper with a Tracer that considers the Wrapper instance a root 841 # module, however, I can't get `__call__` exposed on TorchBind classes 842 # without it messing up Python `hasattr` for some reason. More digging 843 # into CPython's implementation of hasattr is probably in order... 844 845 graph = torch.fx.Graph() 846 # Add placeholders for fn inputs 847 placeholder_nodes = [] 848 for name in fn_input_names: 849 placeholder_nodes.append(graph.create_node('placeholder', name)) 850 851 # Get the interpreter object 852 interpreter_node = graph.create_node('get_attr', 'interpreter') 853 854 # Add a node to call the interpreter instance 855 output_node = graph.create_node( 856 op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) 857 858 # Register output 859 graph.output(output_node) 860 861 graph.lint() 862 863 # Return final GraphModule!!! 864 return GraphModule(wrapper, graph) 865 866 # Lower GraphModule to C++ interpreter 867 lowered = lower_to_elementwise_interpreter(msm) 868 869 # Compare correctness with original module 870 x = torch.rand(3, 4) 871 ref_out = msm(x) 872 test_out = lowered(x) 873 torch.testing.assert_close(test_out, ref_out) 874 875 # Test TorchScript compilation 876 scripted_lowered = torch.jit.script(lowered) 877 script_out = scripted_lowered(x) 878 torch.testing.assert_close(script_out, ref_out) 879 880 # Test TorchScript ser/de 881 import_copy = self.getExportImportCopy(scripted_lowered) 882 imported_out = import_copy(x) 883 torch.testing.assert_close(imported_out, ref_out) 884 885 def test_reserved_getattr(self): 886 """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" 887 class M(torch.nn.Module): 888 def forward(self, a): 889 return a.foo.bar.baz 890 891 m = M() 892 m_g = symbolic_trace(m) 893 m_g.graph.lint() 894 for node in m_g.graph.nodes: 895 self.assertTrue(node.name != "getattr") 896 897 @unittest.skip("Hotfix for SEV remediation") 898 def test_trace_buffer_slice(self): 899 bs, d_hid = 10, 23 900 901 class ExampleCode(torch.nn.Module): 902 def __init__(self) -> None: 903 super().__init__() 904 self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 905 self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 906 self.lin = torch.nn.Linear(d_hid, d_hid) 907 self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid)) 908 909 def forward(self, x): 910 x = torch.mm(x, self.mm_param) 911 skip_connection = x 912 x = torch.relu(x) 913 x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]] 914 x = self.lin(x) 915 x = torch.relu(x) 916 x = x + skip_connection 917 x = torch.mm(x, self.mm_param2) 918 x = self.lin(x) 919 return x 920 921 ec = ExampleCode() 922 923 traced = torch.fx.symbolic_trace(ec) 924 925 x = torch.randn(bs, d_hid) 926 torch.testing.assert_close(ec(x), traced(x)) 927 928 def test_node_tagging(self): 929 class TaggingTracer(Tracer): 930 def create_node(self, kind : str, target : Union[str, Callable], 931 args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 932 type_expr : Optional[Any] = None) -> Node: 933 n = super().create_node(kind, target, args, kwargs, name) 934 n.tag = 'foo' 935 return n 936 937 class M(torch.nn.Module): 938 def forward(self, a, b): 939 return a + b 940 941 m = M() 942 g = TaggingTracer().trace(m) 943 g.lint() 944 for n in g.nodes: 945 self.assertTrue(hasattr(n, 'tag')) 946 self.assertEqual(n.tag, 'foo') 947 948 def test_tensor_attribute(self): 949 class TensorAttribute(torch.nn.Module): 950 def __init__(self) -> None: 951 super().__init__() 952 self.tensor = torch.rand(3, 4) 953 954 def forward(self, x): 955 return torch.nn.functional.linear(x, self.tensor) 956 957 ta = TensorAttribute() 958 traced = symbolic_trace(ta) 959 traced(torch.rand(4, 4)) 960 961 class WrapperForQualname(torch.nn.Module): 962 def __init__(self) -> None: 963 super().__init__() 964 self.ta = TensorAttribute() 965 966 def forward(self, x): 967 return torch.nn.functional.linear(x, self.ta.tensor) 968 969 wfq = WrapperForQualname() 970 traced2 = symbolic_trace(wfq) 971 traced2.graph.lint() 972 traced2(torch.rand(4, 4)) 973 974 def test_tensor_attribute_coalseced(self): 975 976 def count_attrs(fx_module): 977 targets = set() 978 for node in traced.graph.nodes: 979 if node.op == 'get_attr': 980 targets.add(node.target) 981 return len(targets) 982 983 val = torch.tensor(5) 984 985 def f(x): 986 return x + val + val 987 traced = symbolic_trace(f) 988 traced.graph.lint() 989 self.assertEqual(count_attrs(traced), 1) 990 991 val2 = torch.tensor(5) 992 993 def f(x): 994 val = torch.tensor(5) 995 return x + val + val2 996 997 traced = symbolic_trace(f) 998 traced.graph.lint() 999 self.assertEqual(count_attrs(traced), 2) 1000 1001 def test_symbolic_trace_sequential(self): 1002 class Simple(torch.nn.Module): 1003 def forward(self, x): 1004 return torch.neg(x) 1005 1006 seq = torch.nn.Sequential( 1007 Simple(), 1008 Simple(), 1009 Simple() 1010 ) 1011 traced = symbolic_trace(seq) 1012 traced.graph.lint() 1013 x = torch.rand(3, 4) 1014 self.assertEqual(traced(x), seq(x)) 1015 1016 def test_tensor_constant(self): 1017 class ConstTensor(torch.nn.Module): 1018 def forward(self, x): 1019 return torch.nn.functional.linear(x, torch.zeros(3, 4)) 1020 1021 ct = ConstTensor() 1022 traced = symbolic_trace(ct) 1023 traced.graph.lint() 1024 traced(torch.rand(4, 4)) 1025 1026 def test_pickle_graphmodule(self): 1027 class Nested(torch.nn.Module): 1028 def __init__(self) -> None: 1029 super().__init__() 1030 self.st = torch.nn.Linear(4, 4) 1031 1032 def forward(self, x): 1033 return self.st(x) 1034 1035 n = Nested() 1036 traced = symbolic_trace(n) 1037 traced.graph.lint() 1038 pickled = pickle.dumps(traced) 1039 loaded = pickle.loads(pickled) 1040 loaded.graph.lint() 1041 x = torch.rand(3, 4) 1042 self.assertEqual(loaded(x), traced(x)) 1043 1044 def test_pickle_custom_import(self): 1045 graph = torch.fx.Graph() 1046 a = graph.placeholder('x') 1047 b = graph.placeholder('y') 1048 c = graph.call_function(a_non_torch_leaf, (a, b)) 1049 d = graph.call_function(torch.sin, (c,)) 1050 graph.output(d) 1051 gm = GraphModule(torch.nn.Module(), graph) 1052 pickled = pickle.dumps(gm) 1053 loaded = pickle.loads(pickled) 1054 loaded.graph.lint() 1055 x, y = torch.rand(1), torch.rand(1) 1056 self.assertEqual(loaded(x, y), gm(x, y)) 1057 1058 def test_all_input_nodes(self): 1059 graph : torch.fx.Graph = torch.fx.Graph() 1060 a : torch.fx.Node = graph.placeholder('x') 1061 b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) 1062 c : torch.fx.Node = graph.get_attr('y_attr') 1063 d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) 1064 e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) 1065 graph.output(e) 1066 graph.lint() 1067 1068 self.assertEqual(b.all_input_nodes, [a]) 1069 self.assertEqual(c.all_input_nodes, []) 1070 self.assertEqual(d.all_input_nodes, [b, c]) 1071 self.assertEqual(e.all_input_nodes, [d]) 1072 1073 def test_deepcopy_graphmodule_with_transform(self): 1074 st = SimpleTest() 1075 traced = symbolic_trace(st) 1076 traced.graph.lint() 1077 1078 def transform(traced): 1079 new_graph = torch.fx.Graph() 1080 val_map : Dict[Node, Node] = {} 1081 output_value = new_graph.graph_copy(traced.graph, val_map) 1082 relu_out = new_graph.create_node( 1083 op='call_method', target='neg', args=(output_value,), kwargs={}) 1084 new_graph.output(relu_out) 1085 return GraphModule(traced, new_graph) 1086 transformed = transform(traced) 1087 transformed.graph.lint() 1088 copied = copy.deepcopy(transformed) 1089 self.assertNotEqual(id(type(transformed)), id(type(copied))) 1090 x = torch.randn(3, 4) 1091 self.assertEqual(copied(x), transformed(x)) 1092 1093 def test_deepcopy_with_submods_params(self): 1094 class Bar(torch.nn.Module): 1095 def __init__(self) -> None: 1096 super().__init__() 1097 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1098 1099 def forward(self, x): 1100 return torch.relu(x) + self.param 1101 1102 class Baz(torch.nn.Module): 1103 def __init__(self) -> None: 1104 super().__init__() 1105 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1106 self.bar = Bar() 1107 1108 def forward(self, x): 1109 return self.bar(x) - self.param 1110 1111 baz = Baz() 1112 traced = symbolic_trace(baz) 1113 traced.graph.lint() 1114 copied = copy.deepcopy(traced) 1115 copied.graph.lint() 1116 1117 def test_deepcopy_graph_with_tracer_cls(self): 1118 class TestTracer(Tracer): 1119 def is_leaf_module(self, module, name): 1120 return True 1121 1122 g = Graph(tracer_cls=TestTracer) 1123 x = g.placeholder("x") 1124 g.output(x) 1125 1126 h = copy.deepcopy(g) 1127 self.assertIsNotNone(h._tracer_cls) 1128 self.assertTrue(g._tracer_cls == h._tracer_cls) 1129 1130 def test_unpack_list_better_error(self): 1131 class SomeArgs(torch.nn.Module): 1132 def forward(self, a, b): 1133 return torch.rand(3, 4) 1134 1135 class UnpacksList(torch.nn.Module): 1136 def __init__(self) -> None: 1137 super().__init__() 1138 self.sa = SomeArgs() 1139 1140 def forward(self, x : list): 1141 return self.sa(*x) 1142 1143 ul = UnpacksList() 1144 with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): 1145 symbolic_trace(ul) 1146 1147 def test_unpack_dict_better_error(self): 1148 class SomeKwargs(torch.nn.Module): 1149 def forward(self, x=3, y=4): 1150 return torch.rand(3, 4) 1151 1152 class UnpacksDict(torch.nn.Module): 1153 def __init__(self) -> None: 1154 super().__init__() 1155 self.sk = SomeKwargs() 1156 1157 def forward(self, x : dict): 1158 return self.sk(**x) 1159 1160 ud = UnpacksDict() 1161 with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): 1162 symbolic_trace(ud) 1163 1164 def test_pretty_print_targets(self): 1165 # Test that Graph pretty-print prints friendly name for targets 1166 # in `operator` and `builtins` 1167 1168 class SomeMod(torch.nn.Module): 1169 def forward(self, x): 1170 return torch.add(x.foo + x.bar, 3.0) 1171 1172 traced = symbolic_trace(SomeMod()) 1173 graph_str = str(traced.graph) 1174 self.assertIn('builtins.getattr', graph_str) 1175 self.assertIn('operator.add', graph_str) 1176 self.assertIn('torch.add', graph_str) 1177 1178 def test_pretty_print_node(self): 1179 class M(torch.nn.Module): 1180 def __init__(self) -> None: 1181 super().__init__() 1182 self.param: torch.nn.Parameter = torch.nn.Parameter( 1183 torch.rand(3, 4)) 1184 self.linear = torch.nn.Linear(4, 5) 1185 1186 def forward(self, x: torch.Tensor, y: int = 2): 1187 return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0) 1188 1189 traced = symbolic_trace(M()) 1190 1191 all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) 1192 1193 FileCheck().check("x").check("placeholder") \ 1194 .check("y").check("placeholder") \ 1195 .check("getitem").check("call_function") \ 1196 .check("param").check("get_attr") \ 1197 .check("add").check("call_function") \ 1198 .check("linear").check("call_module") \ 1199 .check("clamp").check("call_method") \ 1200 .run(all_formatted) 1201 1202 def test_script_tensor_constant(self): 1203 # TorchScript seems to ignore attributes that start with `__`. 1204 # We used to call anonymous Tensor values `__tensor_constant*`, but 1205 # they were getting ignored by script. Now they're called 1206 # `_tensor_constant*` 1207 class IHaveATensorConstant(torch.nn.Module): 1208 def forward(self, x): 1209 return x + torch.rand(3, 4) 1210 1211 traced = torch.fx.symbolic_trace(IHaveATensorConstant()) 1212 torch.jit.script(traced) 1213 1214 def test_autowrap_functions(self): 1215 class AutowrapFnTest(torch.nn.Module): 1216 def forward(self, x): 1217 return fx_int(x.shape[0] / 2) 1218 1219 class AutowrapFnTest2(torch.nn.Module): 1220 def forward(self, x): 1221 return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2) 1222 1223 # Check function(s) are wrapped 1224 # `int` would normally throw a TypeError as argument can't be `Proxy` 1225 tracer = Tracer(autowrap_functions=(fx_int,)) 1226 graph = tracer.trace(AutowrapFnTest()) 1227 traced = GraphModule(tracer.root, graph, 'test') 1228 tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) 1229 tracer_2.trace(AutowrapFnTest2()) 1230 1231 # Test scriptability 1232 traced_scripted = torch.jit.script(traced) 1233 self.assertEqual(traced_scripted(torch.rand(4)), 2) 1234 1235 def test_tuple_no_subscript(self): 1236 def foo(x : Tuple): 1237 return x[0] 1238 1239 traced = torch.fx.symbolic_trace(foo) 1240 x = (torch.randn(5, 3),) 1241 torch.testing.assert_close(traced(x), x[0]) 1242 1243 bio = io.BytesIO() 1244 1245 torch.save(traced, bio) 1246 1247 bio.seek(0) 1248 1249 # weights_only=False as this loads a GraphModule 1250 # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default 1251 loaded = torch.load(bio, weights_only=False) 1252 1253 torch.testing.assert_close(loaded(x), x[0]) 1254 1255 def test_torch_fx_len(self): 1256 class FXLenTest(torch.nn.Module): 1257 def forward(self, x): 1258 return len(x) 1259 1260 traced = symbolic_trace(FXLenTest()) 1261 self.assertEqual(traced(torch.rand(3, 4)), 3) 1262 1263 # Test scriptability 1264 scripted = torch.jit.script(FXLenTest()) 1265 self.assertEqual(scripted(torch.rand(3)), 3) 1266 1267 traced_scripted = torch.jit.script(traced) 1268 self.assertEqual(traced_scripted(torch.rand(3)), 3) 1269 1270 # Test non-proxy len 1271 class FXLenTest2(torch.nn.Module): 1272 def __init__(self) -> None: 1273 super().__init__() 1274 self.l = [3, 4, 5] 1275 1276 def forward(self, x): 1277 return x + len(self.l) 1278 1279 traced2 = symbolic_trace(FXLenTest2()) 1280 inp = torch.rand(3, 4) 1281 self.assertEqual(traced2(inp), inp + 3.0) 1282 self.assertIs(len, builtins.len) 1283 1284 def test_torch_fx_getattr(self): 1285 class FXGetattrTest(torch.nn.Module): 1286 def forward(self, x): 1287 return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) 1288 1289 traced = symbolic_trace(FXGetattrTest()) 1290 self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) 1291 1292 def test_sqrt(self): 1293 class Sqrt1(torch.nn.Module): 1294 def forward(self, x): 1295 return sqrt(x.size(0)) 1296 1297 class Sqrt2(torch.nn.Module): 1298 def forward(self, x): 1299 return math.sqrt(x.size(0)) 1300 1301 class Sqrt3(torch.nn.Module): 1302 def forward(self, x): 1303 return x + math.sqrt(2) + sqrt(2) 1304 1305 self.checkGraphModule(Sqrt1(), [torch.zeros(8)]) 1306 self.checkGraphModule(Sqrt2(), [torch.zeros(8)]) 1307 self.checkGraphModule(Sqrt3(), [torch.zeros(8)]) 1308 self.assertIs(sqrt, _sqrt) 1309 self.assertIs(math.sqrt, _sqrt) 1310 1311 def test_torch_custom_ops(self): 1312 class M(torch.nn.Module): 1313 def forward(self, a): 1314 b = torch.ops.aten.sigmoid(a) 1315 c = torch.ops.aten.cat([a, b]) 1316 return torch.ops.aten.cat((c, c)) 1317 m = M() 1318 input = torch.randn(3) 1319 ref_out = m(input) 1320 gm = symbolic_trace(m) 1321 gm.graph.lint() 1322 out = gm(input) 1323 self.assertEqual(out, ref_out) 1324 1325 def test_torch_op_overloads(self): 1326 class M(torch.nn.Module): 1327 def forward(self, a): 1328 b = torch.ops.aten.add.Tensor(a, a) 1329 return b 1330 m = M() 1331 input = torch.randn(3) 1332 ref_out = m(input) 1333 gm = symbolic_trace(m) 1334 gm.graph.lint() 1335 out = gm(input) 1336 self.assertEqual(out, ref_out) 1337 1338 for node in gm.graph.nodes: 1339 if node.op == 'call_function': 1340 assert isinstance(node.target, torch._ops.OpOverload) 1341 assert node.target.__name__ == 'add.Tensor' 1342 1343 def test_pickle_torch_custom_ops(self): 1344 class M(torch.nn.Module): 1345 def forward(self, a): 1346 b = torch.ops.aten.sigmoid(a) 1347 c = torch.ops.aten.cat([a, b]) 1348 return torch.ops.aten.cat((c, c)) 1349 m = M() 1350 input = torch.randn(3) 1351 ref_out = m(input) 1352 gm = symbolic_trace(m) 1353 gm.graph.lint() 1354 pickled = pickle.dumps(gm) 1355 loaded = pickle.loads(pickled) 1356 self.assertEqual(loaded(input), gm(input)) 1357 1358 def test_pretty_print(self): 1359 st = SimpleTest() 1360 traced = symbolic_trace(st) 1361 traced.graph.lint() 1362 printed = str(traced) 1363 assert 'SimpleTest()' in printed 1364 assert 'torch.relu' in printed 1365 1366 def test_pretty_print_graph(self): 1367 class KwargPrintTest(torch.nn.Module): 1368 def forward(self, x): 1369 return torch.squeeze(x + 3.0, dim=2) 1370 st = KwargPrintTest() 1371 traced = symbolic_trace(st) 1372 traced.graph.lint() 1373 stringed = str(traced.graph) 1374 for s in ['args', 'kwargs', 'num_users']: 1375 assert s in stringed 1376 1377 def test_custom_proxy_type(self): 1378 class TensorPair: 1379 def __init__(self, left, right): 1380 self.left, self.right = left, right 1381 1382 def add(self, other): 1383 l = self.left + other.left 1384 r = self.right + other.right 1385 return TensorPair(l, r) 1386 1387 def mul(self, other): 1388 l = self.left * other.left 1389 r = self.right * other.right 1390 return TensorPair(l, r) 1391 1392 def use_tensor_pair(x : TensorPair, y : TensorPair): 1393 s = x.add(y) 1394 return s.mul(x) 1395 1396 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1397 y = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1398 1399 ref_out = use_tensor_pair(x, y) 1400 1401 traced = symbolic_trace(use_tensor_pair) 1402 1403 traced_out = traced(x, y) 1404 self.assertEqual(traced_out.left, ref_out.left) 1405 self.assertEqual(traced_out.right, ref_out.right) 1406 1407 def test_custom_proxy_type_literal(self): 1408 class TensorPair(metaclass=torch.fx.ProxyableClassMeta): 1409 def __init__(self, left, right): 1410 self.left, self.right = left, right 1411 1412 def add(self, other): 1413 l = self.left + other.left 1414 r = self.right + other.right 1415 return TensorPair(l, r) 1416 1417 def mul(self, other): 1418 l = self.left * other.left 1419 r = self.right * other.right 1420 return TensorPair(l, r) 1421 1422 def use_tensor_pair_literal(x : TensorPair): 1423 s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) 1424 return s.mul(x) 1425 1426 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1427 1428 ref_out = use_tensor_pair_literal(x) 1429 1430 traced = symbolic_trace(use_tensor_pair_literal) 1431 1432 traced_out = traced(x) 1433 self.assertEqual(traced_out.left, ref_out.left) 1434 self.assertEqual(traced_out.right, ref_out.right) 1435 1436 def test_custom_proxy_dynamic_value(self): 1437 class TensorPair(metaclass=torch.fx.ProxyableClassMeta): 1438 def __init__(self, left, right): 1439 self.left, self.right = left, right 1440 1441 def add(self, other): 1442 l = self.left + other.left 1443 r = self.right + other.right 1444 return TensorPair(l, r) 1445 1446 def mul(self, other): 1447 l = self.left * other.left 1448 r = self.right * other.right 1449 return TensorPair(l, r) 1450 1451 def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): 1452 s = x.add(TensorPair(y, y)) 1453 return s.mul(x) 1454 1455 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1456 y = torch.randn(5, 3) 1457 ref_out = use_tensor_pair_ctor(x, y) 1458 1459 traced = symbolic_trace(use_tensor_pair_ctor) 1460 1461 traced_out = traced(x, y) 1462 self.assertEqual(traced_out.left, ref_out.left) 1463 self.assertEqual(traced_out.right, ref_out.right) 1464 1465 def test_custom_proxy_input_dependent_control_flow(self): 1466 class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta): 1467 def __init__(self, inp): 1468 if inp.sum() == 0: 1469 self.is_zero = True 1470 self.tensor = torch.tensor([]) 1471 else: 1472 self.is_zero = False 1473 self.tensor = inp 1474 1475 def add(self, other): 1476 if self.is_zero: 1477 return ZeroTensor(other.tensor) 1478 elif other.is_zero: 1479 return self 1480 1481 def use_zero_tensor(x : torch.Tensor, y : torch.Tensor): 1482 return ZeroTensor(x + y) 1483 1484 x, y = torch.randn(5, 3), torch.randn(5, 3) 1485 1486 ref_out = use_zero_tensor(x, y) 1487 1488 traced = symbolic_trace(use_zero_tensor) 1489 1490 traced_out = traced(x, y) 1491 1492 self.assertEqual(traced_out.is_zero, ref_out.is_zero) 1493 self.assertEqual(traced_out.tensor, ref_out.tensor) 1494 1495 def test_graph_fns(self): 1496 g = Graph() 1497 a = g.placeholder('a') 1498 b = g.call_module('linear', (a,)) 1499 c = g.get_attr('bias') 1500 d = g.call_method('add', (b, c)) 1501 e = g.call_function(torch.sin, (d,)) 1502 g.output(e) 1503 mod = torch.nn.Module() 1504 mod.linear = torch.nn.Linear(3, 4) 1505 mod.bias = torch.rand(4) 1506 gm = GraphModule(mod, g) 1507 gm.graph.lint() 1508 input = torch.rand(3) 1509 r = gm(input) 1510 ref = torch.sin(mod.linear(input) + mod.bias) 1511 self.assertEqual(r, ref) 1512 1513 def test_remove_uses(self): 1514 g : torch.fx.Graph = Graph() 1515 x : torch.fx.Node = g.placeholder('x') 1516 relu : torch.fx.Node = g.call_function(torch.relu, (x,)) 1517 neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) 1518 g.output(neg) 1519 1520 neg.replace_all_uses_with(relu) 1521 g.erase_node(neg) 1522 1523 self.assertTrue(neg not in relu.users) 1524 1525 def test_remove_uses_with_custom_filter(self): 1526 g : torch.fx.Graph = Graph() 1527 x : torch.fx.Node = g.placeholder('x') 1528 relu : torch.fx.Node = g.call_function(torch.relu, (x,)) 1529 neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) 1530 g.output(neg) 1531 1532 neg.replace_all_uses_with(relu, lambda x: x != neg) 1533 1534 self.assertTrue(neg in relu.users) 1535 1536 def test_nonetype_annotation(self): 1537 eb = torch.nn.EmbeddingBag(3, 4) 1538 symbolic_trace(eb) 1539 1540 def test_pickle_nonetype_annotation(self): 1541 eb = torch.nn.EmbeddingBag(10, 3, mode='sum') 1542 traced = symbolic_trace(eb) 1543 pickled = pickle.dumps(traced) 1544 loaded = pickle.loads(pickled) 1545 loaded.graph.lint() 1546 input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 1547 offsets = torch.LongTensor([0, 4]) 1548 self.assertEqual(loaded(input, offsets), traced(input, offsets)) 1549 1550 def test_return_tuple(self): 1551 class M(torch.nn.Module): 1552 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 1553 return (x, x + x) 1554 1555 original = M() 1556 traced = symbolic_trace(original) 1557 self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) 1558 1559 def test_construct_root_dict(self): 1560 graph : torch.fx.Graph = torch.fx.Graph() 1561 a : torch.fx.Node = graph.create_node('placeholder', 'x') 1562 b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) 1563 c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') 1564 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 1565 graph.output(d) 1566 1567 linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) 1568 add_param : torch.Tensor = torch.rand(3, 4) 1569 gm : torch.fx.GraphModule = torch.fx.GraphModule( 1570 {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) 1571 gm.graph.lint() 1572 1573 assert 'self.foo.bar.baz' in gm.code 1574 1575 x : torch.Tensor = torch.rand(3, 3) 1576 out : torch.Tensor = gm(x) 1577 ref_out : torch.Tensor = linear_mod(x) + add_param 1578 self.assertEqual(out, ref_out) 1579 1580 def test_symbolic_trace_assert(self): 1581 1582 class AssertsTensorShape(torch.nn.Module): 1583 def forward(self, x): 1584 torch._assert(x.shape[1] > 4, "assert_foobar") 1585 return x 1586 1587 m = AssertsTensorShape() 1588 # verify traceability 1589 traced = symbolic_trace(m) 1590 # verify assertion on traced model works correctly at runtime 1591 traced(torch.rand(4, 5)) 1592 with self.assertRaisesRegex(AssertionError, "assert_foobar"): 1593 traced(torch.rand(4, 3)) 1594 # verify the symbolically traced module is scriptable 1595 ms = torch.jit.script(m) 1596 with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): 1597 ms(torch.rand(4, 3)) 1598 1599 def test_fx_create_arg(self): 1600 class CustomArgObject: 1601 def __init__(self, x, y): 1602 self.x = x 1603 self.y = y 1604 1605 def __fx_create_arg__(self, tracer: torch.fx.Tracer): 1606 return tracer.create_node( 1607 "call_function", 1608 CustomArgObject, 1609 args=( 1610 tracer.create_arg(self.x), 1611 tracer.create_arg(self.y), 1612 ), 1613 kwargs={}, 1614 ) 1615 1616 class HasCustomArgObjectWhenLeaf(torch.nn.Module): 1617 def forward(self, o: CustomArgObject): 1618 # Not normally traceable; good reason to make 1619 # this module a leaf. 1620 for x in o.x: 1621 o.y += x 1622 return o.y 1623 1624 class Root(torch.nn.Module): 1625 def __init__(self) -> None: 1626 super().__init__() 1627 self.inner = HasCustomArgObjectWhenLeaf() 1628 1629 def forward(self, x, y): 1630 o = CustomArgObject(x, y) 1631 return self.inner(o) 1632 1633 class CreateArgTracer(torch.fx.Tracer): 1634 def is_leaf_module(self, m, module_qualified_name): 1635 return type(m) is HasCustomArgObjectWhenLeaf 1636 1637 m = Root() 1638 graph = CreateArgTracer().trace(m) 1639 gm = torch.fx.GraphModule(m, graph) 1640 assert "CustomArgObject(" in gm.code 1641 1642 def test_trace_fn_constant(self): 1643 some_constant = torch.rand(3, 4) 1644 1645 def add_const(x): 1646 return some_constant + x 1647 1648 traced = symbolic_trace(add_const) 1649 1650 input = torch.rand(3, 4) 1651 self.assertEqual(traced(input), add_const(input)) 1652 1653 def test_copy_no_remap(self): 1654 traced = symbolic_trace(SimpleTest()) 1655 g = traced.graph 1656 copied = torch.fx.Graph() 1657 for node in g.nodes: 1658 copied.node_copy(node) 1659 with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): 1660 copied.lint() 1661 1662 def test_wrong_topo(self): 1663 graph : torch.fx.Graph = torch.fx.Graph() 1664 a : torch.fx.Node = graph.create_node('placeholder', 'x') 1665 b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) 1666 c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') 1667 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 1668 graph.output(d) 1669 nodes = list(graph.nodes) 1670 nodes[3].append(nodes[2]) 1671 with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): 1672 graph.lint() 1673 1674 def test_wrong_target_type(self): 1675 graph : torch.fx.Graph = torch.fx.Graph() 1676 with self.assertRaises(ValueError): 1677 n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo', 1678 args=(), kwargs={}) 1679 1680 def test_example_shape_prop(self): 1681 class TestCase(torch.nn.Module): 1682 def __init__(self) -> None: 1683 super().__init__() 1684 self.attr = torch.randn(3, 4) 1685 self.submod = torch.nn.Linear(4, 4) 1686 1687 def forward(self, x): 1688 return torch.neg(self.submod(x.relu() + self.attr)) 1689 tc = TestCase() 1690 tc_traced = symbolic_trace(tc) 1691 ref_out = tc_traced(torch.rand(3, 4)) 1692 shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) 1693 1694 # Make sure we're testing all opcodes 1695 opcodes = set() 1696 output_shape : Optional[torch.Shape] = None 1697 output_stride : Optional[Tuple[int]] = None 1698 for node in tc_traced.graph.nodes: 1699 opcodes.add(node.op) 1700 if node.op == 'output': 1701 output_shape = node.args[0].meta['tensor_meta'].shape 1702 output_stride = node.args[0].meta['tensor_meta'].stride 1703 self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method', 1704 'call_module', 'output'}) 1705 1706 # Test shape propagation and make sure results match actual 1707 self.assertEqual(output_shape, ref_out.shape) 1708 self.assertEqual(output_stride, ref_out.stride()) 1709 1710 def test_shape_prop_layout(self): 1711 class ConvTest(torch.nn.Module): 1712 def __init__(self) -> None: 1713 super().__init__() 1714 self.conv_mod = torch.nn.Conv2d(5, 5, 3) 1715 1716 def forward(self, x): 1717 return self.conv_mod(x) 1718 1719 # contiguous layout 1720 test_mod = ConvTest() 1721 traced = symbolic_trace(test_mod) 1722 x = torch.randn(5, 5, 224, 224) 1723 shape_prop.ShapeProp(traced).propagate(x) 1724 1725 assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format 1726 for node in traced.graph.nodes) 1727 1728 x_channels_last = x.contiguous(memory_format=torch.channels_last) 1729 traced.to(memory_format=torch.channels_last) 1730 shape_prop.ShapeProp(traced).propagate(x_channels_last) 1731 for node in traced.graph.nodes: 1732 # NB: the implementation of conv may not preserve the memory format, 1733 # unfortunately. The best we can do is just check that the placeholder 1734 # node is channels-last 1735 if node.op in {'placeholder'}: 1736 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last) 1737 1738 def test_shape_prop_aggregate(self): 1739 class ReturnTwo(torch.nn.Module): 1740 def forward(self, x): 1741 return (3, torch.sum(x)) 1742 1743 class UnderTest(torch.nn.Module): 1744 def __init__(self) -> None: 1745 super().__init__() 1746 self.rt = ReturnTwo() 1747 1748 def forward(self, x): 1749 return self.rt(x) 1750 1751 ut = UnderTest() 1752 1753 class RTTracer(torch.fx.Tracer): 1754 def is_leaf_module(self, m, module_qualified_name): 1755 return type(m) is ReturnTwo 1756 1757 graph = RTTracer().trace(ut) 1758 mod = torch.fx.GraphModule(ut, graph) 1759 1760 shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) 1761 1762 for node in mod.graph.nodes: 1763 if node.op == 'call_module': 1764 assert 'tensor_meta' in node.meta 1765 tensor_meta = node.meta['tensor_meta'] 1766 assert tensor_meta[0] == 3 1767 assert tensor_meta[1].shape == torch.Size([]) 1768 1769 def test_shape_prop_layout_3d(self): 1770 class ConvTest3d(torch.nn.Module): 1771 def __init__(self) -> None: 1772 super().__init__() 1773 self.conv_mod = torch.nn.Conv3d(5, 5, 3) 1774 1775 def forward(self, x): 1776 return self.conv_mod(x) 1777 1778 test_mod_3d = ConvTest3d() 1779 traced_3d = symbolic_trace(test_mod_3d) 1780 x_3d = torch.randn(5, 5, 224, 224, 15) 1781 shape_prop.ShapeProp(traced_3d).propagate(x_3d) 1782 assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format 1783 for node in traced_3d.graph.nodes) 1784 1785 x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) 1786 traced_3d.to(memory_format=torch.channels_last_3d) 1787 shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d) 1788 for node in traced_3d.graph.nodes: 1789 # NB: the implementation of conv may not preserve the memory format, 1790 # unfortunately. The best we can do is just check that the placeholder 1791 # node is channels-last 1792 if node.op in {'placeholder'}: 1793 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d) 1794 1795 def test_nn_module_stack(self): 1796 class SubModule(torch.nn.Module): 1797 def __init__(self) -> None: 1798 super().__init__() 1799 self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False) 1800 1801 def forward(self, x): 1802 return self.conv_mod(x) 1803 1804 class MyModule(torch.nn.Module): 1805 def __init__(self) -> None: 1806 super().__init__() 1807 self.sub_mod = SubModule() 1808 1809 def forward(self, x): 1810 return self.sub_mod(x) 1811 1812 m = MyModule() 1813 gm = torch.fx.symbolic_trace(m) 1814 1815 mod_stack = {} 1816 expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))), 1817 ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))] 1818 for node in gm.graph.nodes: 1819 mod_stack = node.meta.get('nn_module_stack', {}) 1820 if mod_stack: 1821 break 1822 stack_list = list(mod_stack.items()) 1823 self.assertEqual(stack_list, expected_stack) 1824 1825 def test_transformer_preserves_nn_module_stack_for_get_attr(self): 1826 class M(torch.nn.Module): 1827 def __init__(self) -> None: 1828 super().__init__() 1829 self.weight = torch.nn.Parameter(torch.ones(1, 1)) 1830 1831 def forward(self, x): 1832 return self.weight + x 1833 1834 tracer = torch.fx.Tracer() 1835 graph = tracer.trace(M()) 1836 gm = GraphModule(tracer.root, graph) 1837 for node in gm.graph.nodes: 1838 if node.op == 'get_attr': 1839 node.meta["nn_module_stack"] = "self" 1840 node.meta["stack_trace"] = "stack_trace" 1841 node.meta["source_fn_stack"] = "source_fn_stack" 1842 new_gm = Transformer(gm).transform() 1843 for node in new_gm.graph.nodes: 1844 if node.op == 'get_attr': 1845 self.assertEqual(node.meta["nn_module_stack"], "self") 1846 self.assertEqual(node.meta["stack_trace"], "stack_trace") 1847 self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack") 1848 1849 def test_interpreter(self): 1850 class MyModule(torch.nn.Module): 1851 def __init__(self) -> None: 1852 super().__init__() 1853 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1854 self.linear = torch.nn.Linear(4, 5) 1855 1856 def forward(self, x): 1857 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1858 1859 m = MyModule() 1860 gm = torch.fx.symbolic_trace(m) 1861 1862 interpreter = Interpreter(gm) 1863 input = torch.randn(3, 4) 1864 self.assertEqual(interpreter.run(input), gm(input)) 1865 self.assertEqual(interpreter.run(input), m(input)) 1866 1867 def test_interpreter_other_graph(self): 1868 class MyModule(torch.nn.Module): 1869 def __init__(self) -> None: 1870 super().__init__() 1871 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1872 self.linear = torch.nn.Linear(4, 5) 1873 1874 def forward(self, x): 1875 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1876 1877 m = MyModule() 1878 gm = torch.fx.symbolic_trace(m) 1879 1880 interpreter = Interpreter(gm, graph=gm.graph) 1881 input = torch.randn(3, 4) 1882 self.assertEqual(interpreter.run(input), gm(input)) 1883 self.assertEqual(interpreter.run(input), m(input)) 1884 1885 def test_interpreter_run_node_override(self): 1886 class MyModule(torch.nn.Module): 1887 def __init__(self) -> None: 1888 super().__init__() 1889 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1890 self.linear = torch.nn.Linear(4, 5) 1891 1892 def forward(self, x): 1893 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1894 1895 m = MyModule() 1896 gm = torch.fx.symbolic_trace(m) 1897 1898 class RunNodeInterpreter(Interpreter): 1899 def __init__(self, module): 1900 super().__init__(module) 1901 1902 def run_node(self, n : Node) -> Any: 1903 result = super().run_node(n) 1904 n.cached_value = result 1905 return result 1906 1907 input = torch.randn(3, 4) 1908 RunNodeInterpreter(gm).run(input) 1909 for node in gm.graph.nodes: 1910 assert hasattr(node, 'cached_value') 1911 1912 def test_interpreter_onthefly_swap(self): 1913 1914 def fn(x): 1915 return torch.sigmoid(x).neg() 1916 1917 gm = torch.fx.symbolic_trace(fn) 1918 1919 class NegSigmSwapInterpreter(Interpreter): 1920 def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 1921 if target == torch.sigmoid: 1922 return torch.neg(*args, **kwargs) 1923 return super().call_function(n) # noqa: F821 1924 1925 def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 1926 if target == 'neg': 1927 call_self, *args_tail = args 1928 return call_self.sigmoid(*args_tail, **kwargs) 1929 return super().call_method(n) # noqa: F821 1930 1931 input = torch.randn(3, 4) 1932 result = NegSigmSwapInterpreter(gm).run(input) 1933 self.assertEqual(result, torch.neg(input).sigmoid()) 1934 1935 def test_interpreter_partial_eval(self): 1936 class MyModule(torch.nn.Module): 1937 def __init__(self) -> None: 1938 super().__init__() 1939 self.param = torch.nn.Parameter(torch.rand(3, 4)) 1940 self.linear = torch.nn.Linear(4, 5) 1941 1942 def forward(self, x): 1943 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1944 1945 gm = torch.fx.symbolic_trace(MyModule()) 1946 interp = Interpreter(gm) 1947 env = {} 1948 for node in gm.graph.nodes: 1949 if node.op == 'call_module' and node.target == 'linear': 1950 env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 1951 break 1952 assert len(env) == 1 1953 x = torch.randn(3, 4) 1954 result = interp.run(x, initial_env=env) 1955 self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) 1956 1957 def test_interpreter_star_args(self): 1958 def with_star_args(x, *args): 1959 return x + args[0] 1960 1961 gm = torch.fx.symbolic_trace(with_star_args) 1962 interp = Interpreter(gm) 1963 result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) 1964 self.assertEqual(result, torch.ones(3, 4) * 2.0) 1965 1966 @skipIfNoTorchVision 1967 def test_interpreter_noop_resnet18(self): 1968 rn18 = torchvision_models.resnet18() 1969 transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform() 1970 inp = torch.randn(5, 3, 224, 224) 1971 self.assertEqual(transformed(inp), rn18(inp)) 1972 1973 @skipIfNoTorchVision 1974 def test_interpreter_gc_values(self): 1975 rn18 = torchvision_models.resnet18() 1976 interp = Interpreter(symbolic_trace(rn18)) 1977 inp = torch.rand(5, 3, 224, 224) 1978 out = interp.run(inp) 1979 env_key_names = {n.name for n in interp.env.keys()} 1980 self.assertEqual(env_key_names, {'output'}) 1981 1982 def test_interpreter_default_args(self): 1983 class Model(torch.nn.Module): 1984 def forward(self, x, y=3.14159): 1985 return x + y 1986 1987 model = Model() 1988 gm = torch.fx.symbolic_trace(model) 1989 1990 interp = Interpreter(gm) 1991 x = torch.randn(5, 3) 1992 out = interp.run(x) 1993 torch.testing.assert_close(out, x + 3.14159) 1994 1995 def test_interpreter_not_enough_args(self): 1996 class Model(torch.nn.Module): 1997 def forward(self, x, y): 1998 return x + y 1999 2000 model = Model() 2001 gm = torch.fx.symbolic_trace(model) 2002 2003 interp = Interpreter(gm) 2004 x = torch.randn(5, 3) 2005 with self.assertRaisesRegex(RuntimeError, 2006 'Expected positional argument for parameter y, but one was not passed in'): 2007 out = interp.run(x) 2008 2009 def test_transformer_noop(self): 2010 class MyModule(torch.nn.Module): 2011 def __init__(self) -> None: 2012 super().__init__() 2013 self.param = torch.nn.Parameter(torch.rand(3, 4)) 2014 self.linear = torch.nn.Linear(4, 5) 2015 2016 def forward(self, x): 2017 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 2018 2019 m = MyModule() 2020 gm = torch.fx.symbolic_trace(m) 2021 2022 new_gm = Transformer(gm).transform() 2023 2024 input = torch.randn(3, 4) 2025 self.assertEqual(new_gm(input), gm(input)) 2026 2027 def test_transformer_op_swap(self): 2028 2029 def fn(x): 2030 return torch.sigmoid(x).neg() 2031 2032 gm = torch.fx.symbolic_trace(fn) 2033 2034 class NegSigmSwapXformer(Transformer): 2035 def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 2036 if target == torch.sigmoid: 2037 return torch.neg(*args, **kwargs) 2038 return super().call_function(n) # noqa: F821 2039 2040 def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 2041 if target == 'neg': 2042 call_self, *args_tail = args 2043 return call_self.sigmoid(*args_tail, **kwargs) 2044 return super().call_method(n) # noqa: F821 2045 2046 transformed = NegSigmSwapXformer(gm).transform() 2047 input = torch.randn(3, 4) 2048 self.assertEqual(transformed(input), torch.neg(input).sigmoid()) 2049 2050 def test_transformer_multi_outputs(self): 2051 class MyModule(torch.nn.Module): 2052 def __init__(self) -> None: 2053 super().__init__() 2054 self.param = torch.nn.Parameter(torch.rand(3, 4)) 2055 self.linear = torch.nn.Linear(4, 5) 2056 2057 def forward(self, x): 2058 x = x + self.param 2059 out = self.linear(x) 2060 return x, out 2061 2062 m = MyModule() 2063 gm = torch.fx.symbolic_trace(m) 2064 2065 new_gm = Transformer(gm).transform() 2066 2067 input = torch.randn(3, 4) 2068 self.assertEqual(new_gm(input), gm(input)) 2069 2070 def test_fn_type_annotations(self): 2071 class Foo(torch.nn.Module): 2072 def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: 2073 return {'a': p.x + p.y + z + i} 2074 2075 foo_scripted = torch.jit.script(Foo()) 2076 foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) 2077 2078 fxed = symbolic_trace(Foo()) 2079 fxed_scripted = torch.jit.script(fxed) 2080 fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) 2081 2082 def test_fn_type_annotation_empty(self): 2083 def forward(a : List[torch.Tensor]): 2084 return a[0] 2085 torch.jit.script(symbolic_trace(forward)) 2086 2087 def test_wrapped_method(self): 2088 def wrap_with_relu(fn): 2089 @functools.wraps(fn) 2090 def wrapper(*args, **kwargs): 2091 return torch.relu(fn(*args, **kwargs)) 2092 return wrapper 2093 2094 class Foo(torch.nn.Module): 2095 @wrap_with_relu 2096 def forward(self, x, w): 2097 return torch.matmul(x, w) 2098 2099 f = Foo() 2100 traced = symbolic_trace(f) 2101 x, w = torch.rand(3, 4), torch.rand(4, 4) 2102 self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) 2103 2104 def test_empty_graph_codegen(self): 2105 graph = torch.fx.Graph() 2106 gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2107 self.assertEqual(gm(), None) 2108 2109 def test_sequential(self): 2110 m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) 2111 gm = torch.fx.symbolic_trace(m) 2112 gm_copy = copy.deepcopy(gm) 2113 2114 def test_ctx_mgr(self): 2115 @contextlib.contextmanager 2116 def do_nothing(): 2117 yield 2118 2119 class M(torch.nn.Module): 2120 @do_nothing() 2121 def forward(self, x): 2122 return torch.relu(x) 2123 2124 m = M() 2125 self.checkGraphModule(m, (torch.rand(3, 4),)) 2126 2127 def test_typename_print(self): 2128 graph : torch.fx.Graph = torch.fx.Graph() 2129 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2130 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), 2131 type_expr=List[float]) 2132 output : torch.fx.Node = graph.output(b) 2133 2134 self.assertTrue('typing.List[float]' in str(graph)) 2135 2136 def test_layout(self): 2137 class M(torch.nn.Module): 2138 def forward(self, x): 2139 return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0) 2140 2141 traced = symbolic_trace(M()) 2142 x = torch.rand(5, 9, 3, 4) 2143 self.assertEqual(traced(x), torch.zeros_like(x)) 2144 2145 def test_ellipsis(self): 2146 class M(torch.nn.Module): 2147 def forward(self, x, y): 2148 return x + y[:, 1:10, ...] 2149 2150 traced = symbolic_trace(M()) 2151 x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4) 2152 self.assertEqual(traced(x, y), x + y[:, 1:10, ...]) 2153 2154 def test_inf_nan(self): 2155 class FooMod(torch.nn.Module): 2156 def forward(self, x): 2157 return x + float('inf'), x + float('-inf'), x + float('nan') 2158 2159 fm = FooMod() 2160 self.checkGraphModule(fm, (torch.rand(3, 4),)) 2161 2162 def test_inf_nan_kwds(self): 2163 graph : torch.fx.Graph = torch.fx.Graph() 2164 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2165 b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') 2166 c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') 2167 graph.output((b, c)) 2168 2169 gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2170 x = torch.rand(3, 4) 2171 self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) 2172 2173 def test_deepcopy_recursion_depth(self): 2174 depth = sys.getrecursionlimit() + 20 2175 2176 g = torch.fx.Graph() 2177 x = g.placeholder('x') 2178 for i in range(depth): 2179 x = g.call_function(torch.relu, (x,)) 2180 g.output(x) 2181 2182 copied_graph = copy.deepcopy(g) 2183 2184 val_map = {} 2185 for orig_node, new_node in zip(g.nodes, copied_graph.nodes): 2186 val_map[orig_node] = new_node 2187 2188 for orig_node, new_node in zip(g.nodes, copied_graph.nodes): 2189 orig_users = set(orig_node.users.keys()) 2190 orig_users_equiv = {val_map[u] for u in orig_users} 2191 new_users = set(new_node.users.keys()) 2192 self.assertEqual(orig_users_equiv, new_users) 2193 2194 @skipIfNoTorchVision 2195 def test_replace_uses(self): 2196 rn18 = torchvision_models.resnet18() 2197 2198 class LowerReluTracer(torch.fx.Tracer): 2199 def is_leaf_module(self, m : torch.nn.Module, qualname : str): 2200 if isinstance(m, torch.nn.ReLU): 2201 return False 2202 return super().is_leaf_module(m, qualname) 2203 2204 rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) 2205 2206 to_erase = [] 2207 for node in rn18_traced.graph.nodes: 2208 if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: 2209 kwargs = node.kwargs.copy() 2210 # Neg doesn't have in-place 2211 kwargs.pop('inplace') 2212 with rn18_traced.graph.inserting_before(node): 2213 new_node = rn18_traced.graph.call_function( 2214 the_function=torch.neg, args=node.args, kwargs=node.kwargs) 2215 node.replace_all_uses_with(replace_with=new_node) 2216 to_erase.append(node) 2217 2218 for node in to_erase: 2219 rn18_traced.graph.erase_node(node) 2220 2221 def test_replace_input(self): 2222 graph : torch.fx.Graph = torch.fx.Graph() 2223 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2224 y : torch.fx.Node = graph.create_node('placeholder', 'y') 2225 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2226 output : torch.fx.Node = graph.output(b) 2227 2228 b.replace_input_with(x, y) 2229 2230 gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2231 2232 input_x = torch.randn(33, 44) 2233 input_y = torch.randn(11, 22) 2234 self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) 2235 2236 def test_insertion_point(self): 2237 graph : torch.fx.Graph = torch.fx.Graph() 2238 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2239 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2240 output : torch.fx.Node = graph.output(b) 2241 2242 with graph.inserting_before(b): 2243 neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) 2244 _, *relu_args = b.args 2245 b.args = (neg, *relu_args) 2246 2247 gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2248 2249 input = torch.randn(33, 44) 2250 self.assertEqual(gm(input), torch.relu(torch.neg(input))) 2251 2252 def test_update_args_api(self): 2253 graph : torch.fx.Graph = torch.fx.Graph() 2254 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2255 y : torch.fx.Node = graph.create_node('placeholder', 'y') 2256 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2257 output : torch.fx.Node = graph.output(b) 2258 2259 orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2260 inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) 2261 self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) 2262 2263 b.update_arg(0, y) 2264 new_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2265 self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) 2266 2267 def test_update_kwargs_api(self): 2268 graph : torch.fx.Graph = torch.fx.Graph() 2269 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2270 y : torch.fx.Node = graph.create_node('placeholder', 'y') 2271 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x}) 2272 output : torch.fx.Node = graph.output(b) 2273 2274 orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2275 inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) 2276 self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) 2277 2278 b.update_kwarg('input', y) 2279 new_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2280 self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) 2281 2282 def test_immutable_list_pytree_ops(self): 2283 rand_tensor = torch.randn(5, 3) 2284 l = immutable_list([3, [rand_tensor, 42]]) 2285 2286 flattened, spec = pytree.tree_flatten(l) 2287 assert flattened == [3, rand_tensor, 42] 2288 2289 unflattened = pytree.tree_unflatten(flattened, spec) 2290 assert unflattened == l 2291 assert isinstance(unflattened, immutable_list) 2292 2293 def test_immutable_dict_pytree_ops(self): 2294 rand_tensor = torch.randn(5, 3) 2295 d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]}) 2296 2297 flattened, spec = pytree.tree_flatten(d) 2298 assert flattened == [3, rand_tensor, 42] 2299 2300 unflattened = pytree.tree_unflatten(flattened, spec) 2301 assert unflattened == d 2302 assert isinstance(unflattened, immutable_dict) 2303 2304 def test_move_before(self): 2305 graph : torch.fx.Graph = torch.fx.Graph() 2306 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2307 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2308 output : torch.fx.Node = graph.output(b) 2309 2310 neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) 2311 _, *relu_args = b.args 2312 b.args = (neg, *relu_args) 2313 b.prepend(neg) 2314 2315 gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2316 2317 input = torch.randn(33, 44) 2318 self.assertEqual(gm(input), torch.relu(torch.neg(input))) 2319 2320 def test_prepend_self(self): 2321 graph : torch.fx.Graph = torch.fx.Graph() 2322 x : torch.fx.Node = graph.create_node('placeholder', 'x') 2323 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2324 output : torch.fx.Node = graph.output(b) 2325 2326 b.prepend(b) 2327 x.append(b) 2328 self.assertEqual(len(graph.nodes), 3) 2329 2330 def test_erase_node_error(self): 2331 st = SimpleTest() 2332 traced = symbolic_trace(st) 2333 2334 for node in traced.graph.nodes: 2335 # Test deleting with uses both in another Node and at the output 2336 if node.target in [operator.add, torch.relu]: 2337 with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): 2338 traced.graph.erase_node(node) 2339 2340 def test_copy_it(self): 2341 d = immutable_dict([(3, 4), (5, 6)]) 2342 l = immutable_list([(3, 4), (5, 6)]) 2343 2344 self.assertEqual(d, deepcopy(d)) 2345 self.assertEqual(l, deepcopy(l)) 2346 2347 def test_get_torch_func_signature(self): 2348 for key in dir(torch): 2349 obj = getattr(torch, key) 2350 if callable(obj): 2351 schemas = get_signature_for_torch_op(obj) 2352 2353 def test_find_uses(self): 2354 graph = torch.fx.Graph() 2355 x = torch.fx.Proxy(graph.placeholder('x')) 2356 2357 y = torch.relu(x) 2358 z = x + x 2359 u = torch.neg(x) 2360 graph.output((y + z + u).node) 2361 graph.lint() 2362 2363 users_of_x = x.node.users 2364 self.assertEqual(len(users_of_x), 3) 2365 expected_ops = {'relu', 'add', 'neg'} 2366 for use in users_of_x: 2367 assert any(use.name.startswith(prefix) for prefix in expected_ops) 2368 2369 def test_inline_graph(self): 2370 class InlineInto(torch.nn.Module): 2371 def forward(self, x): 2372 return torch.relu(x) 2373 2374 class ToInline(torch.nn.Module): 2375 def forward(self, x): 2376 return torch.neg(x) 2377 2378 inline_into = symbolic_trace(InlineInto()) 2379 to_inline = symbolic_trace(ToInline()) 2380 2381 combined_graph = torch.fx.Graph() 2382 output_node = combined_graph.graph_copy(inline_into.graph, {}) 2383 2384 input_node = next(iter(to_inline.graph.nodes)) 2385 assert input_node and input_node.op == 'placeholder' 2386 2387 val_map = {input_node : output_node} 2388 output = combined_graph.graph_copy(to_inline.graph, val_map) 2389 combined_graph.output(output) 2390 2391 combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) 2392 2393 input = torch.rand(3, 4) 2394 self.assertEqual(combined_module(input), input.relu().neg()) 2395 2396 def test_multi_insert_point(self): 2397 graph = torch.fx.Graph() 2398 x = torch.fx.Proxy(graph.placeholder('x')) 2399 relu = torch.relu(x) 2400 2401 with graph.inserting_before(relu.node): 2402 y = torch.neg(x) 2403 z = torch.tanh(y) 2404 2405 graph.output((relu.node, z.node)) 2406 graph.lint() 2407 2408 expected_ops = ['x', 'neg', 'tanh', 'relu'] 2409 for node, expected in zip(graph.nodes, expected_ops): 2410 assert expected in node.name 2411 2412 def test_reassign_args_kwargs_uses(self): 2413 graph = torch.fx.Graph() 2414 x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) 2415 z = x + y 2416 zed = z + z + z 2417 graph.output(zed.node) 2418 graph.lint() 2419 2420 # zed = z + z + z -> zed = z + z + x 2421 zed.node.args = (zed.node.args[0], x.node) 2422 self.assertEqual(list(x.node.users.keys()), [z.node, zed.node]) 2423 2424 # z = x + y -> z = y + y 2425 z.node.args = (y.node, y.node) 2426 self.assertEqual(list(x.node.users.keys()), [zed.node]) 2427 2428 def test_trace_function(self): 2429 def foo(x, y): 2430 return torch.relu(x) + y 2431 2432 x, y = torch.randn(3, 4), torch.randn(3, 4) 2433 self.checkGraphModule(foo, (x, y)) 2434 2435 def test_trace_return_dataclass(self): 2436 """ 2437 Test case for Module that return dataclass 2438 """ 2439 from dataclasses import dataclass 2440 2441 @dataclass 2442 class MyOutput: 2443 foo: torch.Tensor 2444 bar: torch.Tensor 2445 2446 class ModuleReturnDataclass(torch.nn.Module): 2447 def forward(self, d : torch.Tensor): 2448 return MyOutput(foo=d + d, bar=d * 3) 2449 2450 module = ModuleReturnDataclass() 2451 traced_graph = symbolic_trace(module).graph 2452 print(traced_graph) 2453 2454 gm = GraphModule(module, traced_graph) 2455 x = torch.rand(1) 2456 2457 self.assertEqual(module(x), gm(x)) 2458 2459 def test_trace_return_dataclass_nested(self): 2460 """ 2461 Test case for Module that return dataclass 2462 """ 2463 from dataclasses import dataclass 2464 2465 @dataclass 2466 class MyOutput: 2467 foo: torch.Tensor 2468 bar: torch.Tensor 2469 2470 class ModuleReturnDataclass(torch.nn.Module): 2471 def forward(self, d : torch.Tensor): 2472 return MyOutput(foo=d + d, bar=d * 3) 2473 2474 class CallsModule(torch.nn.Module): 2475 def __init__(self) -> None: 2476 super().__init__() 2477 self.m = ModuleReturnDataclass() 2478 2479 def forward(self, x): 2480 tmp = self.m(x) 2481 return MyOutput(foo=tmp.foo, bar=tmp.bar) 2482 2483 module = CallsModule() 2484 traced_graph = symbolic_trace(module).graph 2485 print(traced_graph) 2486 2487 gm = GraphModule(module, traced_graph) 2488 x = torch.rand(1) 2489 2490 self.assertEqual(module(x), gm(x)) 2491 2492 def test_trace_return_namedtuple(self): 2493 """ 2494 Test case for Module that return namedtuple 2495 """ 2496 class MyOutput(NamedTuple): 2497 foo: torch.Tensor 2498 bar: torch.Tensor 2499 2500 class ModuleReturnNamedTuple(torch.nn.Module): 2501 def forward(self, d : torch.Tensor): 2502 return MyOutput(foo=d, bar=d) 2503 2504 module = ModuleReturnNamedTuple() 2505 2506 traced_graph = symbolic_trace(module).graph 2507 print(traced_graph) 2508 2509 gm = GraphModule(module, traced_graph) 2510 x = torch.rand(1) 2511 2512 self.assertEqual(module(x), gm(x)) 2513 2514 def test_trace_dict_int_keys(self): 2515 class ModWithDictArg(torch.nn.Module): 2516 def forward(self, d : Dict[int, torch.Tensor]): 2517 return d[42] 2518 2519 class CallsModWithDict(torch.nn.Module): 2520 def __init__(self) -> None: 2521 super().__init__() 2522 self.m = ModWithDictArg() 2523 2524 def forward(self, x): 2525 return self.m({42: x}) 2526 2527 class MyTracer(torch.fx.Tracer): 2528 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 2529 return isinstance(m, ModWithDictArg) 2530 2531 traced_graph = MyTracer().trace(CallsModWithDict()) 2532 2533 def test_trace_dict_proxy_keys(self): 2534 class ModWithDictArg(torch.nn.Module): 2535 def forward(self, d : Dict[torch.Tensor, torch.Tensor]): 2536 return d[42] 2537 2538 class CallsModWithDict(torch.nn.Module): 2539 def __init__(self) -> None: 2540 super().__init__() 2541 self.m = ModWithDictArg() 2542 2543 def forward(self, x): 2544 return self.m({x: x}) 2545 2546 class MyTracer(torch.fx.Tracer): 2547 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 2548 return isinstance(m, ModWithDictArg) 2549 2550 with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): 2551 traced_graph = MyTracer().trace(CallsModWithDict()) 2552 2553 def test_module_deepcopy_edit_nodes(self): 2554 class Foo(torch.nn.Module): 2555 def forward(self, x): 2556 return torch.relu(x) 2557 2558 traced1 = symbolic_trace(Foo()) 2559 copied = copy.deepcopy(traced1) 2560 2561 for node in copied.graph.nodes: 2562 if node.target == torch.relu: 2563 node.target = torch.neg 2564 2565 copied.recompile() 2566 traced1.recompile() 2567 2568 x = torch.randn(15, 15) 2569 torch.testing.assert_close(traced1(x), torch.relu(x)) 2570 torch.testing.assert_close(copied(x), torch.neg(x)) 2571 2572 def test_direct_param_use(self): 2573 class TransposeTest(torch.nn.Module): 2574 def __init__(self) -> None: 2575 super().__init__() 2576 self.b = torch.nn.Parameter(torch.rand(4, 3)) 2577 2578 def forward(self, x): 2579 return self.b 2580 2581 class Foo(torch.nn.Module): 2582 def __init__(self) -> None: 2583 super().__init__() 2584 self.a = TransposeTest() 2585 2586 def forward(self, x): 2587 return self.a.b, self.a.b.t(), self.a.b.view(12) 2588 2589 traced = torch.fx.symbolic_trace(Foo()) 2590 assert all('constant' not in node.target for node in traced.graph.nodes) 2591 2592 def test_single_default_arg(self): 2593 class M(torch.nn.Module): 2594 def forward(self, y=1): 2595 return y 2596 2597 m = M() 2598 self.checkGraphModule(m, ()) 2599 self.checkGraphModule(m, (3,)) 2600 2601 def test_multiple_default_args(self): 2602 class M(torch.nn.Module): 2603 def forward(self, y=1, z=2): 2604 return y + z 2605 2606 m = M() 2607 self.checkGraphModule(m, ()) 2608 self.checkGraphModule(m, (3,)) 2609 self.checkGraphModule(m, (3, 4)) 2610 2611 def test_regular_and_default_args(self): 2612 class M(torch.nn.Module): 2613 def forward(self, x, y=1): 2614 return x + y 2615 2616 m = M() 2617 self.checkGraphModule(m, (2,)) 2618 self.checkGraphModule(m, (2, 3)) 2619 2620 def test_string_literal_return(self): 2621 class M(torch.nn.Module): 2622 def forward(self): 2623 return "foo" 2624 2625 m = M() 2626 self.checkGraphModule(m, ()) 2627 2628 def test_namedtuple_return_qualname(self): 2629 class NamedTupReturn(torch.nn.Module): 2630 def forward(self, x): 2631 return MyNamedTup(x, x) 2632 2633 traced = symbolic_trace(NamedTupReturn()) 2634 input = torch.rand(3, 4) 2635 self.assertEqual(traced(input), MyNamedTup(input, input)) 2636 2637 def test_update_args_kwargs_yells_at_you(self): 2638 symtraced = symbolic_trace(SimpleTest()) 2639 node = next(iter(symtraced.graph.nodes)) 2640 with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): 2641 node.__update_args_kwargs((), {}) 2642 2643 def test_torchbind_class_attribute_in_fx(self): 2644 if IS_FBCODE or IS_WINDOWS or IS_MACOS: 2645 self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") 2646 2647 class FooBar1234(torch.nn.Module): 2648 def __init__(self) -> None: 2649 super().__init__() 2650 self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) 2651 2652 def forward(self): 2653 return self.f.top() 2654 2655 m = FooBar1234() 2656 self.checkGraphModule(m, ()) 2657 2658 def test_torchbind_class_attribute_in_fx_tensor_arg(self): 2659 if IS_FBCODE or IS_WINDOWS or IS_MACOS: 2660 self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") 2661 2662 class FooBar2341(torch.nn.Module): 2663 def __init__(self) -> None: 2664 super().__init__() 2665 self.f = torch.classes._TorchScriptTesting._ReLUClass() 2666 2667 def forward(self, x): 2668 return self.f.run(x) 2669 2670 m = FooBar2341() 2671 2672 traced = symbolic_trace(m) 2673 input = torch.randn(3, 4) 2674 self.assertEqual(traced(input), m(input)) 2675 2676 self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) 2677 2678 def test_script_method_trace(self): 2679 class Scripted(torch.nn.Module): 2680 def forward(self, x): 2681 return torch.relu(x) 2682 2683 class Holder(torch.nn.Module): 2684 def __init__(self) -> None: 2685 super().__init__() 2686 self.s = torch.jit.script(Scripted()) 2687 2688 def forward(self, x): 2689 return self.s(x) 2690 2691 h = Holder() 2692 traced = symbolic_trace(h) 2693 input = torch.randn(3, 4) 2694 self.assertEqual(traced(input), h(input)) 2695 2696 self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) 2697 2698 def test_namedtuple_return_trace(self): 2699 class NamedTupReturn(torch.nn.Module): 2700 def forward(self, x): 2701 return Pair(x, x) 2702 2703 traced = symbolic_trace(NamedTupReturn()) 2704 input = torch.rand(3, 4) 2705 self.assertEqual(traced(input), Pair(input, input)) 2706 2707 def test_named_tuple_inlined(self): 2708 class NamedTupMod(torch.nn.Module): 2709 def forward(self, inp): 2710 return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp)) 2711 2712 m = NamedTupMod() 2713 input = torch.rand(3, 4) 2714 ref = m(input) 2715 traced = symbolic_trace(m) 2716 2717 res = traced(input) 2718 self.assertEqual(ref, res) 2719 2720 # Check Pair NamedTuple works when inlined into the function call. 2721 ph = call_func = None 2722 for node in traced.graph.nodes: 2723 if node.op == "placeholder": 2724 ph = node 2725 elif node.op == "call_function" and node.target == wrapped_named_tup: 2726 node.update_arg(0, Pair(ph, 1.2)) 2727 node.update_kwarg("p2", Pair(3.4, ph)) 2728 call_func = node 2729 break 2730 self.assertTrue(call_func is not None) 2731 self.assertTrue(isinstance(call_func.args[0], Pair)) 2732 self.assertTrue(isinstance(call_func.kwargs["p2"], Pair)) 2733 self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)") 2734 self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)") 2735 2736 traced.graph.eliminate_dead_code() 2737 traced.recompile() 2738 res = traced(input) 2739 self.assertEqual(ref, res) 2740 2741 def test_return_type_exists(self): 2742 class ReturnTypeModule(torch.nn.Module): 2743 def other(self, x: List[str]) -> List[str]: 2744 return x 2745 2746 def forward(self, x: List[str]) -> List[str]: 2747 return self.other(x) 2748 2749 traced = symbolic_trace(ReturnTypeModule()) 2750 self.assertIn("-> typing_List[str]", traced._code) 2751 scripted = torch.jit.script(traced) 2752 self.assertIn("-> List[str]", scripted.code) 2753 2754 def getitem_inner(self): 2755 class GetItemBase(torch.nn.Module): 2756 def __init__(self) -> None: 2757 super().__init__() 2758 self.pe = torch.nn.Buffer(torch.randn(8, 8)) 2759 2760 class GetItem1(GetItemBase): 2761 def forward(self, x): 2762 return self.pe[:, :x.size(0)] 2763 2764 class GetItem2(GetItemBase): 2765 def forward(self, x): 2766 return self.pe[x.size(0)] 2767 2768 class GetItem3(GetItemBase): 2769 def forward(self, x): 2770 return self.pe[4] # fx creates `self._tensor_constant0` here 2771 2772 self.checkGraphModule(GetItem1(), [torch.zeros(4)]) 2773 self.checkGraphModule(GetItem2(), [torch.zeros(4)]) 2774 self.checkGraphModule(GetItem3(), [torch.zeros(4)]) 2775 2776 @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1", 2777 "Will be checked in test_getitem_subproc") 2778 def test_getitem(self): 2779 self.getitem_inner() 2780 2781 def test_getitem_subproc(self): 2782 # need to run this test in a subproc to work around: 2783 # https://github.com/pytorch/pytorch/issues/50710 2784 proc = Process(target=run_getitem_target) 2785 proc.start() 2786 proc.join() 2787 self.assertEqual(proc.exitcode, 0) 2788 2789 def test_user_friendly_call_provenance_with_function(self): 2790 def fn(x): 2791 return wrapper_fn(x) 2792 2793 traced = torch.fx.symbolic_trace(fn) 2794 2795 with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " 2796 "being compiled since it was called" 2797 " from 'fn.forward'"): 2798 scripted = torch.jit.script(traced) 2799 2800 def test_user_friendly_call_provenance_with_module(self): 2801 class M(torch.nn.Module): 2802 def forward(self, x): 2803 return wrapper_fn(x) 2804 2805 traced = torch.fx.symbolic_trace(M()) 2806 2807 with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " 2808 "being compiled since it was called" 2809 " from 'M.forward'"): 2810 scripted = torch.jit.script(traced) 2811 2812 def test_snake_case(self): 2813 class M(torch.nn.Module): 2814 def __init__(self) -> None: 2815 super().__init__() 2816 self.activations = torch.nn.ModuleDict([ 2817 ["snake_case", torch.nn.ReLU()], 2818 ["PascalCase", torch.nn.LeakyReLU()], 2819 ["ALL_CAPS", torch.nn.PReLU()] 2820 ]) 2821 2822 def forward(self, x): 2823 a = self.activations["snake_case"](x) 2824 b = self.activations["PascalCase"](x) 2825 c = self.activations["ALL_CAPS"](x) 2826 return a, b, c 2827 2828 traced = symbolic_trace(M()) 2829 2830 check = [ 2831 ("activations_snake_case", "activations.snake_case"), 2832 ("activations_pascal_case", "activations.PascalCase"), 2833 ("activations_all_caps", "activations.ALL_CAPS") 2834 ] 2835 2836 i = 0 2837 for node in traced.graph.nodes: 2838 if node.op == "placeholder" or node.op == "output": 2839 continue 2840 name = check[i][0] 2841 target = check[i][1] 2842 self.assertEqual(name, node.name) 2843 self.assertEqual(target, node.target) 2844 i += 1 2845 self.assertEqual(i, 3) 2846 2847 def test_no_mutation(self): 2848 from torch.fx.immutable_collections import immutable_list 2849 x = immutable_list([3, 4]) 2850 with self.assertRaisesRegex(NotImplementedError, "new_args"): 2851 x[0] = 4 2852 2853 def test_partial_trace(self): 2854 class Foo(torch.nn.Module): 2855 def forward(self, x, y): 2856 if y: 2857 return 2 * x 2858 else: 2859 return x 2860 mod = Foo() 2861 mod_true = symbolic_trace(mod, concrete_args={'y': True}) 2862 mod_false = symbolic_trace(mod, concrete_args={'y': False}) 2863 self.assertEqual(mod_true(3, True), 6) 2864 print(mod_true.code) 2865 assert any(i.target == torch._assert for i in mod_true.graph.nodes) 2866 with self.assertRaises(AssertionError): 2867 mod_true(3, False) 2868 self.assertEqual(mod_false(3, False), 3) 2869 with self.assertRaises(AssertionError): 2870 mod_false(3, True) 2871 2872 def f_higher(a, f): 2873 return f(a) 2874 2875 nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2}) 2876 self.assertEqual(nf(3, lambda x: x * 2), 6) 2877 2878 def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): 2879 class M(torch.nn.Module): 2880 def __init__(self) -> None: 2881 super().__init__() 2882 self.W = torch.nn.Parameter(torch.randn(5)) 2883 2884 def forward(self, x): 2885 return torch.dot(self.W, x) 2886 2887 traced = torch.fx.symbolic_trace(M()) 2888 2889 out = [n for n in traced.graph.nodes if n.op == "output"][-1] 2890 with traced.graph.inserting_before(out): 2891 relu_out = traced.graph.call_method(method_name='relu', 2892 args=(out.args[0],)) 2893 out.args = (relu_out,) 2894 2895 traced.recompile() 2896 2897 with self.capture_stderr() as captured: 2898 with self.assertRaises(TypeError): 2899 traced(5) 2900 2901 self.assertRegex(captured[0], 2902 r"Call using an FX-traced Module, line .* of the " 2903 r"traced Module's generated forward function:") 2904 2905 def test_custom_traceback_not_raised_when_exception_source_is_submodule(self): 2906 class M(torch.nn.Module): 2907 def __init__(self) -> None: 2908 super().__init__() 2909 self.linear = torch.nn.Linear(3, 4) 2910 2911 def forward(self, x): 2912 return self.linear(x) 2913 2914 traced = torch.fx.symbolic_trace(M()) 2915 2916 # Do not change this to `capture_stderr` or another context 2917 # manager without ensuring that the output is as expected 2918 try: 2919 traced(torch.rand(5, 5)) 2920 except RuntimeError: 2921 captured = traceback.format_exc() 2922 2923 self.assertNotRegex(captured, 2924 r"Call using an FX-traced Module, line .* of the " 2925 r"traced Module's generated forward function:") 2926 2927 def test_graph_module_replicate_for_dp(self): 2928 class Foo(torch.nn.Module): 2929 def forward(self, x): 2930 return torch.relu(x) 2931 2932 gm = torch.fx.symbolic_trace(Foo()) 2933 2934 x = torch.randn(5, 3) 2935 out = gm(x) 2936 2937 replica = gm._replicate_for_data_parallel() 2938 out_replica = replica(x) 2939 2940 torch.testing.assert_close(out_replica, out) 2941 2942 def test_ast_rewriter_rewrites_assert(self): 2943 class M(torch.nn.Module): 2944 def forward(self, x: torch.Tensor, y: int, z: int): 2945 assert y == z 2946 return torch.add(x, x) 2947 2948 ast_rewriter = RewritingTracer() 2949 graph = ast_rewriter.trace(M()) 2950 traced = GraphModule(ast_rewriter.root, graph, "gm") 2951 2952 traced.graph.lint() 2953 2954 def test_ast_rewriter_rewrites_assert_with_message(self): 2955 class M(torch.nn.Module): 2956 def forward(self, x: torch.Tensor, y: int, z: int): 2957 assert y == z, "msg" 2958 return torch.add(x, x) 2959 2960 ast_rewriter = RewritingTracer() 2961 graph = ast_rewriter.trace(M()) 2962 traced = GraphModule(ast_rewriter.root, graph, "gm") 2963 2964 traced.graph.lint() 2965 2966 def test_throw_out_variant(self): 2967 def foo(x): 2968 y = torch.rand_like(x) 2969 torch.sigmoid(x, out=y) 2970 return y 2971 2972 class MyTracer(torch.fx.Tracer): 2973 check_mutable_operations = True 2974 2975 tracer = MyTracer() 2976 with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): 2977 traced_graph = tracer.trace(foo) 2978 2979 def test_ast_rewriter_reassigns_submodules(self): 2980 class M(torch.nn.Module): 2981 def __init__(self) -> None: 2982 super().__init__() 2983 self.bn = torch.nn.BatchNorm2d(100) 2984 2985 def forward(self, x: torch.Tensor): 2986 return torch.add(x, x) 2987 2988 ast_rewriter = RewritingTracer() 2989 graph = ast_rewriter.trace(M()) 2990 traced = GraphModule(ast_rewriter.root, graph, "gm") 2991 2992 traced.graph.lint() 2993 2994 def test_ast_rewriter_wrap(self): 2995 self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) 2996 2997 def to_trace(y): 2998 return ( 2999 a_lifted_leaf((4, y), 3) 3000 + a_lifted_leaf((3, 4), 5) 3001 + a_lifted_leaf((y, y), y) 3002 ) 3003 3004 ast_rewriter = RewritingTracer() 3005 graph = ast_rewriter.trace(to_trace) 3006 traced = GraphModule(ast_rewriter.root, graph, "gm") 3007 3008 self.assertIn("a_lifted_leaf", traced.code) 3009 self.assertEqual(27, traced(2)) 3010 self.assertIs(a_lifted_leaf, real_a_lifed_leaf) 3011 3012 def test_ast_rewriter_wrap_fn_directly(self): 3013 self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) 3014 3015 def to_trace(y): 3016 return ( 3017 a_lifted_leaf2((4, y), 3) 3018 + a_lifted_leaf2((3, 4), 5) 3019 + a_lifted_leaf2((y, y), y) 3020 ) 3021 3022 ast_rewriter = RewritingTracer() 3023 graph = ast_rewriter.trace(to_trace) 3024 traced = GraphModule(ast_rewriter.root, graph, "gm") 3025 3026 self.assertIn("a_lifted_leaf2", traced.code) 3027 self.assertEqual(27, traced(2)) 3028 self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) 3029 3030 def test_profiler_ranges_side_effect(self): 3031 g = torch.fx.Graph() 3032 handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',)) 3033 g.call_function(torch.ops.profiler._record_function_exit, (handle,)) 3034 g.output(None) 3035 3036 found_targets = {} 3037 for node in g.nodes: 3038 if node.op == 'call_function': 3039 found_targets.setdefault(node.target) 3040 self.assertEqual( 3041 list(found_targets.keys()), 3042 [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] 3043 ) 3044 3045 g.eliminate_dead_code() 3046 found_targets = {} 3047 for node in g.nodes: 3048 if node.op == 'call_function': 3049 found_targets.setdefault(node.target) 3050 self.assertEqual( 3051 list(found_targets.keys()), 3052 [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] 3053 ) 3054 3055 def test_ast_rewriter_wrapped_via_decorator(self): 3056 class F(torch.nn.Module): 3057 def forward(self, x): 3058 return wrapped_via_decorator(x) 3059 3060 ast_rewriter = RewritingTracer() 3061 graph = ast_rewriter.trace(F()) 3062 traced = GraphModule(ast_rewriter.root, graph, "gm") 3063 3064 self.assertIn("wrapped_via_decorator", traced.code) 3065 self.assertEqual(traced(0), 1) 3066 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3067 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3068 3069 def test_ast_rewriter_wrapped_via_decorator_and_transformed(self): 3070 self.assertEqual(wrapped_via_decorator(0), 1) 3071 3072 def to_trace(y): 3073 return wrapped_via_decorator(y) 3074 3075 ast_rewriter = RewritingTracer() 3076 graph = ast_rewriter.trace(to_trace) 3077 traced = GraphModule(ast_rewriter.root, graph, "gm") 3078 3079 self.assertIn("wrapped_via_decorator", traced.code) 3080 self.assertEqual(traced(0), 1) 3081 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3082 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3083 3084 transformed = torch.fx.Transformer(traced).transform() 3085 self.assertIn("wrapped_via_decorator", transformed.code) 3086 self.assertEqual(transformed(0), 1) 3087 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3088 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3089 3090 def test_ast_rewriter_wrap_with_submodule(self): 3091 class M(torch.nn.Module): 3092 def __init__(self) -> None: 3093 super().__init__() 3094 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 3095 3096 def forward(self, x: torch.Tensor): 3097 return wrapped_with_submodule(x, self.batchnorm1d) 3098 3099 ast_rewriter = RewritingTracer() 3100 graph = ast_rewriter.trace(M()) 3101 traced = GraphModule(ast_rewriter.root, graph, "gm") 3102 3103 self.assertIn("wrapped_with_submodule", traced.code) 3104 3105 input = torch.rand(3, 2) 3106 ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 3107 self.assertEqual(ref_batchnorm1d(input), traced(input)) 3108 3109 def test_submodule_manipulation_API(self): 3110 class C(torch.nn.Module): 3111 def __init__(self) -> None: 3112 super().__init__() 3113 self.conv = torch.nn.Conv2d(16, 33, 3, stride=2) 3114 self.param = torch.nn.Parameter(torch.rand(2, 3)) 3115 3116 def forward(self, x): 3117 return self.conv(torch.cat([self.param, x])) 3118 3119 class B(torch.nn.Module): 3120 def __init__(self) -> None: 3121 super().__init__() 3122 self.linear = torch.nn.Linear(100, 200) 3123 self.buf = torch.nn.Buffer(torch.randn(2, 3)) 3124 self.net_c = C() 3125 3126 def forward(self, x): 3127 return self.linear(torch.cat([self.buf, self.net_c(x)])) 3128 3129 class A(torch.nn.Module): 3130 def __init__(self) -> None: 3131 super().__init__() 3132 self.net_b = B() 3133 self.param = torch.nn.Parameter(torch.rand(2, 3)) 3134 3135 def forward(self, x): 3136 return self.net_b(x) + self.param 3137 3138 a = symbolic_trace(A()) 3139 3140 a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2)) 3141 3142 conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] 3143 with a.graph.inserting_before(conv): 3144 with warnings.catch_warnings(record=True) as w: 3145 dropout = a.graph.call_module(module_name="net_b.net_c.dropout", 3146 args=conv.args) 3147 self.assertEqual(len(w), 0) 3148 3149 conv.replace_all_uses_with(dropout) 3150 a.graph.erase_node(conv) 3151 a.recompile() 3152 3153 def module_exists(gm: GraphModule, path: str) -> bool: 3154 return any(path == name for name, _ in gm.named_modules()) 3155 3156 def parameter_exists(gm: GraphModule, path: str) -> bool: 3157 return (any(path == name for name, _ in gm.named_parameters()) 3158 and any(path == name for name in gm.state_dict().keys())) 3159 3160 def buffer_exists(gm: GraphModule, path: str) -> bool: 3161 return (any(path == name for name, _ in gm.named_buffers()) 3162 and any(path == name for name in gm.state_dict().keys())) 3163 3164 # Test that we added the "dropout" submodule 3165 self.assertTrue(module_exists(a, "net_b.net_c.dropout")) 3166 3167 # Test `get_submodule` with an added submodule 3168 self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout")) 3169 3170 # Test that the "conv" submodule is still there 3171 self.assertTrue(module_exists(a, "net_b.net_c.conv")) 3172 3173 # Test `get_submodule` with an original module 3174 self.assertIsNotNone(a.get_submodule("net_b.net_c.conv")) 3175 3176 # Test that the "conv" node is NOT still there 3177 conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"] 3178 self.assertEqual(conv, []) 3179 3180 a.delete_submodule("net_b.net_c.conv") 3181 3182 # Test that the "conv" submodule is now gone 3183 self.assertFalse(module_exists(a, "net_b.net_c.conv")) 3184 3185 # Test `get_submodule` with a deleted submodule 3186 with self.assertRaisesRegex(AttributeError, "has no attribute " 3187 "`conv`"): 3188 self.assertIsNone(a.get_submodule("net_b.net_c.conv")) 3189 3190 # Test `get_attr` warnings 3191 cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] 3192 3193 with a.graph.inserting_before(cat): 3194 3195 with warnings.catch_warnings(record=True) as w: 3196 param = a.graph.get_attr(qualified_name="net_b.net_c.param") 3197 self.assertEqual(len(w), 0) 3198 3199 with self.assertWarnsRegex(UserWarning, "Attempted to " 3200 "insert a get_attr Node with no " 3201 "underlying reference in the " 3202 "owning GraphModule"): 3203 bad_param = a.graph.get_attr(qualified_name="net_b.param") 3204 a.graph.erase_node(bad_param) 3205 3206 cat.args = (*cat.args, param) 3207 3208 a.recompile() 3209 3210 a.graph.lint() 3211 3212 # Test `get_parameter` 3213 a.get_parameter("net_b.net_c.param") 3214 with self.assertRaisesRegex(AttributeError, "is not an " 3215 "nn.Parameter"): 3216 a.get_parameter("net_b.buf") 3217 with self.assertRaisesRegex(AttributeError, "has no attribute " 3218 "`param`"): 3219 a.get_parameter("net_b.param") 3220 3221 # Test `get_buffer` 3222 a.get_buffer("net_b.buf") 3223 with self.assertRaisesRegex(AttributeError, "is not a " 3224 "buffer"): 3225 a.get_buffer("net_b.net_c.param") 3226 with self.assertRaisesRegex(AttributeError, "has no attribute " 3227 "`buf`"): 3228 a.get_buffer("net_b.net_c.buf") 3229 3230 # Test non-nested attributes 3231 a.get_submodule("") 3232 a.get_parameter("param") 3233 3234 # Insert some unused submodules 3235 a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3)) 3236 a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3)) 3237 a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2)) 3238 a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100)) 3239 3240 # Garbage collection 3241 a.delete_all_unused_submodules() 3242 3243 # Test that all the unused submodules are gone 3244 self.assertFalse(module_exists(a, "net_b.embedding")) 3245 self.assertFalse(module_exists(a, "net_b.net_c.embedding")) 3246 self.assertFalse(module_exists(a, "net_b.net_c.rnn")) 3247 self.assertFalse(module_exists(a, "batch_norm_2d")) 3248 3249 # Test that we didn't delete any unused Parameters or buffers 3250 self.assertTrue(parameter_exists(a, "net_b.net_c.param")) 3251 self.assertTrue(buffer_exists(a, "net_b.buf")) 3252 3253 a.graph.lint() 3254 3255 def test_delete_unused_submodules_leaf(self): 3256 class SubModule(torch.nn.Module): 3257 def __init__(self) -> None: 3258 super().__init__() 3259 self.linear = torch.nn.Linear(10, 10) 3260 self.relu = torch.nn.ReLU() 3261 3262 def forward(self, x): 3263 x = self.linear(x) 3264 x = self.relu(x) 3265 return x 3266 3267 class Model(torch.nn.Module): 3268 def __init__(self) -> None: 3269 super().__init__() 3270 self.submod = SubModule() 3271 3272 def forward(self, x): 3273 x = self.submod(x) 3274 return x 3275 3276 model = Model() 3277 3278 class MyCustomTracer(torch.fx.Tracer): 3279 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 3280 return module_qualified_name == "submod" 3281 3282 inputs = torch.randn(1, 10) 3283 traced_graph = MyCustomTracer().trace(model) 3284 gm2 = torch.fx.GraphModule(model, traced_graph) 3285 gm2.delete_all_unused_submodules() 3286 torch.testing.assert_close(gm2(inputs), model(inputs)) 3287 3288 def test_fx_stateless(self): 3289 class MockModule(torch.nn.Module): 3290 def __init__(self) -> None: 3291 super().__init__() 3292 self.l1 = torch.nn.Linear(1, 1) 3293 self.buffer = torch.nn.Buffer(torch.ones(1)) 3294 3295 def forward(self, x): 3296 return self.l1(x) + self.buffer 3297 3298 module = MockModule() 3299 x = torch.rand((1, 1)) 3300 weight = torch.tensor([[1.0]], requires_grad=True) 3301 bias = torch.tensor([0.0], requires_grad=True) 3302 buffer = torch.tensor([0.0]) 3303 parameters = {'l1.weight': weight, 3304 'l1.bias': bias, 3305 'buffer': buffer} 3306 fx_module = torch.fx.symbolic_trace(module) 3307 res = torch.func.functional_call(fx_module, parameters, x) 3308 res.backward() 3309 self.assertIsNotNone(weight.grad) 3310 self.assertIsNotNone(bias.grad) 3311 self.assertIsNone(buffer.grad) 3312 # Gradient was not calculated for the module stated and buffers 3313 self.assertIsNone(module.l1.weight.grad) 3314 self.assertIsNone(module.l1.bias.grad) 3315 self.assertIsNone(module.buffer.grad) 3316 3317 def test_tracing_graphmodules_as_leaf_submodules(self): 3318 class A(torch.nn.Module): 3319 def forward(self, t): 3320 return t + t 3321 3322 class B(torch.nn.Module): 3323 def __init__(self) -> None: 3324 super(type(self), self).__init__() 3325 self.calling = False 3326 self.called = False 3327 3328 def forward(self, t): 3329 if self.calling: 3330 return t - t 3331 else: 3332 return t + t 3333 3334 def __call__(self, *args): 3335 self.called = True 3336 self.calling = True 3337 return super(type(self), self).__call__(*args) 3338 self.calling = False 3339 3340 class M(torch.nn.Module): 3341 def __init__(self, a, b): 3342 super().__init__() 3343 self.a = a 3344 self.b = b 3345 3346 def forward(self, t): 3347 x = self.a(t) 3348 y = self.b(t) 3349 return x + y 3350 3351 class LeafTracer(Tracer): 3352 def is_leaf_module(self, module, name): 3353 return True 3354 3355 class LeafTracerNotB(Tracer): 3356 def is_leaf_module(self, module, name): 3357 return False if "b" in name else True 3358 3359 # Recompile calls added "for fun", since they 3360 # chain __call__ wrappers. 3361 3362 # 3363 # Test: B as a regular, non-leaf module 3364 # 3365 a = symbolic_trace(A()) 3366 a.recompile() 3367 m = M(a, B()) 3368 graph = LeafTracerNotB().trace(m) 3369 gm = GraphModule(m, graph) 3370 gm.recompile() 3371 3372 # Test graphmodule/submodule a is not inlined. 3373 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3374 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3375 self.assertTrue(len(match) == 1) 3376 3377 # Test submodule b is not treated as leaf. 3378 self.assertFalse(hasattr(gm, "b")) 3379 3380 # Test assert custom __call__ on submodule b was honored. 3381 match = [ 3382 n 3383 for n in gm.graph.nodes 3384 if n.op == "call_function" and n.target == operator.sub 3385 ] 3386 self.assertTrue(len(match) == 1) 3387 3388 # 3389 # Test: B as a regular, leaf module 3390 # symbolic_trace should only patch torch.nn.Module.__call__, 3391 # which means B.__call__ should still execute 3392 # 3393 a = symbolic_trace(A()) 3394 a.recompile() 3395 b = B() 3396 m = M(a, b) 3397 graph = LeafTracer().trace(m) 3398 gm = GraphModule(m, graph) 3399 gm.recompile() 3400 3401 # Test graphmodule/submodule a is not inlined. 3402 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3403 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3404 self.assertTrue(len(match) == 1) 3405 3406 # Test submodule b is leaf: 3407 self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) 3408 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] 3409 self.assertTrue(len(match) == 1) 3410 3411 # Test b.__call__ was run 3412 self.assertTrue(b.called) 3413 self.assertTrue(gm.get_submodule("b").called) 3414 3415 # 3416 # Test: B as GraphModule leaf 3417 # __call__ not honored since symbolic_trace directly invokes forward() 3418 # 3419 a = symbolic_trace(A()) 3420 a.recompile() 3421 b = symbolic_trace(B()) 3422 b.recompile() 3423 m = M(a, b) 3424 graph = LeafTracer().trace(m) 3425 gm = GraphModule(m, graph) 3426 gm.recompile() 3427 3428 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3429 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3430 self.assertTrue(len(match) == 1) 3431 3432 self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) 3433 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] 3434 self.assertTrue(len(match) == 1) 3435 3436 def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool): 3437 class MyModule(torch.nn.Module): 3438 def __init__(self) -> None: 3439 super().__init__() 3440 self.my_buff = torch.nn.Buffer(torch.rand(3, 4)) 3441 self.register_parameter( 3442 "my_param", torch.nn.Parameter(torch.rand(3, 4)) 3443 ) 3444 3445 def forward(self, x): 3446 return x + self.my_buff + self.my_param 3447 3448 mod = MyModule() 3449 mod_traced = symbolic_trace(mod) 3450 3451 # Create new GraphModule based on original, either w/ dict or root module. 3452 orig_buff = mod_traced.get_buffer("my_buff") 3453 orig_param = mod_traced.get_parameter("my_param") 3454 mod_traced_new = GraphModule( 3455 {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod, 3456 mod_traced.graph, 3457 ) 3458 3459 # Check that both my_buff and my_param are found and the same. 3460 try: 3461 new_buff = mod_traced_new.get_buffer("my_buff") 3462 except Exception: 3463 self.fail("Did not find my_buff") 3464 self.assertEqual(orig_buff, new_buff) 3465 3466 try: 3467 new_param = mod_traced_new.get_parameter("my_param") 3468 except Exception: 3469 self.fail("Did not find my_param") 3470 self.assertEqual(orig_param, new_param) 3471 3472 x = torch.rand(3, 4) 3473 orig_out = mod_traced(x) 3474 submodules_out = mod_traced_new(x) 3475 3476 self.assertEqual(orig_out, submodules_out) 3477 3478 def test_graph_module_init_buffer_param_copied_dict_init(self): 3479 self._test_graph_module_init_buffer_param_copied(use_dict_init=True) 3480 3481 def test_graph_module_init_buffer_param_copied_mod_init(self): 3482 self._test_graph_module_init_buffer_param_copied(use_dict_init=False) 3483 3484 def test_annotations_with_no_forward_references(self): 3485 class A: 3486 def __call__(self, x: torch.Tensor): 3487 return torch.add(x, x) 3488 3489 class M(torch.nn.Module): 3490 def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 3491 return a(x) 3492 3493 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3494 3495 def test_annotations_with_forward_references(self): 3496 class A: 3497 def __call__(self, x: torch.Tensor): 3498 return torch.add(x, x) 3499 3500 class M(torch.nn.Module): 3501 def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': 3502 return a(x) 3503 3504 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3505 3506 def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self): 3507 class A: 3508 def __call__(self, x: torch.Tensor): 3509 return torch.add(x, x) 3510 3511 class M(torch.nn.Module): 3512 def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor: 3513 return a(x[0]) 3514 3515 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3516 3517 def test_annotations_with_non_torch_reference_and_internal_forward_references(self): 3518 class A: 3519 def __call__(self, x: torch.Tensor): 3520 return torch.add(x, x) 3521 3522 class M(torch.nn.Module): 3523 def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor': 3524 return a(x)[0] 3525 3526 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3527 3528 @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature " 3529 "`annotations` is not defined in Python <3.7") 3530 def test_annotation_with_future(self): 3531 try: 3532 import fx.test_future # noqa: F401 3533 finally: 3534 del sys.modules["__future__"] 3535 3536 @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11") 3537 def test_annotations_empty_tuple(self): 3538 class Foo(torch.nn.Module): 3539 def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): 3540 return "foo" 3541 3542 traced = torch.fx.symbolic_trace(Foo()) 3543 3544 x = () 3545 y = ("bar", ()) 3546 3547 traced(x, y) 3548 3549 FileCheck().check("_Tuple[()]") \ 3550 .check("typing_Tuple[str,typing_Tuple[()]]") \ 3551 .run(traced.code) 3552 3553 scripted = torch.jit.script(traced) 3554 3555 scripted(x, y) 3556 3557 FileCheck().check("Tuple[()]") \ 3558 .check("Tuple[str, Tuple[()]]") \ 3559 .run(scripted.code) 3560 3561 @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108") 3562 @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10") 3563 def test_assert(self): 3564 def f(x): 3565 assert x > 1 3566 return x + 1 3567 try: 3568 torch.fx.proxy.TracerBase.trace_asserts = True 3569 traced = symbolic_trace(f) 3570 finally: 3571 torch.fx.proxy.TracerBase.trace_asserts = False 3572 3573 self.assertEqual(f(2), traced(2)) 3574 with self.assertRaises(AssertionError): 3575 traced(0) 3576 3577 def test_pytree(self): 3578 # Used to test that you can use your own placeholder class 3579 class PHTest(PHBase): 3580 pass 3581 3582 def f_sum(x): 3583 return sum(x) 3584 3585 def f_sum_dict(x): 3586 out = 0 3587 for v in x.values(): 3588 out += v 3589 return out 3590 3591 def f_dict_list_map(x): 3592 new_dict = {} 3593 for k, v in x.items(): 3594 new_dict[k] = [i + 1 for i in v] 3595 return new_dict 3596 3597 def f_dict_add(x): 3598 return x['a'] + sum(x['z']) 3599 3600 def f_namedtuple_add(x): 3601 return x.x + x.y 3602 3603 pytree.register_pytree_node( 3604 Foo, 3605 lambda x: ([x.a, x.b], None), 3606 lambda x, _: Foo(x[0], x[1]), 3607 ) 3608 fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b]) 3609 3610 def f_custom(x): 3611 return x.a + x.b 3612 3613 def f_custom_dict(x): 3614 return f_sum_dict(x.a) + x.b 3615 3616 def f_return_custom(x): 3617 return Foo(x.b, x.a) 3618 3619 tests = [ 3620 (f_sum, [PH, PH, PH]), 3621 (f_sum, []), 3622 (f_sum, [PHTest(), PHTest(), PHTest()]), 3623 (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}), 3624 (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}), 3625 (f_dict_list_map, {5: (PH, PH, PH)}), 3626 (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}), 3627 (f_dict_add, {'a': PH, 'z': []}), 3628 (f_custom, Foo(PH, PH)), 3629 (f_custom, Foo(PH, 3)), 3630 (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)), 3631 # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees 3632 (f_namedtuple_add, Point(PH, PH)), 3633 ] 3634 3635 def verify_pytree(f, inp): 3636 val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp) 3637 num_flat_args = len(pytree.tree_leaves(inp)) 3638 orig_out = f(val) 3639 nf = symbolic_trace(f, concrete_args={'x': inp}) 3640 self.assertEqual(nf(val), orig_out) 3641 3642 bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) 3643 bare_fx.graph.set_codegen(CodeGen()) 3644 bare_fx.recompile() 3645 self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) 3646 3647 assert num_flat_args == 0 or "tree_flatten_spec" in nf.code 3648 assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args 3649 3650 nf = symbolic_trace(nf) 3651 self.assertEqual(nf(val), orig_out) 3652 assert "tree_flatten_spec" not in nf.code 3653 assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1 3654 3655 nf = symbolic_trace(nf, concrete_args={'x': inp}) 3656 self.assertEqual(nf(val), orig_out) 3657 assert num_flat_args == 0 or "tree_flatten_spec" in nf.code 3658 assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args 3659 3660 pickled = pickle.dumps(nf) 3661 nf = pickle.loads(pickled) 3662 self.assertEqual(nf(val), orig_out) 3663 3664 for f, inp in tests: 3665 verify_pytree(f, inp) 3666 3667 def test_pytree_concrete(self): 3668 def f(b, a): 3669 if b: 3670 return a['a'] 3671 else: 3672 return a['z'] 3673 3674 inp = {'a': {'a': PH, 'z': PH}, 'b': True} 3675 nf = symbolic_trace(f, concrete_args=inp) 3676 val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) 3677 self.assertEqual(nf(**val), f(**val)) 3678 3679 nf = symbolic_trace(nf) 3680 self.assertEqual(nf(**val), f(**val)) 3681 3682 def test_metadata_on_ph(self): 3683 def f_sum(a: int, b: int) -> int: 3684 return a + b 3685 3686 # Due to unflattening of dict, the batch argument 3687 # will be split into two separate nodes with the names 3688 # "batch_1" and "batch_2", referring to the keys 3689 # "f1" and "f2" respectively in the dict. 3690 def f_dict(a: Dict[str, str]) -> bool: 3691 return a["f1"] == a["f2"] 3692 3693 def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]): 3694 for node in gm.graph.nodes: 3695 if node.op == "placeholder": 3696 self.assertTrue(node.name in arg_names) 3697 self.assertTrue(node.ph_key in metadata) 3698 3699 verify_metadata( 3700 gm=symbolic_trace( 3701 f_sum, 3702 concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")} 3703 ), 3704 arg_names=["a_1", "b_1"], 3705 metadata=["a", "b"] 3706 ) 3707 verify_metadata( 3708 gm=symbolic_trace( 3709 f_dict, 3710 concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}} 3711 ), 3712 arg_names=["a_1", "a_2"], 3713 metadata=["f1", "f2"] 3714 ) 3715 3716 # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag) 3717 class TaggingTracer(Tracer): 3718 def create_node(self, kind : str, target : Union[str, Callable], 3719 args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 3720 type_expr : Optional[Any] = None) -> Node: 3721 n = super().create_node(kind, target, args, kwargs, name) 3722 n.tag = "foo" 3723 return n 3724 3725 class PHWithTag(PHBase): 3726 def __init__(self, tag: str): 3727 super().__init__() 3728 3729 self.tag = tag 3730 3731 g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")}) 3732 for n in g.nodes: 3733 self.assertTrue(hasattr(n, "tag")) 3734 # Ensure that tag is still "foo" and not "bar" (from PHWithTag) 3735 self.assertEqual(n.tag, "foo") 3736 3737 def test_custom_codegen(self): 3738 class ListCodeGen(CodeGen): 3739 def gen_fn_def(self, free_vars, maybe_return_annotation): 3740 lst_unpack = f""" 3741def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3742 {', '.join(free_vars)} = args_list""" 3743 return lst_unpack 3744 3745 def additional_globals(self): 3746 return [('List', typing.List)] 3747 3748 def process_inputs(self, *inputs): 3749 assert len(inputs) == 1 3750 return inputs[0] 3751 3752 def f(a, b): 3753 return a + b 3754 3755 nf = symbolic_trace(f) 3756 vals = [torch.randn(3), torch.randn(3)] 3757 self.assertEqual(nf(*vals), f(*vals)) 3758 3759 nf.graph.set_codegen(ListCodeGen()) 3760 nf.recompile() 3761 3762 bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) 3763 bare_fx.graph.set_codegen(CodeGen()) 3764 bare_fx.recompile() 3765 3766 self.assertEqual(nf(vals), f(*vals)) 3767 self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals)) 3768 3769 ts_f = torch.jit.script(nf) 3770 self.assertEqual(nf(vals), ts_f(vals)) 3771 3772 def test_custom_codegen_with_transformer(self): 3773 class ListCodeGen(CodeGen): 3774 def gen_fn_def(self, free_vars, maybe_return_annotation): 3775 lst_unpack = f""" 3776def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3777 {', '.join(free_vars)} = args_list""" 3778 return lst_unpack 3779 3780 def additional_globals(self): 3781 return [('List', typing.List)] 3782 3783 def process_inputs(self, *inputs): 3784 assert len(inputs) == 1 3785 return inputs[0] 3786 3787 def f(a, b): 3788 return a + b 3789 3790 nf = symbolic_trace(f) 3791 vals = [torch.randn(3), torch.randn(3)] 3792 self.assertEqual(nf(*vals), f(*vals)) 3793 3794 nf.graph.set_codegen(ListCodeGen()) 3795 nf.recompile() 3796 self.assertEqual(nf(vals), f(*vals)) 3797 3798 transformed_gm = Transformer(nf).transform() 3799 self.assertEqual(nf(vals), transformed_gm(vals)) 3800 3801 def test_interpreter_with_codegen(self): 3802 class ListCodeGen(CodeGen): 3803 def gen_fn_def(self, free_vars, maybe_return_annotation): 3804 lst_unpack = f""" 3805def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3806 {', '.join(free_vars)} = args_list""" 3807 return lst_unpack 3808 3809 def additional_globals(self): 3810 return [('List', typing.List)] 3811 3812 def process_inputs(self, *inputs): 3813 assert len(inputs) == 1 3814 return inputs[0] 3815 3816 def generate_output(self, output_args): 3817 return f'return list({repr(output_args)})' 3818 3819 def process_outputs(self, outputs): 3820 return list(outputs) 3821 3822 def f(a, b): 3823 a = a + b 3824 b = a + b 3825 return a, b 3826 3827 nf = symbolic_trace(f) 3828 vals = [torch.randn(3), torch.randn(3)] 3829 nf.graph.set_codegen(ListCodeGen()) 3830 nf.recompile() 3831 self.assertEqual(Interpreter(nf).run(vals), nf(vals)) 3832 3833 def test_imul_code_print(self): 3834 graph = torch.fx.Graph() 3835 a = graph.placeholder("a") 3836 b = graph.placeholder("b") 3837 graph.call_function(operator.imul, (a, b), {}) 3838 graph.output(a) 3839 gm = torch.fx.GraphModule({}, graph) 3840 gm.recompile() 3841 self.assertEqual(gm(2, 3), 6) 3842 self.assertIn("a *= b", gm.code) 3843 3844 def test_deepcopy_tracer(self): 3845 def fn(x, y): 3846 return (x + y).relu().sin() 3847 3848 tracer = Tracer() 3849 tracer_before = copy.deepcopy(tracer) 3850 tracer.trace(fn) 3851 tracer_after = copy.deepcopy(tracer) 3852 3853 self.assertEqual(str(tracer.graph), str(tracer_after.graph)) 3854 self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph)) 3855 3856 def test_deepcopy_graphmodule(self): 3857 m = symbolic_trace(SimpleTest()) 3858 m.meta['hello'] = 'world' 3859 copy_m = copy.deepcopy(m) 3860 self.assertEqual(copy_m.meta['hello'], 'world') 3861 3862 def test_deepcopy_no_recursion(self): 3863 m = symbolic_trace(SimpleTest()) 3864 m.meta['hello'] = m # circular reference 3865 copy_m = copy.deepcopy(m) # finishes 3866 self.assertEqual(id(copy_m), id(copy_m.meta['hello'])) 3867 3868 def test_enum(self): 3869 from enum import Enum 3870 3871 class Foo(Enum): 3872 A = 1 3873 B = 2 3874 3875 def leaf_fn(arr, enum_val): 3876 # Use the raw enum. 3877 arr.append(enum_val) 3878 return arr[-1].value 3879 3880 def foo(x): 3881 # Pass the enum as argument. 3882 return leaf_fn(x, Foo.A) 3883 3884 traced = torch.fx.symbolic_trace(foo) 3885 self.assertEqual(foo([]), traced([])) 3886 3887 def test_insert_arg(self): 3888 m = symbolic_trace(SimpleTest()) 3889 m.buf = torch.nn.Buffer(torch.tensor(0)) 3890 output_node = next(iter(reversed(m.graph.nodes))) 3891 with m.graph.inserting_before(output_node): 3892 a = m.graph.get_attr("buf") 3893 r = len(output_node.args) 3894 output_node.insert_arg(0, a) 3895 self.assertEqual(len(output_node.args), r + 1) 3896 self.assertEqual(len(a.users), 1) 3897 self.assertIs(output_node.args[0], a) 3898 self.assertIs(next(iter(a.users.keys())), output_node) 3899 output_node.insert_arg(2, a) 3900 self.assertEqual(len(output_node.args), r + 2) 3901 self.assertEqual(len(a.users), 1) 3902 self.assertIs(output_node.args[2], a) 3903 self.assertIs(next(iter(a.users.keys())), output_node) 3904 m.graph.lint() 3905 3906 def test_delete_unused_values(self): 3907 from torch.fx.experimental.proxy_tensor import make_fx 3908 3909 # disable mutable checking temporarily 3910 orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3911 torch.fx.proxy.TracerBase.check_mutable_operations = False 3912 3913 def fn(a, b, c, d): 3914 x = a + b 3915 y = c + d 3916 y.copy_(x) 3917 x = torch.relu(x) 3918 return x 3919 3920 a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4)) 3921 fx_fn = make_fx(fn)(a, b, c, d) 3922 print(fx_fn) 3923 3924 fx_fn.graph.eliminate_dead_code() 3925 py_code = fx_fn.recompile() 3926 self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src) 3927 self.assertTrue("copy_ = None" in py_code.src) 3928 3929 # recorver mutable checking flag 3930 torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag 3931 3932def run_getitem_target(): 3933 from torch.fx._symbolic_trace import _wrapped_methods_to_patch 3934 _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) 3935 try: 3936 TestFX().getitem_inner() 3937 finally: 3938 _wrapped_methods_to_patch.pop() 3939 3940 3941class TestOperatorSignatures(JitTestCase): 3942 def setUp(self): 3943 # Checking for mutable operations whil tracing is feature flagged 3944 # Enable it in testing but not by default 3945 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3946 torch.fx.proxy.TracerBase.check_mutable_operations = True 3947 3948 def tearDown(self): 3949 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 3950 3951 @onlyCPU 3952 @ops(op_db, allowed_dtypes=(torch.float,)) 3953 def test_get_torch_func_signature_exhaustive(self, device, dtype, op): 3954 if not isinstance(op.op, types.BuiltinFunctionType): 3955 raise unittest.SkipTest("This path doesn't work on Python functions") 3956 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 3957 schemas = get_signature_for_torch_op(op.op) 3958 if not schemas: 3959 raise RuntimeError('No Schemas Returned') 3960 for sample_input in sample_inputs_itr: 3961 # Iterate through overloads until we hit a match. If we exit this 3962 # loop via `else`, we haven't found a match 3963 for schema in schemas: 3964 try: 3965 bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs) 3966 bound_args.apply_defaults() 3967 op(*bound_args.args, **bound_args.kwargs) 3968 break 3969 except TypeError as e: 3970 pass 3971 else: 3972 raise RuntimeError(f'Did not match any schemas for op {op.name}!') 3973 3974 3975class TestFXAPIBackwardCompatibility(JitTestCase): 3976 def setUp(self): 3977 super().setUp() 3978 self.maxDiff = None 3979 3980 # Checking for mutable operations whil tracing is feature flagged 3981 # Enable it in testing but not by default 3982 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3983 torch.fx.proxy.TracerBase.check_mutable_operations = True 3984 3985 def tearDown(self): 3986 super().tearDown() 3987 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 3988 3989 3990 def _fn_to_stable_annotation_str(self, obj): 3991 """ 3992 Unfortunately we have to serialize function signatures manually since 3993 serialization for `inspect.Signature` objects is not stable across 3994 python versions 3995 """ 3996 fn_name = torch.typename(obj) 3997 3998 signature = inspect.signature(obj) 3999 4000 sig_str = f'{fn_name}{signature}' 4001 4002 arg_strs = [] 4003 for k, v in signature.parameters.items(): 4004 maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\ 4005 if v.annotation is not inspect.Signature.empty else '' 4006 4007 def default_val_str(val): 4008 if isinstance(val, (tuple, list)): 4009 str_pieces = ['(' if isinstance(val, tuple) else '['] 4010 str_pieces.append(', '.join(default_val_str(v) for v in val)) 4011 if isinstance(val, tuple) and len(str_pieces) == 2: 4012 str_pieces.append(',') 4013 str_pieces.append(')' if isinstance(val, tuple) else ']') 4014 return ''.join(str_pieces) 4015 4016 # Need to fix up some default value strings. 4017 # First case: modules. Default module `repr` contains the FS path of the module. 4018 # Don't leak that 4019 if isinstance(val, types.ModuleType): 4020 return f'<module {val.__name__}>' 4021 4022 # Second case: callables. Callables (such as lambdas) encode their address in 4023 # their string repr. Don't do that 4024 if callable(val): 4025 return f'<function {val.__name__}>' 4026 4027 return str(val) 4028 4029 if v.default is not inspect.Signature.empty: 4030 default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'" 4031 maybe_default = f' = {default_val_str}' 4032 else: 4033 maybe_default = '' 4034 maybe_stars = '' 4035 if v.kind == inspect.Parameter.VAR_POSITIONAL: 4036 maybe_stars = '*' 4037 elif v.kind == inspect.Parameter.VAR_KEYWORD: 4038 maybe_stars = '**' 4039 arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}') 4040 4041 return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\ 4042 if signature.return_annotation is not inspect.Signature.empty else '' 4043 4044 return f'{fn_name}({", ".join(arg_strs)}){return_annot}' 4045 4046 def _annotation_type_to_stable_str(self, t, sig_str): 4047 if t is inspect.Signature.empty: 4048 return '' 4049 4050 # Forward ref 4051 if isinstance(t, str): 4052 return f"'{t}'" 4053 if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): 4054 return t.__forward_arg__ 4055 if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): 4056 return t.__forward_arg__ 4057 4058 trivial_mappings = { 4059 str : 'str', 4060 int : 'int', 4061 float: 'float', 4062 bool: 'bool', 4063 torch.dtype: 'torch.dtype', 4064 torch.Tensor: 'torch.Tensor', 4065 torch.device: 'torch.device', 4066 torch.memory_format: 'torch.memory_format', 4067 slice: 'slice', 4068 torch.nn.Module: 'torch.nn.modules.module.Module', 4069 torch.fx.Graph : 'torch.fx.graph.Graph', 4070 torch.fx.Node : 'torch.fx.node.Node', 4071 torch.fx.Proxy : 'torch.fx.proxy.Proxy', 4072 torch.fx.node.Target : 'torch.fx.node.Target', 4073 torch.fx.node.Argument : 'torch.fx.node.Argument', 4074 torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', 4075 torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', 4076 torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', 4077 Ellipsis : '...', 4078 typing.Any: 'Any', 4079 type(None): 'NoneType', 4080 None: 'None', 4081 typing.Iterator: 'Iterator', 4082 } 4083 4084 mapping = trivial_mappings.get(t, None) 4085 if mapping: 4086 return mapping 4087 4088 # Handle types with contained types 4089 contained = getattr(t, '__args__', None) or [] 4090 4091 # Callables contain a bare List for arguments 4092 contained = t if isinstance(t, list) else contained 4093 4094 # Python 3.8 puts type vars into __args__ for unbound types such as Dict 4095 if all(isinstance(ct, typing.TypeVar) for ct in contained): 4096 contained = [] 4097 4098 contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] 4099 contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' 4100 4101 4102 origin = getattr(t, '__origin__', None) 4103 if origin is None: 4104 # Unbound types don't have `__origin__` in some Python versions, so fix that up here. 4105 origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin 4106 4107 if origin in {tuple, typing.Tuple}: 4108 return f'Tuple{contained_type_str}' 4109 if origin in {typing.Union}: 4110 # Annoying hack to detect Optional 4111 if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): 4112 not_none_param = contained[0] if contained[0] is not type(None) else contained[1] 4113 return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' 4114 return f'Union{contained_type_str}' 4115 if origin in {dict, typing.Dict}: 4116 return f'Dict{contained_type_str}' 4117 if origin in {list, typing.List}: 4118 return f'List{contained_type_str}' 4119 if origin in {type, typing.Type}: 4120 return f'Type{contained_type_str}' 4121 if isinstance(t, typing.Callable): 4122 if len(contained) > 0 and contained[0] is not Ellipsis: 4123 return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' 4124 else: 4125 return f'Callable{contained_type_str}' 4126 4127 raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.' 4128 f'Please add support for this type and confirm with the ' 4129 f'FX team that your signature change is valid.') 4130 4131 4132 def test_function_back_compat(self): 4133 """ 4134 Test backward compatibility for function signatures with 4135 @compatibility(is_backward_compatible=True). Currently this checks for 4136 exact signature matches, which may lead to false positives. If this 4137 becomes too annoying, we can refine this check to actually parse out 4138 the saved schema strings and check if the change is truly backward- 4139 incompatible. 4140 """ 4141 signature_strs = [] 4142 4143 for obj in _BACK_COMPAT_OBJECTS: 4144 if not isinstance(obj, type): 4145 signature_strs.append(self._fn_to_stable_annotation_str(obj)) 4146 4147 signature_strs.sort() 4148 4149 try: 4150 self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures') 4151 except AssertionError as e: 4152 msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \ 4153 f"as backwards-compatible has experienced a signature change. See the " \ 4154 f"above exception context for more information. If this change was " \ 4155 f"unintended, please revert it. If it was intended, check with the FX " \ 4156 f"team to ensure that the proper deprecation protocols have been followed " \ 4157 f"and subsequently --accept the change." 4158 raise AssertionError(msg) # noqa: B904 4159 4160 def test_class_member_back_compat(self): 4161 """ 4162 Test backward compatibility for members of classes with 4163 @compatibility(is_backward_compatible=True). Currently this checks for 4164 exact matches on the publicly visible members of the class. 4165 """ 4166 class_method_strs = [] 4167 4168 for obj in _BACK_COMPAT_OBJECTS: 4169 if isinstance(obj, type): 4170 public_members = [name for name in obj.__dict__ if not name.startswith('_')] 4171 class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}') 4172 4173 class_method_strs.sort() 4174 4175 try: 4176 self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members') 4177 except AssertionError as e: 4178 msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \ 4179 f"as backwards-compatible has experienced change in its public members. See the " \ 4180 f"above exception context for more information. If this change was " \ 4181 f"unintended, please revert it. If it was intended, check with the FX " \ 4182 f"team to ensure that the proper deprecation protocols have been followed " \ 4183 f"and subsequently --accept the change." 4184 raise AssertionError(msg) from e 4185 4186 def test_public_api_surface(self): 4187 non_back_compat_objects = {} 4188 4189 def check_symbols_have_bc_designation(m, seen): 4190 if not m.__name__.startswith('torch.fx'): 4191 return 4192 if m.__name__.startswith('torch.fx.experimental'): 4193 return 4194 # It's really common for inner functions to point to random modules 4195 # - make sure we don't recurse into modules we've already checked. 4196 seen.add(m.__name__) 4197 for k, v in m.__dict__.items(): 4198 if hasattr(v, '__name__') and v.__name__ in seen: 4199 continue 4200 if v is m: 4201 continue 4202 if k.startswith('_'): 4203 continue 4204 if isinstance(v, types.ModuleType): 4205 check_symbols_have_bc_designation(v, seen) 4206 elif isinstance(v, (type, types.FunctionType)): 4207 if v not in _MARKED_WITH_COMPATIBILITY: 4208 non_back_compat_objects.setdefault(v) 4209 4210 check_symbols_have_bc_designation(torch.fx, set()) 4211 check_symbols_have_bc_designation(torch.fx.passes, set()) 4212 4213 non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()] 4214 # Only want objects in torch.fx 4215 non_back_compat_strs = [ 4216 s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')] 4217 # Only want objects in public namespaces 4218 non_back_compat_strs = [ 4219 s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))] 4220 non_back_compat_strs.sort() 4221 4222 if len(non_back_compat_strs) != 0: 4223 raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a " 4224 f"backwards-compatibility classification! Please decorate these " 4225 f"API(s) with `@torch.fx._compatibility.compatibility` to specify " 4226 f"BC guarantees.") 4227 4228 def test_adding_side_effect_function(self): 4229 class TestModule(torch.nn.Module): 4230 def forward(self, x): 4231 side_effect_func(x) 4232 return x 4233 4234 gm = torch.fx.symbolic_trace(TestModule()) 4235 self.assertEqual(len(gm.graph.nodes), 3) 4236 gm.graph.eliminate_dead_code() 4237 gm.recompile() 4238 self.assertEqual(len(gm.graph.nodes), 3) 4239 found = False 4240 for node in gm.graph.nodes: 4241 if node.op == 'call_function' and node.target == side_effect_func: 4242 found = True 4243 self.assertTrue(found) 4244 4245 def test_preserve_unused_attr_after_unpickle(self): 4246 gm = torch.fx.symbolic_trace(Add()) 4247 gm.add_submodule("foo", Add()) 4248 gm.dummy_buffer = torch.nn.Buffer(torch.empty(1)) 4249 gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1))) 4250 b = io.BytesIO() 4251 torch.save(gm, b) 4252 b.seek(0) 4253 # weights_only=False as this loads a GraphModule 4254 # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default 4255 reload_gm = torch.load(b, weights_only=False) 4256 self.assertTrue(hasattr(reload_gm, "foo")) 4257 self.assertTrue(hasattr(reload_gm, "dummy_buffer")) 4258 self.assertTrue(hasattr(reload_gm, "dummy_parameter")) 4259 4260# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454 4261@unittest.skipIf( 4262 sys.version_info >= (3, 12), "Failing on python 3.12+" 4263) 4264class TestFunctionalTracing(JitTestCase): 4265 def setUp(self): 4266 super().setUp() 4267 # Checking for mutable operations whil tracing is feature flagged 4268 # Enable it in testing but not by default 4269 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 4270 torch.fx.proxy.TracerBase.check_mutable_operations = True 4271 4272 def tearDown(self): 4273 super().tearDown() 4274 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 4275 4276 IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", 4277 "has_torch_function_variadic", "handle_torch_function", 4278 "boolean_dispatch") 4279 TO_PATCH = {"has_torch_function": None, 4280 "has_torch_function_unary": None, 4281 "has_torch_function_variadic": None} 4282 4283 BUILT_IN_FUNC = (AssertionError, "") 4284 PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") 4285 PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") 4286 LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default") 4287 ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") 4288 CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") 4289 INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") 4290 MUTABLE = (RuntimeError, r"Tried to trace mutable operation") 4291 4292 UNTRACEABLE_FUNCTIONALS = { 4293 "adaptive_avg_pool1d": BUILT_IN_FUNC, 4294 "avg_pool1d": BUILT_IN_FUNC, 4295 "avg_pool2d": BUILT_IN_FUNC, 4296 "avg_pool3d": BUILT_IN_FUNC, 4297 "bilinear": BUILT_IN_FUNC, 4298 "celu_": BUILT_IN_FUNC, 4299 "channel_shuffle": BUILT_IN_FUNC, 4300 "native_channel_shuffle": BUILT_IN_FUNC, 4301 "conv1d": BUILT_IN_FUNC, 4302 "conv2d": BUILT_IN_FUNC, 4303 "conv3d": BUILT_IN_FUNC, 4304 "conv_tbc": BUILT_IN_FUNC, 4305 "conv_transpose1d": BUILT_IN_FUNC, 4306 "conv_transpose2d": BUILT_IN_FUNC, 4307 "conv_transpose3d": BUILT_IN_FUNC, 4308 "cosine_similarity": BUILT_IN_FUNC, 4309 "elu_": BUILT_IN_FUNC, 4310 "gelu": BUILT_IN_FUNC, 4311 "hardshrink": BUILT_IN_FUNC, 4312 "hardtanh_": BUILT_IN_FUNC, 4313 "leaky_relu_": BUILT_IN_FUNC, 4314 "linear": BUILT_IN_FUNC, 4315 "logsigmoid": BUILT_IN_FUNC, 4316 "one_hot": BUILT_IN_FUNC, 4317 "pad": ARG_TYPE_MISMATCH, 4318 "pairwise_distance": BUILT_IN_FUNC, 4319 "pdist": BUILT_IN_FUNC, 4320 "pixel_shuffle": BUILT_IN_FUNC, 4321 "pixel_unshuffle": BUILT_IN_FUNC, 4322 "prelu": BUILT_IN_FUNC, 4323 "relu_": BUILT_IN_FUNC, 4324 "rrelu_": BUILT_IN_FUNC, 4325 "selu_": BUILT_IN_FUNC, 4326 "scaled_dot_product_attention": BUILT_IN_FUNC, 4327 "softplus": BUILT_IN_FUNC, 4328 "softshrink": BUILT_IN_FUNC, 4329 "threshold_": BUILT_IN_FUNC, 4330 4331 "adaptive_avg_pool2d": LEN_ERROR, 4332 "adaptive_avg_pool3d": LEN_ERROR, 4333 "adaptive_max_pool2d_with_indices": LEN_ERROR, 4334 "adaptive_max_pool3d_with_indices": LEN_ERROR, 4335 "instance_norm": CONTROL_FLOW, 4336 4337 "adaptive_max_pool1d": PROXY_ITERABLE, 4338 "adaptive_max_pool2d": PROXY_ITERABLE, 4339 "adaptive_max_pool3d": PROXY_ITERABLE, 4340 "fractional_max_pool2d": PROXY_ITERABLE, 4341 "fractional_max_pool3d": PROXY_ITERABLE, 4342 "max_pool1d": PROXY_ITERABLE, 4343 "max_pool2d": PROXY_ITERABLE, 4344 "max_pool3d": PROXY_ITERABLE, 4345 4346 "lp_pool2d": PROXY_ITERATED, 4347 "lp_pool3d": PROXY_ITERATED, 4348 "max_unpool1d": PROXY_ITERATED, 4349 "max_unpool2d": PROXY_ITERATED, 4350 "max_unpool3d": PROXY_ITERATED, 4351 "fold": PROXY_ITERATED, 4352 "unfold": PROXY_ITERATED, 4353 4354 "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, 4355 "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, 4356 "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, 4357 "layer_norm": ARG_TYPE_MISMATCH, 4358 "rms_norm": ARG_TYPE_MISMATCH, 4359 "lp_pool1d": ARG_TYPE_MISMATCH, 4360 4361 "affine_grid": CONTROL_FLOW, 4362 "alpha_dropout": CONTROL_FLOW, 4363 "batch_norm": CONTROL_FLOW, 4364 "binary_cross_entropy": CONTROL_FLOW, 4365 "binary_cross_entropy_with_logits": CONTROL_FLOW, 4366 "celu": CONTROL_FLOW, 4367 "cosine_embedding_loss": CONTROL_FLOW, 4368 "cross_entropy": CONTROL_FLOW, 4369 "ctc_loss": CONTROL_FLOW, 4370 "dropout": CONTROL_FLOW, 4371 "dropout1d": CONTROL_FLOW, 4372 "dropout2d": CONTROL_FLOW, 4373 "dropout3d": CONTROL_FLOW, 4374 "elu": CONTROL_FLOW, 4375 "embedding": CONTROL_FLOW, 4376 "embedding_bag": CONTROL_FLOW, 4377 "feature_alpha_dropout": CONTROL_FLOW, 4378 "gaussian_nll_loss": CONTROL_FLOW, 4379 "glu": CONTROL_FLOW, 4380 "grid_sample": CONTROL_FLOW, 4381 "group_norm": CONTROL_FLOW, 4382 "gumbel_softmax": CONTROL_FLOW, 4383 "hardsigmoid": CONTROL_FLOW, 4384 "hardswish": CONTROL_FLOW, 4385 "hardtanh": CONTROL_FLOW, 4386 "hinge_embedding_loss": CONTROL_FLOW, 4387 "huber_loss": CONTROL_FLOW, 4388 "interpolate": CONTROL_FLOW, 4389 "kl_div": CONTROL_FLOW, 4390 "l1_loss": CONTROL_FLOW, 4391 "leaky_relu": CONTROL_FLOW, 4392 "local_response_norm": CONTROL_FLOW, 4393 "margin_ranking_loss": CONTROL_FLOW, 4394 "max_pool1d_with_indices": ARG_TYPE_MISMATCH, 4395 "max_pool2d_with_indices": ARG_TYPE_MISMATCH, 4396 "max_pool3d_with_indices": ARG_TYPE_MISMATCH, 4397 "mse_loss": CONTROL_FLOW, 4398 "multi_head_attention_forward": CONTROL_FLOW, 4399 "multi_margin_loss": CONTROL_FLOW, 4400 "multilabel_margin_loss": CONTROL_FLOW, 4401 "multilabel_soft_margin_loss": CONTROL_FLOW, 4402 "nll_loss": CONTROL_FLOW, 4403 "poisson_nll_loss": CONTROL_FLOW, 4404 "relu": CONTROL_FLOW, 4405 "relu6": CONTROL_FLOW, 4406 "rrelu": CONTROL_FLOW, 4407 "selu": CONTROL_FLOW, 4408 "silu": CONTROL_FLOW, 4409 "mish": CONTROL_FLOW, 4410 "smooth_l1_loss": CONTROL_FLOW, 4411 "soft_margin_loss": CONTROL_FLOW, 4412 "threshold": CONTROL_FLOW, 4413 "triplet_margin_loss": CONTROL_FLOW, 4414 "triplet_margin_with_distance_loss": CONTROL_FLOW, 4415 "upsample": CONTROL_FLOW, 4416 4417 "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, 4418 "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, 4419 } 4420 4421 # List of nn.functionals with Tensor inputs but not with type annotation 4422 FUNCTIONALS_WITHOUT_ANNOTATION = ( 4423 "adaptive_max_pool1d", 4424 "adaptive_max_pool2d", 4425 "adaptive_max_pool3d", 4426 "fractional_max_pool2d", 4427 "fractional_max_pool3d", 4428 "max_pool1d", 4429 "max_pool2d", 4430 "max_pool3d", 4431 "gaussian_nll_loss", 4432 "upsample", 4433 "upsample_bilinear", 4434 "upsample_nearest", 4435 ) 4436 4437 # Inconsistent behavior between Python 3.8 and other Python versions: 4438 # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED` 4439 # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same 4440 # internal exception above 4441 # Use the following map to override the expected exception for Python 3.8 4442 UNTRACEABLE_FUNCTIONALS_PY38 = { 4443 "adaptive_max_pool1d": PROXY_ITERATED, 4444 "adaptive_max_pool2d": PROXY_ITERATED, 4445 "adaptive_max_pool3d": PROXY_ITERATED, 4446 "fractional_max_pool2d": PROXY_ITERATED, 4447 "fractional_max_pool3d": PROXY_ITERATED, 4448 "max_pool1d": PROXY_ITERATED, 4449 "max_pool2d": PROXY_ITERATED, 4450 "max_pool3d": PROXY_ITERATED, 4451 4452 "group_norm": CONTROL_FLOW 4453 } 4454 4455 @classmethod 4456 def _get_functional(cls): 4457 functional_list = [] 4458 for f in dir(torch.nn.functional): 4459 if not f.islower(): 4460 continue 4461 # Ignore internal functions 4462 if f.startswith('_'): 4463 continue 4464 # Ignore supporting functions 4465 if f in cls.IGNORE_FUNCS: 4466 continue 4467 fn = getattr(torch.nn.functional, f) 4468 # Ignore non-callable object like modules 4469 if not isinstance(fn, Callable): 4470 continue 4471 if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION: 4472 try: 4473 sig = inspect.signature(fn) 4474 has_tensor_arg = False 4475 for param in sig.parameters.values(): 4476 if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor): 4477 has_tensor_arg = True 4478 if not has_tensor_arg: 4479 continue 4480 # No signature or Object is not supported 4481 except ValueError: 4482 pass 4483 functional_list.append((f, fn)) 4484 return functional_list 4485 4486 @classmethod 4487 def generate_test_func(cls, func_name, fn): 4488 4489 def functional_test(self): 4490 if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \ 4491 sys.version_info >= (3, 8) and sys.version_info < (3, 12): 4492 exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] 4493 with self.assertRaisesRegex(exc, err): 4494 symbolic_trace(fn) 4495 elif func_name in self.UNTRACEABLE_FUNCTIONALS: 4496 exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name] 4497 with self.assertRaisesRegex(exc, err): 4498 symbolic_trace(fn) 4499 else: 4500 symbolic_trace(fn) 4501 return functional_test 4502 4503 @classmethod 4504 def generate_tests(cls): 4505 functional_list = cls._get_functional() 4506 for func_name, fn in functional_list: 4507 test_name = "test_nn_functional_" + func_name 4508 functional_test = cls.generate_test_func(func_name, fn) 4509 setattr(cls, test_name, functional_test) 4510 4511 @classmethod 4512 def setUpClass(cls): 4513 4514 def no(*args, **kwargs): 4515 return False 4516 4517 for name in cls.TO_PATCH.keys(): 4518 cls.TO_PATCH[name] = getattr(torch.nn.functional, name) 4519 setattr(torch.nn.functional, name, no) 4520 4521 @classmethod 4522 def tearDownClass(cls): 4523 for name in cls.TO_PATCH.keys(): 4524 setattr(torch.nn.functional, name, cls.TO_PATCH[name]) 4525 4526TestFunctionalTracing.generate_tests() 4527 4528 4529instantiate_device_type_tests(TestOperatorSignatures, globals()) 4530 4531@skipIfTorchDynamo("too slow") 4532@skipIfNoTorchVision 4533class TestVisionTracing(JitTestCase): 4534 def setUp(self): 4535 # Checking for mutable operations while tracing is feature flagged 4536 # Enable it in testing but not by default 4537 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 4538 torch.fx.proxy.TracerBase.check_mutable_operations = True 4539 4540 def tearDown(self): 4541 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 4542 4543 PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") 4544 INCONSISTENT_TYPE = ( 4545 RuntimeError, 4546 r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor" 4547 ) 4548 4549 UNTRACEABLE_MODELS = { 4550 "fasterrcnn_resnet50_fpn": PROXY_ITERATED, 4551 "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED, 4552 "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED, 4553 "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED, 4554 "maskrcnn_resnet50_fpn": PROXY_ITERATED, 4555 "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED, 4556 "keypointrcnn_resnet50_fpn": PROXY_ITERATED, 4557 "retinanet_resnet50_fpn": PROXY_ITERATED, 4558 "retinanet_resnet50_fpn_v2": PROXY_ITERATED, 4559 "ssd300_vgg16": PROXY_ITERATED, 4560 "fcos_resnet50_fpn": PROXY_ITERATED, 4561 "ssdlite320_mobilenet_v3_large": PROXY_ITERATED, 4562 } 4563 UNSCRIPTABLE_MODELS = { 4564 "googlenet": INCONSISTENT_TYPE, 4565 "inception_v3": INCONSISTENT_TYPE, 4566 } 4567 4568 output_transform = { 4569 "fcn_resnet50": lambda x: x["out"], 4570 "fcn_resnet101": lambda x: x["out"], 4571 "deeplabv3_resnet50": lambda x: x["out"], 4572 "deeplabv3_resnet101": lambda x: x["out"], 4573 "deeplabv3_mobilenet_v3_large": lambda x: x["out"], 4574 "lraspp_mobilenet_v3_large": lambda x: x["out"], 4575 "fasterrcnn_resnet50_fpn": lambda x: x[1], 4576 "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], 4577 "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], 4578 "maskrcnn_resnet50_fpn": lambda x: x[1], 4579 "keypointrcnn_resnet50_fpn": lambda x: x[1], 4580 "retinanet_resnet50_fpn": lambda x: x[1], 4581 } 4582 4583 @classmethod 4584 def generate_test_fn(cls, name, x, kwargs): 4585 def run_test(self): 4586 model = torchvision_models.get_model(name, **kwargs) 4587 model = model.eval() 4588 if name in self.UNTRACEABLE_MODELS: 4589 err, exc = self.UNTRACEABLE_MODELS[name] 4590 with self.assertRaisesRegex(err, exc): 4591 graph = symbolic_trace(model) 4592 else: 4593 out_transform = self.output_transform.get(name, lambda x: x) 4594 graph : torch.fx.GraphModule = symbolic_trace(model) 4595 a = out_transform(model(x)) 4596 b = out_transform(graph(x)) 4597 self.assertEqual(a, b) 4598 4599 if name in self.UNSCRIPTABLE_MODELS: 4600 err, exc = self.UNSCRIPTABLE_MODELS[name] 4601 with self.assertRaisesRegex(err, exc): 4602 script = torch.jit.script(graph) 4603 else: 4604 script = torch.jit.script(graph) 4605 c = out_transform(script(x)) 4606 self.assertEqual(a, c) 4607 4608 return run_test 4609 4610 @classmethod 4611 def generate_classification_tests(cls): 4612 for k in torchvision_models.list_models(module=torchvision_models): 4613 test_name = 'test_torchvision_models_' + k 4614 x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224) 4615 kwargs = dict(num_classes=50) 4616 model_test = cls.generate_test_fn(k, x, kwargs) 4617 setattr(cls, test_name, model_test) 4618 4619 @classmethod 4620 def generate_segmentation_tests(cls): 4621 for k in torchvision_models.list_models(module=torchvision_models.segmentation): 4622 test_name = 'test_torchvision_models_segmentation_' + k 4623 x = torch.rand(1, 3, 32, 32) 4624 kwargs = dict(num_classes=10, pretrained_backbone=False) 4625 model_test = cls.generate_test_fn(k, x, kwargs) 4626 setattr(cls, test_name, model_test) 4627 4628 @classmethod 4629 def generate_detection_tests(cls): 4630 for k in torchvision_models.list_models(module=torchvision_models.detection): 4631 test_name = 'test_torchvision_models_detection_' + k 4632 x = [torch.rand(3, 300, 300)] 4633 kwargs = dict(num_classes=10, pretrained_backbone=False) 4634 model_test = cls.generate_test_fn(k, x, kwargs) 4635 setattr(cls, test_name, model_test) 4636 4637 @classmethod 4638 def generate_video_tests(cls): 4639 for k in torchvision_models.list_models(module=torchvision_models.video): 4640 test_name = 'test_torchvision_models_video_' + k 4641 x = ( 4642 torch.rand(1, 3, 4, 112, 112) 4643 if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} 4644 else torch.rand(1, 3, 16, 224, 224) 4645 ) 4646 kwargs = dict(num_classes=50) 4647 model_test = cls.generate_test_fn(k, x, kwargs) 4648 setattr(cls, test_name, model_test) 4649 4650 @classmethod 4651 def generate_tests(cls): 4652 cls.generate_classification_tests() 4653 cls.generate_detection_tests() 4654 cls.generate_segmentation_tests() 4655 cls.generate_video_tests() 4656 4657if HAS_TORCHVISION: 4658 TestVisionTracing.generate_tests() 4659 4660if __name__ == '__main__': 4661 run_tests() 4662