xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/nccl_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for GPU collective operations implemented using NVIDIA nccl."""
16import threading
17
18from tensorflow.python.eager import context
19from tensorflow.python.eager import def_function
20from tensorflow.python.framework import device
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import gen_nccl_ops
23
24
25_module_lock = threading.Lock()
26_shared_name_counter = 0
27
28
29def all_sum(tensors):
30  """Returns a list of tensors with the all-reduce sum across `tensors`.
31
32  The computation is done with an all-reduce operation, so if only some of the
33  returned tensors are evaluated then the computation will hang.
34
35  Args:
36    tensors: The input tensors across which to sum; must be assigned
37      to GPU devices.
38
39  Returns:
40    List of tensors, each with the sum of the input tensors, where tensor i has
41    the same device as `tensors[i]`.
42  """
43  return _apply_all_reduce('sum', tensors)
44
45
46@ops.RegisterGradient('NcclAllReduce')
47def _all_sum_grad(op, grad):
48  """The gradients for `all_sum`.
49
50  Args:
51    op: The `all_sum` `Operation` that we are differentiating.
52    grad: Gradient with respect to the output of the `all_sum` op.
53
54  Returns:
55    The gradient with respect to the output of `all_sum`.
56
57  Raises:
58    LookupError: If `reduction` is not `sum`.
59  """
60  if op.get_attr('reduction') != b'sum':
61    raise LookupError('No gradient defined for NcclAllReduce except for '
62                      'reduction="sum".')
63
64  _check_device(grad, expected=op.device)
65  num_devices = op.get_attr('num_devices')
66  shared_name = op.get_attr('shared_name') + b'_grad'
67
68  with ops.device(op.device):
69    return gen_nccl_ops.nccl_all_reduce(
70        input=grad,
71        reduction='sum',
72        num_devices=num_devices,
73        shared_name=shared_name)
74
75
76def all_prod(tensors):
77  """Returns a list of tensors with the all-reduce product across `tensors`.
78
79  The computation is done with an all-reduce operation, so if only some of the
80  returned tensors are evaluated then the computation will hang.
81
82  Args:
83    tensors: The input tensors across which to multiply; must be assigned
84      to GPU devices.
85
86  Returns:
87    List of tensors, each with the product of the input tensors, where tensor i
88    has the same device as `tensors[i]`.
89  """
90  return _apply_all_reduce('prod', tensors)
91
92
93def all_min(tensors):
94  """Returns a list of tensors with the all-reduce min across `tensors`.
95
96  The computation is done with an all-reduce operation, so if only some of the
97  returned tensors are evaluated then the computation will hang.
98
99  Args:
100    tensors: The input tensors across which to reduce; must be assigned
101      to GPU devices.
102
103  Returns:
104    List of tensors, each with the minimum of the input tensors, where tensor i
105    has the same device as `tensors[i]`.
106  """
107  return _apply_all_reduce('min', tensors)
108
109
110def all_max(tensors):
111  """Returns a list of tensors with the all-reduce max across `tensors`.
112
113  The computation is done with an all-reduce operation, so if only some of the
114  returned tensors are evaluated then the computation will hang.
115
116  Args:
117    tensors: The input tensors across which to reduce; must be assigned
118      to GPU devices.
119
120  Returns:
121    List of tensors, each with the maximum of the input tensors, where tensor i
122    has the same device as `tensors[i]`.
123  """
124  return _apply_all_reduce('max', tensors)
125
126
127def reduce_sum(tensors):
128  """Returns a tensor with the reduce sum across `tensors`.
129
130  The computation is done with a reduce operation, so only one tensor is
131  returned.
132
133  Args:
134    tensors: The input tensors across which to sum; must be assigned
135      to GPU devices.
136
137  Returns:
138    A tensor containing the sum of the input tensors.
139
140  Raises:
141    LookupError: If context is not currently using a GPU device.
142  """
143  return _apply_reduce('sum', tensors)
144
145
146@ops.RegisterGradient('NcclReduce')
147def _reduce_sum_grad(op, grad):
148  """The gradients for input `Operation` of `reduce_sum`.
149
150  Args:
151    op: The `sum send` `Operation` that we are differentiating.
152    grad: Gradient with respect to the output of the `reduce_sum` op.
153
154  Returns:
155    The gradient with respect to the input of `reduce_sum` op.
156
157  Raises:
158    LookupError: If the reduction attribute of op is not `sum`.
159  """
160  if op.get_attr('reduction') != b'sum':
161    raise LookupError('No gradient defined for NcclAllReduce except for '
162                      'reduction="sum".')
163  _check_device(grad, expected=op.device)
164
165  with ops.device(op.device):
166    result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape)
167
168  return [result] * len(op.inputs)
169
170
171def broadcast(tensor):
172  """Returns a tensor that can be efficiently transferred to other devices.
173
174  Args:
175    tensor: The tensor to send; must be assigned to a GPU device.
176
177  Returns:
178    A tensor with the value of `src_tensor`, which can be used as input to
179    ops on other GPU devices.
180  """
181  _check_device(tensor)
182
183  with ops.device(tensor.device):
184    return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape)
185
186
187@ops.RegisterGradient('NcclBroadcast')
188def _broadcast_grad(op, accumulated_grad):
189  """The gradients for input `Operation` of `broadcast`.
190
191  Args:
192    op: The `broadcast send` `Operation` that we are differentiating.
193    accumulated_grad: Accumulated gradients with respect to the output of the
194      `broadcast` op.
195
196  Returns:
197    Gradients with respect to the input of `broadcast`.
198  """
199  # Grab inputs of accumulated_grad and replace accumulation with reduce_sum.
200  grads = [t for t in accumulated_grad.op.inputs]
201  for t in grads:
202    _check_device(t)
203
204  with ops.device(op.device):
205    return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum')
206
207
208def _apply_all_reduce(reduction, tensors):
209  """Helper function for all_* functions."""
210  if not tensors:
211    raise ValueError('Must pass >0 tensors to all reduce operations')
212
213  shared_name = _get_shared_name()
214
215  def _all_reduce():
216    """Call nccl allreduce."""
217    res = []
218    for t in tensors:
219      _check_device(t)
220      with ops.device(t.device):
221        res.append(
222            gen_nccl_ops.nccl_all_reduce(
223                input=t,
224                reduction=reduction,
225                num_devices=len(tensors),
226                shared_name=shared_name))
227    return res
228
229  if context.executing_eagerly():
230    # Nccl ops will block unless they are executed concurrently such as in a
231    # graph or a defun.
232    return def_function.function(_all_reduce)()
233  else:
234    return _all_reduce()
235
236
237def _apply_reduce(reduction, tensors):
238  """Helper function for reduce_* functions."""
239  if not tensors:
240    raise ValueError('Must pass >0 tensors to reduce operations')
241
242  for t in tensors:
243    _check_device(t)
244  result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction)
245  try:
246    next(t for t in tensors if t.device == result.device)
247  except StopIteration:
248    raise ValueError('One input tensor must be assigned to current device')
249  return result
250
251
252def _get_shared_name():
253  global _shared_name_counter
254
255  with _module_lock:
256    val = _shared_name_counter
257    _shared_name_counter += 1
258  return 'c%s' % val
259
260
261def _check_device(tensor, expected=None):
262  if not device.canonical_name(tensor.device):
263    raise ValueError(f'Device assignment for tensor={tensor} required for nccl '
264                     'collective ops')
265  if expected and expected != tensor.device:
266    raise ValueError(f'Expected device {expected}, got {tensor.device} for '
267                     f'tensor={tensor}.')
268