1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6import torch.distributed as dist 7import torch.nn as nn 8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 9from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 10from torch.testing._internal.common_fsdp import ( 11 CUDAInitMode, 12 FSDPInitMode, 13 FSDPTest, 14 NestedWrappedModule, 15 TransformerWithSharedParams, 16) 17from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 18 19 20if not dist.is_available(): 21 print("Distributed not available, skipping tests", file=sys.stderr) 22 sys.exit(0) 23 24if TEST_WITH_DEV_DBG_ASAN: 25 print( 26 "Skip dev-asan as torch + multiprocessing spawn have known issues", 27 file=sys.stderr, 28 ) 29 sys.exit(0) 30 31 32class TestApply(FSDPTest): 33 @property 34 def world_size(self): 35 return 2 36 37 @torch.no_grad() 38 def _init_linear_weights(self, m): 39 if type(m) == nn.Linear: 40 m.weight.fill_(1.0) 41 m.bias.fill_(1.0) 42 43 def check_weights(self, fsdp, expected_tensor_fn, check): 44 with FSDP.summon_full_params(fsdp, recurse=True): 45 linear_modules = [ 46 module for module in fsdp.modules() if type(module) == nn.Linear 47 ] 48 for module in linear_modules: 49 for param in module.parameters(): 50 expected = expected_tensor_fn(param) 51 check(param, expected, f"Got {param} but expected {expected}") 52 53 def _check_apply(self, fsdp): 54 # Assert linear weights are not all 1.0 55 self.check_weights( 56 fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual 57 ) 58 59 fsdp.apply(self._init_linear_weights) 60 61 # Ensure all weights are 1.0 62 self.check_weights( 63 fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual 64 ) 65 66 @skip_if_lt_x_gpu(2) 67 def test_nested_module_apply(self): 68 """Tests that ``apply()`` modifies parameter values in-place on a 69 non-FSDP-root nested FSDP-wrapped model.""" 70 nested_wrapped_module = NestedWrappedModule.init( 71 self.process_group, 72 FSDPInitMode.RECURSIVE, 73 CUDAInitMode.CUDA_AFTER, 74 ) 75 self._check_apply(nested_wrapped_module) 76 77 @skip_if_lt_x_gpu(2) 78 def test_transformer_module_apply(self): 79 """Tests that ``apply()`` modifies parameter values in-place on an 80 FSDP-wrapped transformer model with shared parameters.""" 81 transformer = TransformerWithSharedParams.init( 82 self.process_group, 83 FSDPInitMode.RECURSIVE, 84 CUDAInitMode.CUDA_AFTER, 85 ) 86 self._check_apply(transformer) 87 88 @skip_if_lt_x_gpu(2) 89 def test_apply_in_summon_raises_error(self): 90 """Tests that calling ``apply()`` on an FSDP instance inside the 91 ``summon_full_params()`` context raises an error.""" 92 transformer = TransformerWithSharedParams.init( 93 self.process_group, 94 FSDPInitMode.RECURSIVE, 95 CUDAInitMode.CUDA_AFTER, 96 ) 97 with transformer.summon_full_params(transformer): 98 with self.assertRaisesRegex(ValueError, "expected to be in states"): 99 transformer.apply(self._init_linear_weights) 100 101 102if __name__ == "__main__": 103 run_tests() 104