1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch import distributed as dist 7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8from torch.nn import Linear 9from torch.optim import SGD 10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 11from torch.testing._internal.common_fsdp import FSDPTest 12from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 13 14 15if not dist.is_available(): 16 print("Distributed not available, skipping tests", file=sys.stderr) 17 sys.exit(0) 18 19if TEST_WITH_DEV_DBG_ASAN: 20 print( 21 "Skip dev-asan as torch + multiprocessing spawn have known issues", 22 file=sys.stderr, 23 ) 24 sys.exit(0) 25 26 27class TestUnevenParamShard(FSDPTest): 28 def _get_ref_results(self, model, input, my_lr): 29 with torch.no_grad(): 30 # Compute one iteration local output. 31 weight = model.weight.T.clone().to(self.rank) 32 v = torch.Tensor(input[self.rank]).to(self.rank) 33 ref_forward_output_my_rank = torch.matmul(v, weight) 34 # Compute one iteration global weight update. 35 v = torch.Tensor(input[: self.world_size]).to(self.rank) 36 grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size) 37 ref_weight_out = weight - grad.T * my_lr 38 39 return ref_forward_output_my_rank, ref_weight_out 40 41 @skip_if_lt_x_gpu(2) 42 def test_one_iteration(self): 43 """Test FSDP with uneven divide of parameter shards.""" 44 model = Linear(3, 3, bias=False) 45 input = torch.rand(8, 3) 46 my_lr = 0.1 47 48 ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( 49 model, input, my_lr 50 ) 51 52 model.to(self.rank) 53 model = FSDP(model) 54 optim = SGD(model.parameters(), lr=my_lr) 55 self.assertTrue(len(input) >= self.world_size) 56 in_data = torch.Tensor(input[self.rank]).to(self.rank) 57 out = model(in_data) 58 out.float().sum().backward() 59 optim.step() 60 optim.zero_grad() 61 62 with model.summon_full_params(model): 63 weight_out = model.module.weight.T.clone() 64 self.assertEqual(ref_forward_output_my_rank, out) 65 self.assertEqual(ref_weight_out, weight_out) 66 67 68if __name__ == "__main__": 69 run_tests() 70