xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/debug/_comm_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import json
4import re
5import weakref
6from collections import defaultdict
7from typing import Any, Dict
8
9import torch
10import torch.nn
11from torch._guards import detect_fake_mode
12from torch.autograd.graph import register_multi_grad_hook
13from torch.distributed._tools.mod_tracker import ModTracker
14from torch.distributed.tensor._api import DTensor
15from torch.nn.modules.module import (
16    register_module_forward_hook,
17    register_module_forward_pre_hook,
18    register_module_full_backward_pre_hook,
19)
20from torch.utils._python_dispatch import TorchDispatchMode
21from torch.utils._pytree import tree_flatten
22
23
24__all__ = ["CommDebugMode"]
25
26funcol_native = torch.ops._c10d_functional
27funcol_py = torch.ops.c10d_functional
28funcol_autograd = torch.ops._c10d_functional_autograd
29c10d_ops = torch.ops.c10d
30
31NATIVE_TO_PY_MAPPING = {
32    funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
33    funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced,
34    funcol_native.all_reduce: funcol_py.all_reduce,
35    funcol_native.all_reduce_coalesced: funcol_py.all_reduce_coalesced,
36    funcol_native.all_to_all_single: funcol_py.all_to_all_single,
37    funcol_native.broadcast: funcol_py.broadcast,
38    funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor,
39    funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced,
40    # functional ops
41    funcol_autograd.all_to_all_single: funcol_py.all_to_all_single,
42}
43
44c10d_collective_ops = {
45    c10d_ops._allgather_base_,
46    c10d_ops._reduce_scatter_base_,
47    c10d_ops.allgather_,
48    c10d_ops.allgather_coalesced_,
49    c10d_ops.allgather_into_tensor_coalesced_,
50    c10d_ops.allreduce_,
51    c10d_ops.allreduce_coalesced_,
52    c10d_ops.alltoall_,
53    c10d_ops.alltoall_base_,
54    c10d_ops.broadcast_,
55    c10d_ops.gather_,
56    c10d_ops.scatter_,
57    c10d_ops.reduce_,
58    c10d_ops.reduce_scatter_,
59    c10d_ops.reduce_scatter_tensor_coalesced_,
60}
61
62trivial_ops = {
63    "aten.detach.default",
64    "aten.t.default",
65    "aten.view.default",
66    "aten._to_copy.default",
67    "aten.as_strided.default",
68    "aten.transpose.int",
69}
70
71
72class _CommModeModuleTracker(ModTracker):
73    """
74    Inherits ModuleTracker and expands on its functionality to track the
75    parameters and sharding information of a model at a module-level
76    """
77
78    def __init__(self):
79        super().__init__()
80        self.module_helper_dict = {}
81        self.module_parameters_dict = {}
82        self.module_parents_dict = {}
83        self.register_forward_hook_handles = {}
84        self.parent_dict = {}
85        self.parent_list = []
86        self.sharding_dict = {}
87        self.activation_checkpointing = False
88        self.name = ""
89
90    def _fw_set_module_hook(self, mod, input, output):
91        """
92        Updates the current module after module finishes running and
93        all other hooks are resolved
94        """
95
96        if self.is_bw:
97            self.activation_checkpointing = True
98        else:
99            self.activation_checkpointing = False
100
101        if not self.activation_checkpointing:
102            # module is no longer parent of next modules
103            self.parent_list.pop()
104
105            # set current module to previous parent module
106            self.name = self.parent_list[-1]
107
108    def _fw_pre_hook(self, mod, input):
109        """
110        This function is called before the forward pass of a module. It
111        collects the parameters and sharding information of a module and
112        stores it in a dictionary.
113        """
114        if self.is_bw:
115            self.activation_checkpointing = True
116        else:
117            self.activation_checkpointing = False
118
119        self.name = super()._get_mod_name(mod)
120        w_mod = weakref.ref(mod)
121
122        # adds current sub-module to module tracker parent class
123        super()._get_append_fn(w_mod, self.name, False)()
124
125        args, _ = tree_flatten(input)
126        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
127        if not self.is_bw and tensors:
128            register_multi_grad_hook(
129                tensors, super()._get_pop_fn(w_mod, self.name, True)
130            )
131
132        if not self.activation_checkpointing:
133            # contains information about module ordering and depth in the module tree
134            if self.name not in self.module_helper_dict:
135                self.module_helper_dict[self.name] = {}
136
137            self.module_helper_dict[self.name]["module_type"] = (
138                str(type(mod)).replace("<", "").replace(">", "")
139            )
140            self.module_helper_dict[self.name]["depth"] = len(self.parents) - 1
141
142            for param_name, param in mod.named_parameters(recurse=False):
143                if self.name not in self.module_parameters_dict:
144                    self.module_parameters_dict[self.name] = {}
145
146                self.module_parameters_dict[self.name][param_name] = param.data
147
148                if isinstance(param.data, DTensor):
149                    key_name = self.name + "." + param_name
150                    self.sharding_dict[key_name] = param.data.placements
151
152                    if "parameters" not in self.module_helper_dict[self.name]:
153                        self.module_helper_dict[self.name]["parameters"] = {}
154
155                    self.module_helper_dict[self.name]["parameters"][param_name] = str(
156                        param.data.placements
157                    )
158
159            # used to store module's parents to ensure correctness in backward pass/checkpointing
160            if self.name not in self.module_parents_dict:
161                self.module_parents_dict[self.name] = copy.deepcopy(self.parents)
162
163            # used to create parent-child module associations for json dumps
164            parent = self.parent_list[-1]
165            if parent not in self.parent_dict:
166                self.parent_dict[parent] = []
167
168            self.parent_dict[parent].append(self.name)
169            self.parent_list.append(self.name)
170
171            self.register_forward_hook_handles[self.name] = mod.register_forward_hook(
172                self._fw_set_module_hook
173            )
174
175    def _fw_post_hook(self, mod, input, output):
176        """
177        This function is called when the forward pass of a module is called.
178        It updates the module tracker and removes the module from parent data
179        """
180
181        super()._fw_post_hook(mod, input, output)
182
183    def _bw_hook(self, mod, output):
184        """
185        This function is called when the backward pass of a module is called. It
186        updates the current module for backward passes
187        """
188        self.activation_checkpointing = False
189        self.name = super()._get_mod_name(mod)
190
191    def __enter__(self):
192        self.activation_checkpointing = False
193        self.module_parameters_dict.clear()
194        self.sharding_dict.clear()
195        self.parent_dict.clear()
196        self.parent_list = ["Global"]
197        self.module_helper_dict.clear()
198        self.module_helper_dict["Global"] = {"depth": 0}
199        self.module_parents_dict.clear()
200        self.module_parents_dict["Global"] = set()
201        self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
202        self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
203        self.register_forward_hook_handles.clear()
204        self._bw_handle = register_module_full_backward_pre_hook(self._bw_hook)
205        self.name = "Global"
206
207    def __exit__(self, *args):
208        super().__exit__(*args)
209        self._bw_handle.remove()
210
211        # removes all forward_hook handles added in the pre-hook
212        for handle in self.register_forward_hook_handles.values():
213            handle.remove()
214
215    def print_paramater_info(self):
216        print(self.module_parameters_dict)
217
218    def print_sharding_info(self):
219        for key, value in self.sharding_dict.items():
220            print(key + ": " + str(value))
221
222
223class CommDebugMode(TorchDispatchMode):
224    """
225    :class:`CommDebugMode` is a context manager that counts the number of
226    functional collectives within its context. It does this using a
227    ``TorchDispatchMode``.
228
229    .. note: Not all collectives are supported yet.
230
231    Example usage
232
233    .. code-block:: python
234
235        mod = ...
236        comm_mode = CommDebugMode()
237        with comm_mode:
238            mod.sum().backward()
239        print(comm_mode.get_comm_counts())
240    """
241
242    def __init__(self):
243        self.comm_counts: Dict[Any, int] = defaultdict(int)
244        self.comm_module_counts = {}
245        self.comm_module_operation_counts = {}
246        self.comm_registry = set()
247        for native_op, py_op in NATIVE_TO_PY_MAPPING.items():
248            self.comm_registry.add(native_op)
249            self.comm_registry.add(py_op)
250
251        self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall)
252        self.advanced_module_tracker = _CommModeModuleTracker()
253
254    def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3):
255        """
256        Creates json file used to build browser visual
257        0. prints module-level collective counts
258        1. prints dTensor operations not included in trivial operations
259        2. prints operations not included in trivial operations
260        3. prints all operations
261        """
262
263        (
264            include_DTensor_ops,
265            include_module_data,
266            include_ops,
267            include_trivial_ops,
268        ) = self._set_noise_parameters(noise_level)
269
270        # recursively builds json data
271        def add_json_information(json_dict, fqn):
272            json_dict["fqn"] = fqn
273            json_dict["module_type"] = ""
274            json_dict["parameters"] = []
275            json_dict["children"] = []
276            json_dict["collectives_forward"] = []
277            json_dict["collectives_backward"] = []
278            json_dict["operations_forward"] = []
279            json_dict["operations_backward"] = []
280
281            # adds module layer type and parameters, and their sharding
282            if (
283                "module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
284                and include_module_data
285            ):
286                json_dict[
287                    "module_type"
288                ] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
289
290                if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
291                    for (
292                        param_name,
293                        placement,
294                    ) in self.advanced_module_tracker.module_helper_dict[fqn][
295                        "parameters"
296                    ].items():
297                        json_dict["parameters"].append((param_name, placement))
298
299            # adds module collective information
300            if fqn in self.comm_module_counts:
301                for collective, count in self.comm_module_counts[fqn][
302                    "forward"
303                ].items():
304                    json_dict["collectives_forward"].append((str(collective), count))
305
306                for collective, count in self.comm_module_counts[fqn][
307                    "backward"
308                ].items():
309                    json_dict["collectives_backward"].append((str(collective), count))
310
311            # adds module operation information
312            forward_operations = []
313            backward_operations = []
314            checkpointing_operations = []
315
316            # only get operations if the minimum operation noise level is set to true
317            if include_DTensor_ops:
318                if fqn in self.comm_module_operation_counts:
319                    (
320                        forward_operations,
321                        backward_operations,
322                        checkpointing_operations,
323                    ) = self._get_operations_list(
324                        self.comm_module_operation_counts[fqn]
325                    )
326
327            # remove all operations who don't have DTensor inputs
328            if not include_ops:
329                forward_operations = [
330                    op for op in forward_operations if len(op["input_sharding"])
331                ]
332                backward_operations = [
333                    op for op in backward_operations if len(op["input_sharding"])
334                ]
335                checkpointing_operations = [
336                    op for op in checkpointing_operations if len(op["input_sharding"])
337                ]
338
339            # remove all operations in trivial operations set
340            if not include_trivial_ops:
341                forward_operations = [
342                    op
343                    for op in forward_operations
344                    if str(op["name"]) not in trivial_ops
345                ]
346                backward_operations = [
347                    op
348                    for op in backward_operations
349                    if str(op["name"]) not in trivial_ops
350                ]
351                checkpointing_operations = [
352                    op
353                    for op in checkpointing_operations
354                    if str(op["name"]) not in trivial_ops
355                ]
356
357            # converts operation information into string format for json.dumps()
358            forward_operations = copy.deepcopy(forward_operations)
359            for op in forward_operations:
360                op["name"] = str(op["name"])
361
362                for i in range(len(op["input_sharding"])):
363                    op["input_sharding"][i] = str(op["input_sharding"][i])
364                    op["input_shape"][i] = str(op["input_shape"][i])
365
366            backward_operations = copy.deepcopy(backward_operations)
367            for op in backward_operations:
368                op["name"] = str(op["name"])
369
370                for i in range(len(op["input_sharding"])):
371                    op["input_sharding"][i] = str(op["input_sharding"][i])
372                    op["input_shape"][i] = str(op["input_shape"][i])
373
374            checkpointing_operations = copy.deepcopy(checkpointing_operations)
375            for op in checkpointing_operations:
376                op["name"] = str(op["name"])
377
378                for i in range(len(op["input_sharding"])):
379                    op["input_sharding"][i] = str(op["input_sharding"][i])
380                    op["input_shape"][i] = str(op["input_shape"][i])
381
382            json_dict["operations_forward"] = forward_operations
383            json_dict["operations_backward"] = backward_operations
384            json_dict["operations_checkpointing"] = checkpointing_operations
385
386            if fqn not in self.advanced_module_tracker.parent_dict:
387                return json_dict
388
389            # recursively adds module's children
390            for ele in self.advanced_module_tracker.parent_dict[fqn]:
391                json_dict["children"].append(add_json_information({}, ele))
392
393            return json_dict
394
395        json_dict: Dict[str, Any] = {}
396        add_json_information(json_dict, "Global")
397
398        # converts dictonary into json file
399        with open(file_name, "w") as json_file:
400            json.dump(json_dict, json_file, indent=4)
401
402    def generate_comm_debug_tracing_table(self, noise_level=3):
403        """
404        Generates detailed table displaying operations and collective tracing information
405        on a module level. Amount of information is dependent on noise_level
406
407        0. prints module-level collective counts
408        1. prints dTensor operations not included in trivial operations, module information
409        2. prints operations not included in trivial operations
410        3. prints all operations
411        """
412
413        (
414            include_DTensor_ops,
415            include_module_data,
416            include_ops,
417            include_trivial_ops,
418        ) = self._set_noise_parameters(noise_level)
419
420        table = ""
421        for fqn in self.advanced_module_tracker.module_helper_dict:
422            # setting up indentations for table formatting
423            indent = "  " * (
424                2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"]
425            )
426            table += f"{indent}{fqn}\n"
427
428            if include_module_data:
429                if (
430                    "module_type"
431                    in self.advanced_module_tracker.module_helper_dict[fqn]
432                ):
433                    module_type = self.advanced_module_tracker.module_helper_dict[fqn][
434                        "module_type"
435                    ]
436                    table += f"{indent}*module type: {module_type}\n"
437
438                if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
439                    table += f"{indent}*Parameter List\n"
440                    for (
441                        param_name,
442                        placement,
443                    ) in self.advanced_module_tracker.module_helper_dict[fqn][
444                        "parameters"
445                    ].items():
446                        table += f"{indent} *{param_name}: {placement}\n"
447
448            indent += "  "
449            collective_indent = "  " * (
450                2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 2
451            )
452            operation_indent = "  " * (
453                2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 3
454            )
455
456            # separate the module's collective and operations by forward and backward
457            forward_collectives = {}
458            backward_collectives = {}
459            if fqn in self.comm_module_counts:
460                forward_collectives = self.comm_module_counts[fqn]["forward"]
461                backward_collectives = self.comm_module_counts[fqn]["backward"]
462
463            forward_operations = []
464            backward_operations = []
465            checkpointing_operations = []
466
467            if include_DTensor_ops:
468                if fqn in self.comm_module_operation_counts:
469                    (
470                        forward_operations,
471                        backward_operations,
472                        checkpointing_operations,
473                    ) = self._get_operations_list(
474                        self.comm_module_operation_counts[fqn]
475                    )
476
477            def add_tracing_information(table, collectives_dict, operation_list):
478                """
479                adds tracing information for module's forward or backward
480                """
481                for collective, count in collectives_dict.items():
482                    table += (
483                        f"\033[1;33m{collective_indent}*{collective}: {count}\033[0m\n"
484                    )
485
486                def add_operations(
487                    table, operation, collective_indent, operation_indent
488                ):
489                    """
490                    adds operation information to the table
491                    """
492                    table += f"\033[1;33m{collective_indent}**{operation_name}\033[0m\n"
493
494                    if len(operation["input_shape"]):
495                        operation_shape = operation["input_shape"]
496                        operation_sharding = operation["input_sharding"]
497                        operation_device_mesh = operation["device_mesh"]
498
499                        table += f"\033[1;31m{operation_indent}shape: {operation_shape}\033[0m\n"
500                        table += f"\033[1;31m{operation_indent}sharding: {operation_sharding}\033[0m\n"
501                        table += f"\033[1;31m{operation_indent}device mesh: {operation_device_mesh}\033[0m\n"
502
503                    return table
504
505                for operation in operation_list:
506                    operation_name = str(operation["name"])
507
508                    # include all operations
509                    if include_trivial_ops:
510                        table = add_operations(
511                            table, operation, collective_indent, operation_indent
512                        )
513
514                    # include all operations not in trivial operations
515                    elif include_ops and operation_name not in trivial_ops:
516                        table = add_operations(
517                            table, operation, collective_indent, operation_indent
518                        )
519
520                    # only include dTensor operations not in trivial set
521                    elif (
522                        include_DTensor_ops
523                        and (operation_name not in trivial_ops)
524                        and len(operation["input_shape"])
525                    ):
526                        table = add_operations(
527                            table, operation, collective_indent, operation_indent
528                        )
529
530                return table
531
532            if len(forward_collectives) or len(forward_operations):
533                table += f"{indent}FORWARD PASS\n"
534                table = add_tracing_information(
535                    table, forward_collectives, forward_operations
536                )
537
538            if len(backward_collectives) or len(backward_operations):
539                table += f"{indent}BACKWARD PASS\n"
540                table = add_tracing_information(
541                    table, backward_collectives, backward_operations
542                )
543
544            if len(checkpointing_operations):
545                table += f"{indent}ACTIVATION CHECKPOINTING\n"
546                table = add_tracing_information(table, {}, checkpointing_operations)
547
548        return table
549
550    def _get_operations_list(self, module_operation_counts):
551        forward_operations = [
552            op for op in module_operation_counts["operations_list"] if not op["is_bw"]
553        ]
554        backward_operations = [
555            op
556            for op in module_operation_counts["operations_list"]
557            if op["is_bw"] and not op["is_activation_checkpointing"]
558        ]
559        checkpointing_operations = [
560            op
561            for op in module_operation_counts["operations_list"]
562            if op["is_activation_checkpointing"]
563        ]
564
565        return forward_operations, backward_operations, checkpointing_operations
566
567    def get_total_counts(self) -> int:
568        return sum(self.comm_counts.values())
569
570    def get_comm_counts(self) -> Dict[Any, int]:
571        """Returns the communication counts as a dictionary.
572
573        Returns:
574            Dict[Any, int]: The communication counts as a dictionary.
575        """
576        return self.comm_counts
577
578    def get_parameter_info(self) -> Dict[str, Dict[str, Any]]:
579        return self.advanced_module_tracker.module_parameters_dict
580
581    def get_sharding_info(self) -> Dict[str, Dict[str, Any]]:
582        return self.advanced_module_tracker.sharding_dict
583
584    def __enter__(self):
585        self.comm_counts.clear()
586        self.comm_module_counts.clear()
587        self.comm_module_counts["Global"] = {}
588        self.comm_module_counts["Global"]["forward"] = defaultdict(int)
589        self.comm_module_counts["Global"]["backward"] = defaultdict(int)
590
591        self.comm_module_operation_counts.clear()
592
593        super().__enter__()
594        self.advanced_module_tracker.__enter__()
595        return self
596
597    def __exit__(self, *args):
598        self.advanced_module_tracker.__exit__()
599        super().__exit__(*args)
600
601    def log_comm_debug_tracing_table_to_file(
602        self, file_name="comm_mode_log.txt", noise_level=3
603    ):
604        """
605        Alternative to console CommDebugMode output, writes to file specified by the user
606        """
607        ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
608        table = ansi_escape.sub("", self.generate_comm_debug_tracing_table(noise_level))
609
610        with open(file_name, "w") as log_file:
611            log_file.write(table)
612
613    def _set_noise_parameters(self, noise_level):
614        """
615        sets variables controlling what information displays based on noise level
616        """
617        include_DTensor_ops = False
618        include_module_data = False
619        include_ops = False
620        include_trivial_ops = False
621
622        if noise_level > 0:
623            include_DTensor_ops = True
624            include_module_data = True
625
626        if noise_level > 1:
627            include_ops = True
628
629        if noise_level > 2:
630            include_trivial_ops = True
631
632        return (
633            include_DTensor_ops,
634            include_module_data,
635            include_ops,
636            include_trivial_ops,
637        )
638
639    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
640        # When running this mode with DTensor, ordinarily all modes will
641        # run **before** subclasses get a chance to run.
642        # Returning NotImplemented here gives us a chance to let DTensor
643        # run and desugar into comms ops, before CommDebugMode sees them.
644
645        # sets up operation-level collective count
646        if self.advanced_module_tracker.name not in self.comm_module_operation_counts:
647            # dictionary should hold module input and output shape, operations list and collective counter
648            self.comm_module_operation_counts[self.advanced_module_tracker.name] = {
649                "operations_list": []
650            }
651        operation_dict = {}
652        operation_dict["name"] = func
653
654        operation_dict["input_shape"] = []
655        operation_dict["input_sharding"] = []
656        operation_dict["device_mesh"] = ""
657
658        # tracks if the operation is part of the backward pass
659        operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
660
661        # tracks if the operation is part of activation checkpointing
662        operation_dict[
663            "is_activation_checkpointing"
664        ] = self.advanced_module_tracker.activation_checkpointing
665
666        if any(t == DTensor for t in types):
667            for ele in args:
668                if isinstance(ele, DTensor):
669                    # saves shapes and placements of all DTensor args
670                    operation_dict["input_shape"].append(ele.shape)
671                    operation_dict["input_sharding"].append(ele.placements)
672                    operation_dict["device_mesh"] = str(ele.device_mesh)
673
674            self.comm_module_operation_counts[self.advanced_module_tracker.name][
675                "operations_list"
676            ].append(operation_dict)
677
678            return NotImplemented
679
680        kwargs = kwargs if kwargs else {}
681        out = func(*args, **kwargs)
682        func_packet = func._overloadpacket
683
684        # We have many tests that use CommDebugMode to verify the occurrence of
685        # collectives. These tests do so by querying comm_counts with legacy
686        # funcol ops as key. For the purpose of native funcol migration, we
687        # need these tests to work for both legacy and native funcol. To avoid
688        # the need to modify all tests to accommodate the two implementations,
689        # we make CommDebugMode translate native funcol ops into legacy funcol
690        # ops until the migration finishes.
691
692        if func_packet in self.comm_registry or func_packet in c10d_collective_ops:
693            if func_packet in NATIVE_TO_PY_MAPPING:
694                func_packet = NATIVE_TO_PY_MAPPING[func_packet]
695            self.comm_counts[func_packet] += 1
696
697            key = "forward"
698            if self.advanced_module_tracker.is_bw:
699                key = "backward"
700
701            # adds collective count to current module
702            if self.advanced_module_tracker.name not in self.comm_module_counts:
703                self.comm_module_counts[self.advanced_module_tracker.name] = {}
704                self.comm_module_counts[self.advanced_module_tracker.name][
705                    "forward"
706                ] = defaultdict(int)
707                self.comm_module_counts[self.advanced_module_tracker.name][
708                    "backward"
709                ] = defaultdict(int)
710            self.comm_module_counts[self.advanced_module_tracker.name][key][
711                func_packet
712            ] += 1
713
714            # adds collective count to parent modules
715            for par in self.advanced_module_tracker.module_parents_dict[
716                self.advanced_module_tracker.name
717            ]:
718                # makes sure we aren't double counting when current sub-module hasn't been removed from parents
719                if par != self.advanced_module_tracker.name:
720                    if par not in self.comm_module_counts:
721                        self.comm_module_counts[par] = {}
722                        self.comm_module_counts[par]["forward"] = defaultdict(int)
723                        self.comm_module_counts[par]["backward"] = defaultdict(int)
724                    self.comm_module_counts[par][key][func_packet] += 1
725
726        # if tensor op uses fake tensors, return
727        if detect_fake_mode(args):
728            return out
729
730        # add tensor operation to module operation list
731        self.comm_module_operation_counts[self.advanced_module_tracker.name][
732            "operations_list"
733        ].append(operation_dict)
734
735        return out
736