xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/collective_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow collective Ops."""
16from tensorflow.python.ops import gen_collective_ops
17
18
19def all_reduce(t,
20               group_size,
21               group_key,
22               instance_key,
23               merge_op='Add',
24               final_op='Id',
25               subdiv_offsets=(0,),
26               communication_hint='auto',
27               timeout=0):
28  """Reduces tensors collectively, across devices.
29
30  Args:
31    t: the tensor to be reduced.
32    group_size: the total number of tensors to be collectively reduced.
33      Each must reside on a different device.  Should be a positive integer.
34    group_key: an integer identifying the group of devices.
35    instance_key: an integer identifying the participating group of Ops.
36    merge_op: string naming the binary Op to be applied to compute each
37      partial reduction.
38    final_op: string naming the unary Op to be applied to each fully
39      reduced value.  Can be 'Id' for no operation.
40    subdiv_offsets: a list of integer offsets into the tensor at which each
41      independent subdivision should begin.  Use [0] if no subdivision should
42      be done.
43    communication_hint: preferred collective communication.  The implementation
44      may fall back to another mechanism.  Options include `auto`, `ring`, and
45      `nccl`.
46    timeout: a float. If set to a non zero, set a completion timeout to detect
47      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
48      timeout value in seconds. This feature is experimental.
49
50  Returns:
51    An Op implementing the distributed reduction.
52
53  Raises:
54    ValueError: if any of the input parameter constraints are not met.
55  """
56  if group_size < 1:
57    raise ValueError('Parameter `group_size` to all_reduce must be at least 1. '
58                     f'Received: {group_size}.')
59  return gen_collective_ops.collective_reduce(
60      t,
61      group_size=group_size,
62      group_key=group_key,
63      instance_key=instance_key,
64      merge_op=merge_op,
65      final_op=final_op,
66      subdiv_offsets=subdiv_offsets,
67      communication_hint=communication_hint.lower(),
68      timeout_seconds=timeout)
69
70
71def assign_group_v2(group_assignment, device_index, base_key):
72  """Assign group key based on group_assignment.
73
74  Args:
75    group_assignment: a 2 dimensional integer Tensor that encodes which devices
76      belong to the same group. The values are indices of the devices within 0
77      to number of devices.
78    device_index: integer for the index of the current device
79    base_key: integer to offset the resulted group_key. The base key shall be
80      unique for different values of group_assignment in the same tf.function.
81  Notes: The device_index argument must be consistent with the index of the
82    device of this Op in the device assignment list. The behavior of this Op is
83    undefined if they are inconsistent.
84
85  Returns:
86    group_size, group_key: The group size and group key for the current device.
87  """
88  group_size, group_key = gen_collective_ops.collective_assign_group_v2(
89      group_assignment=group_assignment,
90      device_index=device_index,
91      base_key=base_key)
92  return group_size, group_key
93
94
95def all_reduce_v2(t,
96                  group_size,
97                  group_key,
98                  instance_key,
99                  merge_op='Add',
100                  final_op='Id',
101                  communication_hint='auto',
102                  timeout=0,
103                  ordering_token=None,
104                  max_subdivs_per_device=-1,
105                  name=None):
106  """Reduces tensors collectively, across devices.
107
108  Args:
109    t: the tensor to be reduced.
110    group_size: an int32 tensor. The total number of tensors to be collectively
111      reduced.  Each must reside on a different device.  Should be a positive
112      integer.
113    group_key: an int32 tensor identifying the group of devices.
114    instance_key: an int32 tensor identifying the participating group of Ops.
115    merge_op: string naming the binary Op to be applied to compute each partial
116      reduction.
117    final_op: string naming the unary Op to be applied to each fully reduced
118      value.  Can be 'Id' for no operation.
119    communication_hint: preferred collective communication.  The implementation
120      may fall back to another mechanism.  Options include `auto`, `ring`, and
121      `nccl`.
122    timeout: a float. If set to a non zero, set a completion timeout to detect
123      staleness.  If the timer goes off, a DeadlineExceededError is raised.  The
124      timeout value in seconds. This feature is experimental.
125    ordering_token: a resource tensor on the same device as the op to order
126      the collectives in a per-device manner by auto control dependency.
127      This argument can be omited when there is one collective Op per
128      `tf.function`, or when explicit control dependency is used instead of
129      auto control dependency.
130    max_subdivs_per_device: int specifying the maximum number of subdivisions a
131      tensor on a device can be divided into. The runtime uses this contraint to
132      parallelize processing of each per-device tensor. Setting to -1 disables
133      subdivision and reverts to previous behavior of not sub-dividing tensor.
134      Setting to 0 uses sytem defaults.
135    name: name of the Op.
136
137  Returns:
138    An Op implementing the distributed reduction.
139  """
140  if ordering_token is not None:
141    ordering_token = [ordering_token]
142  else:
143    ordering_token = []
144
145  return gen_collective_ops.collective_reduce_v2(
146      t,
147      group_size=group_size,
148      group_key=group_key,
149      instance_key=instance_key,
150      merge_op=merge_op,
151      final_op=final_op,
152      communication_hint=communication_hint.lower(),
153      timeout_seconds=timeout,
154      ordering_token=ordering_token,
155      max_subdivs_per_device=max_subdivs_per_device,
156      name=name)
157
158
159def all_gather(t,
160               group_size,
161               group_key,
162               instance_key,
163               communication_hint='auto',
164               timeout=0):
165  """Accumulates tensors collectively, across devices, along first dimension.
166
167  Args:
168    t: the tensor to participate in the accumulation.
169    group_size: the total number of tensors to be collectively accumulated.
170      Each must reside on a different device. Should be a positive integer.
171    group_key: an integer identifying the group of devices.
172    instance_key: an integer identifying the participating group of Ops.
173    communication_hint: preferred collective communication. The implementation
174      may fall back to another mechanism. Options include `auto`, `ring`, and
175      `nccl`.
176    timeout: a float. If set to a non zero, set a completion timeout to detect
177      staleness. If the timer goes off, a DeadlineExceededError is raised. The
178      timeout value in seconds. This feature is experimental.
179
180  Returns:
181    An Op implementing the distributed operation.
182
183  Raises:
184    ValueError: if any of the input parameter constraints are not met.
185  """
186  if group_size < 1:
187    raise ValueError('Parameter `group_size` to all_gather must be at least 1.'
188                     f' Received: {group_size}.')
189  return gen_collective_ops.collective_gather(
190      t,
191      shape=[0],
192      group_size=group_size,
193      group_key=group_key,
194      instance_key=instance_key,
195      communication_hint=communication_hint.lower(),
196      timeout_seconds=timeout)
197
198
199def all_gather_v2(t,
200                  group_size,
201                  group_key,
202                  instance_key,
203                  communication_hint='auto',
204                  timeout=0,
205                  ordering_token=None,
206                  name=None):
207  """Accumulates tensors collectively, across devices, along first dimension.
208
209  Args:
210    t: the tensor to participate in the accumulation.
211    group_size: an int32 tensor, the total number of tensors to be collectively
212      accumulated. Each must reside on a different device. Should be a positive
213      integer.
214    group_key: an int32 tensor identifying the group of devices.
215    instance_key: an int32 tensor identifying the participating group of Ops.
216    communication_hint: preferred collective communication. The implementation
217      may fall back to another mechanism. Options include `auto`, `ring`, and
218      `nccl`.
219    timeout: a float. If set to a non zero, set a completion timeout to detect
220      staleness. If the timer goes off, a DeadlineExceededError is raised. The
221      timeout value in seconds. This feature is experimental.
222    ordering_token: a resource tensor on the same device as the op to order
223      the collectives in a per-device manner by auto control dependency.
224      This argument can be omited when there is one collective Op per
225      `tf.function`, or when explicit control dependency is used instead of
226      auto control dependency.
227    name: name of the Op.
228
229  Returns:
230    An Op implementing the distributed operation.
231  """
232  if ordering_token is not None:
233    ordering_token = [ordering_token]
234  else:
235    ordering_token = []
236
237  return gen_collective_ops.collective_gather_v2(
238      t,
239      group_size=group_size,
240      group_key=group_key,
241      instance_key=instance_key,
242      communication_hint=communication_hint.lower(),
243      timeout_seconds=timeout,
244      ordering_token=ordering_token,
245      name=name)
246
247
248def broadcast_send(t,
249                   shape,
250                   dtype,
251                   group_size,
252                   group_key,
253                   instance_key,
254                   communication_hint='auto',
255                   timeout=0):
256  """Broadcasts one tensor to a group of others, across devices.
257
258  Args:
259    t: the tensor to be sent.
260    shape: the shape of the tensor being sent, which must agree with t.
261    dtype: the type of the tensor being sent, which must agree with t.
262    group_size: one plus the number of receiving tensors, i.e. the total
263      number of devices participating.  Each tensor must reside on a
264      different device.
265    group_key: an integer identifying the group of devices.
266    instance_key: an integer identifying the participating group of Ops.
267    communication_hint: preferred collective communication.  The implementation
268      may fall back to another mechanism.  Options include `auto`, `ring`, and
269      `nccl`.
270    timeout: If set to a non zero, set a completion timeout to detect staleness.
271      If the timer goes off, a DeadlineExceededError is raised.
272      The timeout value in seconds. This feature is experimental.
273
274  Returns:
275    An Op implementing the distributed broadcast send.
276
277  Raises:
278    ValueError: if any of the input parameter constraints are not met.
279
280  Note that the shape and dtype arguments appear redundant since they
281  should be obtainable from t.  The are two reasons for including
282  them.  First, the shape and type of tensors passed via broadcast must
283  be known ahead of time in their most specific form so that the receive
284  side can allocate memory for the operation and shape/type inference can
285  carry forward from there.  Including the same declarations on the
286  send side clarifies a commitment already made.  Secondly, having nearly
287  identical use syntax for send and receive sides may simplify tool-driven
288  generation of broadcast.
289  """
290  if group_size <= 1:
291    raise ValueError(
292        'Parameter `group_size` to broadcast_send must be at least 2. '
293        f'Received: {group_size}.')
294  if t.shape != shape:
295    raise ValueError(
296        'Shape of broadcast_send tensor `t` not equal to declared shape. '
297        f'Received {t.shape}, expected {shape}.')
298  if t.dtype != dtype:
299    raise ValueError(
300        'Type of broadcast_send tensor `t` not equal to declared type. '
301        f'Received {t.dtype}, expected {dtype}.')
302  return gen_collective_ops.collective_bcast_send(
303      t,
304      shape=shape,
305      group_size=group_size,
306      group_key=group_key,
307      instance_key=instance_key,
308      communication_hint=communication_hint.lower(),
309      timeout_seconds=timeout)
310
311
312def broadcast_send_v2(t,
313                      group_size,
314                      group_key,
315                      instance_key,
316                      communication_hint='auto',
317                      timeout=0):
318  """Broadcasts one tensor to a group of others, across devices.
319
320  Args:
321    t: the tensor to be sent.
322    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
323        the total number of devices participating.  Each tensor must reside on a
324        different device.
325    group_key: an int32 tensor identifying the group of devices.
326    instance_key: an int32 tensor identifying the participating group of Ops.
327    communication_hint: preferred collective communication.  The implementation
328      may fall back to another mechanism.  Options include `auto`, `ring`, and
329      `nccl`.
330    timeout: If set to a non zero, set a completion timeout to detect staleness.
331      If the timer goes off, a DeadlineExceededError is raised.
332      The timeout value in seconds. This feature is experimental.
333
334  Returns:
335    An Op implementing the distributed broadcast send.
336  """
337  return gen_collective_ops.collective_bcast_send_v2(
338      t,
339      group_size=group_size,
340      group_key=group_key,
341      instance_key=instance_key,
342      communication_hint=communication_hint.lower(),
343      timeout_seconds=timeout)
344
345
346def broadcast_recv(shape,
347                   dtype,
348                   group_size,
349                   group_key,
350                   instance_key,
351                   communication_hint='auto',
352                   timeout=0):
353  """Receives a broadcasts tensor, across devices.
354
355  Args:
356    shape: Shape of the tensor to be received.
357    dtype: Type of the tensor to be received.
358    group_size: one plus the number of receiving tensors, i.e. the total
359      number of devices participating.  Each tensor must reside on a
360      different device.
361    group_key: an integer identifying the group of devices.
362    instance_key: an integer identifying the participating group of Ops.
363    communication_hint: preferred collective communication.  The implementation
364      may fall back to another mechanism.  Options include `auto`, `ring`, and
365      `nccl`.
366    timeout: If set to a non zero, set a completion timeout to detect staleness.
367      If the timer goes off, a DeadlineExceededError is raised.
368      The timeout value in seconds. This feature is experimental.
369
370  Returns:
371    An Op implementing the broadcast receive.
372
373  Raises:
374    ValueError: if any of the input parameter constraints are not met.
375  """
376  if group_size <= 1:
377    raise ValueError(
378        'Parameter `group_size` to broadcast_send must be at least 2. '
379        f'Received: {group_size}.')
380  return gen_collective_ops.collective_bcast_recv(
381      shape=shape,
382      T=dtype,
383      group_size=group_size,
384      group_key=group_key,
385      instance_key=instance_key,
386      communication_hint=communication_hint.lower(),
387      timeout_seconds=timeout)
388
389
390def broadcast_recv_v2(shape,
391                      dtype,
392                      group_size,
393                      group_key,
394                      instance_key,
395                      communication_hint='auto',
396                      timeout=0):
397  """Receives a broadcasts tensor, across devices.
398
399  Args:
400    shape: an int tensor.  Shape of the tensor to be received.
401    dtype: Type of the tensor to be received.
402    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
403        the total number of devices participating.  Each tensor must reside on a
404        different device.
405    group_key: an int32 tensor identifying the group of devices.
406    instance_key: an int32 tensor identifying the participating group of Ops.
407    communication_hint: preferred collective communication.  The implementation
408      may fall back to another mechanism.  Options include `auto`, `ring`, and
409      `nccl`.
410    timeout: If set to a non zero, set a completion timeout to detect staleness.
411      If the timer goes off, a DeadlineExceededError is raised.
412      The timeout value in seconds. This feature is experimental.
413
414  Returns:
415    An Op implementing the broadcast receive.
416  """
417  return gen_collective_ops.collective_bcast_recv_v2(
418      T=dtype,
419      group_size=group_size,
420      group_key=group_key,
421      instance_key=instance_key,
422      shape=shape,
423      communication_hint=communication_hint.lower(),
424      timeout_seconds=timeout)
425
426
427def initialize_communicator(group_key,
428                            rank,
429                            group_size,
430                            communication_hint='auto',
431                            timeout_seconds=0):
432  """Initializes a collective communicator.
433
434  This creates a collective communicator, which represents membership to a
435  collective group identified by the group_key. It should be called once per
436  member of the group, and each member needs to be on a different device.
437  It blocks until all members of the group run this op.
438
439  Communicators of a group can only be initialized once. Trying to initialize
440  communicators for an existing group key will result in an error.
441
442  Args:
443    group_key: an int32 `tf.Tensor` identifying the group.
444    rank: an `tf.Tensor` specifying the rank of this device in the group. If
445      specified, the rank is required to be unique in the group.
446    group_size: an int32 `tf.Tensor`. The size of the group.
447    communication_hint: preferred collective communication.  The implementation
448      may fall back to another mechanism.  Options include `auto`, `ring`, and
449      `nccl`.
450    timeout_seconds: If set to a non zero, set a completion timeout to detect
451      staleness. If the timer goes off, a DeadlineExceededError is raised. The
452      timeout value in seconds. This feature is experimental.
453
454
455  Returns:
456    A resource `tf.Tensor`.
457  """
458  return gen_collective_ops.collective_initialize_communicator(
459      group_key=group_key,
460      rank=rank,
461      group_size=group_size,
462      communication_hint=communication_hint,
463      timeout_seconds=timeout_seconds)
464
465
466def all_reduce_v3(communicator,
467                  t,
468                  reduction='Add',
469                  group_assignment=None,
470                  timeout_seconds=None):
471  """Reduces tensors mutually.
472
473  Args:
474    communicator: the resource `tf.Tensor` returned from
475      `initialize_communicator`.
476    t: the `tf.Tensor` to be reduced.
477    reduction: a string. The name of the operation to reduce the values.
478      Accpeted values are `"min"`, `"max"`, `"mul"`, `"add"`.
479    group_assignment: Optional int32 `tf.Tensor` with shape [num_groups,
480      num_ranks_per_group]. `group_assignment[i]` represents the ranks in the
481      `ith` subgroup.
482    timeout_seconds: If set to a non zero, set a completion timeout to detect
483      staleness. If the timer goes off, a DeadlineExceededError is raised. The
484      timeout value in seconds. This feature is experimental.
485
486  Returns:
487    The reduced `tf.Tensor`.
488  """
489  if group_assignment is None:
490    group_assignment = []
491  return gen_collective_ops.collective_reduce_v3(
492      communicator=communicator,
493      input=t,
494      group_assignment=group_assignment,
495      reduction=reduction,
496      timeout_seconds=timeout_seconds)
497
498
499def all_to_all_v3(communicator, t, group_assignment=None, timeout_seconds=None):
500  """Exchanges tensors mutually.
501
502  Args:
503    communicator: the resource `tf.Tensor` returned from
504      `initialize_communicator`.
505    t: a `tf.Tensor`. The first dimension should have the length as the size of
506      the group. `t[i]` is sent to `rank i` within the group.
507    group_assignment: Optional int32 `tf.Tensor` with shape [num_groups,
508      num_ranks_per_group]. `group_assignment[i]` represents the ranks in the
509      `ith` subgroup.
510    timeout_seconds: If set to a non zero, set a completion timeout to detect
511      staleness. If the timer goes off, a DeadlineExceededError is raised. The
512      timeout value in seconds. This feature is experimental.
513
514  Returns:
515    a `tf.Tensor`. `t[i]` is sent from `rank i` within the group.
516  """
517  if group_assignment is None:
518    group_assignment = []
519  return gen_collective_ops.collective_all_to_all_v3(
520      communicator=communicator,
521      input=t,
522      group_assignment=group_assignment,
523      timeout_seconds=timeout_seconds)
524