1from utils import process_bucket_with_remote_server 2 3import torch 4import torch.distributed as c10d 5 6 7def allreduce_hook(state, bucket): 8 r""" 9 A ddp communication hook that uses the process_group allreduce implementation. 10 Args: 11 state (object): maintains state during the training process 12 bucket (GradBucket): gradient bucket 13 """ 14 cref = state.cref 15 tensor = bucket.buffer() 16 tensors = [tensor / state.process_group.size()] 17 key = state.get_key(bucket.get_index()) 18 if tensor.is_sparse: 19 tensor = tensor.coalesce() 20 tensor_type = "sparse" if tensor.is_sparse else "dense" 21 cref.record_start( 22 "hook_future_metric", key, f"{cref.backend}_{tensor_type}_allreduce" 23 ) 24 fut = state.process_group.allreduce(tensors).get_future() 25 26 def callback(fut): 27 cref.record_end("hook_future_metric", key) 28 return fut.wait() 29 30 return fut.then(callback) 31 32 33def hybrid_hook(state, bucket): 34 r""" 35 A ddp communication hook that uses Gloo default process 36 group for sparse gradients and NCCL non-default process 37 group for dense gradients. 38 Args: 39 state (object): maintains state during the training process 40 bucket (GradBucket): gradient bucket 41 """ 42 cref = state.cref 43 tensor = bucket.buffer() 44 key = state.get_key(bucket.get_index()) 45 46 if tensor.is_sparse: 47 cref.record_start("hook_c10d_metric", key, "gloo_sparse_allreduce") 48 tensor = tensor.coalesce() 49 tensor = tensor / state.process_group.size() 50 c10d.all_reduce(tensor, op=c10d.ReduceOp.SUM) 51 cref.record_end("hook_c10d_metric", key) 52 fut = torch.futures.Future() 53 fut.set_result([tensor]) 54 else: 55 cref.record_start("hook_future_metric", key, "nccl_dense_allreduce") 56 tensors = [bucket.buffer() / state.process_group.size()] 57 fut = state.process_group.allreduce(tensors).get_future() 58 59 def callback(fut): 60 cref.record_end("hook_future_metric", key) 61 return fut.wait() 62 63 fut = fut.then(callback) 64 return fut 65 66 67def rpc_hook(state, bucket): 68 r""" 69 A ddp communication hook that averages sparse and dense tensors using 70 process_bucket_with_remote_server method. 71 Args: 72 state (object): maintains state during the training process 73 bucket (GradBucket): gradient bucket 74 """ 75 return process_bucket_with_remote_server(state, bucket) 76 77 78def sparse_rpc_hook(state, bucket): 79 r""" 80 A ddp communication hook that uses the current backend allreduce 81 implementation for dense tensors and a server for sparse tensors. 82 Args: 83 state (object): maintains state during the training process 84 bucket (GradBucket): gradient bucket 85 """ 86 tensor = bucket.buffer() 87 if tensor.is_sparse: 88 return process_bucket_with_remote_server(state, bucket) 89 else: 90 cref = state.cref 91 tensor = [tensor / state.process_group.size()] 92 key = state.get_key(bucket.get_index()) 93 cref.record_start("hook_future_metric", key, f"{cref.backend}_dense_allreduce") 94 fut = state.process_group.allreduce(tensor).get_future() 95 96 def callback(fut): 97 cref.record_end("hook_future_metric", key) 98 return fut.wait() 99 100 return fut.then(callback) 101