1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for GPU collective operations implemented using NVIDIA nccl.""" 16import threading 17 18from tensorflow.python.eager import context 19from tensorflow.python.eager import def_function 20from tensorflow.python.framework import device 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import gen_nccl_ops 23 24 25_module_lock = threading.Lock() 26_shared_name_counter = 0 27 28 29def all_sum(tensors): 30 """Returns a list of tensors with the all-reduce sum across `tensors`. 31 32 The computation is done with an all-reduce operation, so if only some of the 33 returned tensors are evaluated then the computation will hang. 34 35 Args: 36 tensors: The input tensors across which to sum; must be assigned 37 to GPU devices. 38 39 Returns: 40 List of tensors, each with the sum of the input tensors, where tensor i has 41 the same device as `tensors[i]`. 42 """ 43 return _apply_all_reduce('sum', tensors) 44 45 46@ops.RegisterGradient('NcclAllReduce') 47def _all_sum_grad(op, grad): 48 """The gradients for `all_sum`. 49 50 Args: 51 op: The `all_sum` `Operation` that we are differentiating. 52 grad: Gradient with respect to the output of the `all_sum` op. 53 54 Returns: 55 The gradient with respect to the output of `all_sum`. 56 57 Raises: 58 LookupError: If `reduction` is not `sum`. 59 """ 60 if op.get_attr('reduction') != b'sum': 61 raise LookupError('No gradient defined for NcclAllReduce except for ' 62 'reduction="sum".') 63 64 _check_device(grad, expected=op.device) 65 num_devices = op.get_attr('num_devices') 66 shared_name = op.get_attr('shared_name') + b'_grad' 67 68 with ops.device(op.device): 69 return gen_nccl_ops.nccl_all_reduce( 70 input=grad, 71 reduction='sum', 72 num_devices=num_devices, 73 shared_name=shared_name) 74 75 76def all_prod(tensors): 77 """Returns a list of tensors with the all-reduce product across `tensors`. 78 79 The computation is done with an all-reduce operation, so if only some of the 80 returned tensors are evaluated then the computation will hang. 81 82 Args: 83 tensors: The input tensors across which to multiply; must be assigned 84 to GPU devices. 85 86 Returns: 87 List of tensors, each with the product of the input tensors, where tensor i 88 has the same device as `tensors[i]`. 89 """ 90 return _apply_all_reduce('prod', tensors) 91 92 93def all_min(tensors): 94 """Returns a list of tensors with the all-reduce min across `tensors`. 95 96 The computation is done with an all-reduce operation, so if only some of the 97 returned tensors are evaluated then the computation will hang. 98 99 Args: 100 tensors: The input tensors across which to reduce; must be assigned 101 to GPU devices. 102 103 Returns: 104 List of tensors, each with the minimum of the input tensors, where tensor i 105 has the same device as `tensors[i]`. 106 """ 107 return _apply_all_reduce('min', tensors) 108 109 110def all_max(tensors): 111 """Returns a list of tensors with the all-reduce max across `tensors`. 112 113 The computation is done with an all-reduce operation, so if only some of the 114 returned tensors are evaluated then the computation will hang. 115 116 Args: 117 tensors: The input tensors across which to reduce; must be assigned 118 to GPU devices. 119 120 Returns: 121 List of tensors, each with the maximum of the input tensors, where tensor i 122 has the same device as `tensors[i]`. 123 """ 124 return _apply_all_reduce('max', tensors) 125 126 127def reduce_sum(tensors): 128 """Returns a tensor with the reduce sum across `tensors`. 129 130 The computation is done with a reduce operation, so only one tensor is 131 returned. 132 133 Args: 134 tensors: The input tensors across which to sum; must be assigned 135 to GPU devices. 136 137 Returns: 138 A tensor containing the sum of the input tensors. 139 140 Raises: 141 LookupError: If context is not currently using a GPU device. 142 """ 143 return _apply_reduce('sum', tensors) 144 145 146@ops.RegisterGradient('NcclReduce') 147def _reduce_sum_grad(op, grad): 148 """The gradients for input `Operation` of `reduce_sum`. 149 150 Args: 151 op: The `sum send` `Operation` that we are differentiating. 152 grad: Gradient with respect to the output of the `reduce_sum` op. 153 154 Returns: 155 The gradient with respect to the input of `reduce_sum` op. 156 157 Raises: 158 LookupError: If the reduction attribute of op is not `sum`. 159 """ 160 if op.get_attr('reduction') != b'sum': 161 raise LookupError('No gradient defined for NcclAllReduce except for ' 162 'reduction="sum".') 163 _check_device(grad, expected=op.device) 164 165 with ops.device(op.device): 166 result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) 167 168 return [result] * len(op.inputs) 169 170 171def broadcast(tensor): 172 """Returns a tensor that can be efficiently transferred to other devices. 173 174 Args: 175 tensor: The tensor to send; must be assigned to a GPU device. 176 177 Returns: 178 A tensor with the value of `src_tensor`, which can be used as input to 179 ops on other GPU devices. 180 """ 181 _check_device(tensor) 182 183 with ops.device(tensor.device): 184 return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) 185 186 187@ops.RegisterGradient('NcclBroadcast') 188def _broadcast_grad(op, accumulated_grad): 189 """The gradients for input `Operation` of `broadcast`. 190 191 Args: 192 op: The `broadcast send` `Operation` that we are differentiating. 193 accumulated_grad: Accumulated gradients with respect to the output of the 194 `broadcast` op. 195 196 Returns: 197 Gradients with respect to the input of `broadcast`. 198 """ 199 # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. 200 grads = [t for t in accumulated_grad.op.inputs] 201 for t in grads: 202 _check_device(t) 203 204 with ops.device(op.device): 205 return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') 206 207 208def _apply_all_reduce(reduction, tensors): 209 """Helper function for all_* functions.""" 210 if not tensors: 211 raise ValueError('Must pass >0 tensors to all reduce operations') 212 213 shared_name = _get_shared_name() 214 215 def _all_reduce(): 216 """Call nccl allreduce.""" 217 res = [] 218 for t in tensors: 219 _check_device(t) 220 with ops.device(t.device): 221 res.append( 222 gen_nccl_ops.nccl_all_reduce( 223 input=t, 224 reduction=reduction, 225 num_devices=len(tensors), 226 shared_name=shared_name)) 227 return res 228 229 if context.executing_eagerly(): 230 # Nccl ops will block unless they are executed concurrently such as in a 231 # graph or a defun. 232 return def_function.function(_all_reduce)() 233 else: 234 return _all_reduce() 235 236 237def _apply_reduce(reduction, tensors): 238 """Helper function for reduce_* functions.""" 239 if not tensors: 240 raise ValueError('Must pass >0 tensors to reduce operations') 241 242 for t in tensors: 243 _check_device(t) 244 result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) 245 try: 246 next(t for t in tensors if t.device == result.device) 247 except StopIteration: 248 raise ValueError('One input tensor must be assigned to current device') 249 return result 250 251 252def _get_shared_name(): 253 global _shared_name_counter 254 255 with _module_lock: 256 val = _shared_name_counter 257 _shared_name_counter += 1 258 return 'c%s' % val 259 260 261def _check_device(tensor, expected=None): 262 if not device.canonical_name(tensor.device): 263 raise ValueError(f'Device assignment for tensor={tensor} required for nccl ' 264 'collective ops') 265 if expected and expected != tensor.device: 266 raise ValueError(f'Expected device {expected}, got {tensor.device} for ' 267 f'tensor={tensor}.') 268