1# Owner(s): ["oncall: distributed"] 2 3import sys 4import time 5from statistics import mean 6from unittest.mock import patch 7 8import torch 9import torch.nn as nn 10from torch import distributed as dist 11from torch.cuda import Event 12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 14from torch.testing._internal.common_fsdp import FSDPTest 15from torch.testing._internal.common_utils import ( 16 get_cycles_per_ms, 17 run_tests, 18 TEST_WITH_DEV_DBG_ASAN, 19) 20 21 22if not dist.is_available(): 23 print("Distributed not available, skipping tests", file=sys.stderr) 24 sys.exit(0) 25 26if TEST_WITH_DEV_DBG_ASAN: 27 print( 28 "Skip dev-asan as torch + multiprocessing spawn have known issues", 29 file=sys.stderr, 30 ) 31 sys.exit(0) 32 33 34class Layer(nn.Module): 35 def __init__(self, compute_cycles, has_params: bool): 36 super().__init__() 37 self.sleep_cycles = compute_cycles 38 self.optional_param = None 39 if has_params: 40 self.optional_param = nn.Parameter(torch.rand(1)) 41 42 def forward(self, x): 43 # Get 2 events. 44 self.e1 = Event(enable_timing=True) 45 self.e2 = Event(enable_timing=True) 46 47 # Record the fake forward compute time. 48 self.e1.record() 49 if self.sleep_cycles > 0: 50 torch.cuda._sleep(self.sleep_cycles) 51 if self.optional_param is not None: 52 x = x + self.optional_param # force the param to be part of the graph 53 self.e2.record() 54 return x 55 56 def get_time(self): 57 # return the recorded duration. 58 return self.e1.elapsed_time(self.e2) 59 60 61def _create_model(compute_cycles, has_params: bool): 62 # Use `limit_all_gathers=False` since the timing being tested relies on the 63 # CPU running ahead of the GPU 64 model = FSDP( 65 nn.Sequential( 66 FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False), 67 FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False), 68 FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False), 69 FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False), 70 ), 71 limit_all_gathers=False, 72 ).cuda() 73 return model 74 75 76class Min10: 77 def __init__(self) -> None: 78 self.data = [] 79 80 def add(self, new_data): 81 if len(self.data) < 10: 82 self.data.append(new_data) 83 else: 84 self.data = sorted(self.data) 85 if new_data < self.data[-1]: 86 self.data[-1] = new_data 87 88 def avg(self): 89 return mean(self.data) 90 91 92class TestForwardOverlapWorldSizeOne(FSDPTest): 93 @property 94 def world_size(self): 95 return 1 96 97 def _dist_train(self): 98 rank = self.rank 99 world_size = self.world_size 100 # Save the original torch.distributed.all_gather_into_tensor function since we will 101 # patch it to include an artificial delay. 102 orig_all_gather = torch.distributed.all_gather_into_tensor 103 104 def run(compute_cycles, all_gather_cycles): 105 has_params = all_gather_cycles > 0 106 model = _create_model(compute_cycles, has_params) 107 108 # Get the input and sets the input's requires_grad to True because 109 # we have a fake compute in the forward pass. 110 batch = torch.rand(1).cuda() 111 batch.requires_grad = True 112 113 # Run one dummy iteration to trigger the execution order validation 114 # all-gathers 115 out = model(batch) 116 out.backward() 117 model.zero_grad(set_to_none=True) 118 119 # We run 20 iterations but only collect timing data from the minimal 10 120 # data points because nondeterministic system events can disturb the timing. 121 cpu_iter = Min10() 122 cpu_wait = Min10() 123 gpu_compute = Min10() 124 gpu_total = Min10() 125 for _ in range(20): 126 # Get two events for measuring the overall time. 127 e1 = Event(enable_timing=True) 128 e2 = Event(enable_timing=True) 129 130 cpu_start = time.process_time() 131 132 all_gather_called = False 133 134 def _delayed_all_gather(*args, **kwargs): 135 nonlocal all_gather_called 136 all_gather_called = True 137 torch.cuda._sleep(all_gather_cycles) 138 assert orig_all_gather 139 return orig_all_gather(*args, **kwargs) 140 141 # forward pass 142 # 143 # Even though both e1 & e2 are on the compute stream, since 144 # compute depends on all_gather, e2-e1 includes all_gather time. 145 e1.record() 146 with patch( 147 "torch.distributed.all_gather_into_tensor", _delayed_all_gather 148 ): 149 out = model(batch) 150 if has_params and world_size > 1: 151 self.assertTrue(all_gather_called) 152 else: 153 self.assertFalse(all_gather_called) 154 e2.record() 155 156 # backward pass 157 out.backward() 158 model.zero_grad(set_to_none=True) 159 160 cpu_iter_time = time.process_time() - cpu_start 161 162 # wait for gpu 163 out.item() 164 cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time 165 166 # get sum of the compute time 167 times = [] 168 for mod in model.modules(): 169 if not isinstance(mod, Layer): 170 continue 171 times.append(mod.get_time()) 172 173 # get gpu compute + all_gather time 174 overall_gpu_time = e1.elapsed_time(e2) 175 176 cpu_iter.add(cpu_iter_time) 177 cpu_wait.add(cpu_wait_for_gpu_time) 178 gpu_compute.add(sum(times)) 179 gpu_total.add(overall_gpu_time) 180 181 del model 182 return { 183 "cpu_iter": cpu_iter.avg(), 184 "cpu_wait": cpu_wait.avg(), 185 "gpu_compute": gpu_compute.avg(), 186 "gpu_total": gpu_total.avg(), 187 } 188 189 sleep_cycles = int(100 * get_cycles_per_ms()) 190 191 e1 = run(0, 0) # no compute, no all-gather 192 e2 = run(0, sleep_cycles) # no compute, only all-gather 193 e3 = run(sleep_cycles, 0) # only compute, no all-gather 194 e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather 195 debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}" 196 print(debug_string) 197 198 # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu 199 # wait should be long, except when there is no real work on GPU. 200 # 201 # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass. 202 # e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather. 203 short = [ 204 e1["cpu_iter"], 205 e2["cpu_iter"], 206 e3["cpu_iter"], 207 e1["cpu_wait"], 208 ] 209 long = [e3["cpu_wait"], e4["cpu_wait"]] 210 if world_size == 1: 211 short.append(e2["cpu_wait"]) # all gather should not be happening. 212 else: 213 long.append( 214 e2["cpu_wait"] 215 ) # all gather should happen and prolong the cpu-gpu wait. 216 for s in short: 217 for l in long: 218 # 10X longer is a safe margin, since the GPU work timing is around 100X more 219 # of that of the CPU. 220 self.assertTrue(s * 10 < l) 221 222 # Check the GPU timing. 223 short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]] 224 long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]] 225 if world_size == 1: 226 short.append(e2["gpu_total"]) # all gather should not be happening. 227 else: 228 long.append( 229 e2["gpu_total"] 230 ) # all gather should happen and prolong the cpu-gpu wait. 231 for s in short: 232 for l in long: 233 # 10X longer is a safe margin, since the time is around 100X longer 234 # when there is work on GPU vs. no work. 235 self.assertTrue(s * 10 < l) 236 237 # Check the GPU overlapping when there is all-gather. 238 if world_size > 1: 239 compute_only = e3["gpu_compute"] 240 all_gather_only = e2["gpu_total"] 241 both = e4["gpu_total"] 242 self.assertTrue(compute_only + all_gather_only > 1.1 * both) 243 244 @skip_if_lt_x_gpu(2) 245 def test_forward_overlap(self): 246 self._dist_train() 247 248 249class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne): 250 @property 251 def world_size(self): 252 return 2 253 254 255if __name__ == "__main__": 256 run_tests() 257