1# mypy: allow-untyped-defs 2import copy 3import json 4import re 5import weakref 6from collections import defaultdict 7from typing import Any, Dict 8 9import torch 10import torch.nn 11from torch._guards import detect_fake_mode 12from torch.autograd.graph import register_multi_grad_hook 13from torch.distributed._tools.mod_tracker import ModTracker 14from torch.distributed.tensor._api import DTensor 15from torch.nn.modules.module import ( 16 register_module_forward_hook, 17 register_module_forward_pre_hook, 18 register_module_full_backward_pre_hook, 19) 20from torch.utils._python_dispatch import TorchDispatchMode 21from torch.utils._pytree import tree_flatten 22 23 24__all__ = ["CommDebugMode"] 25 26funcol_native = torch.ops._c10d_functional 27funcol_py = torch.ops.c10d_functional 28funcol_autograd = torch.ops._c10d_functional_autograd 29c10d_ops = torch.ops.c10d 30 31NATIVE_TO_PY_MAPPING = { 32 funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor, 33 funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced, 34 funcol_native.all_reduce: funcol_py.all_reduce, 35 funcol_native.all_reduce_coalesced: funcol_py.all_reduce_coalesced, 36 funcol_native.all_to_all_single: funcol_py.all_to_all_single, 37 funcol_native.broadcast: funcol_py.broadcast, 38 funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor, 39 funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced, 40 # functional ops 41 funcol_autograd.all_to_all_single: funcol_py.all_to_all_single, 42} 43 44c10d_collective_ops = { 45 c10d_ops._allgather_base_, 46 c10d_ops._reduce_scatter_base_, 47 c10d_ops.allgather_, 48 c10d_ops.allgather_coalesced_, 49 c10d_ops.allgather_into_tensor_coalesced_, 50 c10d_ops.allreduce_, 51 c10d_ops.allreduce_coalesced_, 52 c10d_ops.alltoall_, 53 c10d_ops.alltoall_base_, 54 c10d_ops.broadcast_, 55 c10d_ops.gather_, 56 c10d_ops.scatter_, 57 c10d_ops.reduce_, 58 c10d_ops.reduce_scatter_, 59 c10d_ops.reduce_scatter_tensor_coalesced_, 60} 61 62trivial_ops = { 63 "aten.detach.default", 64 "aten.t.default", 65 "aten.view.default", 66 "aten._to_copy.default", 67 "aten.as_strided.default", 68 "aten.transpose.int", 69} 70 71 72class _CommModeModuleTracker(ModTracker): 73 """ 74 Inherits ModuleTracker and expands on its functionality to track the 75 parameters and sharding information of a model at a module-level 76 """ 77 78 def __init__(self): 79 super().__init__() 80 self.module_helper_dict = {} 81 self.module_parameters_dict = {} 82 self.module_parents_dict = {} 83 self.register_forward_hook_handles = {} 84 self.parent_dict = {} 85 self.parent_list = [] 86 self.sharding_dict = {} 87 self.activation_checkpointing = False 88 self.name = "" 89 90 def _fw_set_module_hook(self, mod, input, output): 91 """ 92 Updates the current module after module finishes running and 93 all other hooks are resolved 94 """ 95 96 if self.is_bw: 97 self.activation_checkpointing = True 98 else: 99 self.activation_checkpointing = False 100 101 if not self.activation_checkpointing: 102 # module is no longer parent of next modules 103 self.parent_list.pop() 104 105 # set current module to previous parent module 106 self.name = self.parent_list[-1] 107 108 def _fw_pre_hook(self, mod, input): 109 """ 110 This function is called before the forward pass of a module. It 111 collects the parameters and sharding information of a module and 112 stores it in a dictionary. 113 """ 114 if self.is_bw: 115 self.activation_checkpointing = True 116 else: 117 self.activation_checkpointing = False 118 119 self.name = super()._get_mod_name(mod) 120 w_mod = weakref.ref(mod) 121 122 # adds current sub-module to module tracker parent class 123 super()._get_append_fn(w_mod, self.name, False)() 124 125 args, _ = tree_flatten(input) 126 tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 127 if not self.is_bw and tensors: 128 register_multi_grad_hook( 129 tensors, super()._get_pop_fn(w_mod, self.name, True) 130 ) 131 132 if not self.activation_checkpointing: 133 # contains information about module ordering and depth in the module tree 134 if self.name not in self.module_helper_dict: 135 self.module_helper_dict[self.name] = {} 136 137 self.module_helper_dict[self.name]["module_type"] = ( 138 str(type(mod)).replace("<", "").replace(">", "") 139 ) 140 self.module_helper_dict[self.name]["depth"] = len(self.parents) - 1 141 142 for param_name, param in mod.named_parameters(recurse=False): 143 if self.name not in self.module_parameters_dict: 144 self.module_parameters_dict[self.name] = {} 145 146 self.module_parameters_dict[self.name][param_name] = param.data 147 148 if isinstance(param.data, DTensor): 149 key_name = self.name + "." + param_name 150 self.sharding_dict[key_name] = param.data.placements 151 152 if "parameters" not in self.module_helper_dict[self.name]: 153 self.module_helper_dict[self.name]["parameters"] = {} 154 155 self.module_helper_dict[self.name]["parameters"][param_name] = str( 156 param.data.placements 157 ) 158 159 # used to store module's parents to ensure correctness in backward pass/checkpointing 160 if self.name not in self.module_parents_dict: 161 self.module_parents_dict[self.name] = copy.deepcopy(self.parents) 162 163 # used to create parent-child module associations for json dumps 164 parent = self.parent_list[-1] 165 if parent not in self.parent_dict: 166 self.parent_dict[parent] = [] 167 168 self.parent_dict[parent].append(self.name) 169 self.parent_list.append(self.name) 170 171 self.register_forward_hook_handles[self.name] = mod.register_forward_hook( 172 self._fw_set_module_hook 173 ) 174 175 def _fw_post_hook(self, mod, input, output): 176 """ 177 This function is called when the forward pass of a module is called. 178 It updates the module tracker and removes the module from parent data 179 """ 180 181 super()._fw_post_hook(mod, input, output) 182 183 def _bw_hook(self, mod, output): 184 """ 185 This function is called when the backward pass of a module is called. It 186 updates the current module for backward passes 187 """ 188 self.activation_checkpointing = False 189 self.name = super()._get_mod_name(mod) 190 191 def __enter__(self): 192 self.activation_checkpointing = False 193 self.module_parameters_dict.clear() 194 self.sharding_dict.clear() 195 self.parent_dict.clear() 196 self.parent_list = ["Global"] 197 self.module_helper_dict.clear() 198 self.module_helper_dict["Global"] = {"depth": 0} 199 self.module_parents_dict.clear() 200 self.module_parents_dict["Global"] = set() 201 self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) 202 self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) 203 self.register_forward_hook_handles.clear() 204 self._bw_handle = register_module_full_backward_pre_hook(self._bw_hook) 205 self.name = "Global" 206 207 def __exit__(self, *args): 208 super().__exit__(*args) 209 self._bw_handle.remove() 210 211 # removes all forward_hook handles added in the pre-hook 212 for handle in self.register_forward_hook_handles.values(): 213 handle.remove() 214 215 def print_paramater_info(self): 216 print(self.module_parameters_dict) 217 218 def print_sharding_info(self): 219 for key, value in self.sharding_dict.items(): 220 print(key + ": " + str(value)) 221 222 223class CommDebugMode(TorchDispatchMode): 224 """ 225 :class:`CommDebugMode` is a context manager that counts the number of 226 functional collectives within its context. It does this using a 227 ``TorchDispatchMode``. 228 229 .. note: Not all collectives are supported yet. 230 231 Example usage 232 233 .. code-block:: python 234 235 mod = ... 236 comm_mode = CommDebugMode() 237 with comm_mode: 238 mod.sum().backward() 239 print(comm_mode.get_comm_counts()) 240 """ 241 242 def __init__(self): 243 self.comm_counts: Dict[Any, int] = defaultdict(int) 244 self.comm_module_counts = {} 245 self.comm_module_operation_counts = {} 246 self.comm_registry = set() 247 for native_op, py_op in NATIVE_TO_PY_MAPPING.items(): 248 self.comm_registry.add(native_op) 249 self.comm_registry.add(py_op) 250 251 self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) 252 self.advanced_module_tracker = _CommModeModuleTracker() 253 254 def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): 255 """ 256 Creates json file used to build browser visual 257 0. prints module-level collective counts 258 1. prints dTensor operations not included in trivial operations 259 2. prints operations not included in trivial operations 260 3. prints all operations 261 """ 262 263 ( 264 include_DTensor_ops, 265 include_module_data, 266 include_ops, 267 include_trivial_ops, 268 ) = self._set_noise_parameters(noise_level) 269 270 # recursively builds json data 271 def add_json_information(json_dict, fqn): 272 json_dict["fqn"] = fqn 273 json_dict["module_type"] = "" 274 json_dict["parameters"] = [] 275 json_dict["children"] = [] 276 json_dict["collectives_forward"] = [] 277 json_dict["collectives_backward"] = [] 278 json_dict["operations_forward"] = [] 279 json_dict["operations_backward"] = [] 280 281 # adds module layer type and parameters, and their sharding 282 if ( 283 "module_type" in self.advanced_module_tracker.module_helper_dict[fqn] 284 and include_module_data 285 ): 286 json_dict[ 287 "module_type" 288 ] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] 289 290 if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: 291 for ( 292 param_name, 293 placement, 294 ) in self.advanced_module_tracker.module_helper_dict[fqn][ 295 "parameters" 296 ].items(): 297 json_dict["parameters"].append((param_name, placement)) 298 299 # adds module collective information 300 if fqn in self.comm_module_counts: 301 for collective, count in self.comm_module_counts[fqn][ 302 "forward" 303 ].items(): 304 json_dict["collectives_forward"].append((str(collective), count)) 305 306 for collective, count in self.comm_module_counts[fqn][ 307 "backward" 308 ].items(): 309 json_dict["collectives_backward"].append((str(collective), count)) 310 311 # adds module operation information 312 forward_operations = [] 313 backward_operations = [] 314 checkpointing_operations = [] 315 316 # only get operations if the minimum operation noise level is set to true 317 if include_DTensor_ops: 318 if fqn in self.comm_module_operation_counts: 319 ( 320 forward_operations, 321 backward_operations, 322 checkpointing_operations, 323 ) = self._get_operations_list( 324 self.comm_module_operation_counts[fqn] 325 ) 326 327 # remove all operations who don't have DTensor inputs 328 if not include_ops: 329 forward_operations = [ 330 op for op in forward_operations if len(op["input_sharding"]) 331 ] 332 backward_operations = [ 333 op for op in backward_operations if len(op["input_sharding"]) 334 ] 335 checkpointing_operations = [ 336 op for op in checkpointing_operations if len(op["input_sharding"]) 337 ] 338 339 # remove all operations in trivial operations set 340 if not include_trivial_ops: 341 forward_operations = [ 342 op 343 for op in forward_operations 344 if str(op["name"]) not in trivial_ops 345 ] 346 backward_operations = [ 347 op 348 for op in backward_operations 349 if str(op["name"]) not in trivial_ops 350 ] 351 checkpointing_operations = [ 352 op 353 for op in checkpointing_operations 354 if str(op["name"]) not in trivial_ops 355 ] 356 357 # converts operation information into string format for json.dumps() 358 forward_operations = copy.deepcopy(forward_operations) 359 for op in forward_operations: 360 op["name"] = str(op["name"]) 361 362 for i in range(len(op["input_sharding"])): 363 op["input_sharding"][i] = str(op["input_sharding"][i]) 364 op["input_shape"][i] = str(op["input_shape"][i]) 365 366 backward_operations = copy.deepcopy(backward_operations) 367 for op in backward_operations: 368 op["name"] = str(op["name"]) 369 370 for i in range(len(op["input_sharding"])): 371 op["input_sharding"][i] = str(op["input_sharding"][i]) 372 op["input_shape"][i] = str(op["input_shape"][i]) 373 374 checkpointing_operations = copy.deepcopy(checkpointing_operations) 375 for op in checkpointing_operations: 376 op["name"] = str(op["name"]) 377 378 for i in range(len(op["input_sharding"])): 379 op["input_sharding"][i] = str(op["input_sharding"][i]) 380 op["input_shape"][i] = str(op["input_shape"][i]) 381 382 json_dict["operations_forward"] = forward_operations 383 json_dict["operations_backward"] = backward_operations 384 json_dict["operations_checkpointing"] = checkpointing_operations 385 386 if fqn not in self.advanced_module_tracker.parent_dict: 387 return json_dict 388 389 # recursively adds module's children 390 for ele in self.advanced_module_tracker.parent_dict[fqn]: 391 json_dict["children"].append(add_json_information({}, ele)) 392 393 return json_dict 394 395 json_dict: Dict[str, Any] = {} 396 add_json_information(json_dict, "Global") 397 398 # converts dictonary into json file 399 with open(file_name, "w") as json_file: 400 json.dump(json_dict, json_file, indent=4) 401 402 def generate_comm_debug_tracing_table(self, noise_level=3): 403 """ 404 Generates detailed table displaying operations and collective tracing information 405 on a module level. Amount of information is dependent on noise_level 406 407 0. prints module-level collective counts 408 1. prints dTensor operations not included in trivial operations, module information 409 2. prints operations not included in trivial operations 410 3. prints all operations 411 """ 412 413 ( 414 include_DTensor_ops, 415 include_module_data, 416 include_ops, 417 include_trivial_ops, 418 ) = self._set_noise_parameters(noise_level) 419 420 table = "" 421 for fqn in self.advanced_module_tracker.module_helper_dict: 422 # setting up indentations for table formatting 423 indent = " " * ( 424 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] 425 ) 426 table += f"{indent}{fqn}\n" 427 428 if include_module_data: 429 if ( 430 "module_type" 431 in self.advanced_module_tracker.module_helper_dict[fqn] 432 ): 433 module_type = self.advanced_module_tracker.module_helper_dict[fqn][ 434 "module_type" 435 ] 436 table += f"{indent}*module type: {module_type}\n" 437 438 if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: 439 table += f"{indent}*Parameter List\n" 440 for ( 441 param_name, 442 placement, 443 ) in self.advanced_module_tracker.module_helper_dict[fqn][ 444 "parameters" 445 ].items(): 446 table += f"{indent} *{param_name}: {placement}\n" 447 448 indent += " " 449 collective_indent = " " * ( 450 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 2 451 ) 452 operation_indent = " " * ( 453 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 3 454 ) 455 456 # separate the module's collective and operations by forward and backward 457 forward_collectives = {} 458 backward_collectives = {} 459 if fqn in self.comm_module_counts: 460 forward_collectives = self.comm_module_counts[fqn]["forward"] 461 backward_collectives = self.comm_module_counts[fqn]["backward"] 462 463 forward_operations = [] 464 backward_operations = [] 465 checkpointing_operations = [] 466 467 if include_DTensor_ops: 468 if fqn in self.comm_module_operation_counts: 469 ( 470 forward_operations, 471 backward_operations, 472 checkpointing_operations, 473 ) = self._get_operations_list( 474 self.comm_module_operation_counts[fqn] 475 ) 476 477 def add_tracing_information(table, collectives_dict, operation_list): 478 """ 479 adds tracing information for module's forward or backward 480 """ 481 for collective, count in collectives_dict.items(): 482 table += ( 483 f"\033[1;33m{collective_indent}*{collective}: {count}\033[0m\n" 484 ) 485 486 def add_operations( 487 table, operation, collective_indent, operation_indent 488 ): 489 """ 490 adds operation information to the table 491 """ 492 table += f"\033[1;33m{collective_indent}**{operation_name}\033[0m\n" 493 494 if len(operation["input_shape"]): 495 operation_shape = operation["input_shape"] 496 operation_sharding = operation["input_sharding"] 497 operation_device_mesh = operation["device_mesh"] 498 499 table += f"\033[1;31m{operation_indent}shape: {operation_shape}\033[0m\n" 500 table += f"\033[1;31m{operation_indent}sharding: {operation_sharding}\033[0m\n" 501 table += f"\033[1;31m{operation_indent}device mesh: {operation_device_mesh}\033[0m\n" 502 503 return table 504 505 for operation in operation_list: 506 operation_name = str(operation["name"]) 507 508 # include all operations 509 if include_trivial_ops: 510 table = add_operations( 511 table, operation, collective_indent, operation_indent 512 ) 513 514 # include all operations not in trivial operations 515 elif include_ops and operation_name not in trivial_ops: 516 table = add_operations( 517 table, operation, collective_indent, operation_indent 518 ) 519 520 # only include dTensor operations not in trivial set 521 elif ( 522 include_DTensor_ops 523 and (operation_name not in trivial_ops) 524 and len(operation["input_shape"]) 525 ): 526 table = add_operations( 527 table, operation, collective_indent, operation_indent 528 ) 529 530 return table 531 532 if len(forward_collectives) or len(forward_operations): 533 table += f"{indent}FORWARD PASS\n" 534 table = add_tracing_information( 535 table, forward_collectives, forward_operations 536 ) 537 538 if len(backward_collectives) or len(backward_operations): 539 table += f"{indent}BACKWARD PASS\n" 540 table = add_tracing_information( 541 table, backward_collectives, backward_operations 542 ) 543 544 if len(checkpointing_operations): 545 table += f"{indent}ACTIVATION CHECKPOINTING\n" 546 table = add_tracing_information(table, {}, checkpointing_operations) 547 548 return table 549 550 def _get_operations_list(self, module_operation_counts): 551 forward_operations = [ 552 op for op in module_operation_counts["operations_list"] if not op["is_bw"] 553 ] 554 backward_operations = [ 555 op 556 for op in module_operation_counts["operations_list"] 557 if op["is_bw"] and not op["is_activation_checkpointing"] 558 ] 559 checkpointing_operations = [ 560 op 561 for op in module_operation_counts["operations_list"] 562 if op["is_activation_checkpointing"] 563 ] 564 565 return forward_operations, backward_operations, checkpointing_operations 566 567 def get_total_counts(self) -> int: 568 return sum(self.comm_counts.values()) 569 570 def get_comm_counts(self) -> Dict[Any, int]: 571 """Returns the communication counts as a dictionary. 572 573 Returns: 574 Dict[Any, int]: The communication counts as a dictionary. 575 """ 576 return self.comm_counts 577 578 def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: 579 return self.advanced_module_tracker.module_parameters_dict 580 581 def get_sharding_info(self) -> Dict[str, Dict[str, Any]]: 582 return self.advanced_module_tracker.sharding_dict 583 584 def __enter__(self): 585 self.comm_counts.clear() 586 self.comm_module_counts.clear() 587 self.comm_module_counts["Global"] = {} 588 self.comm_module_counts["Global"]["forward"] = defaultdict(int) 589 self.comm_module_counts["Global"]["backward"] = defaultdict(int) 590 591 self.comm_module_operation_counts.clear() 592 593 super().__enter__() 594 self.advanced_module_tracker.__enter__() 595 return self 596 597 def __exit__(self, *args): 598 self.advanced_module_tracker.__exit__() 599 super().__exit__(*args) 600 601 def log_comm_debug_tracing_table_to_file( 602 self, file_name="comm_mode_log.txt", noise_level=3 603 ): 604 """ 605 Alternative to console CommDebugMode output, writes to file specified by the user 606 """ 607 ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") 608 table = ansi_escape.sub("", self.generate_comm_debug_tracing_table(noise_level)) 609 610 with open(file_name, "w") as log_file: 611 log_file.write(table) 612 613 def _set_noise_parameters(self, noise_level): 614 """ 615 sets variables controlling what information displays based on noise level 616 """ 617 include_DTensor_ops = False 618 include_module_data = False 619 include_ops = False 620 include_trivial_ops = False 621 622 if noise_level > 0: 623 include_DTensor_ops = True 624 include_module_data = True 625 626 if noise_level > 1: 627 include_ops = True 628 629 if noise_level > 2: 630 include_trivial_ops = True 631 632 return ( 633 include_DTensor_ops, 634 include_module_data, 635 include_ops, 636 include_trivial_ops, 637 ) 638 639 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 640 # When running this mode with DTensor, ordinarily all modes will 641 # run **before** subclasses get a chance to run. 642 # Returning NotImplemented here gives us a chance to let DTensor 643 # run and desugar into comms ops, before CommDebugMode sees them. 644 645 # sets up operation-level collective count 646 if self.advanced_module_tracker.name not in self.comm_module_operation_counts: 647 # dictionary should hold module input and output shape, operations list and collective counter 648 self.comm_module_operation_counts[self.advanced_module_tracker.name] = { 649 "operations_list": [] 650 } 651 operation_dict = {} 652 operation_dict["name"] = func 653 654 operation_dict["input_shape"] = [] 655 operation_dict["input_sharding"] = [] 656 operation_dict["device_mesh"] = "" 657 658 # tracks if the operation is part of the backward pass 659 operation_dict["is_bw"] = self.advanced_module_tracker.is_bw 660 661 # tracks if the operation is part of activation checkpointing 662 operation_dict[ 663 "is_activation_checkpointing" 664 ] = self.advanced_module_tracker.activation_checkpointing 665 666 if any(t == DTensor for t in types): 667 for ele in args: 668 if isinstance(ele, DTensor): 669 # saves shapes and placements of all DTensor args 670 operation_dict["input_shape"].append(ele.shape) 671 operation_dict["input_sharding"].append(ele.placements) 672 operation_dict["device_mesh"] = str(ele.device_mesh) 673 674 self.comm_module_operation_counts[self.advanced_module_tracker.name][ 675 "operations_list" 676 ].append(operation_dict) 677 678 return NotImplemented 679 680 kwargs = kwargs if kwargs else {} 681 out = func(*args, **kwargs) 682 func_packet = func._overloadpacket 683 684 # We have many tests that use CommDebugMode to verify the occurrence of 685 # collectives. These tests do so by querying comm_counts with legacy 686 # funcol ops as key. For the purpose of native funcol migration, we 687 # need these tests to work for both legacy and native funcol. To avoid 688 # the need to modify all tests to accommodate the two implementations, 689 # we make CommDebugMode translate native funcol ops into legacy funcol 690 # ops until the migration finishes. 691 692 if func_packet in self.comm_registry or func_packet in c10d_collective_ops: 693 if func_packet in NATIVE_TO_PY_MAPPING: 694 func_packet = NATIVE_TO_PY_MAPPING[func_packet] 695 self.comm_counts[func_packet] += 1 696 697 key = "forward" 698 if self.advanced_module_tracker.is_bw: 699 key = "backward" 700 701 # adds collective count to current module 702 if self.advanced_module_tracker.name not in self.comm_module_counts: 703 self.comm_module_counts[self.advanced_module_tracker.name] = {} 704 self.comm_module_counts[self.advanced_module_tracker.name][ 705 "forward" 706 ] = defaultdict(int) 707 self.comm_module_counts[self.advanced_module_tracker.name][ 708 "backward" 709 ] = defaultdict(int) 710 self.comm_module_counts[self.advanced_module_tracker.name][key][ 711 func_packet 712 ] += 1 713 714 # adds collective count to parent modules 715 for par in self.advanced_module_tracker.module_parents_dict[ 716 self.advanced_module_tracker.name 717 ]: 718 # makes sure we aren't double counting when current sub-module hasn't been removed from parents 719 if par != self.advanced_module_tracker.name: 720 if par not in self.comm_module_counts: 721 self.comm_module_counts[par] = {} 722 self.comm_module_counts[par]["forward"] = defaultdict(int) 723 self.comm_module_counts[par]["backward"] = defaultdict(int) 724 self.comm_module_counts[par][key][func_packet] += 1 725 726 # if tensor op uses fake tensors, return 727 if detect_fake_mode(args): 728 return out 729 730 # add tensor operation to module operation list 731 self.comm_module_operation_counts[self.advanced_module_tracker.name][ 732 "operations_list" 733 ].append(operation_dict) 734 735 return out 736