1# mypy: allow-untyped-defs 2import torch 3import torch.distributed as dist 4import torch.distributed.distributed_c10d as distributed_c10d 5from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor 6 7 8def _communicate_result(result, pg): 9 # Gather results from all ranks. 10 if result: 11 result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device())) 12 else: 13 result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device())) 14 15 dist.all_reduce(result_tensor, group=pg) 16 17 expected_result = torch.ones( 18 1, device=torch.device(torch.cuda.current_device()) 19 ) * dist.get_world_size(pg) 20 21 return torch.equal(result_tensor, expected_result) 22 23 24def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): 25 if len(args) != 2: 26 raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") 27 28 result = True 29 st1 = args[0] 30 st2 = args[1] 31 if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): 32 raise TypeError( 33 f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor" 34 ) 35 36 # Verify same PG 37 if st1._process_group != st2._process_group: 38 return False 39 40 if distributed_c10d._rank_not_in_group( 41 st1._process_group 42 ) or distributed_c10d._rank_not_in_group(st2._process_group): 43 return distributed_c10d._rank_not_in_group( 44 st1._process_group 45 ) == distributed_c10d._rank_not_in_group(st2._process_group) 46 47 # Verify metadata 48 if st1.metadata() != st2.metadata(): 49 return _communicate_result(False, st1._process_group) 50 51 # Verify number of local shards 52 st1_local_shards = st1.local_shards() 53 st2_local_shards = st2.local_shards() 54 if len(st1_local_shards) != len(st2_local_shards): 55 return _communicate_result(False, st1._process_group) 56 57 # kwargs must be dict-like 58 if kwargs is None: 59 kwargs = {} 60 # Verify each local shard 61 for idx in range(len(st1_local_shards)): 62 if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: 63 return _communicate_result(False, st1._process_group) 64 if not cmp_fun( 65 st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs 66 ): 67 return _communicate_result(False, st1._process_group) 68 69 return _communicate_result(True, st1._process_group) 70 71 72@_sharded_op_impl(torch.equal) 73def equal(types, args, kwargs, process_group): 74 return binary_cmp(torch.equal, types, args, kwargs, process_group) 75 76 77@_sharded_op_impl(torch.allclose) 78def allclose(types, args, kwargs, process_group): 79 return binary_cmp(torch.allclose, types, args, kwargs, process_group) 80