1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities.""" 16 17import collections 18import dataclasses 19import functools 20import io 21import itertools 22import threading 23 24from absl import app 25 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.distribute import collective_all_reduce_strategy 28from tensorflow.python.distribute import multi_process_runner 29from tensorflow.python.distribute import multi_worker_test_base 30from tensorflow.python.distribute import tpu_strategy 31from tensorflow.python.distribute import values 32from tensorflow.python.eager import context 33from tensorflow.python.framework import config 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.util import nest 37 38try: 39 import objgraph # pylint:disable=g-import-not-at-top 40except ImportError: 41 objgraph = None 42 43 44@dataclasses.dataclass 45class TestClusterParams: 46 cluster: dict 47 max_num_worker: int 48 max_num_ps: int 49 50 51def get_cluster_def(cluster_params, num_workers, num_ps): 52 if (num_workers > cluster_params.max_num_worker or 53 num_ps > cluster_params.max_num_ps): 54 raise ValueError("Requesting more servers than the maximum, adjust" 55 "cluster params' max_num_ps and max_num_worker") 56 if cluster_params.cluster is None: 57 cluster_params.cluster = multi_worker_test_base.create_in_process_cluster( 58 num_workers=cluster_params.max_num_worker, 59 num_ps=cluster_params.max_num_ps) 60 return { 61 "worker": cluster_params.cluster["worker"][:num_workers], 62 "ps": cluster_params.cluster["ps"][:num_ps], 63 } 64 65 66def gather(strategy, value): 67 """Gathers value from all workers. 68 69 This is intended for tests before we implement an official all-gather API. 70 71 Args: 72 strategy: a `tf.distribute.Strategy`. 73 value: a nested structure of n-dim `tf.distribute.DistributedValue` of 74 `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. 75 Cannot contain tf.sparse.SparseTensor. 76 77 Returns: 78 a (n+1)-dim `tf.Tensor`. 79 """ 80 return nest.map_structure(functools.partial(_gather, strategy), value) 81 82 83def _gather(strategy, value): 84 """Gathers a single value.""" 85 # pylint: disable=protected-access 86 if not isinstance(value, values.DistributedValues): 87 value = values.PerReplica([ops.convert_to_tensor(value)]) 88 if not isinstance(strategy.extended, 89 collective_all_reduce_strategy.CollectiveAllReduceExtended): 90 return array_ops.stack(value._values) 91 assert len(strategy.extended.worker_devices) == len(value._values) 92 inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] 93 return strategy.gather(values.PerReplica(inputs), axis=0) 94 # pylint: enable=protected-access 95 96 97def set_logical_devices_to_at_least(device, num): 98 """Create logical devices of at least a given number.""" 99 if num < 1: 100 raise ValueError("`num` must be at least 1 not %r" % (num,)) 101 physical_devices = config.list_physical_devices(device) 102 if not physical_devices: 103 raise RuntimeError("No {} found".format(device)) 104 if len(physical_devices) >= num: 105 return 106 # By default each physical device corresponds to one logical device. We create 107 # multiple logical devices for the last physical device so that we have `num` 108 # logical devices. 109 num = num - len(physical_devices) + 1 110 logical_devices = [] 111 for _ in range(num): 112 if device.upper() == "GPU": 113 logical_devices.append( 114 context.LogicalDeviceConfiguration(memory_limit=2048)) 115 else: 116 logical_devices.append(context.LogicalDeviceConfiguration()) 117 # Create logical devices from the last device since sometimes the first GPU 118 # is the primary graphic card and may have less memory available. 119 config.set_logical_device_configuration(physical_devices[-1], logical_devices) 120 121 122def _set_logical_devices(): 123 if config.list_physical_devices("GPU"): 124 set_logical_devices_to_at_least("GPU", 2) 125 if config.list_physical_devices("CPU"): 126 set_logical_devices_to_at_least("CPU", 2) 127 128 129def main(enable_v2_behavior=True, config_logical_devices=True): 130 """All-in-one main function for tf.distribute tests.""" 131 if config_logical_devices: 132 app.call_after_init(_set_logical_devices) 133 if enable_v2_behavior: 134 v2_compat.enable_v2_behavior() 135 else: 136 v2_compat.disable_v2_behavior() 137 multi_process_runner.test_main() 138 139 140def _op_dependencies(op): 141 """Returns the data and control dependencies of a tf.Operation combined.""" 142 deps = [] 143 for node in itertools.chain(op.inputs, op.control_inputs): 144 if isinstance(node, ops.Tensor): 145 node = node.op 146 assert isinstance(node, ops.Operation) 147 deps.append(node) 148 return deps 149 150 151def topological_sort_operations(operations): 152 """Topological sorts a list of operations. 153 154 This does a topological sort of the operations in a graph. The edges include 155 both data dependencies and control dependencies. Note that the edge goes from 156 an operation to its dependencies. 157 158 The sort is intentionally unstable, reversing orders of operations and 159 dependencies on ties. 160 161 Args: 162 operations: a list of tf.Operation in the same graph. 163 164 Returns: 165 A map from a tf.Operation to its topological order. 166 """ 167 in_degrees = collections.OrderedDict() 168 for op in reversed(operations): 169 if op not in in_degrees: 170 in_degrees[op] = 0 171 for next_op in reversed(_op_dependencies(op)): 172 in_degrees[next_op] = in_degrees.get(next_op, 0) + 1 173 nexts = [] 174 for op, in_degree in in_degrees.items(): 175 if in_degree == 0: 176 nexts.append(op) 177 order = {} 178 next_order = 0 179 while nexts: 180 op, nexts = nexts[0], nexts[1:] 181 order[op] = next_order 182 next_order += 1 183 for next_op in reversed(_op_dependencies(op)): 184 in_degrees[next_op] -= 1 185 if in_degrees[next_op] == 0: 186 nexts.append(next_op) 187 assert len(order) == len(operations) 188 return order 189 190 191def _exists_dependency(start, end): 192 """Returns whether there exists a dependency chain from start to end.""" 193 nexts = [start] 194 while nexts: 195 op, nexts = nexts[0], nexts[1:] 196 for next_op in _op_dependencies(op): 197 if next_op == end: 198 return True 199 nexts.append(next_op) 200 return False 201 202 203def assert_sequential_execution(order, operations): 204 """Asserts there's a deterministic execution order between the operations. 205 206 Args: 207 order: a map from a tf.Operation to its topological order. 208 operations: a list of operations that should be executed sequentially. It 209 can be given in any order. 210 """ 211 # Topological ordering guarantees that, if there's a dependency from N_a to 212 # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies 213 # among the operations, it always goes from a operation with a smaller 214 # topological order to one with a larger topological order. Therefore, we only 215 # need to sort the operations by their topological orders, and verify that 216 # there's a path of dependency between adjacent pairs. 217 operations = sorted(operations, key=lambda op: order[op]) 218 for i in range(len(operations) - 1): 219 if not _exists_dependency(operations[i], operations[i + 1]): 220 print(operations[i].graph.as_graph_def()) 221 raise AssertionError( 222 "No dependency between {} and {}. Graph is dumped to stdout.".format( 223 operations[i].name, operations[i + 1].name)) 224 225 226def get_running_threads(): 227 """Returns a set of all running thread names.""" 228 running_threads = set() 229 for thread in threading.enumerate(): 230 if thread.name is not None: 231 running_threads.add(thread.name) 232 return running_threads 233 234 235def has_thread(prefix, running_threads): 236 """Returns whether any 'running_threads' is prefixed with 'prefix'. 237 238 Args: 239 prefix: The prefix of the expected thread name. 240 running_threads: A collection of the running thread names. 241 """ 242 for thread in running_threads: 243 if thread.startswith(prefix): 244 return True 245 return False 246 247 248def show_backref(target, max_depth=3): 249 """Returns a dot graph of all the objects that are referencing the target. 250 251 A object referencing graph is useful to debug memory leak like circular 252 reference. objgraph provides a good visualization of the memory graph than 253 most python built-in utilities like gc.get_referrers(), which are not 254 human-readable sometimes. 255 256 The dot graph will be written to a string IO object, and can be rendered with 257 graphviz in operating system. 258 E.g. dot -Tpng {$dot_graph} -o output.png 259 Args: 260 target: The target object for the memory graph. 261 max_depth: The maximum depth of the graph. By default 3 layers of references 262 are used. Increases this a lot may result in the graph growing too big. 263 264 Returns: 265 A string that contains the object reference graph. 266 Raises: 267 NotImplementedError: if objgraph is not installed. 268 """ 269 if objgraph is None: 270 raise NotImplementedError("objgraph is not installed.") 271 string_io = io.StringIO() 272 objgraph.show_backrefs(target, max_depth=max_depth, output=string_io) 273 graph = string_io.getvalue() 274 string_io.close() 275 return graph 276 277 278def create_per_replica(strategy, value_list): 279 """Creates a PerReplica of Tensors from the value_list.""" 280 if len(strategy.extended.worker_devices) != len(value_list): 281 raise ValueError( 282 "the length of values must be the same as the number of worker devices") 283 tensors = [] 284 for device, value in zip(strategy.extended.worker_devices, value_list): 285 with ops.device(device): 286 tensors.append(ops.convert_to_tensor(value)) 287 return values.PerReplica(tensors) 288 289 290def is_tpu_strategy(strategy): 291 """Returns whether the strategy is a TPU strategy.""" 292 return isinstance(strategy, 293 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 294 tpu_strategy.TPUStrategyV2)) 295