xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/server_process_global_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/python3
2# mypy: allow-untyped-defs
3
4import itertools
5from typing import List
6
7import torch
8from torch.autograd.profiler_legacy import profile
9
10from . import (
11    _disable_server_process_global_profiler,
12    _enable_server_process_global_profiler,
13)
14
15
16__all__: List[str] = []
17
18
19class _server_process_global_profile(profile):
20    """
21    It has the same API as ``torch.autograd.profiler.profile`` class,
22    except that it enables profiling on all threads running RPC server request callbacks.
23
24    Context manager that manages autograd profiler state and holds a summary of results.
25    Under the hood it just records events of functions being executed in C++ and
26    exposes those events to Python. You can wrap any code into it and it will
27    only report runtime of PyTorch functions.
28    Note: profiler is thread local and is automatically propagated into the async tasks
29
30    Args:
31        enabled (bool, optional): Setting this to False makes this context manager a no-op.
32            Default: ``True``.
33
34        use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
35            Adds approximately 4us of overhead to each tensor operation.
36            Default: ``False``
37
38        record_shapes (bool, optional): If shapes recording is set, information
39            about input dimensions will be collected. This allows one to see which
40            dimensions have been used under the hood and further group by them
41            using prof.key_averages(group_by_input_shape=True). Please note that
42            shape recording might skew your profiling data. It is recommended to
43            use separate runs with and without shape recording to validate the timing.
44            Most likely the skew will be negligible for bottom most events (in a case
45            of nested function calls). But for higher level functions the total
46            self cpu time might be artificially increased because of the shape
47            collection.
48
49        profile_memory (bool, optional): Whether to report memory usage, default: ``False``
50
51    .. warning:
52        Enabling memory profiling incurs additional profiler overhead
53
54    .. warning:
55        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
56        one cannot use the profiler with ``use_cuda = True`` to benchmark
57        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
58        please use ``use_cuda = False`` or ``num_workers = 0``.
59
60    Example:
61        >>> # xdoctest: +SKIP
62        >>> # On worker 0:
63        >>> import torch
64        >>> import torch.distributed.rpc as rpc
65        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
66        >>> x, y = torch.tensor(1), torch.tensor(2)
67        >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
68        >>> outer_profile_rref.rpc_sync().__enter__()
69        >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
70        >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
71        >>> inner_profile_rref.rpc_sync().__enter__()
72        >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
73        >>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
74        >>> outer_profile_rref.rpc_sync().__exit__(None, None, None)
75        >>> print(inner_profile_rref.rpc_sync().key_averages())
76        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
77        Name       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
78        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
79        sub        85.06%           76.275us         100.00%          89.667us         89.667us         1
80        empty      14.94%           13.392us         14.94%           13.392us         13.392us         1
81        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
82        Self CPU time total: 89.667us
83        >>> print(outer_profile_rref.rpc_sync().key_averages())
84        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
85        Name       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
86        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
87        sub        35.65%           76.275us         41.91%           89.667us         89.667us         1
88        empty      12.67%           27.101us         12.67%           27.101us         13.551us         2
89        add        51.68%           110.550us        58.09%           124.259us        124.259us        1
90        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
91        Self CPU time total: 213.926us
92        >>> rpc.shutdown()
93
94        >>> # On worker 1:
95        >>> import torch.distributed.rpc as rpc
96        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
97        >>> # wait for worker 0 to finish work, and then shutdown.
98        >>> rpc.shutdown()
99    """
100
101    def __init__(self, *args, **kwargs):
102        super().__init__(*args, **kwargs)
103
104    def __enter__(self):
105        """
106        Turn on server-side process-global profiling.
107        This enables thread-local profiler on all RPC threads running server-side request callbacks.
108        """
109        if not self.enabled:
110            return
111
112        if self.entered:  # type: ignore[has-type]
113            raise RuntimeError("autograd profiler traces are not reentrant")
114        self.entered = True
115
116        profiler_kind = (
117            torch.autograd.ProfilerState.CUDA
118            if self.use_cuda
119            else torch.autograd.ProfilerState.CPU
120        )
121        profiler_config = torch.autograd.ProfilerConfig(
122            profiler_kind,
123            self.record_shapes,
124            self.profile_memory,
125            False,
126            False,
127            False,
128            torch.profiler._ExperimentalConfig(),
129        )
130        _enable_server_process_global_profiler(profiler_config)
131        return self
132
133    def __exit__(self, exc_type, exc_val, exc_tb):
134        """
135        Turn off server-side process-global profiling.
136        Aggregate all profiling events recorded by RPC threads.
137
138        These attributes are assigned on exiting context.
139
140        Attributes:
141            function_events (torch.autograd.profiler.EventList).  It's a list that has helper
142            methods, like 1) show record items in a pretty-print table.
143            2) do averaging by grouping on keys. 3) and more.
144
145            process_global_function_events (List[torch.autograd.profiler.FunctionEvent]).
146            It's a list of ``FunctionEvent`` elements. Every element is a profiling result
147            of an RPC request handling within the profiling range.
148        """
149        if not self.enabled:
150            return
151
152        process_global_events = _disable_server_process_global_profiler()
153
154        # Every element in this list is a thread profiling result from an RPC request handling.
155        process_global_function_events = []
156        for thread_local_events in process_global_events:
157            # Parse from ``Event``s to ``FunctionEvent``s.
158            thread_local_function_events = (
159                torch.autograd.profiler_legacy._parse_legacy_records(
160                    thread_local_events
161                )
162            )
163            thread_local_function_events.sort(
164                key=lambda function_event: [
165                    function_event.time_range.start,
166                    -(function_event.time_range.end),
167                ]
168            )
169            process_global_function_events.append(thread_local_function_events)
170
171        flattened_function_events = list(
172            itertools.chain.from_iterable(process_global_function_events)
173        )
174        self.function_events = torch.autograd.profiler_util.EventList(
175            flattened_function_events,
176            use_device="cuda" if self.use_cuda else None,
177            profile_memory=self.profile_memory,
178        )
179        self.function_events._build_tree()
180
181        self.process_global_function_events = process_global_function_events
182
183        return False
184