1# Owner(s): ["oncall: distributed"] 2 3import itertools 4import sys 5from typing import Union 6 7import torch 8import torch.distributed as dist 9import torch.nn as nn 10from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision 11from torch.distributed.fsdp.wrap import ( 12 always_wrap_policy as always_wrap, 13 enable_wrap, 14 ModuleWrapPolicy, 15 wrap, 16) 17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 18from torch.testing._internal.common_fsdp import FSDPTest 19from torch.testing._internal.common_utils import ( 20 instantiate_parametrized_tests, 21 parametrize, 22 run_tests, 23 skip_but_pass_in_sandcastle_if, 24 TEST_WITH_DEV_DBG_ASAN, 25) 26 27 28_TORCHDISTX_AVAIL = True 29try: 30 from torchdistx import deferred_init 31except ImportError: 32 _TORCHDISTX_AVAIL = False 33 34 35if not dist.is_available(): 36 print("Distributed not available, skipping tests", file=sys.stderr) 37 sys.exit(0) 38 39if TEST_WITH_DEV_DBG_ASAN: 40 print( 41 "Skip dev-asan as torch + multiprocessing spawn have known issues", 42 file=sys.stderr, 43 ) 44 sys.exit(0) 45 46 47def _reset_params_if_meta(is_meta: bool, model: nn.Module): 48 # For torchdistX init, we don't need to call reset_params, as 49 # deferred_init(model).materialize() is equivalent to model(). 50 if is_meta: 51 for module in model.modules(): 52 # Assume that a module has `reset_parameters()` iff it has directly 53 # managed parameters or buffers 54 if hasattr(module, "reset_parameters"): 55 module.reset_parameters() 56 57 58class MyLinear(nn.Linear): 59 """ 60 Linear layer with deterministic reset_parameters for testing. 61 """ 62 63 def __init__(self, *args, **kwargs): 64 super().__init__(*args, **kwargs) 65 66 def reset_parameters(self, *args, **kwargs): 67 torch.manual_seed(42) 68 with torch.no_grad(): 69 # Use an initialization method that depends on shape 70 torch.nn.init.xavier_uniform_(self.weight, 1.0) 71 72 73class MyBuffer(nn.Module): 74 def __init__(self, device: torch.device): 75 super().__init__() 76 self.buf = torch.nn.Buffer(torch.empty((3, 3), device=device)) 77 78 def reset_parameters(self, *args, **kwargs): 79 torch.manual_seed(42) 80 # Use an initialization method that depends on shape 81 torch.nn.init.xavier_uniform_(self.buf, 0.5) 82 83 84class MyModel(nn.Module): 85 def __init__(self, device: torch.device): 86 super().__init__() 87 self.lin1 = MyLinear(2, 2, bias=False, device=device) 88 self.lin2 = MyLinear(2, 2, bias=False, device=device) 89 self.buf_mod = MyBuffer(device) 90 91 def forward(self, x): 92 return self.lin2(self.lin1(x)) 93 94 95class NestedModel(nn.Module): 96 def __init__(self, device): 97 super().__init__() 98 self.lin1 = MyLinear(2, 2, bias=False, device=device) 99 self.lin1 = wrap(self.lin1) 100 self.lin2 = MyLinear(2, 2, bias=False, device=device) 101 self.l3 = MyModel(device=device) 102 self.l3 = wrap(self.l3) 103 104 def forward(self, x): 105 return self.l3(self.lin2(self.lin1(x))) 106 107 108def _init_with_reset_params(module: nn.Module): 109 """ 110 to_empty + reset_parameters() init function example for modules 111 initialized with device="meta" 112 """ 113 has_meta_states = any( 114 t.is_meta 115 for t in itertools.chain( 116 module.parameters(recurse=False), module.buffers(recurse=False) 117 ) 118 ) 119 if has_meta_states: 120 device = torch.device("cuda", torch.cuda.current_device()) 121 module.to_empty(device=device, recurse=False) 122 module.reset_parameters() 123 124 125def _init_with_torchdistX(module: nn.Module): 126 """ 127 torchdistX-based deferred module initialization function example 128 using ``materialize_module``. 129 """ 130 assert _TORCHDISTX_AVAIL 131 132 def check_fn(k): 133 return not isinstance(k, FSDP) 134 135 deferred_init.materialize_module(module, check_fn=check_fn) 136 137 138class TestFSDPWithMetaDevice(FSDPTest): 139 @property 140 def world_size(self): 141 return 2 142 143 @property 144 def process_group(self): 145 return dist.distributed_c10d._get_default_group() 146 147 def _compare_fsdp(self, fsdp1, fsdp2): 148 with FSDP.summon_full_params(fsdp1): 149 with FSDP.summon_full_params(fsdp2): 150 for p1, p2 in zip(fsdp1.parameters(), fsdp2.parameters()): 151 self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") 152 153 def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): 154 # Create model on meta device and wrap with FSDP. 155 model = meta_module_fn() 156 is_meta = next(model.parameters()).is_meta 157 fsdp_meta = FSDP( 158 model, 159 auto_wrap_policy=always_wrap, 160 param_init_fn=init_fn, 161 ) 162 163 meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) 164 165 # Test to make sure it is the same model parameters as regular FSDP 166 # approach. 167 regular = MyModel(device="cuda") 168 _reset_params_if_meta(is_meta, regular) 169 fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) 170 regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) 171 172 self._compare_fsdp(fsdp_meta, fsdp_regular) 173 inp = torch.randn(10, 2, device="cuda") 174 fsdp_meta(inp).sum().backward() 175 fsdp_regular(inp).sum().backward() 176 meta_opt.step() 177 regular_opt.step() 178 self._compare_fsdp(fsdp_meta, fsdp_regular) 179 180 # Test that meta init works if all submodules are contained in only a 181 # single FSDP unit. 182 model = meta_module_fn() 183 fsdp_meta = FSDP(model, param_init_fn=init_fn) 184 meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) 185 regular = MyModel(device="cuda") 186 _reset_params_if_meta(is_meta, regular) 187 fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) 188 regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) 189 190 # Run a forward + backward pass + optimizer step 191 fsdp_meta(inp).sum().backward() 192 fsdp_regular(inp).sum().backward() 193 meta_opt.step() 194 regular_opt.step() 195 self._compare_fsdp(fsdp_meta, fsdp_regular) 196 197 @skip_if_lt_x_gpu(2) 198 def test_simple_model_with_meta_device_reset_params(self): 199 def meta_module_fn(): 200 return MyModel(device="meta") 201 202 self._test_simple_model_with_meta_device( 203 meta_module_fn, _init_with_reset_params 204 ) 205 206 @skip_if_lt_x_gpu(2) 207 def test_simple_model_with_meta_device_default_init(self): 208 def meta_module_fn(): 209 return MyModel(device="meta") 210 211 self._test_simple_model_with_meta_device(meta_module_fn) 212 213 @skip_if_lt_x_gpu(2) 214 @skip_but_pass_in_sandcastle_if( 215 not _TORCHDISTX_AVAIL, 216 "Test requires torchdistX: https://github.com/pytorch/torchdistX", 217 ) 218 def test_simple_model_with_torchdistX_default_init(self): 219 def meta_module_fn(): 220 return deferred_init.deferred_init(MyModel, device="cuda") 221 222 self._test_simple_model_with_meta_device(meta_module_fn) 223 224 @skip_if_lt_x_gpu(2) 225 @skip_but_pass_in_sandcastle_if( 226 not _TORCHDISTX_AVAIL, 227 "Test requires torchdistX: https://github.com/pytorch/torchdistX", 228 ) 229 def test_simple_model_with_torchdistX_init_fn(self): 230 def meta_module_fn(): 231 return deferred_init.deferred_init(MyModel, device="cuda") 232 233 self._test_simple_model_with_meta_device( 234 meta_module_fn, init_fn=_init_with_torchdistX 235 ) 236 237 def _test_nested_model_with_meta_device( 238 self, auto_wrap, meta_module_fn, init_fn=None 239 ): 240 if auto_wrap: 241 module = meta_module_fn() 242 is_meta = ( 243 next(module.parameters()).is_meta or next(module.buffers()).is_meta 244 ) 245 fsdp_meta = FSDP( 246 module, 247 auto_wrap_policy=always_wrap, 248 param_init_fn=init_fn, 249 ) 250 meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) 251 module_regular = NestedModel(device="cuda") 252 _reset_params_if_meta(is_meta, module_regular) 253 fsdp_regular = FSDP( 254 module_regular, 255 auto_wrap_policy=always_wrap, 256 ) 257 regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) 258 else: 259 with enable_wrap( 260 wrapper_cls=FSDP, 261 param_init_fn=init_fn, 262 ): 263 module = meta_module_fn() 264 is_meta = next(module.parameters()).is_meta 265 # Non FSDP modules will still be initialized because they bubble up 266 # to be part of a larger FSDP unit. 267 fsdp_meta = wrap(module) 268 meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) 269 270 # Init and reset parameters before wrapping so that reset_params 271 # matches up with meta device's initialization. 272 module_regular = NestedModel(device="cuda") 273 _reset_params_if_meta(is_meta, module_regular) 274 with enable_wrap(wrapper_cls=FSDP): 275 module_regular.lin1 = wrap(module_regular.lin1) 276 module_regular.l3 = wrap(module_regular.l3) 277 fsdp_regular = wrap(module_regular) 278 regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) 279 280 # Compare it before training 281 self._compare_fsdp(fsdp_meta, fsdp_regular) 282 inp = torch.randn(10, 2, device="cuda") 283 fsdp_meta(inp).sum().backward() 284 fsdp_regular(inp).sum().backward() 285 meta_opt.step() 286 regular_opt.step() 287 self._compare_fsdp(fsdp_meta, fsdp_regular) 288 289 @skip_if_lt_x_gpu(2) 290 @parametrize("auto_wrap", [True, False]) 291 def test_nested_model_with_meta_device_reset_params(self, auto_wrap): 292 def meta_module_fn(): 293 return NestedModel(device="meta") 294 295 self._test_nested_model_with_meta_device( 296 auto_wrap=auto_wrap, 297 meta_module_fn=meta_module_fn, 298 init_fn=_init_with_reset_params, 299 ) 300 301 @skip_if_lt_x_gpu(2) 302 @parametrize("auto_wrap", [True, False]) 303 def test_nested_model_with_meta_device_default_init(self, auto_wrap): 304 def meta_module_fn(): 305 return NestedModel(device="meta") 306 307 self._test_nested_model_with_meta_device( 308 auto_wrap=auto_wrap, 309 meta_module_fn=meta_module_fn, 310 ) 311 312 @skip_if_lt_x_gpu(2) 313 @skip_but_pass_in_sandcastle_if( 314 not _TORCHDISTX_AVAIL, 315 "Test requires torchdistX: https://github.com/pytorch/torchdistX", 316 ) 317 @parametrize("auto_wrap", [True, False]) 318 def test_nested_model_with_torchdistX_default_init(self, auto_wrap): 319 def meta_module_fn(): 320 return deferred_init.deferred_init(NestedModel, device="cuda") 321 322 self._test_nested_model_with_meta_device( 323 auto_wrap=auto_wrap, meta_module_fn=meta_module_fn 324 ) 325 326 @skip_if_lt_x_gpu(2) 327 @skip_but_pass_in_sandcastle_if( 328 not _TORCHDISTX_AVAIL, 329 "Test requires torchdistX: https://github.com/pytorch/torchdistX", 330 ) 331 @parametrize("auto_wrap", [True, False]) 332 def test_nested_model_with_torchdistX_init_fn(self, auto_wrap): 333 def meta_module_fn(): 334 return deferred_init.deferred_init(NestedModel, device="cuda") 335 336 self._test_nested_model_with_meta_device( 337 auto_wrap=auto_wrap, 338 meta_module_fn=meta_module_fn, 339 init_fn=_init_with_torchdistX, 340 ) 341 342 def _test_bad_arg(self, meta_module_fn): 343 mod = meta_module_fn() 344 with self.assertRaisesRegex(ValueError, "to be callable"): 345 FSDP(mod, param_init_fn=42) 346 347 @skip_if_lt_x_gpu(2) 348 @skip_but_pass_in_sandcastle_if( 349 not _TORCHDISTX_AVAIL, 350 "Test requires torchdistX: https://github.com/pytorch/torchdistX", 351 ) 352 def test_bad_arg_torchdistx(self): 353 def meta_module_fn(): 354 return deferred_init.deferred_init(NestedModel, "cuda") 355 356 self._test_bad_arg(meta_module_fn) 357 358 @skip_if_lt_x_gpu(2) 359 def test_bad_arg_meta(self): 360 def meta_module_fn(): 361 return NestedModel(device="meta") 362 363 self._test_bad_arg(meta_module_fn) 364 365 @skip_if_lt_x_gpu(2) 366 def test_meta_device_with_mixed_precision(self): 367 """ 368 Tests meta device initialization with a ``param_init_fn`` when 369 specifying mixed precision with ``param_dtype=torch.float32``. 370 """ 371 372 class FakeLinear(nn.Module): 373 def __init__( 374 self, in_dim: int, out_dim: int, device: Union[torch.device, str] 375 ) -> None: 376 super().__init__() 377 self.weight = nn.Parameter( 378 torch.randn((in_dim, out_dim), device=device) 379 ) 380 381 def forward(self, x: torch.Tensor) -> torch.Tensor: 382 return x @ self.weight 383 384 class Model(nn.Module): 385 def __init__(self) -> None: 386 super().__init__() 387 self.lin1 = nn.Linear(5, 5, device="meta") 388 self.lin2 = FakeLinear(5, 5, device="meta") 389 self.relu = nn.ReLU() 390 391 def forward(self, x: torch.Tensor) -> torch.Tensor: 392 return self.lin2(self.relu(self.lin1(x))) 393 394 def _module_init_fn(self, module: nn.Module): 395 if isinstance(module, nn.Linear): 396 torch.nn.init.normal_(module.weight, mean=0.0, std=0.1) 397 if module.bias is not None: 398 torch.nn.init.zeros_(module.bias) 399 400 def _param_init_fn(module: nn.Module) -> None: 401 # TODO: `module.to_empty()` is not generally correct for meta 402 # device initialization. 403 # https://github.com/pytorch/pytorch/issues/90465 404 module.to_empty(device=torch.device("cuda")) 405 module.apply(model._module_init_fn) 406 407 model = Model() 408 # Wrap `lin1` and the top level `model` to create nested FSDP instances 409 # where each instance has parameters 410 FSDP( 411 model, 412 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), 413 mixed_precision=MixedPrecision( 414 param_dtype=torch.float32, reduce_dtype=torch.float16 415 ), 416 param_init_fn=_param_init_fn, 417 device_id=torch.cuda.current_device(), 418 ) 419 420 421instantiate_parametrized_tests(TestFSDPWithMetaDevice) 422 423if __name__ == "__main__": 424 run_tests() 425