1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import sys 5from copy import deepcopy 6from functools import partial 7 8import torch 9import torch.distributed as dist 10import torch.nn as nn 11from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 12 checkpoint_wrapper, 13 offload_wrapper, 14) 15from torch.distributed.fsdp import ShardingStrategy 16from torch.distributed.fsdp.fully_sharded_data_parallel import ( 17 CPUOffload, 18 FullyShardedDataParallel as FSDP, 19) 20from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 21from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest 22from torch.testing._internal.common_utils import ( 23 instantiate_parametrized_tests, 24 parametrize, 25 run_tests, 26 TEST_WITH_DEV_DBG_ASAN, 27) 28from torch.utils.checkpoint import checkpoint 29 30 31if not dist.is_available(): 32 print("Distributed not available, skipping tests", file=sys.stderr) 33 sys.exit(0) 34 35if TEST_WITH_DEV_DBG_ASAN: 36 print( 37 "Skip dev-asan as torch + multiprocessing spawn have known issues", 38 file=sys.stderr, 39 ) 40 sys.exit(0) 41 42 43_save_on_cpu_called = False 44 45 46def get_patched_save_on_cpu(): 47 orig_save_on_cpu = ( 48 torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu 49 ) 50 51 def patched_save_on_cpu(*args, **kwargs): 52 global _save_on_cpu_called 53 _save_on_cpu_called = True 54 return orig_save_on_cpu(*args, **kwargs) 55 56 return patched_save_on_cpu 57 58 59@contextlib.contextmanager 60def patch_save_on_cpu(new_save_on_cpu): 61 orig_save_on_cpu = ( 62 torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu 63 ) 64 torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( 65 new_save_on_cpu 66 ) 67 try: 68 yield 69 finally: 70 torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( 71 orig_save_on_cpu 72 ) 73 74 75class TestFSDPCheckpoint(FSDPTest): 76 class SequentialModule(nn.Module): 77 def __init__( 78 self, 79 checkpoint_layer=False, 80 offload_activations=False, 81 wrap_fsdp=False, 82 *fsdp_args, 83 **fsdp_kwargs, 84 ): 85 torch.manual_seed(0) 86 torch.cuda.manual_seed(0) 87 super().__init__() 88 l1 = nn.Linear(3, 3).cuda() 89 l2 = nn.Linear(3, 3).cuda() 90 l3 = nn.Linear(3, 3).cuda() 91 92 if checkpoint_layer: 93 if offload_activations: 94 ckpt_wrapper = offload_wrapper 95 else: 96 ckpt_wrapper = checkpoint_wrapper 97 98 l1 = ckpt_wrapper(l1) 99 l2 = ckpt_wrapper(l2) 100 l3 = ckpt_wrapper(l3) 101 102 fsdp_wrapper = partial( 103 _maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs 104 ) 105 self.ffn = nn.Sequential( 106 fsdp_wrapper(l1), 107 fsdp_wrapper(l2), 108 fsdp_wrapper(l3), 109 ) 110 111 def forward(self, x): 112 return self.ffn(x) 113 114 def _verify_parity(self, losses, outputs, models): 115 assert losses 116 assert outputs 117 assert models 118 119 for l, o in zip(losses[1:], outputs[1:]): 120 self.assertEqual(losses[0], l) 121 self.assertEqual(outputs[0], o) 122 123 # Verify grads 124 ref_model = models[0] 125 ref_grads = [p.grad for p in ref_model.parameters()] 126 for m in models[1:]: 127 grads = [p.grad for p in m.parameters()] 128 for ref_g, g in zip(ref_grads, grads): 129 self.assertEqual(ref_g, g) 130 131 @skip_if_lt_x_gpu(2) 132 @parametrize( 133 "cpu_offload", 134 [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], 135 ) 136 @parametrize("offload_activations", [True, False]) 137 @parametrize("use_orig_params", [False, True]) 138 def test_checkpoint_fsdp_wrapping( 139 self, 140 cpu_offload: CPUOffload, 141 offload_activations: bool, 142 use_orig_params: bool, 143 ): 144 # Test checkpoint(FSDP(layer1), FSDP(layer2), ....) 145 if offload_activations: 146 wrapper_to_use = offload_wrapper 147 else: 148 wrapper_to_use = checkpoint_wrapper 149 150 fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} 151 ckpt_sequential_wrapped_fsdp = wrapper_to_use( 152 TestFSDPCheckpoint.SequentialModule( 153 wrap_fsdp=True, 154 **fsdp_kwargs, 155 ), 156 ) 157 # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), .... 158 inner_ckpt = TestFSDPCheckpoint.SequentialModule( 159 checkpoint_layer=True, 160 offload_activations=offload_activations, 161 wrap_fsdp=True, 162 **fsdp_kwargs, 163 ) 164 165 baseline = TestFSDPCheckpoint.SequentialModule( 166 wrap_fsdp=True, 167 **fsdp_kwargs, 168 ) 169 170 # note that reentrant-based checkpointing requires inputs to have grad 171 # flag set. 172 inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) 173 174 global _save_on_cpu_called 175 models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline] 176 with patch_save_on_cpu(get_patched_save_on_cpu()): 177 for i in range(2): 178 losses = [] 179 outputs = [] 180 for m in models: 181 check_offload = m != baseline and i == 0 and offload_activations 182 if check_offload: 183 self.assertFalse(_save_on_cpu_called) 184 out = m(inp) 185 if check_offload: 186 self.assertTrue(_save_on_cpu_called) 187 _save_on_cpu_called = False 188 loss = out.sum() 189 loss.backward() 190 losses.append(loss) 191 outputs.append(out) 192 193 self._verify_parity(losses, outputs, models) 194 195 dist.barrier() 196 197 @skip_if_lt_x_gpu(2) 198 @parametrize( 199 "cpu_offload", 200 [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], 201 ) 202 @parametrize("offload_activations", [True, False]) 203 @parametrize("use_orig_params", [False, True]) 204 def test_basic_checkpoint_end_to_end( 205 self, 206 cpu_offload: CPUOffload, 207 offload_activations: bool, 208 use_orig_params: bool, 209 ): 210 fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} 211 global _save_on_cpu_called 212 with patch_save_on_cpu(get_patched_save_on_cpu()): 213 seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) 214 # Runs FSDP with no checkpointing 215 fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs) 216 # Runs checkpoint-wrapped FSDP 217 if offload_activations: 218 wrapper_to_use = offload_wrapper 219 else: 220 wrapper_to_use = checkpoint_wrapper 221 222 checkpointed_fsdp = wrapper_to_use( 223 FSDP(deepcopy(seq), **fsdp_kwargs), 224 ) 225 # Runs FSDP-wrapped checkpointed module 226 fsdp_wrapped_checkpoint = FSDP( 227 wrapper_to_use(deepcopy(seq)), 228 **fsdp_kwargs, 229 ) 230 # Runs FSDP with manual calls to checkpoint. 231 fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs) 232 # note that reentrant-based checkpointing requires inputs to have grad 233 # flag set. 234 235 inp = torch.randn( 236 10, 3, device=torch.cuda.current_device(), requires_grad=True 237 ) 238 239 models = [ 240 fsdp_only_seq, 241 checkpointed_fsdp, 242 fsdp_wrapped_checkpoint, 243 fsdp_call_checkpoint, 244 ] 245 # Ensure _save_on_cpu is not yet called 246 self.assertFalse(_save_on_cpu_called) 247 for i in range(6): 248 losses = [] 249 outputs = [] 250 for m in models: 251 check_offload = ( 252 m != fsdp_only_seq and i == 0 and offload_activations 253 ) 254 if m == fsdp_call_checkpoint: 255 # _save_on_cpu should not be called yet 256 self.assertFalse(_save_on_cpu_called) 257 offload_ctx = ( 258 get_patched_save_on_cpu()(pin_memory=True) 259 if offload_activations 260 else contextlib.nullcontext() 261 ) 262 with offload_ctx: 263 out = checkpoint(m, inp, use_reentrant=True) 264 else: 265 # _save_on_cpu should not be called yet 266 self.assertFalse(_save_on_cpu_called) 267 out = m(inp) 268 269 if check_offload: 270 self.assertTrue(_save_on_cpu_called) 271 loss = out.sum() 272 loss.backward() 273 losses.append(loss) 274 outputs.append(out) 275 _save_on_cpu_called = False 276 277 self._verify_parity(losses, outputs, models) 278 279 dist.barrier() 280 281 282instantiate_parametrized_tests(TestFSDPCheckpoint) 283 284 285class CheckpointModule(nn.Module): 286 def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): 287 super().__init__() 288 self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)]) 289 self.checkpoint = checkpoint 290 self.use_reentrant = use_reentrant 291 292 def forward(self, x): 293 return ( 294 checkpoint(self.seq, x, use_reentrant=self.use_reentrant) 295 if self.checkpoint 296 else self.seq(x) 297 ) 298 299 300class ModelWithCheckpointSubmodule(nn.Module): 301 def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): 302 super().__init__() 303 self.l1 = nn.Linear(100, 100) 304 self.s1 = CheckpointModule(checkpoint, use_reentrant) 305 self.s2 = CheckpointModule(checkpoint, use_reentrant) 306 self.relu = nn.ReLU() 307 self.l2 = nn.Linear(100, 100) 308 309 def forward(self, x): 310 return self.l2(self.relu(self.s2(self.s1(self.l1(x))))) 311 312 313class TestModel(nn.Module): 314 def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): 315 super().__init__() 316 self.l1 = nn.Linear(100, 100) 317 self.relu = nn.ReLU() 318 self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) 319 self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) 320 self.l2 = nn.Linear(100, 100) 321 322 def forward(self, x): 323 return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x))))) 324 325 326class TestFSDPCheckpointSubmodule(FSDPTest): 327 # TODO: grad value checks occasionally fails when use_reentrant = True 328 @skip_if_lt_x_gpu(2) 329 @parametrize("use_reentrant", [False]) 330 def test_checkpoint_submodule(self, use_reentrant: bool): 331 model = TestModel(use_reentrant=use_reentrant).cuda() 332 model_ac = deepcopy(model) 333 334 for _, m in model_ac.named_modules(): 335 if isinstance(m, CheckpointModule): 336 m.checkpoint = True 337 338 self.assertTrue(model_ac.checkpoint1.s1.checkpoint) 339 self.assertTrue(model_ac.checkpoint2.s2.checkpoint) 340 341 fsdp_kwargs = { 342 "device_id": torch.cuda.current_device(), 343 "sharding_strategy": ShardingStrategy.NO_SHARD, 344 } 345 346 # Wrap no checkpointing model submodules with FSDP 347 model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs) 348 model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs) 349 350 # Wrap checkpointing model submodules with FSDP 351 model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs) 352 model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs) 353 354 x = torch.randn(2, 100, device="cuda") 355 356 model(x).sum().backward() 357 model_ac(x).sum().backward() 358 359 for (n1, p1), (n2, p2) in zip( 360 model.named_parameters(), model_ac.named_parameters() 361 ): 362 self.assertEqual(n1, n2) 363 self.assertTrue(p1.grad.allclose(p2.grad)) 364 365 366instantiate_parametrized_tests(TestFSDPCheckpointSubmodule) 367 368 369if __name__ == "__main__": 370 run_tests() 371