1# Owner(s): ["module: fx"] 2 3import functools 4import math 5import numbers 6import operator 7import pickle 8import sys 9import sympy 10import tempfile 11import unittest 12from types import BuiltinFunctionType 13from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union 14 15import torch 16import torch.fx.experimental.meta_tracer 17import torch.fx.experimental.optimization as optimization 18from torch.fx._symbolic_trace import symbolic_trace 19from torch.fx.experimental import merge_matmul 20from torch.fx.experimental.accelerator_partitioner import Partitioner 21from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators 22from torch.fx.experimental.partitioner_utils import ( 23 Device, 24 get_latency_of_partitioned_graph, 25 get_partition_to_latency_mapping, 26 NodeLatency, 27 PartitionerConfig, 28 PartitionMode, 29) 30from torch.fx.experimental.rewriter import RewritingTracer 31from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema 32from torch.fx.graph_module import GraphModule 33from torch.fx.node import Node 34from torch.fx.operator_schemas import ( 35 _torchscript_type_to_python_type, 36 create_type_hint, 37 normalize_function, 38 normalize_module, 39 type_matches, 40) 41from torch.fx.passes import graph_manipulation 42from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes 43from torch.fx.passes.shape_prop import ShapeProp 44from torch.fx.passes.split_module import split_module 45from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes 46from torch.testing._internal.common_device_type import ( 47 instantiate_device_type_tests, 48 onlyCPU, 49 ops, 50) 51from torch.testing._internal.common_methods_invocations import op_db 52from torch.testing._internal.common_nn import module_tests, new_module_tests 53from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase 54from torch.testing._internal.jit_utils import JitTestCase 55import torch.utils._pytree as pytree 56 57try: 58 import torchvision.models 59 from torchvision.models import resnet18 60 61 HAS_TORCHVISION = True 62except ImportError: 63 HAS_TORCHVISION = False 64skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 65skipIfNoMkldnn = unittest.skipIf( 66 not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()), 67 "no MKLDNN", 68) 69 70 71def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: 72 return GraphModule( 73 root if isinstance(root, torch.nn.Module) else torch.nn.Module(), 74 RewritingTracer().trace(root), 75 ) 76 77 78class TestFXExperimental(JitTestCase): 79 def test_find_single_partition(self): 80 class TestModule(torch.nn.Module): 81 def forward(self, a, b): 82 return a + b 83 84 m = TestModule() 85 traced = symbolic_trace(m) 86 a = torch.rand(1) 87 b = torch.rand(1) 88 graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 89 partitioner = Partitioner() 90 devices = [ 91 Device("dev_0", 125, 0), 92 Device("dev_1", 150, 1), 93 Device("dev_2", 125, 2), 94 ] 95 partitioner_config = PartitionerConfig(devices) 96 ret = partitioner.partition_graph(traced, m, partitioner_config) 97 module_with_submodules = ret.module_with_submodules 98 dag = ret.dag 99 self.assertEqual(traced(a, b), module_with_submodules(a, b)) 100 assert dag.nodes[0].logical_device_ids == [1] 101 102 def test_lack_of_devices(self): 103 class TestModule(torch.nn.Module): 104 def forward(self, a, b): 105 return a + b 106 107 m = TestModule() 108 traced = symbolic_trace(m) 109 a = torch.rand(4) 110 b = torch.rand(4) 111 graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 112 partitioner = Partitioner() 113 devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] 114 partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 115 catch_runtime_error = False 116 try: 117 ret = partitioner.partition_graph(traced, m, partitioner_config) 118 except RuntimeError: 119 catch_runtime_error = True 120 assert catch_runtime_error 121 122 def test_large_node_error(self): 123 class TestModule(torch.nn.Module): 124 def __init__(self) -> None: 125 super().__init__() 126 self.linear = torch.nn.Linear(4, 4) 127 128 def forward(self, a): 129 linear = self.linear(a) 130 add = linear + a 131 return add 132 133 m = TestModule() 134 traced = symbolic_trace(m) 135 a = torch.rand(4) 136 graph_manipulation.get_size_of_all_nodes(traced, [a]) 137 partitioner = Partitioner() 138 devices = [ 139 Device("dev_0", 40, 0), 140 Device("dev_1", 40, 0), 141 Device("dev_2", 40, 0), 142 Device("dev_3", 40, 0), 143 Device("dev_4", 40, 0), 144 ] 145 partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 146 catch_runtime_error = False 147 try: 148 ret = partitioner.partition_graph(traced, m, partitioner_config) 149 except RuntimeError: 150 catch_runtime_error = True 151 assert catch_runtime_error 152 153 def test_partition_node_manipulation(self): 154 class TestModule(torch.nn.Module): 155 def forward(self, a, b): 156 add_1 = a + b 157 add_2 = add_1 + torch.rand(4) 158 add_3 = add_2 + torch.rand(4) 159 return add_3 160 161 m = TestModule() 162 traced = symbolic_trace(m) 163 a, b = torch.rand(4), torch.rand(4) 164 graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 165 partitioner = Partitioner() 166 devices = [Device("dev_0", 1000, 0)] 167 partitioner_config = PartitionerConfig(devices) 168 ret = partitioner.partition_graph(traced, m, partitioner_config) 169 partition = partitioner.partitions[0] 170 assert partition.used_mem_bytes == 112 171 # Select add_2 node to remove 172 selected_node = None 173 for node in partition.nodes: 174 if node.name == "add_2": 175 selected_node = node 176 partition.remove_node(selected_node) 177 assert partition.used_mem_bytes == 80 178 179 def test_size_based_partition(self): 180 class TestModule(torch.nn.Module): 181 def __init__(self) -> None: 182 super().__init__() 183 self.linear = torch.nn.Linear(4, 4) 184 self.c = torch.rand(4) 185 186 def forward(self, a, b): 187 add_1 = a + b 188 linear = self.linear(add_1) 189 add_2 = linear + self.c 190 return add_2 191 192 m = TestModule() 193 traced = symbolic_trace(m) 194 a = torch.rand(4) 195 b = torch.rand(4) 196 graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 197 partitioner = Partitioner() 198 devices = [ 199 Device("dev_0", 125, 0), 200 Device("dev_1", 125, 1), 201 Device("dev_2", 125, 2), 202 ] 203 partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 204 ret = partitioner.partition_graph(traced, m, partitioner_config) 205 module_with_submodules = ret.module_with_submodules 206 dag = ret.dag 207 self.assertEqual(traced(a, b), module_with_submodules(a, b)) 208 for i, node in enumerate(dag.nodes): 209 assert node.logical_device_ids == [i] 210 211 def test_partition_device_mapping(self): 212 class TestModule(torch.nn.Module): 213 def __init__(self) -> None: 214 super().__init__() 215 self.linear = torch.nn.Linear(4, 4) 216 217 def forward(self, a): 218 b = torch.rand(4) 219 add_1 = a + b 220 linear_1 = self.linear(add_1) 221 add_2 = torch.rand(4) + a 222 add_3 = add_2 + linear_1 223 return add_3 224 225 m = TestModule() 226 traced = symbolic_trace(m) 227 a = torch.rand(4) 228 graph_manipulation.get_size_of_all_nodes(traced, [a]) 229 partitioner = Partitioner() 230 devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] 231 partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 232 ret = partitioner.partition_graph(traced, m, partitioner_config) 233 module_with_submodules = ret.module_with_submodules 234 dag = ret.dag 235 self.assertEqual(traced(a), module_with_submodules(a)) 236 for i, node in enumerate(dag.nodes): 237 if i == 1: 238 assert node.logical_device_ids == [1] 239 else: 240 assert node.logical_device_ids == [0] 241 242 def test_sparse_nn_partition(self): 243 class MyRecommendationModule(torch.nn.Module): 244 def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): 245 layers = torch.nn.ModuleList() 246 for _ in range(num_of_layers): 247 ll = torch.nn.Linear(input_size, output_size) 248 layers.append(ll) 249 layers.append(torch.nn.ReLU()) 250 return layers 251 252 def __init__(self) -> None: 253 super().__init__() 254 layers = self.create_mlp(4, 4, 4) 255 self.bottom_layers = torch.nn.Sequential(*layers) 256 layers = self.create_mlp(3, 24, 24) 257 self.top_layers = torch.nn.Sequential(*layers) 258 self.embedding_layers = torch.nn.ModuleList() 259 el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) 260 self.embedding_layers.append(el) 261 for i in range(3): 262 el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) 263 self.embedding_layers.append(el) 264 el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) 265 self.embedding_layers.append(el) 266 267 def forward(self, a, b, offset): 268 x = self.bottom_layers(a) 269 y = [] 270 c = [] 271 for i in range(len(self.embedding_layers)): 272 temp = torch.randint(10, (8,)) 273 c.append(temp + b) 274 for i in range(len(self.embedding_layers)): 275 if i % 2 == 0: 276 y.append(self.embedding_layers[i](c[i], offset)) 277 else: 278 y.append( 279 self.embedding_layers[i](torch.randint(10, (8,)), offset) 280 ) 281 z = torch.cat([x] + y, dim=1) 282 p = self.top_layers(z) 283 return p 284 285 m = MyRecommendationModule() 286 a = torch.rand(2, 4) 287 b = torch.randint(10, (8,)) 288 offset = torch.randint(1, (2,)) 289 traced = symbolic_trace(m) 290 graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) 291 devices = [ 292 Device("dev_0", 33000000, 0), 293 Device("dev_1", 33000000, 1), 294 Device("dev_2", 33000000, 2), 295 ] 296 partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) 297 partitioner = Partitioner() 298 ret = partitioner.partition_graph(traced, m, partitioner_config) 299 module_with_submodules = ret.module_with_submodules 300 dag = ret.dag 301 self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) 302 assert len(module_with_submodules.graph.nodes) == 24 303 304 def test_partition_latency(self): 305 class TestModule(torch.nn.Module): 306 def __init__(self) -> None: 307 super().__init__() 308 self.linear = torch.nn.Linear(4, 4) 309 310 def forward(self, a): 311 add_1 = a + torch.rand(4) 312 add_2 = add_1 + torch.rand(4) 313 linear_1 = self.linear(add_1) 314 add_3 = add_2 + linear_1 315 add_4 = add_2 + add_3 316 return add_4 317 318 def get_node_to_latency_mapping(fx_module: GraphModule): 319 """Given a fx module, generate node latency for each node 320 based on the size of each node 321 """ 322 node_to_latency_mapping: Dict[Node, NodeLatency] = {} 323 for node in fx_module.graph.nodes: 324 if node.op not in {"output", "placeholder", "get_attr"}: 325 if node.size_bytes.total_size == node.size_bytes.output_size: 326 node_to_latency_mapping[node] = NodeLatency( 327 node.size_bytes.total_size, 2.0 * node.size_bytes.total_size 328 ) 329 else: 330 node_to_latency_mapping[node] = NodeLatency( 331 node.size_bytes.total_size, node.size_bytes.output_size 332 ) 333 return node_to_latency_mapping 334 335 m = TestModule() 336 traced = symbolic_trace(m) 337 a = torch.rand(4) 338 graph_manipulation.get_size_of_all_nodes(traced, [a]) 339 node_to_latency_mapping = get_node_to_latency_mapping(traced) 340 devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] 341 partitioner = Partitioner() 342 partitioner_config = PartitionerConfig(devices) 343 ret = partitioner.partition_graph(traced, m, partitioner_config) 344 module_with_submodules = ret.module_with_submodules 345 self.assertEqual(traced(a), module_with_submodules(a)) 346 partitions = partitioner.partitions 347 partition_to_latency_mapping = get_partition_to_latency_mapping( 348 partitions, node_to_latency_mapping 349 ) 350 for p in partition_to_latency_mapping: 351 if p.partition_id == 0: 352 assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) 353 else: 354 assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) 355 transfer_rate_bytes_per_sec = 2 356 critical_path_latency_sec = get_latency_of_partitioned_graph( 357 partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec 358 ) 359 assert critical_path_latency_sec == 208.0 360 361 def test_cost_aware_partition(self): 362 class MyModule(torch.nn.Module): 363 def __init__(self) -> None: 364 super().__init__() 365 self.linear = torch.nn.Linear(4, 4) 366 367 def forward(self, a): 368 add_1 = a + torch.rand(4) 369 add_2 = add_1 + torch.rand(4) 370 linear_1 = self.linear(add_1) 371 add_3 = add_2 + torch.rand(4) 372 add_4 = add_2 + linear_1 373 add_5 = add_3 + add_4 374 return add_5 375 376 def get_node_to_latency_mapping(fx_module: GraphModule): 377 node_to_latency_mapping: Dict[Node, NodeLatency] = {} 378 for node in fx_module.graph.nodes: 379 if node.op not in {"output", "placeholder", "get_attr"}: 380 if node.size_bytes.total_size == node.size_bytes.output_size: 381 node_to_latency_mapping[node] = NodeLatency( 382 node.size_bytes.total_size, 1 383 ) 384 else: 385 node_to_latency_mapping[node] = NodeLatency( 386 node.size_bytes.total_size, node.size_bytes.output_size 387 ) 388 return node_to_latency_mapping 389 390 m = MyModule() 391 traced = symbolic_trace(m) 392 a = torch.rand(4) 393 graph_manipulation.get_size_of_all_nodes(traced, [a]) 394 devices = [ 395 Device("dev_0", 125, 0), 396 Device("dev_1", 125, 1), 397 Device("dev_2", 125, 2), 398 Device("dev_3", 125, 3), 399 ] 400 node_to_latency_mapping = get_node_to_latency_mapping(traced) 401 partitioner_config = PartitionerConfig( 402 devices, 403 mode=PartitionMode.cost_aware, 404 transfer_rate_bytes_per_sec=2, 405 node_to_latency_mapping=node_to_latency_mapping, 406 ) 407 partitioner = Partitioner() 408 ret = partitioner.partition_graph(traced, m, partitioner_config) 409 module_with_submodules = ret.module_with_submodules 410 dag = ret.dag 411 self.assertEqual(traced(a), module_with_submodules(a)) 412 partitions = partitioner.partitions 413 partition_to_latency_mapping = get_partition_to_latency_mapping( 414 partitions, node_to_latency_mapping 415 ) 416 critical_path_latency_sec = get_latency_of_partitioned_graph( 417 partitions, 418 partition_to_latency_mapping, 419 partitioner_config.transfer_rate_bytes_per_sec, 420 ) 421 assert critical_path_latency_sec == 160.0 422 423 def test_aot_based_partition(self): 424 class TestModule(torch.nn.Module): 425 def __init__(self) -> None: 426 super().__init__() 427 self.b = torch.rand(4) 428 self.c = torch.rand(4) 429 430 def forward(self, a): 431 add_1 = a + self.b 432 add_2 = self.c + add_1 433 return add_2 434 435 m = TestModule() 436 traced = symbolic_trace(m) 437 a = torch.rand(4) 438 node_to_partition_id = {} 439 partition_to_logical_devices = {} 440 count = 0 441 graph_manipulation.get_size_of_all_nodes(traced, [a]) 442 for node in traced.graph.nodes: 443 if node.op not in {"placeholder", "get_attr", "output"}: 444 node_to_partition_id[node] = count 445 partition_to_logical_devices[count] = [0] 446 count += 1 447 devices = [Device("dev_0", 200, 0)] 448 partitioner_config = PartitionerConfig( 449 devices=devices, 450 mode=PartitionMode.aot_based, 451 node_to_partition_mapping=node_to_partition_id, 452 partition_to_logical_device_mapping=partition_to_logical_devices, 453 ) 454 partitioner = Partitioner() 455 ret = partitioner.partition_graph(traced, m, partitioner_config) 456 module_with_submodules = ret.module_with_submodules 457 dag = ret.dag 458 self.assertEqual(module_with_submodules(a), traced(a)) 459 for node in dag.nodes: 460 assert node.size_bytes == 48 461 assert node.logical_device_ids == [0] 462 463 def test_replace_target_nodes_with(self): 464 class testModule(torch.nn.Module): 465 def forward(self, a, b): 466 return a + b 467 468 m = testModule() 469 traced = symbolic_trace(m) 470 input1 = torch.randn(1) 471 input2 = torch.randn(1) 472 assert (input1 + input2) == traced(input1, input2) 473 graph_manipulation.replace_target_nodes_with( 474 fx_module=traced, 475 old_op="call_function", 476 old_target=operator.add, 477 new_op="call_function", 478 new_target=operator.mul, 479 ) 480 assert (input1 * input2) == traced(input1, input2) 481 482 def test_saturate_host(self): 483 class TestModule(torch.nn.Module): 484 def __init__(self) -> None: 485 super().__init__() 486 self.linear = torch.nn.Linear(4, 4) 487 488 def forward(self, a): 489 add_1 = a + torch.rand(4) 490 add_2 = add_1 + torch.rand(4) 491 linear_1 = self.linear(add_1) 492 add_3 = add_2 + linear_1 493 add_4 = add_2 + add_3 494 return add_4 495 496 m = TestModule() 497 traced = symbolic_trace(m) 498 a = torch.rand(4) 499 graph_manipulation.get_size_of_all_nodes(traced, [a]) 500 devices = [ 501 Device("dev_0", 200, 0), 502 Device("dev_1", 200, 1), 503 Device("dev_2", 100, 2), 504 Device("dev_3", 100, 3), 505 Device("dev_4", 200, 4), 506 Device("dev_5", 100, 5), 507 ] 508 partitioner = Partitioner() 509 # Without host saturation, the model will be split into two partitions. 510 # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. 511 partitioner_config = PartitionerConfig(devices, saturate_host=True) 512 ret = partitioner.partition_graph(traced, m, partitioner_config) 513 module_with_submodules = ret.module_with_submodules 514 self.assertEqual(traced(a), module_with_submodules(a)) 515 516 partitions = partitioner.partitions 517 self.assertEqual(len(partitions), 2) 518 # With host saturation, partition 1 will be replicated to dev_4, and partition 2 519 # will be replicated to dev_2. 520 self.assertEqual(partitions[0].logical_device_ids, [0, 4]) 521 self.assertEqual(partitions[1].logical_device_ids, [1, 2]) 522 523 @skipIfNoTorchVision 524 def test_conv_bn_fusion(self): 525 rn18 = resnet18().eval() 526 traced = symbolic_trace(rn18) 527 fused = optimization.fuse(traced) 528 529 self.assertTrue( 530 all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 531 ) 532 533 N, C, H, W = 20, 3, 224, 224 534 inp = torch.randn(N, C, H, W) 535 536 self.assertEqual(fused(inp), rn18(inp)) 537 538 def test_conv_bn_fusion_not_running_state(self): 539 class M(torch.nn.Module): 540 def __init__(self) -> None: 541 super().__init__() 542 self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) 543 self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) 544 545 def forward(self, x): 546 x = self.conv(x) 547 x = self.bn(x) 548 return x 549 550 model = M().eval() 551 552 traced = symbolic_trace(model) 553 fused = optimization.fuse(traced) 554 inp = torch.randn([1, 32, 50, 50]) 555 556 # bn need not be folded in conv 557 self.assertTrue( 558 any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 559 ) 560 self.assertEqual(fused(inp), model(inp)) 561 562 def test_conv_bn_fusion_mixed_dtype(self): 563 class M(torch.nn.Module): 564 def __init__(self) -> None: 565 super().__init__() 566 self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16) 567 self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True) 568 569 def forward(self, x): 570 x = self.conv(x) 571 x = self.bn(x) 572 return x 573 574 model = M().eval() 575 576 traced = symbolic_trace(model) 577 fused = optimization.fuse(traced) 578 inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) 579 580 self.assertTrue( 581 all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 582 ) 583 self.assertEqual(fused(inp), model(inp)) 584 585 def test_call_to_assert_no_msg(self): 586 class M(torch.nn.Module): 587 def forward(self, a, b): 588 assert a == b 589 return a + b 590 591 m = M() 592 traced = symbolic_trace_with_rewrite(m) 593 594 # Make sure the graph is well-formed 595 traced.graph.lint() 596 597 # Check the IR to make sure there's a call_function node with target == "Assert" 598 self.assertTrue( 599 any( 600 node.op == "call_function" and node.target == torch._assert 601 for node in traced.graph.nodes 602 ) 603 ) 604 605 # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 606 traced(3, 3) 607 with self.assertRaisesRegex(AssertionError, ""): 608 traced(3, 5) 609 610 # Confirm that the output is correct 611 self.assertEqual(traced(3, 3), m(3, 3)) 612 613 def test_meta_tracer(self): 614 class MetaTracerTestModule(torch.nn.Module): 615 def __init__(self) -> None: 616 super().__init__() 617 self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16) 618 self.layernorm = torch.nn.LayerNorm(16) 619 620 def forward(self, x): 621 emb = self.emb(x) 622 emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) 623 lol = self.layernorm(emb) 624 return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) 625 626 mttm = MetaTracerTestModule() 627 for BS in [15, 35]: 628 x = torch.zeros(BS, dtype=torch.long).random_(42) 629 meta_args = {'x' : x.to(device='meta')} 630 gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) 631 torch.testing.assert_close(gm(x), mttm(x)) 632 633 # Test serialization/deserialization 634 with tempfile.TemporaryDirectory() as tmp_dir: 635 with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: 636 pickle.dump(gm, f) 637 638 with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: 639 loaded = pickle.load(f) 640 641 torch.testing.assert_close(loaded(x), mttm(x)) 642 643 644 def test_call_to_assert_with_msg(self): 645 class M(torch.nn.Module): 646 def forward(self, a, b): 647 assert a == b, "test message" 648 return a + b 649 650 m = M() 651 traced = symbolic_trace_with_rewrite(m) 652 653 # Make sure the graph is well-formed 654 traced.graph.lint() 655 656 # Check the IR to make sure there's a call_function node with target == "Assert" 657 self.assertTrue( 658 any( 659 node.op == "call_function" and node.target == torch._assert 660 for node in traced.graph.nodes 661 ) 662 ) 663 664 # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 665 traced(3, 3) 666 with self.assertRaisesRegex(AssertionError, "test message"): 667 traced(3, 5) 668 669 # Confirm that the output is correct 670 self.assertEqual(traced(3, 3), m(3, 3)) 671 672 def test_call_to_assert_with_empty_msg(self): 673 class M(torch.nn.Module): 674 def forward(self, a, b): 675 assert a == b, "" 676 return a + b 677 678 m = M() 679 traced = symbolic_trace_with_rewrite(m) 680 681 # Make sure the graph is well-formed 682 traced.graph.lint() 683 684 # Check the IR to make sure there's a call_function node with target == "Assert" 685 self.assertTrue( 686 any( 687 node.op == "call_function" and node.target == torch._assert 688 for node in traced.graph.nodes 689 ) 690 ) 691 692 # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 693 traced(3, 3) 694 with self.assertRaisesRegex(AssertionError, ""): 695 traced(3, 5) 696 697 # Confirm that the output is correct 698 self.assertEqual(traced(3, 3), m(3, 3)) 699 700 def test_call_to_assert_with_multiline_message(self): 701 class M(torch.nn.Module): 702 def forward(self, a, b): 703 error_msg = """ 704An error message with 705terrible spacing 706 """ 707 assert a == b, error_msg 708 return a + b 709 710 m = M() 711 traced = symbolic_trace_with_rewrite(m) 712 713 # Make sure the graph is well-formed 714 traced.graph.lint() 715 716 # Check the IR to make sure there's a call_function node with target == "Assert" 717 self.assertTrue( 718 any( 719 node.op == "call_function" and node.target == torch._assert 720 for node in traced.graph.nodes 721 ) 722 ) 723 724 # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 725 error_msg = """ 726An error message with 727terrible spacing 728 """ 729 traced(3, 3) 730 with self.assertRaisesRegex(AssertionError, error_msg): 731 traced(3, 5) 732 733 # Confirm that the output is correct 734 self.assertEqual(traced(3, 3), m(3, 3)) 735 736 def test_subgraph_creation(self): 737 class MyModule(torch.nn.Module): 738 def __init__(self) -> None: 739 super().__init__() 740 self.param = torch.nn.Parameter(torch.rand(3, 4)) 741 self.linear = torch.nn.Linear(4, 5) 742 743 def forward(self, x, y): 744 z = self.linear(x + self.param).clamp(min=0.0, max=1.0) 745 w = self.linear(y).clamp(min=0.0, max=1.0) 746 return z + w 747 748 # symbolically trace model 749 my_module = MyModule() 750 my_module_traced = symbolic_trace(my_module) 751 752 # random mod partitioning 753 partition_counter = 0 754 NPARTITIONS = 3 755 756 # Add some random meta info to make sure it is kept around. 757 for node in my_module_traced.graph.nodes: 758 if node.op != "output": 759 node.meta["test_meta_info"] = True 760 761 def mod_partition(node: Node): 762 nonlocal partition_counter 763 partition = partition_counter % NPARTITIONS 764 partition_counter = (partition_counter + 1) % NPARTITIONS 765 return partition 766 767 # split module in module with submodules 768 module_with_submodules = split_module( 769 my_module_traced, my_module, mod_partition 770 ) 771 772 # Check that test_meta_info was still on all nodes. 773 submodules = dict(module_with_submodules.named_modules()) 774 for node in module_with_submodules.graph.nodes: 775 if node.op == "call_module": 776 submod = submodules[node.target] 777 self.assertTrue(isinstance(submod, torch.fx.GraphModule)) 778 for submod_node in submod.graph.nodes: 779 if submod_node.op != "output": 780 stored_op = submod_node.meta.get("test_meta_info") 781 self.assertTrue(stored_op is not None and stored_op) 782 783 x = torch.rand(3, 4) 784 y = torch.rand(3, 4) 785 786 orig_out = my_module_traced(x, y) 787 submodules_out = module_with_submodules(x, y) 788 789 self.assertEqual(orig_out, submodules_out) 790 791 def test_split_module_dead_code(self): 792 class ModWithDeadCode(torch.nn.Module): 793 def forward(self, x): 794 output = x * 2 # we want this 795 dead_line = x + 2 # this is dead 796 return output 797 798 mod = ModWithDeadCode() 799 traced = torch.fx.symbolic_trace(mod) 800 801 # split into before (0), target (1), and after(2) 802 saw_mul = False 803 804 def split_callback(n): 805 nonlocal saw_mul 806 if n.target == operator.mul: 807 saw_mul = True 808 return 1 809 810 if not saw_mul: 811 return 0 812 if saw_mul: 813 return 2 814 815 split = split_module(traced, mod, split_callback) 816 817 x = torch.randn((5,)) 818 torch.testing.assert_close( 819 split(x), traced(x) 820 ) 821 822 823 def test_split_module_kwargs_expansion(self): 824 class ModuleWithKwargsExpansion(torch.nn.Module): 825 def forward(self, x, **kwargs): 826 return x + kwargs['foo'] 827 828 mod = ModuleWithKwargsExpansion() 829 traced = torch.fx.symbolic_trace(mod) 830 831 seen_getitem = False 832 833 def split_callback(n): 834 nonlocal seen_getitem 835 split_idx = int(seen_getitem) 836 if n.target == operator.getitem: 837 seen_getitem = True 838 return split_idx 839 840 split = split_module(traced, mod, split_callback) 841 842 x = torch.randn(5, 3) 843 foo = torch.randn(5, 3) 844 torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo)) 845 846 @skipIfNoTorchVision 847 def test_subgraph_trivial_resnet(self): 848 # Smoke test trivially splitting resnet into 1 partition works 849 # There was an issue before causing submodule names to be aliased 850 m = resnet18() 851 traced = symbolic_trace(m) 852 a = torch.rand(64, 3, 7, 7) 853 module_with_submodules = split_module(traced, m, lambda node: 0) 854 module_with_submodules(a) 855 856 def test_split_module_default_arg(self): 857 class ModelToTrace(torch.nn.Module): 858 def __init__(self) -> None: 859 super().__init__() 860 self.lin = torch.nn.Linear(512, 512) 861 862 def forward(self, x, targets=None): 863 x = self.lin(x) 864 865 if targets is not None: 866 x = x + targets 867 868 return x 869 870 mtt = ModelToTrace() 871 traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None}) 872 873 split = split_module(traced, mtt, lambda node: 0) 874 875 x = torch.randn(50, 512) 876 torch.testing.assert_close(split(x), traced(x)) 877 878 def test_normalize_binary_operators(self): 879 ops_to_test = { 880 torch.add, 881 torch.mul, 882 torch.sub, 883 torch.div, 884 torch.floor_divide, 885 torch.remainder, 886 torch.eq, 887 torch.ne, 888 torch.lt, 889 torch.le, 890 torch.gt, 891 torch.ge, 892 } 893 894 # Test Tensor/Tensor callsite 895 for op in ops_to_test: 896 897 class WrapperMod(torch.nn.Module): 898 def forward(self, x, y): 899 return op(x, y) 900 901 traced = symbolic_trace(WrapperMod()) 902 normalized = NormalizeOperators(traced).transform() 903 x, y = torch.randn(3, 4), torch.randn(3, 4) 904 torch.testing.assert_close(traced(x, y), normalized(x, y)) 905 self.assertFalse( 906 any(n.target in ops_to_test for n in normalized.graph.nodes) 907 ) 908 909 # Test Tensor/scalar callsite 910 for op in ops_to_test: 911 912 class WrapperMod(torch.nn.Module): 913 def forward(self, x): 914 return op(x, 42) 915 916 traced = symbolic_trace(WrapperMod()) 917 normalized = NormalizeOperators(traced).transform() 918 x = torch.randn(3, 4) 919 torch.testing.assert_close(traced(x), normalized(x)) 920 self.assertFalse( 921 any(n.target in ops_to_test for n in normalized.graph.nodes) 922 ) 923 924 @skipIfNoTorchVision 925 def test_normalize_args(self): 926 m = resnet18() 927 928 class FunctionalTracer(torch.fx.Tracer): 929 def is_leaf_module( 930 self, m: torch.nn.Module, module_qualified_name: str 931 ) -> bool: 932 # `leaves` contains the set of standard `nn.Modules` that are not 933 # currently symbolically traceable. Ideally this set would be empty 934 leaves = {torch.nn.BatchNorm2d} 935 return type(m) in leaves 936 937 traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) 938 939 input = torch.randn(5, 3, 224, 224) 940 ref_outs = traced(input) 941 942 ShapeProp(traced).propagate(input) 943 traced = NormalizeArgs(traced).transform() 944 945 modules = dict(traced.named_modules()) 946 947 for node in traced.graph.nodes: 948 if node.op == "call_function" and node.target != operator.add: 949 self.assertEqual(len(node.args), 0) 950 elif node.op == "call_module": 951 submod_class = modules[node.target].__class__ 952 nn_class = getattr(torch.nn, submod_class.__name__) 953 if submod_class == nn_class: 954 self.assertEqual(len(node.args), 0) 955 traced(input) 956 self.assertEqual(traced(input), ref_outs) 957 958 def test_normalize_modules_exhaustive(self): 959 """ 960 Exhaustively test `Node.normalized_arguments` on all standard 961 torch.nn Module classes 962 """ 963 for test_params in module_tests + new_module_tests: 964 if "constructor" not in test_params: 965 constructor = getattr(torch.nn, test_params["module_name"]) 966 else: 967 constructor = test_params["constructor"] 968 969 if "constructor_args" not in test_params: 970 args = () 971 else: 972 args = test_params["constructor_args"] 973 974 mod = constructor(*args) 975 # Skip modules that are not standard `torch.nn` 976 # instances, including functionals. (functionals 977 # are tested in test_normalize_args) 978 if mod.__class__.__name__ not in dir(torch.nn): 979 continue 980 981 if "input_fn" not in test_params: 982 inputs = torch.randn(test_params["input_size"]) 983 else: 984 inputs = test_params["input_fn"]() 985 986 if not isinstance(inputs, (tuple, list)): 987 inputs = (inputs,) 988 989 params = ", ".join(f"v{i}" for i in range(len(inputs))) 990 991 # Generate a class to wrap this standard `nn.Module` instance 992 test_classname = f"Test{mod.__class__.__name__}" 993 test_mod_code = f""" 994class {test_classname}(torch.nn.Module): 995 def __init__(self, mod): 996 super().__init__() 997 self.mod = mod 998 999 def forward(self, {params}): 1000 return self.mod({params}) 1001 """ 1002 1003 gbls = {"torch": torch} 1004 exec(test_mod_code, gbls) 1005 1006 test_instance = gbls[test_classname](mod) 1007 traced = symbolic_trace(test_instance) 1008 1009 # Use `Node.normalized_arguments` to get a new set of arguments 1010 # to feed to the Module. Then, rewrite the node to only take 1011 # in those arguments as kwargs 1012 modules = dict(traced.named_modules()) 1013 for node in traced.graph.nodes: 1014 if node.op == "call_module": 1015 submod_class = modules[node.target].__class__ 1016 nn_class = getattr(torch.nn, submod_class.__name__) 1017 if submod_class == nn_class: 1018 normalized_args = node.normalized_arguments(traced) 1019 normalized_args2 = normalize_module( 1020 traced, node.target, node.args, node.kwargs 1021 ) 1022 assert normalized_args == normalized_args2 1023 assert normalized_args 1024 node.args = normalized_args.args 1025 node.kwargs = normalized_args.kwargs 1026 1027 traced.recompile() 1028 1029 # These Modules have an RNG in their forward, so testing 1030 # correctness by comparing outputs is not correct. Skip that 1031 # check for these 1032 stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"} 1033 1034 if mod.__class__.__name__ not in stochastic_modules: 1035 self.assertEqual(traced(*inputs), mod(*inputs)) 1036 1037 traced = NormalizeArgs(symbolic_trace(test_instance)).transform() 1038 modules = dict(traced.named_modules()) 1039 for node in traced.graph.nodes: 1040 if node.op == "call_module": 1041 submod_class = modules[node.target].__class__ 1042 nn_class = getattr(torch.nn, submod_class.__name__) 1043 if submod_class == nn_class: 1044 self.assertEqual(len(node.args), 0) 1045 1046 def test_normalize_args_preserve_meta(self): 1047 class MyModule(torch.nn.Module): 1048 def forward(self, a): 1049 return torch.add(a, 3) 1050 1051 m = MyModule() 1052 traced = symbolic_trace(m) 1053 1054 for node in traced.graph.nodes: 1055 if node.op == "call_function" and node.target == torch.add: 1056 node.meta["my_key"] = 7 1057 break 1058 else: 1059 self.fail("Didn't find call_function torch.add") 1060 1061 input = torch.randn(2, 3) 1062 ShapeProp(traced).propagate(input) 1063 traced = NormalizeArgs(traced).transform() 1064 1065 for node in traced.graph.nodes: 1066 if node.op == "call_function" and node.target == torch.add: 1067 self.assertTrue("my_key" in node.meta) 1068 self.assertEqual(node.meta["my_key"], 7) 1069 break 1070 else: 1071 self.fail("Didn't find call_function torch.add") 1072 1073 def test_normalize_args_perserve_type(self): 1074 class MyModule(torch.nn.Module): 1075 def forward(self, a: List[torch.Tensor]): 1076 return torch.add(a[0], a[1]) 1077 1078 m = MyModule() 1079 traced = symbolic_trace(m) 1080 traced = NormalizeArgs(traced).transform() 1081 1082 for node in traced.graph.nodes: 1083 if node.op == "placeholder": 1084 self.assertEqual(node.type, List[torch.Tensor]) 1085 1086 @skipIfNoTorchVision 1087 def test_annotate_returns_with_schema(self): 1088 m = resnet18() 1089 1090 traced_modules = symbolic_trace(m) 1091 traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform() 1092 for node in traced_modules_annotated.graph.nodes: 1093 if node.type is None: 1094 check = (node.op, node.target) 1095 self.assertIn( 1096 check, 1097 { 1098 ("placeholder", "x"), 1099 ("call_module", "maxpool"), 1100 ("call_function", operator.add), 1101 ("call_function", torch.flatten), 1102 ("output", "output"), 1103 } 1104 ) 1105 1106 # Smoke test torchscript compilation since now we're emitting type annotations 1107 torch.jit.script(traced_modules_annotated) 1108 1109 class FunctionalTracer(torch.fx.Tracer): 1110 def is_leaf_module( 1111 self, m: torch.nn.Module, module_qualified_name: str 1112 ) -> bool: 1113 # `leaves` contains the set of standard `nn.Modules` that are not 1114 # currently symbolically traceable. Ideally this set would be empty 1115 leaves = {torch.nn.BatchNorm2d} 1116 return type(m) in leaves 1117 1118 traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) 1119 1120 traced_functionals_annotated = AnnotateTypesWithSchema( 1121 traced_functionals 1122 ).transform() 1123 for node in traced_functionals_annotated.graph.nodes: 1124 if node.type is None: 1125 check = (node.op, node.target) 1126 excluded_nodes = { 1127 ("placeholder", "x"), 1128 # Return type differs based on boolean dispatch :( 1129 ("call_function", torch.nn.functional.max_pool2d), 1130 ("output", "output"), 1131 } 1132 # AnnotateTypesWithSchema doesn't work with bound C++ functions 1133 if not isinstance(node.target, BuiltinFunctionType): 1134 self.assertIn(check, excluded_nodes) 1135 1136 # Smoke test torchscript compilation since now we're emitting type annotations 1137 torch.jit.script(traced_functionals_annotated) 1138 1139 def test_annotate_getitem_node(self): 1140 class CustomType: 1141 pass 1142 1143 class CustomNamedTuple(NamedTuple): 1144 x: int 1145 y: float 1146 1147 class MyModule(torch.nn.Module): 1148 def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple): 1149 inp_0 = inp[0] 1150 inp_1 = inp[1] 1151 inp2_0 = inp2[0] 1152 inp3_x = inp3.x 1153 inp3_y = inp3.y 1154 return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y 1155 1156 my_module = MyModule() 1157 my_module_traced = torch.fx.symbolic_trace(my_module) 1158 1159 # by default, fx transform loses type annotation of getitem nodes. 1160 for node in my_module_traced.graph.nodes: 1161 if node.target == operator.getitem: 1162 assert node.type is None 1163 1164 annotate_getitem_nodes(my_module_traced.graph) 1165 1166 for node in my_module_traced.graph.nodes: 1167 if node.target == operator.getitem: 1168 self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") 1169 1170 def test_subgraph_uniquename(self): 1171 class MyModule(torch.nn.Module): 1172 def __init__(self) -> None: 1173 super().__init__() 1174 self.linear = torch.nn.Linear(4, 4) 1175 1176 def forward(self, a, b, c, d): 1177 add_1 = a + b 1178 add_2 = add_1 + c 1179 linear_1 = self.linear(add_1) 1180 add_3 = add_2 + d 1181 add_4 = add_2 + linear_1 1182 add_5 = add_3 + add_4 1183 return add_5 1184 1185 a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) 1186 mm = MyModule() 1187 traced = symbolic_trace(mm) 1188 1189 def split_cb(node: torch.fx.Node): 1190 if node.name == "a" or node.name == "b" or node.name == "add": 1191 return 0 1192 else: 1193 return 1 1194 1195 module_with_submodule = split_module(traced, mm, split_cb) 1196 self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) 1197 1198 def test_split_qualname_mapping(self): 1199 d_hid = 4 1200 1201 class ExampleCode(torch.nn.Module): 1202 def __init__(self) -> None: 1203 super().__init__() 1204 self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 1205 self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 1206 self.lin = torch.nn.Linear(d_hid, d_hid) 1207 1208 def forward(self, x): 1209 x = torch.mm(x, self.mm_param) 1210 x = torch.relu(x) 1211 x = torch.mm(x, self.mm_param) 1212 x = self.lin(x) 1213 x = torch.relu(x) 1214 x = torch.mm(x, self.mm_param2) 1215 x = self.lin(x) 1216 return x 1217 1218 my_module = ExampleCode() 1219 my_module_traced = symbolic_trace(my_module) 1220 1221 part_idx = 0 1222 1223 def split_callback(n : torch.fx.Node): 1224 nonlocal part_idx 1225 if (n.op, n.target) == ('call_module', 'lin'): 1226 part_idx += 1 1227 return part_idx 1228 1229 # split module in module with submodules 1230 qualname_map : Dict[str, str] = {} 1231 module_with_submodules = split_module( 1232 my_module_traced, my_module, split_callback, qualname_map 1233 ) 1234 expected_qualname_map = { 1235 'submod_1.lin': 'lin', 'submod_2.lin': 'lin' 1236 } 1237 self.assertEqual(qualname_map, expected_qualname_map) 1238 1239 def test_traceable_function_with_nonstandard_name(self): 1240 def foo(x): 1241 return torch.relu(x) 1242 1243 traced = symbolic_trace_with_rewrite(foo) 1244 1245 def test_to_folder(self): 1246 class Test(torch.nn.Module): 1247 def __init__(self) -> None: 1248 super().__init__() 1249 self.W = torch.nn.Parameter(torch.randn(2)) 1250 self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) 1251 self.linear = torch.nn.Linear(2, 2) 1252 self.attr = torch.randn(2) 1253 self.attr2 = torch.nn.Buffer(torch.randn(2)) 1254 self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) 1255 1256 def forward(self, x): 1257 return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) 1258 1259 mod = symbolic_trace(Test()) 1260 module_name = "Foo" 1261 import tempfile 1262 from pathlib import Path 1263 1264 with tempfile.TemporaryDirectory() as tmp_dir: 1265 tmp_dir = Path(tmp_dir) 1266 mod.to_folder(tmp_dir, module_name) 1267 # Recipe taken from here: 1268 # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly 1269 import importlib.util 1270 1271 spec = importlib.util.spec_from_file_location( 1272 module_name, tmp_dir / "__init__.py" 1273 ) 1274 module = importlib.util.module_from_spec(spec) 1275 sys.modules[module_name] = module 1276 spec.loader.exec_module(module) 1277 t = torch.randn(2, 2) 1278 self.assertEqual(module.Foo()(t), mod(t)) 1279 1280 def test_fetch(self): 1281 attrs_for_lowering: Dict[str, List[str]] = { 1282 "torch.nn.modules.conv.Conv2d": [ 1283 "weight", 1284 "bias", 1285 "kernel_size", 1286 "stride", 1287 "padding", 1288 "dilation", 1289 "groups", 1290 "padding_mode", 1291 ], 1292 "torch.nn.modules.batchnorm.BatchNorm2d": [ 1293 "weight", 1294 "bias", 1295 "running_mean", 1296 "running_var", 1297 "eps", 1298 ], 1299 } 1300 1301 class TestModule(torch.nn.Module): 1302 def __init__(self) -> None: 1303 super().__init__() 1304 self.conv = torch.nn.Conv2d(3, 3, 2) 1305 self.bn = torch.nn.BatchNorm2d(3) 1306 1307 def forward(self, a): 1308 a = self.conv(a) 1309 a += a 1310 return self.bn(a) 1311 1312 mod = TestModule() 1313 traced = symbolic_trace(mod) 1314 lift_lowering_attrs_to_nodes(traced) 1315 1316 for node in traced.graph.nodes: 1317 if node.op == "call_module": 1318 assert hasattr(node, "attrs_for_lowering") 1319 para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] 1320 1321 # node.attrs_for_lowering has an addition field of class name 1322 assert len(para_list) + 1 == len(node.attrs_for_lowering) 1323 for p_name in para_list: 1324 assert p_name in node.attrs_for_lowering 1325 1326 def test_merge_matmuls(self): 1327 """ 1328 A collection of test cases for torch.fx.experimental.merge_matmul, 1329 a graph transformation that merges matrix multiplication operations. 1330 """ 1331 # Utility function for counting matmuls for test assertions. 1332 def _count_matmuls(mod): 1333 gm = torch.fx.symbolic_trace(mod) 1334 1335 num_matmuls = 0 1336 for node in gm.graph.nodes: 1337 if node.target == torch.matmul: 1338 num_matmuls += 1 1339 1340 return num_matmuls 1341 1342 # Simple test case in which there are two matmuls of the same size to merge. 1343 class SimpleMergeMatmulModule(torch.nn.Module): 1344 def __init__(self, rhs): 1345 super().__init__() 1346 self.rhs = rhs 1347 1348 def forward(self, x, y): 1349 a = torch.matmul(x, self.rhs) 1350 b = torch.matmul(y, self.rhs) 1351 return a + b 1352 1353 # Initialize inputs. 1354 a = torch.randn(3, 3) 1355 b = torch.randn(3, 3) 1356 1357 # Initialize RHS for matmuls. 1358 rhs = torch.randn(3, 4) 1359 1360 # Construct SimpleMergeMatmulModule and call merge_matmul on it. 1361 module = SimpleMergeMatmulModule(rhs) 1362 opt_module = merge_matmul.merge_matmul(module) 1363 1364 # Numerical correctness check. 1365 before = module(a, b) 1366 after = opt_module(a, b) 1367 before.allclose(after) 1368 1369 # Basic graph structure check; original module should have 2 matmuls 1370 # and optimized module should have 1. 1371 self.assertEqual(_count_matmuls(module), 2) 1372 self.assertEqual(_count_matmuls(opt_module), 1) 1373 1374 # Test case in which there are multiple matmuls of different sizes to merge. 1375 class FiveMergeMatmulModule(torch.nn.Module): 1376 def __init__(self, rhs): 1377 super().__init__() 1378 self.rhs = rhs 1379 1380 def forward(self, a, b, c, d, e): 1381 s = torch.tensor([]) 1382 matmuls = [] 1383 1384 # For some reason using a list comprehension or for-loop for this 1385 # doesn't work. 1386 matmuls.append(torch.matmul(a, self.rhs)) 1387 matmuls.append(torch.matmul(b, self.rhs)) 1388 matmuls.append(torch.matmul(c, self.rhs)) 1389 matmuls.append(torch.matmul(d, self.rhs)) 1390 matmuls.append(torch.matmul(e, self.rhs)) 1391 1392 for m in matmuls: 1393 s += torch.sum(m) 1394 1395 return s 1396 1397 # Initialize inputs. 1398 inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] 1399 1400 # Initialize RHS. 1401 rhs = torch.randn(5, 4) 1402 1403 # Construct FiveMergeMatmulModule and call merge_matmul on it. 1404 module = FiveMergeMatmulModule(rhs) 1405 opt_module = merge_matmul.merge_matmul(module) 1406 1407 # Numerical correctness check. 1408 before = module(*inputs) 1409 after = opt_module(*inputs) 1410 before.allclose(after) 1411 1412 # Basic graph structure check; original module should have len(inputs) matmuls 1413 # and optimized module should have 1. 1414 self.assertEqual(_count_matmuls(module), len(inputs)) 1415 self.assertEqual(_count_matmuls(opt_module), 1) 1416 1417 # Simple test case in which two matmuls cannot be merged due to a data dependency between 1418 # the LHS operands. 1419 class UnmergeableMatmulModule(torch.nn.Module): 1420 def __init__(self, rhs): 1421 super().__init__() 1422 self.rhs = rhs 1423 1424 def forward(self, x): 1425 a = torch.matmul(x, self.rhs) 1426 a_abs = torch.abs(a) 1427 b = torch.matmul(a_abs.transpose(1, 0), self.rhs) 1428 return b 1429 1430 # Initialize inputs. 1431 a = torch.randn(3, 3) 1432 1433 # Initialize RHS for matmuls. 1434 rhs = torch.randn(3, 4) 1435 1436 # Construct UnmergeableMatmulModule and call merge_matmul on it. 1437 module = UnmergeableMatmulModule(rhs) 1438 opt_module = merge_matmul.merge_matmul(module) 1439 1440 # Numerical correctness check. 1441 before = module(a) 1442 after = opt_module(a) 1443 before.allclose(after) 1444 1445 # Basic graph structure check; the number of matrix multiplcations should not have changed. 1446 self.assertEqual(_count_matmuls(module), 2) 1447 self.assertEqual(_count_matmuls(opt_module), 2) 1448 1449 def test_type_matches(self): 1450 should_be_equal = [ 1451 (int, int), 1452 (numbers.Number, int), 1453 (numbers.Number, float), 1454 (int, type(torch.float)), 1455 (Union[int, float], int), 1456 (Union[int, float], float), 1457 (List[int], int), 1458 (List[int], create_type_hint([int, int])), 1459 (List[int], create_type_hint((int, int))), 1460 (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), 1461 ( 1462 List[torch.Tensor], 1463 create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), 1464 ), 1465 (torch.Tensor, torch.nn.Parameter), 1466 (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), 1467 (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), 1468 (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), 1469 ( 1470 List[torch.Tensor], 1471 create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), 1472 ), 1473 (torch.Tensor, torch.nn.Parameter), 1474 (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), 1475 (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), 1476 (Optional[List[torch.Tensor]], List[torch.Tensor]), 1477 (Optional[List[int]], List[int]), 1478 ] 1479 for sig_type, arg_type in should_be_equal: 1480 self.assertTrue(type_matches(sig_type, arg_type)) 1481 1482 should_fail = [ 1483 (int, float), 1484 (Union[int, float], str), 1485 (List[torch.Tensor], List[int]), 1486 ] 1487 1488 for sig_type, arg_type in should_fail: 1489 self.assertFalse(type_matches(sig_type, arg_type)) 1490 1491 @skipIfNoMkldnn 1492 def test_optimize_for_inference_cpu(self): 1493 import torch.nn as nn 1494 1495 class Foo(nn.Module): 1496 def __init__(self) -> None: 1497 super().__init__() 1498 layers = [] 1499 layers2 = [] 1500 for _ in range(10): 1501 layers.append(nn.Conv2d(3, 3, 1)) 1502 layers.append(nn.BatchNorm2d(3)) 1503 layers.append(nn.ReLU()) 1504 1505 layers2.append(nn.Conv2d(3, 3, 1)) 1506 layers2.append(nn.BatchNorm2d(3)) 1507 layers2.append(nn.ReLU()) 1508 self.model = nn.Sequential(*layers) 1509 self.model2 = nn.Sequential(*layers2) 1510 1511 def forward(self, x): 1512 return self.model(x) + self.model2(x) 1513 1514 N, C, H, W, = ( 1515 1, 1516 3, 1517 224, 1518 224, 1519 ) 1520 inp = torch.randn(N, C, H, W) 1521 with torch.no_grad(): 1522 model = Foo().eval() 1523 optimized_model = optimization.optimize_for_inference(model) 1524 torch.testing.assert_close(model(inp), optimized_model(inp)) 1525 1526 optimized_model2 = optimization.optimize_for_inference( 1527 model, pass_config={"remove_dropout": False} 1528 ) 1529 torch.testing.assert_close(model(inp), optimized_model2(inp)) 1530 1531 @skipIfNoTorchVision 1532 @skipIfNoMkldnn 1533 def test_optimize_for_inference_cpu_torchvision(self): 1534 models = [ 1535 torchvision.models.resnet18, 1536 torchvision.models.resnet50, 1537 torchvision.models.densenet121, 1538 torchvision.models.shufflenet_v2_x1_0, 1539 torchvision.models.vgg16, 1540 torchvision.models.mobilenet_v2, 1541 torchvision.models.mnasnet1_0, 1542 torchvision.models.resnext50_32x4d, 1543 ] 1544 with torch.no_grad(): 1545 for model_type in models: 1546 model = model_type() 1547 C, H, W, = ( 1548 3, 1549 224, 1550 224, 1551 ) 1552 inp = torch.randn(3, C, H, W) 1553 model(inp) 1554 model.eval() 1555 inp = torch.randn(1, C, H, W) 1556 heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0) 1557 optimized_model = optimization.optimize_for_inference(model) 1558 1559 orig_out = model(inp) 1560 new_out = optimized_model(inp) 1561 torch.testing.assert_close(orig_out, new_out) 1562 1563 1564class TestNormalizeOperators(JitTestCase): 1565 @onlyCPU 1566 @ops(op_db, allowed_dtypes=(torch.float,)) 1567 def test_normalize_operator_exhaustive(self, device, dtype, op): 1568 # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) 1569 fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"} 1570 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 1571 if isinstance(op.op, torch._ops.OpOverload): 1572 self.skipTest("normalize operator doesn't work on torch.ops") 1573 for sample_input in sample_inputs_itr: 1574 unsupported_arg_type = False 1575 arg_values = [sample_input.input] + list(sample_input.args) 1576 kwarg_values = sample_input.kwargs 1577 arg_types = [] 1578 kwarg_types = {} 1579 1580 def jit_infer_type(v): 1581 inferred_arg_type = torch._C._jit_try_infer_type(v) 1582 assert inferred_arg_type.success() 1583 t = _torchscript_type_to_python_type(inferred_arg_type.type()) 1584 return t 1585 1586 for v in arg_values: 1587 if isinstance(v, torch.Tensor): 1588 arg_types.append(type(v)) 1589 else: 1590 if isinstance(v, complex): 1591 # Complex type not supported in FX 1592 unsupported_arg_type = True 1593 arg_types.append(jit_infer_type(v)) 1594 1595 for k, v in kwarg_values.items(): 1596 if isinstance(v, torch.Tensor): 1597 kwarg_types[k] = type(v) 1598 else: 1599 if isinstance(v, complex): 1600 # Complex type not supported in FX 1601 unsupported_arg_type = True 1602 kwarg_types[k] = jit_infer_type(v) 1603 1604 if unsupported_arg_type: 1605 continue 1606 # Test normalize_function by itself 1607 ref_out = op.op(*arg_values, **kwarg_values) 1608 norm_args_and_kwargs = normalize_function( 1609 op.op, arg_values, kwarg_values, arg_types, kwarg_types 1610 ) 1611 if norm_args_and_kwargs is None: 1612 raise RuntimeError( 1613 """ 1614 FX failed to normalize op - add the op to the op_skip list. 1615 A common reason is if your OpInfo was implemented with a lambda 1616 - otherwise, file an issue 1617 """ 1618 ) 1619 test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs) 1620 self.assertEqual(test_out, ref_out) 1621 1622 # Test normalized_arguments as part of FX 1623 if op.name in fx_fail: 1624 continue 1625 param_names = [] 1626 param_values = [] 1627 fx_args = [] 1628 1629 idx = 0 1630 1631 def process_arg(arg, name): 1632 if isinstance(arg, torch.Tensor): 1633 param_names.append(name) 1634 param_values.append(arg) 1635 return name 1636 else: 1637 return f"{repr(arg)}" 1638 1639 def process_arg_with_idx(arg): 1640 nonlocal idx 1641 res = process_arg(arg, f"arg_{idx}") 1642 idx = idx + 1 1643 return res 1644 1645 def str_arg(arg): 1646 if isinstance(arg, tuple): 1647 args = [f"{str_arg(v)}, " for v in arg] 1648 return f"({' '.join(args)})" 1649 elif isinstance(arg, list): 1650 args = [f"{str_arg(v)}" for v in arg] 1651 return f"[{', '.join(args)}]" 1652 else: 1653 return arg 1654 1655 for v in arg_values: 1656 arg = pytree.tree_map(process_arg_with_idx, v) 1657 fx_args.append(str_arg(arg)) 1658 1659 for k, v in kwarg_values.items(): 1660 arg = pytree.tree_map(functools.partial(process_arg, name=k), v) 1661 fx_args.append(f"{k} = {str_arg(arg)}") 1662 1663 code = f""" 1664class TestModule(torch.nn.Module): 1665 def forward(self, {', '.join(param_names)}): 1666 return torch.{op.name}({', '.join(fx_args)}) 1667 """ 1668 1669 g = {"torch": torch, "inf": math.inf} 1670 exec(code, g) 1671 TestModule = g["TestModule"] 1672 1673 m = TestModule() 1674 traced = torch.fx.symbolic_trace(m) 1675 ref_out = traced(*param_values) 1676 1677 for node in traced.graph.nodes: 1678 if node.op == "call_function": 1679 normalized_args = node.normalized_arguments( 1680 traced, arg_types, kwarg_types 1681 ) 1682 assert normalized_args 1683 node.args = normalized_args.args 1684 node.kwargs = normalized_args.kwargs 1685 traced.recompile() 1686 1687 test_out = traced(*param_values) 1688 self.assertEqual(test_out, ref_out) 1689 1690 def test_normalize_quantized_eb(self): 1691 target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets 1692 args = ( 1693 torch.empty((2, 3), dtype=torch.uint8), 1694 torch.empty((2,), dtype=torch.int64), 1695 torch.empty((2,), dtype=torch.int64), 1696 ) 1697 norm_args_and_kwargs = normalize_function( 1698 target, args, normalize_to_only_use_kwargs=True 1699 ) 1700 self.assertTrue(norm_args_and_kwargs is not None) 1701 self.assertEqual( 1702 set(norm_args_and_kwargs.kwargs.keys()), 1703 { 1704 "weight", 1705 "indices", 1706 "offsets", 1707 "scale_grad_by_freq", 1708 "mode", 1709 "pruned_weights", 1710 "per_sample_weights", 1711 "compressed_indices_mapping", 1712 "include_last_offset", 1713 }, 1714 ) 1715 self.assertEqual(norm_args_and_kwargs.args, ()) 1716 1717 def test_normalize_args_op_overload(self): 1718 for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: 1719 inp1 = torch.rand([1]) 1720 inp2 = torch.rand([4]) 1721 args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) 1722 self.assertIs(kwargs["input"], inp1) 1723 self.assertIs(kwargs["the_template"], inp2) 1724 1725 1726if TEST_Z3: 1727 import z3 1728 1729 import torch._dynamo.config 1730 1731 from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str 1732 from torch.utils._sympy.functions import FloorDiv, Mod 1733 1734 class TestTranslationValidation(TestCase): 1735 def _prepare_for_translation_validation(self): 1736 validator = TranslationValidator() 1737 1738 # SymPy symbols. 1739 s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) 1740 1741 # Z3 symbols. 1742 [validator.add_var(s, int) for s in (s0, s1, s2)] 1743 z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) 1744 1745 return (s0, s1, s2), (z0, z1, z2), validator 1746 1747 def test_sympy_to_z3(self): 1748 1749 ( 1750 (s0, s1, s2), 1751 (z0, z1, z2), 1752 validator, 1753 ) = self._prepare_for_translation_validation() 1754 1755 test_cases = [ 1756 # Integer constants. 1757 (sympy.S.Zero, z3.IntVal(0)), 1758 (sympy.S.One, z3.IntVal(1)), 1759 (sympy.S.NegativeOne, z3.IntVal(-1)), 1760 (sympy.Integer(2), z3.IntVal(2)), 1761 ( 1762 s0, 1763 z0, 1764 ), 1765 # Arithmetic operations. 1766 *[ 1767 (op(s0, s1), op(z0, z1)) 1768 for op in ( 1769 operator.add, 1770 operator.mul, 1771 operator.pow, 1772 ) 1773 ], 1774 # Logical operations. 1775 *[ 1776 (sympy_op(s0, s1), z3_op(z0, z1)) 1777 for sympy_op, z3_op in ( 1778 (sympy.Eq, operator.eq), 1779 (sympy.Ne, operator.ne), 1780 (sympy.Lt, operator.lt), 1781 (sympy.Le, operator.le), 1782 (sympy.Gt, operator.gt), 1783 (sympy.Ge, operator.ge), 1784 ) 1785 ], 1786 # Other operations. 1787 ( 1788 s0 - s1, 1789 z0 + z3.IntVal(-1) * z1, 1790 ), 1791 ( 1792 s0 / s1, 1793 z3.ToReal(z0) * (z1**-1), 1794 ), 1795 (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), 1796 (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), 1797 ( 1798 Mod(s2, (s0 / s1)), 1799 z2 1800 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) 1801 * (z3.ToReal(z0) * z1**-1), 1802 ), 1803 ( 1804 Mod(s2, s0**3), 1805 z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, 1806 ), 1807 ] 1808 1809 toZ3 = SympyToZ3(validator) 1810 for sympy_expr, z3_expr in test_cases: 1811 result = toZ3.run(sympy_expr) 1812 self.assertTrue( 1813 z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" 1814 ) 1815 1816 def test_sat(self): 1817 ( 1818 (s0, s1, s2), 1819 (z0, z1, z2), 1820 validator, 1821 ) = self._prepare_for_translation_validation() 1822 1823 validator.add_source_expr(z0 > 5) 1824 validator.add_source_expr(z1 / 2 > z0) 1825 1826 # Solutions for target is a subset of the solutions for the source. 1827 validator.add_target_expr(s0 > 20) 1828 validator.add_target_expr(s1 > s0**2) 1829 1830 validator.validate() 1831 1832 def test_unsat(self): 1833 ( 1834 (s0, s1, s2), 1835 (z0, z1, z2), 1836 validator, 1837 ) = self._prepare_for_translation_validation() 1838 1839 validator.add_source_expr(z0 > 5) 1840 validator.add_source_expr(z1 / 2 > z0) 1841 1842 # Solutions for target is NOT a subset of the solutions for the source. 1843 validator.add_target_expr(s0 > 20) 1844 # This expression is less restrictive than its counterpart. 1845 validator.add_target_expr(s1 > s0 + 2) 1846 1847 with self.assertRaisesRegex(ValidationException, "translation validation failed."): 1848 validator.validate() 1849 1850 def test_z3str(self): 1851 a = z3.Int("a") 1852 b = z3.Int("b") 1853 special = z3.Real("this.size()[2]") 1854 1855 test_cases = [ 1856 (z3.IntVal(42), "42"), 1857 # Variable. 1858 (a, "a"), 1859 # Name with special characters. 1860 (special, "this.size()[2]"), 1861 # Renamed function fpplications. 1862 (a != b, "(!= a b)"), 1863 (a ** b, "(pow a b)"), 1864 # Chain of associative operations. 1865 *[ 1866 (op(op(a, 5), b), f"({opstr} 5 a b)") 1867 for op, opstr in [ 1868 (operator.add, "+"), 1869 (operator.mul, "*") 1870 ] 1871 ], 1872 # Revert 'Not' conversions. 1873 (a != b, "(!= a b)"), 1874 (a < b, "(> b a)"), 1875 (a > b, "(> a b)"), 1876 # Ignore 'ToInt' and 'ToReal' functions. 1877 (z3.ToInt(special) + a, "(+ this.size()[2] a)"), 1878 (z3.ToReal(a + b), "(+ a b)"), 1879 # Convert to floor division: 'idiv'. 1880 (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"), 1881 ] 1882 1883 for expr, expected in test_cases: 1884 self.assertEqual(z3str(expr), expected) 1885 1886 1887instantiate_device_type_tests(TestNormalizeOperators, globals()) 1888 1889if __name__ == "__main__": 1890 run_tests() 1891