1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5import itertools 6from typing import List, Union 7 8import torch 9import torch.distributed as dist 10import torch.nn as nn 11import torch.nn.functional as F 12from torch.distributed._composable import checkpoint, replicate 13from torch.distributed._composable.fsdp import fully_shard 14from torch.distributed._composable.fsdp._fsdp_param_group import ( 15 RegisterPostBackwardFunction, 16) 17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 18from torch.testing._internal.common_fsdp import ( 19 check_sharded_parity, 20 FSDPTest, 21 MLP, 22 patch_reduce_scatter, 23 patch_register_post_backward_hook_backward, 24 reduce_scatter_with_assert, 25) 26from torch.testing._internal.common_utils import run_tests 27 28 29class TestFullyShardFrozen(FSDPTest): 30 @property 31 def world_size(self) -> int: 32 return min(4, torch.cuda.device_count()) 33 34 @skip_if_lt_x_gpu(2) 35 def test_train_mixed_requires_grad_per_group(self): 36 """ 37 Tests training parity with DDP when mixing frozen and non-frozen 38 parameters in the same FSDP communication group. This checks that 39 the reduce-scatters reduce the expected numel and that they are called 40 via the custom autograd function backward (i.e. that they are not 41 delayed until the end of backward). 42 """ 43 self.run_subtests( 44 { 45 "reshard_after_forward": [False, True, 2], 46 "use_activation_checkpointing": [False, True], 47 "freeze_after_init": [False, True], 48 }, 49 self._test_train_mixed_requires_grad_per_group, 50 ) 51 52 def _test_train_mixed_requires_grad_per_group( 53 self, 54 reshard_after_forward: Union[bool, int], 55 use_activation_checkpointing: bool, 56 freeze_after_init: bool, 57 ): 58 torch.manual_seed(42) 59 num_mlps, lin_dim = (3, 32) 60 model = nn.Sequential( 61 *[MLP(lin_dim, torch.device("cpu")) for _ in range(num_mlps)] 62 ) 63 # Train biases only (e.g. like BitFit) 64 if not freeze_after_init: 65 for param_name, param in model.named_parameters(): 66 if "bias" not in param_name: 67 param.requires_grad_(False) 68 ref_model = replicate( 69 copy.deepcopy(model).cuda(), 70 device_ids=[self.rank], 71 find_unused_parameters=freeze_after_init, 72 ) 73 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 74 for mlp in model: 75 if use_activation_checkpointing: 76 checkpoint(mlp) 77 fully_shard(mlp, reshard_after_forward=reshard_after_forward) 78 fully_shard(model, reshard_after_forward=reshard_after_forward) 79 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 80 orig_reduce_scatter = dist.reduce_scatter_tensor 81 if freeze_after_init: 82 for param_name, param in itertools.chain( 83 model.named_parameters(), ref_model.named_parameters() 84 ): 85 if "bias" not in param_name: 86 param.requires_grad_(False) 87 for mlp in model: 88 assert isinstance(mlp, MLP), ( 89 "The reduce-scatter numel check assumes the model consists of " 90 f"only the same MLP class but got {type(mlp)}" 91 ) 92 expected_numel = sum( 93 p._local_tensor.numel() 94 for n, p in model[0].named_parameters() 95 if "bias" in n 96 ) 97 98 def assert_fn(output: torch.Tensor): 99 self.assertEqual(output.numel(), expected_numel) 100 101 reduce_scatter = functools.partial( 102 reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn 103 ) 104 orig_backward = RegisterPostBackwardFunction.backward 105 backward_count = 0 106 107 def backward_with_count(*args, **kwargs): 108 nonlocal backward_count 109 backward_count += 1 110 return orig_backward(*args, **kwargs) 111 112 torch.manual_seed(42 + self.rank + 1) 113 device = torch.device("cuda") 114 with patch_reduce_scatter( 115 reduce_scatter 116 ), patch_register_post_backward_hook_backward(backward_with_count): 117 for iter_idx in range(10): 118 inp = torch.randn((8, lin_dim), device=device) 119 losses: List[torch.Tensor] = [] 120 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 121 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 122 losses.append(_model(inp).sum()) 123 losses[-1].backward() 124 _optim.step() 125 check_sharded_parity(self, ref_model, model) 126 self.assertEqual(losses[0], losses[1]) 127 # Check that the post-backward hooks ran through the autograd 128 # backward, not the final callback (except possibly that of the 129 # first MLP, which does not have an input that requires grad) 130 self.assertTrue(backward_count >= num_mlps - 1) 131 132 @skip_if_lt_x_gpu(2) 133 def test_train_mixed_requires_grad_across_groups(self): 134 """ 135 Tests training parity with DDP when mixing frozen and non-frozen 136 parameters across different FSDP communication groups, including 137 possibly unfreezing parameters. 138 """ 139 self.run_subtests( 140 { 141 "reshard_after_forward": [False, True, 2], 142 "unfreeze_params": [False, True], 143 }, 144 self._test_train_mixed_requires_grad_across_groups, 145 ) 146 147 def _test_train_mixed_requires_grad_across_groups( 148 self, 149 reshard_after_forward: Union[bool, int], 150 unfreeze_params: bool, 151 ): 152 torch.manual_seed(42) 153 num_linears, lin_dim = (6, 32) 154 modules: List[nn.Module] = [] 155 for _ in range(num_linears): 156 modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()] 157 model = nn.Sequential(*modules) 158 ref_model = replicate( 159 copy.deepcopy(model).cuda(), 160 device_ids=[self.rank], 161 find_unused_parameters=True, 162 ) 163 for module in model.modules(): 164 if isinstance(module, nn.Linear): 165 fully_shard(module, reshard_after_forward=reshard_after_forward) 166 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 167 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 168 orig_backward = RegisterPostBackwardFunction.backward 169 backward_count = 0 170 171 def _set_requires_grad(seq: nn.Module, requires_grad: bool): 172 for i in range(num_linears): 173 # Interleave frozen -> non-frozen -> ... linears 174 if i % 2 == 0: 175 for param in seq[i % 2].parameters(): 176 param.requires_grad_(requires_grad) 177 178 def backward_with_count(*args, **kwargs): 179 nonlocal backward_count 180 backward_count += 1 181 return orig_backward(*args, **kwargs) 182 183 _set_requires_grad(model, False) 184 _set_requires_grad(ref_model, False) 185 num_iters, no_grad_iter_idx = (3, 1) 186 torch.manual_seed(42 + self.rank) 187 inp = torch.randn((8, lin_dim), device="cuda") 188 with patch_register_post_backward_hook_backward(backward_with_count): 189 for iter_idx in range(num_iters): 190 losses: List[torch.Tensor] = [] 191 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 192 # Unfreeze the parameters on the last step to emulate some 193 # kinds of fine-tuning 194 if unfreeze_params and iter_idx == num_iters - 1: 195 _set_requires_grad(model, True) 196 if iter_idx == no_grad_iter_idx: 197 with torch.no_grad(): 198 losses.append(_model(inp).sum()) 199 else: 200 losses.append(_model(inp).sum()) 201 losses[-1].backward() 202 _optim.step() 203 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 204 self.assertEqual(losses[0], losses[1]) 205 # Check that the post-backward hooks ran through the autograd 206 # backward, not the final callback (except possibly that of the 207 # first linear, which does not have an input that requires grad) 208 self.assertTrue(backward_count >= num_linears - 1) 209 210 @skip_if_lt_x_gpu(2) 211 def test_multi_forward_mixed_requires_grad(self): 212 """ 213 Tests training parity with DDP when having trainable and frozen modules 214 that participate multiple times in forward. 215 """ 216 self.run_subtests( 217 {"reshard_after_forward": [True, False, 2]}, 218 self._test_multi_forward_mixed_requires_grad, 219 ) 220 221 def _test_multi_forward_mixed_requires_grad( 222 self, 223 reshard_after_forward: Union[bool, int], 224 ): 225 class MultiForwardModule(nn.Module): 226 def __init__(self, device: torch.device): 227 super().__init__() 228 self.layer_0 = nn.Linear(5, 5, device=device) 229 self.layer_no_grad = nn.Linear(5, 5, device=device) 230 self.layer_with_grad = nn.Linear(5, 5, device=device) 231 self.layer_no_grad.requires_grad_(False) 232 233 def forward(self, x: torch.Tensor) -> torch.Tensor: 234 x = self.layer_0(x) 235 for _ in range(3): 236 x = self.layer_no_grad(F.relu(self.layer_with_grad(x))) 237 # Make sure that calling the same layer multiple times 238 # works regardless whether gradient is enabled 239 with torch.no_grad(): 240 x += F.relu(self.layer_with_grad(x)) 241 return x 242 243 torch.manual_seed(42) 244 model = MultiForwardModule(torch.device("cpu")) 245 ref_model = replicate(copy.deepcopy(model).cuda(), device_ids=[self.rank]) 246 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 247 for module in model.modules(): 248 if isinstance(module, nn.Linear): 249 fully_shard(module, reshard_after_forward=reshard_after_forward) 250 fully_shard(model, reshard_after_forward=reshard_after_forward) 251 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 252 for iter_idx in range(10): 253 inp = torch.randn((8, 5), device="cuda") 254 losses: List[torch.Tensor] = [] 255 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 256 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 257 losses.append(_model(inp).sum()) 258 losses[-1].backward() 259 _optim.step() 260 self.assertEqual(losses[0], losses[1]) 261 262 263if __name__ == "__main__": 264 run_tests() 265