1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: distributed"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport re 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 8*da0073e9SAndroid Build Coastguard Workerimport torch.cuda.nccl as nccl 9*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as c10d 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker dtypes, 13*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 16*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 17*da0073e9SAndroid Build Coastguard Worker load_tests, 18*da0073e9SAndroid Build Coastguard Worker NoTest, 19*da0073e9SAndroid Build Coastguard Worker run_tests, 20*da0073e9SAndroid Build Coastguard Worker skip_but_pass_in_sandcastle_if, 21*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 22*da0073e9SAndroid Build Coastguard Worker TestCase, 23*da0073e9SAndroid Build Coastguard Worker) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerHIP_VERSION = ( 27*da0073e9SAndroid Build Coastguard Worker 0.0 28*da0073e9SAndroid Build Coastguard Worker if torch.version.hip is None 29*da0073e9SAndroid Build Coastguard Worker else float(re.search(r"^\d+\.\d+", torch.version.hip)[0]) 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 33*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 34*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard WorkernGPUs = torch.cuda.device_count() 37*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA: 38*da0073e9SAndroid Build Coastguard Worker print("CUDA not available, skipping tests", file=sys.stderr) 39*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerdatatypes = [torch.float] 43*da0073e9SAndroid Build Coastguard Workerif ( 44*da0073e9SAndroid Build Coastguard Worker TEST_CUDA and c10d.is_nccl_available() and nccl.version() >= (2, 10) 45*da0073e9SAndroid Build Coastguard Worker) or TEST_WITH_ROCM: 46*da0073e9SAndroid Build Coastguard Worker datatypes.append(torch.bfloat16) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerclass TestNCCL(TestCase): 50*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 51*da0073e9SAndroid Build Coastguard Worker def test_unique_id(self, device): 52*da0073e9SAndroid Build Coastguard Worker uid = nccl.unique_id() 53*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(uid, bytes) 54*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(uid), 1) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 57*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm" 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 60*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") 61*da0073e9SAndroid Build Coastguard Worker @dtypes(*datatypes) 62*da0073e9SAndroid Build Coastguard Worker def test_broadcast(self, device, dtype): 63*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(128).uniform_().to(dtype=dtype) 64*da0073e9SAndroid Build Coastguard Worker tensors = [expected.cuda()] 65*da0073e9SAndroid Build Coastguard Worker for device in range(1, torch.cuda.device_count()): 66*da0073e9SAndroid Build Coastguard Worker tensors.append(torch.zeros(128, dtype=dtype, device=device)) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker nccl.broadcast(tensors) 69*da0073e9SAndroid Build Coastguard Worker for i in range(torch.cuda.device_count()): 70*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensors[i], expected) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker # Test with tuple 73*da0073e9SAndroid Build Coastguard Worker tensors = [expected.cuda()] 74*da0073e9SAndroid Build Coastguard Worker for device in range(1, torch.cuda.device_count()): 75*da0073e9SAndroid Build Coastguard Worker tensors.append(torch.zeros(128, dtype=dtype, device=device)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker nccl.broadcast(tuple(tensors)) 78*da0073e9SAndroid Build Coastguard Worker for i in range(torch.cuda.device_count()): 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensors[i], expected) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 82*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm" 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 85*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") 86*da0073e9SAndroid Build Coastguard Worker @dtypes(*datatypes) 87*da0073e9SAndroid Build Coastguard Worker def test_reduce(self, device, dtype): 88*da0073e9SAndroid Build Coastguard Worker cpu_tensors = [ 89*da0073e9SAndroid Build Coastguard Worker torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs) 90*da0073e9SAndroid Build Coastguard Worker ] 91*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(128, dtype=dtype) 92*da0073e9SAndroid Build Coastguard Worker for t in cpu_tensors: 93*da0073e9SAndroid Build Coastguard Worker expected.add_(t) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)] 96*da0073e9SAndroid Build Coastguard Worker nccl.reduce(tensors) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensors[0], expected) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker # Test with tuple 101*da0073e9SAndroid Build Coastguard Worker tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)] 102*da0073e9SAndroid Build Coastguard Worker nccl.reduce(tuple(tensors)) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensors[0], expected) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 107*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") 108*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 109*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16, # noqa: F821 110*da0073e9SAndroid Build Coastguard Worker "Skip bfloat16 test for ROCm < 3.5", 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker @dtypes(*datatypes) 113*da0073e9SAndroid Build Coastguard Worker def test_all_reduce(self, device, dtype): 114*da0073e9SAndroid Build Coastguard Worker cpu_tensors = [ 115*da0073e9SAndroid Build Coastguard Worker torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs) 116*da0073e9SAndroid Build Coastguard Worker ] 117*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(128, dtype=dtype) 118*da0073e9SAndroid Build Coastguard Worker for t in cpu_tensors: 119*da0073e9SAndroid Build Coastguard Worker expected.add_(t) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)] 122*da0073e9SAndroid Build Coastguard Worker nccl.all_reduce(tensors) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, expected) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker # Test with tuple. 128*da0073e9SAndroid Build Coastguard Worker tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs)) 129*da0073e9SAndroid Build Coastguard Worker nccl.all_reduce(tensors) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, expected) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker # Test with set. 135*da0073e9SAndroid Build Coastguard Worker tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)} 136*da0073e9SAndroid Build Coastguard Worker nccl.all_reduce(tensors) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, expected) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 142*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm" 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 145*da0073e9SAndroid Build Coastguard Worker def test_collective_errors(self, device): 146*da0073e9SAndroid Build Coastguard Worker t = torch.rand(10).cuda(0) 147*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 148*da0073e9SAndroid Build Coastguard Worker TypeError, "Inputs should be a collection of tensors" 149*da0073e9SAndroid Build Coastguard Worker ): 150*da0073e9SAndroid Build Coastguard Worker nccl.all_reduce(t) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 153*da0073e9SAndroid Build Coastguard Worker TypeError, "Inputs should be a collection of tensors" 154*da0073e9SAndroid Build Coastguard Worker ): 155*da0073e9SAndroid Build Coastguard Worker nccl.reduce(t) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 158*da0073e9SAndroid Build Coastguard Worker TypeError, "Inputs should be a collection of tensors" 159*da0073e9SAndroid Build Coastguard Worker ): 160*da0073e9SAndroid Build Coastguard Worker nccl.broadcast(t) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 163*da0073e9SAndroid Build Coastguard Worker TypeError, "Inputs should be a collection of tensors" 164*da0073e9SAndroid Build Coastguard Worker ): 165*da0073e9SAndroid Build Coastguard Worker nccl.all_gather(t, t) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 168*da0073e9SAndroid Build Coastguard Worker TypeError, "Inputs should be a collection of tensors" 169*da0073e9SAndroid Build Coastguard Worker ): 170*da0073e9SAndroid Build Coastguard Worker nccl.reduce_scatter(t, t) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 173*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm" 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 176*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") 177*da0073e9SAndroid Build Coastguard Worker @dtypes(*datatypes) 178*da0073e9SAndroid Build Coastguard Worker def test_all_gather(self, device, dtype): 179*da0073e9SAndroid Build Coastguard Worker cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)] 180*da0073e9SAndroid Build Coastguard Worker expected = torch.cat(cpu_inputs, 0) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)] 183*da0073e9SAndroid Build Coastguard Worker outputs = [ 184*da0073e9SAndroid Build Coastguard Worker torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs) 185*da0073e9SAndroid Build Coastguard Worker ] 186*da0073e9SAndroid Build Coastguard Worker nccl.all_gather(inputs, outputs) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker for tensor in outputs: 189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, expected) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker # Test with tuple. 192*da0073e9SAndroid Build Coastguard Worker inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)] 193*da0073e9SAndroid Build Coastguard Worker outputs = [ 194*da0073e9SAndroid Build Coastguard Worker torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs) 195*da0073e9SAndroid Build Coastguard Worker ] 196*da0073e9SAndroid Build Coastguard Worker nccl.all_gather(tuple(inputs), tuple(outputs)) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker for tensor in outputs: 199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, expected) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if( 202*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm" 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") 205*da0073e9SAndroid Build Coastguard Worker @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") 206*da0073e9SAndroid Build Coastguard Worker @dtypes(*datatypes) 207*da0073e9SAndroid Build Coastguard Worker def test_reduce_scatter(self, device, dtype): 208*da0073e9SAndroid Build Coastguard Worker in_size = 32 * nGPUs 209*da0073e9SAndroid Build Coastguard Worker out_size = 32 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker cpu_inputs = [ 212*da0073e9SAndroid Build Coastguard Worker torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs) 213*da0073e9SAndroid Build Coastguard Worker ] 214*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(in_size, dtype=dtype) 215*da0073e9SAndroid Build Coastguard Worker for t in cpu_inputs: 216*da0073e9SAndroid Build Coastguard Worker expected.add_(t) 217*da0073e9SAndroid Build Coastguard Worker expected = expected.view(nGPUs, 32) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)] 220*da0073e9SAndroid Build Coastguard Worker outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)] 221*da0073e9SAndroid Build Coastguard Worker nccl.reduce_scatter(inputs, outputs) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker for i in range(nGPUs): 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs[i], expected[i]) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker # Test with tuple 227*da0073e9SAndroid Build Coastguard Worker inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)] 228*da0073e9SAndroid Build Coastguard Worker outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)] 229*da0073e9SAndroid Build Coastguard Worker nccl.reduce_scatter(tuple(inputs), tuple(outputs)) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker for i in range(nGPUs): 232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs[i], expected[i]) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNCCL, globals(), only_for="cuda") 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 238*da0073e9SAndroid Build Coastguard Worker run_tests() 239