xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from utils import process_bucket_with_remote_server
2
3import torch
4import torch.distributed as c10d
5
6
7def allreduce_hook(state, bucket):
8    r"""
9    A ddp communication hook that uses the process_group allreduce implementation.
10    Args:
11        state (object): maintains state during the training process
12        bucket (GradBucket): gradient bucket
13    """
14    cref = state.cref
15    tensor = bucket.buffer()
16    tensors = [tensor / state.process_group.size()]
17    key = state.get_key(bucket.get_index())
18    if tensor.is_sparse:
19        tensor = tensor.coalesce()
20    tensor_type = "sparse" if tensor.is_sparse else "dense"
21    cref.record_start(
22        "hook_future_metric", key, f"{cref.backend}_{tensor_type}_allreduce"
23    )
24    fut = state.process_group.allreduce(tensors).get_future()
25
26    def callback(fut):
27        cref.record_end("hook_future_metric", key)
28        return fut.wait()
29
30    return fut.then(callback)
31
32
33def hybrid_hook(state, bucket):
34    r"""
35    A ddp communication hook that uses Gloo default process
36    group for sparse gradients and NCCL non-default process
37    group for dense gradients.
38    Args:
39        state (object): maintains state during the training process
40        bucket (GradBucket): gradient bucket
41    """
42    cref = state.cref
43    tensor = bucket.buffer()
44    key = state.get_key(bucket.get_index())
45
46    if tensor.is_sparse:
47        cref.record_start("hook_c10d_metric", key, "gloo_sparse_allreduce")
48        tensor = tensor.coalesce()
49        tensor = tensor / state.process_group.size()
50        c10d.all_reduce(tensor, op=c10d.ReduceOp.SUM)
51        cref.record_end("hook_c10d_metric", key)
52        fut = torch.futures.Future()
53        fut.set_result([tensor])
54    else:
55        cref.record_start("hook_future_metric", key, "nccl_dense_allreduce")
56        tensors = [bucket.buffer() / state.process_group.size()]
57        fut = state.process_group.allreduce(tensors).get_future()
58
59        def callback(fut):
60            cref.record_end("hook_future_metric", key)
61            return fut.wait()
62
63        fut = fut.then(callback)
64    return fut
65
66
67def rpc_hook(state, bucket):
68    r"""
69    A ddp communication hook that averages sparse and dense tensors using
70    process_bucket_with_remote_server method.
71    Args:
72        state (object): maintains state during the training process
73        bucket (GradBucket): gradient bucket
74    """
75    return process_bucket_with_remote_server(state, bucket)
76
77
78def sparse_rpc_hook(state, bucket):
79    r"""
80    A ddp communication hook that uses the current backend allreduce
81    implementation for dense tensors and a server for sparse tensors.
82    Args:
83        state (object): maintains state during the training process
84        bucket (GradBucket): gradient bucket
85    """
86    tensor = bucket.buffer()
87    if tensor.is_sparse:
88        return process_bucket_with_remote_server(state, bucket)
89    else:
90        cref = state.cref
91        tensor = [tensor / state.process_group.size()]
92        key = state.get_key(bucket.get_index())
93        cref.record_start("hook_future_metric", key, f"{cref.backend}_dense_allreduce")
94        fut = state.process_group.allreduce(tensor).get_future()
95
96        def callback(fut):
97            cref.record_end("hook_future_metric", key)
98            return fut.wait()
99
100        return fut.then(callback)
101