xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.distributed as dist
4import torch.distributed.distributed_c10d as distributed_c10d
5from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor
6
7
8def _communicate_result(result, pg):
9    # Gather results from all ranks.
10    if result:
11        result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device()))
12    else:
13        result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device()))
14
15    dist.all_reduce(result_tensor, group=pg)
16
17    expected_result = torch.ones(
18        1, device=torch.device(torch.cuda.current_device())
19    ) * dist.get_world_size(pg)
20
21    return torch.equal(result_tensor, expected_result)
22
23
24def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
25    if len(args) != 2:
26        raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}")
27
28    result = True
29    st1 = args[0]
30    st2 = args[1]
31    if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
32        raise TypeError(
33            f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor"
34        )
35
36    # Verify same PG
37    if st1._process_group != st2._process_group:
38        return False
39
40    if distributed_c10d._rank_not_in_group(
41        st1._process_group
42    ) or distributed_c10d._rank_not_in_group(st2._process_group):
43        return distributed_c10d._rank_not_in_group(
44            st1._process_group
45        ) == distributed_c10d._rank_not_in_group(st2._process_group)
46
47    # Verify metadata
48    if st1.metadata() != st2.metadata():
49        return _communicate_result(False, st1._process_group)
50
51    # Verify number of local shards
52    st1_local_shards = st1.local_shards()
53    st2_local_shards = st2.local_shards()
54    if len(st1_local_shards) != len(st2_local_shards):
55        return _communicate_result(False, st1._process_group)
56
57    # kwargs must be dict-like
58    if kwargs is None:
59        kwargs = {}
60    # Verify each local shard
61    for idx in range(len(st1_local_shards)):
62        if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
63            return _communicate_result(False, st1._process_group)
64        if not cmp_fun(
65            st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs
66        ):
67            return _communicate_result(False, st1._process_group)
68
69    return _communicate_result(True, st1._process_group)
70
71
72@_sharded_op_impl(torch.equal)
73def equal(types, args, kwargs, process_group):
74    return binary_cmp(torch.equal, types, args, kwargs, process_group)
75
76
77@_sharded_op_impl(torch.allclose)
78def allclose(types, args, kwargs, process_group):
79    return binary_cmp(torch.allclose, types, args, kwargs, process_group)
80