1#!/usr/bin/python3 2# mypy: allow-untyped-defs 3 4import itertools 5from typing import List 6 7import torch 8from torch.autograd.profiler_legacy import profile 9 10from . import ( 11 _disable_server_process_global_profiler, 12 _enable_server_process_global_profiler, 13) 14 15 16__all__: List[str] = [] 17 18 19class _server_process_global_profile(profile): 20 """ 21 It has the same API as ``torch.autograd.profiler.profile`` class, 22 except that it enables profiling on all threads running RPC server request callbacks. 23 24 Context manager that manages autograd profiler state and holds a summary of results. 25 Under the hood it just records events of functions being executed in C++ and 26 exposes those events to Python. You can wrap any code into it and it will 27 only report runtime of PyTorch functions. 28 Note: profiler is thread local and is automatically propagated into the async tasks 29 30 Args: 31 enabled (bool, optional): Setting this to False makes this context manager a no-op. 32 Default: ``True``. 33 34 use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. 35 Adds approximately 4us of overhead to each tensor operation. 36 Default: ``False`` 37 38 record_shapes (bool, optional): If shapes recording is set, information 39 about input dimensions will be collected. This allows one to see which 40 dimensions have been used under the hood and further group by them 41 using prof.key_averages(group_by_input_shape=True). Please note that 42 shape recording might skew your profiling data. It is recommended to 43 use separate runs with and without shape recording to validate the timing. 44 Most likely the skew will be negligible for bottom most events (in a case 45 of nested function calls). But for higher level functions the total 46 self cpu time might be artificially increased because of the shape 47 collection. 48 49 profile_memory (bool, optional): Whether to report memory usage, default: ``False`` 50 51 .. warning: 52 Enabling memory profiling incurs additional profiler overhead 53 54 .. warning: 55 Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), 56 one cannot use the profiler with ``use_cuda = True`` to benchmark 57 DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, 58 please use ``use_cuda = False`` or ``num_workers = 0``. 59 60 Example: 61 >>> # xdoctest: +SKIP 62 >>> # On worker 0: 63 >>> import torch 64 >>> import torch.distributed.rpc as rpc 65 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 66 >>> x, y = torch.tensor(1), torch.tensor(2) 67 >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) 68 >>> outer_profile_rref.rpc_sync().__enter__() 69 >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) 70 >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) 71 >>> inner_profile_rref.rpc_sync().__enter__() 72 >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) 73 >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) 74 >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) 75 >>> print(inner_profile_rref.rpc_sync().key_averages()) 76 --------- --------------- --------------- --------------- --------------- --------------- --------------- 77 Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls 78 --------- --------------- --------------- --------------- --------------- --------------- --------------- 79 sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 80 empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 81 --------- --------------- --------------- --------------- --------------- --------------- --------------- 82 Self CPU time total: 89.667us 83 >>> print(outer_profile_rref.rpc_sync().key_averages()) 84 --------- --------------- --------------- --------------- --------------- --------------- --------------- 85 Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls 86 --------- --------------- --------------- --------------- --------------- --------------- --------------- 87 sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 88 empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 89 add 51.68% 110.550us 58.09% 124.259us 124.259us 1 90 --------- --------------- --------------- --------------- --------------- --------------- --------------- 91 Self CPU time total: 213.926us 92 >>> rpc.shutdown() 93 94 >>> # On worker 1: 95 >>> import torch.distributed.rpc as rpc 96 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 97 >>> # wait for worker 0 to finish work, and then shutdown. 98 >>> rpc.shutdown() 99 """ 100 101 def __init__(self, *args, **kwargs): 102 super().__init__(*args, **kwargs) 103 104 def __enter__(self): 105 """ 106 Turn on server-side process-global profiling. 107 This enables thread-local profiler on all RPC threads running server-side request callbacks. 108 """ 109 if not self.enabled: 110 return 111 112 if self.entered: # type: ignore[has-type] 113 raise RuntimeError("autograd profiler traces are not reentrant") 114 self.entered = True 115 116 profiler_kind = ( 117 torch.autograd.ProfilerState.CUDA 118 if self.use_cuda 119 else torch.autograd.ProfilerState.CPU 120 ) 121 profiler_config = torch.autograd.ProfilerConfig( 122 profiler_kind, 123 self.record_shapes, 124 self.profile_memory, 125 False, 126 False, 127 False, 128 torch.profiler._ExperimentalConfig(), 129 ) 130 _enable_server_process_global_profiler(profiler_config) 131 return self 132 133 def __exit__(self, exc_type, exc_val, exc_tb): 134 """ 135 Turn off server-side process-global profiling. 136 Aggregate all profiling events recorded by RPC threads. 137 138 These attributes are assigned on exiting context. 139 140 Attributes: 141 function_events (torch.autograd.profiler.EventList). It's a list that has helper 142 methods, like 1) show record items in a pretty-print table. 143 2) do averaging by grouping on keys. 3) and more. 144 145 process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). 146 It's a list of ``FunctionEvent`` elements. Every element is a profiling result 147 of an RPC request handling within the profiling range. 148 """ 149 if not self.enabled: 150 return 151 152 process_global_events = _disable_server_process_global_profiler() 153 154 # Every element in this list is a thread profiling result from an RPC request handling. 155 process_global_function_events = [] 156 for thread_local_events in process_global_events: 157 # Parse from ``Event``s to ``FunctionEvent``s. 158 thread_local_function_events = ( 159 torch.autograd.profiler_legacy._parse_legacy_records( 160 thread_local_events 161 ) 162 ) 163 thread_local_function_events.sort( 164 key=lambda function_event: [ 165 function_event.time_range.start, 166 -(function_event.time_range.end), 167 ] 168 ) 169 process_global_function_events.append(thread_local_function_events) 170 171 flattened_function_events = list( 172 itertools.chain.from_iterable(process_global_function_events) 173 ) 174 self.function_events = torch.autograd.profiler_util.EventList( 175 flattened_function_events, 176 use_device="cuda" if self.use_cuda else None, 177 profile_memory=self.profile_memory, 178 ) 179 self.function_events._build_tree() 180 181 self.process_global_function_events = process_global_function_events 182 183 return False 184