xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/test_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities."""
16
17import collections
18import dataclasses
19import functools
20import io
21import itertools
22import threading
23
24from absl import app
25
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.distribute import collective_all_reduce_strategy
28from tensorflow.python.distribute import multi_process_runner
29from tensorflow.python.distribute import multi_worker_test_base
30from tensorflow.python.distribute import tpu_strategy
31from tensorflow.python.distribute import values
32from tensorflow.python.eager import context
33from tensorflow.python.framework import config
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.util import nest
37
38try:
39  import objgraph  # pylint:disable=g-import-not-at-top
40except ImportError:
41  objgraph = None
42
43
44@dataclasses.dataclass
45class TestClusterParams:
46  cluster: dict
47  max_num_worker: int
48  max_num_ps: int
49
50
51def get_cluster_def(cluster_params, num_workers, num_ps):
52  if (num_workers > cluster_params.max_num_worker or
53      num_ps > cluster_params.max_num_ps):
54    raise ValueError("Requesting more servers than the maximum, adjust"
55                     "cluster params' max_num_ps and max_num_worker")
56  if cluster_params.cluster is None:
57    cluster_params.cluster = multi_worker_test_base.create_in_process_cluster(
58        num_workers=cluster_params.max_num_worker,
59        num_ps=cluster_params.max_num_ps)
60  return {
61      "worker": cluster_params.cluster["worker"][:num_workers],
62      "ps": cluster_params.cluster["ps"][:num_ps],
63  }
64
65
66def gather(strategy, value):
67  """Gathers value from all workers.
68
69  This is intended for tests before we implement an official all-gather API.
70
71  Args:
72    strategy: a `tf.distribute.Strategy`.
73    value: a nested structure of n-dim `tf.distribute.DistributedValue` of
74      `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
75      Cannot contain tf.sparse.SparseTensor.
76
77  Returns:
78    a (n+1)-dim `tf.Tensor`.
79  """
80  return nest.map_structure(functools.partial(_gather, strategy), value)
81
82
83def _gather(strategy, value):
84  """Gathers a single value."""
85  # pylint: disable=protected-access
86  if not isinstance(value, values.DistributedValues):
87    value = values.PerReplica([ops.convert_to_tensor(value)])
88  if not isinstance(strategy.extended,
89                    collective_all_reduce_strategy.CollectiveAllReduceExtended):
90    return array_ops.stack(value._values)
91  assert len(strategy.extended.worker_devices) == len(value._values)
92  inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
93  return strategy.gather(values.PerReplica(inputs), axis=0)
94  # pylint: enable=protected-access
95
96
97def set_logical_devices_to_at_least(device, num):
98  """Create logical devices of at least a given number."""
99  if num < 1:
100    raise ValueError("`num` must be at least 1 not %r" % (num,))
101  physical_devices = config.list_physical_devices(device)
102  if not physical_devices:
103    raise RuntimeError("No {} found".format(device))
104  if len(physical_devices) >= num:
105    return
106  # By default each physical device corresponds to one logical device. We create
107  # multiple logical devices for the last physical device so that we have `num`
108  # logical devices.
109  num = num - len(physical_devices) + 1
110  logical_devices = []
111  for _ in range(num):
112    if device.upper() == "GPU":
113      logical_devices.append(
114          context.LogicalDeviceConfiguration(memory_limit=2048))
115    else:
116      logical_devices.append(context.LogicalDeviceConfiguration())
117  # Create logical devices from the last device since sometimes the first GPU
118  # is the primary graphic card and may have less memory available.
119  config.set_logical_device_configuration(physical_devices[-1], logical_devices)
120
121
122def _set_logical_devices():
123  if config.list_physical_devices("GPU"):
124    set_logical_devices_to_at_least("GPU", 2)
125  if config.list_physical_devices("CPU"):
126    set_logical_devices_to_at_least("CPU", 2)
127
128
129def main(enable_v2_behavior=True, config_logical_devices=True):
130  """All-in-one main function for tf.distribute tests."""
131  if config_logical_devices:
132    app.call_after_init(_set_logical_devices)
133  if enable_v2_behavior:
134    v2_compat.enable_v2_behavior()
135  else:
136    v2_compat.disable_v2_behavior()
137  multi_process_runner.test_main()
138
139
140def _op_dependencies(op):
141  """Returns the data and control dependencies of a tf.Operation combined."""
142  deps = []
143  for node in itertools.chain(op.inputs, op.control_inputs):
144    if isinstance(node, ops.Tensor):
145      node = node.op
146    assert isinstance(node, ops.Operation)
147    deps.append(node)
148  return deps
149
150
151def topological_sort_operations(operations):
152  """Topological sorts a list of operations.
153
154  This does a topological sort of the operations in a graph. The edges include
155  both data dependencies and control dependencies. Note that the edge goes from
156  an operation to its dependencies.
157
158  The sort is intentionally unstable, reversing orders of operations and
159  dependencies on ties.
160
161  Args:
162    operations: a list of tf.Operation in the same graph.
163
164  Returns:
165    A map from a tf.Operation to its topological order.
166  """
167  in_degrees = collections.OrderedDict()
168  for op in reversed(operations):
169    if op not in in_degrees:
170      in_degrees[op] = 0
171    for next_op in reversed(_op_dependencies(op)):
172      in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
173  nexts = []
174  for op, in_degree in in_degrees.items():
175    if in_degree == 0:
176      nexts.append(op)
177  order = {}
178  next_order = 0
179  while nexts:
180    op, nexts = nexts[0], nexts[1:]
181    order[op] = next_order
182    next_order += 1
183    for next_op in reversed(_op_dependencies(op)):
184      in_degrees[next_op] -= 1
185      if in_degrees[next_op] == 0:
186        nexts.append(next_op)
187  assert len(order) == len(operations)
188  return order
189
190
191def _exists_dependency(start, end):
192  """Returns whether there exists a dependency chain from start to end."""
193  nexts = [start]
194  while nexts:
195    op, nexts = nexts[0], nexts[1:]
196    for next_op in _op_dependencies(op):
197      if next_op == end:
198        return True
199      nexts.append(next_op)
200  return False
201
202
203def assert_sequential_execution(order, operations):
204  """Asserts there's a deterministic execution order between the operations.
205
206  Args:
207    order: a map from a tf.Operation to its topological order.
208    operations: a list of operations that should be executed sequentially. It
209      can be given in any order.
210  """
211  # Topological ordering guarantees that, if there's a dependency from N_a to
212  # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
213  # among the operations, it always goes from a operation with a smaller
214  # topological order to one with a larger topological order. Therefore, we only
215  # need to sort the operations by their topological orders, and verify that
216  # there's a path of dependency between adjacent pairs.
217  operations = sorted(operations, key=lambda op: order[op])
218  for i in range(len(operations) - 1):
219    if not _exists_dependency(operations[i], operations[i + 1]):
220      print(operations[i].graph.as_graph_def())
221      raise AssertionError(
222          "No dependency between {} and {}. Graph is dumped to stdout.".format(
223              operations[i].name, operations[i + 1].name))
224
225
226def get_running_threads():
227  """Returns a set of all running thread names."""
228  running_threads = set()
229  for thread in threading.enumerate():
230    if thread.name is not None:
231      running_threads.add(thread.name)
232  return running_threads
233
234
235def has_thread(prefix, running_threads):
236  """Returns whether any 'running_threads' is prefixed with 'prefix'.
237
238  Args:
239    prefix: The prefix of the expected thread name.
240    running_threads: A collection of the running thread names.
241  """
242  for thread in running_threads:
243    if thread.startswith(prefix):
244      return True
245  return False
246
247
248def show_backref(target, max_depth=3):
249  """Returns a dot graph of all the objects that are referencing the target.
250
251  A object referencing graph is useful to debug memory leak like circular
252  reference. objgraph provides a good visualization of the memory graph than
253  most python built-in utilities like gc.get_referrers(), which are not
254  human-readable sometimes.
255
256  The dot graph will be written to a string IO object, and can be rendered with
257  graphviz in operating system.
258  E.g. dot -Tpng {$dot_graph} -o output.png
259  Args:
260    target: The target object for the memory graph.
261    max_depth: The maximum depth of the graph. By default 3 layers of references
262      are used. Increases this a lot may result in the graph growing too big.
263
264  Returns:
265    A string that contains the object reference graph.
266  Raises:
267    NotImplementedError: if objgraph is not installed.
268  """
269  if objgraph is None:
270    raise NotImplementedError("objgraph is not installed.")
271  string_io = io.StringIO()
272  objgraph.show_backrefs(target, max_depth=max_depth, output=string_io)
273  graph = string_io.getvalue()
274  string_io.close()
275  return graph
276
277
278def create_per_replica(strategy, value_list):
279  """Creates a PerReplica of Tensors from the value_list."""
280  if len(strategy.extended.worker_devices) != len(value_list):
281    raise ValueError(
282        "the length of values must be the same as the number of worker devices")
283  tensors = []
284  for device, value in zip(strategy.extended.worker_devices, value_list):
285    with ops.device(device):
286      tensors.append(ops.convert_to_tensor(value))
287  return values.PerReplica(tensors)
288
289
290def is_tpu_strategy(strategy):
291  """Returns whether the strategy is a TPU strategy."""
292  return isinstance(strategy,
293                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
294                     tpu_strategy.TPUStrategyV2))
295