1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17import contextlib 18import functools 19import threading 20import weakref 21 22from tensorflow.python import pywrap_tfe 23from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 24from tensorflow.python.autograph.impl import api as autograph 25from tensorflow.python.distribute import distribute_lib 26from tensorflow.python.distribute import distribute_utils 27from tensorflow.python.distribute import shared_variable_creator 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import device as tf_device 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import summary_ops_v2 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.training import coordinator 36from tensorflow.python.util import traceback_utils 37 38 39def _is_gpu_device(device): 40 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 41 42 43def call_for_each_replica(strategy, fn, args=None, kwargs=None): 44 """Call `fn` on each worker devices(replica). 45 46 It's highly recommended to wrap the call to this function inside a 47 `tf.function`, otherwise the performance is poor. 48 49 Args: 50 strategy: `tf.distribute.Strategy`. 51 fn: function to call on each worker devices. 52 args: positional arguments to `fn`. 53 kwargs: keyword arguments to `fn`. 54 55 Returns: 56 Wrapped returned value of `fn` from all replicas. 57 """ 58 if args is None: 59 args = () 60 if kwargs is None: 61 kwargs = {} 62 63 if isinstance(fn, def_function.Function): 64 # Don't lift up the tf.function decoration if `fn` is compiled with XLA 65 # and all devices are GPU. In this case we will use collectives to do 66 # cross-device communication, thus no merge_call is in the path. 67 if fn._jit_compile and all( # pylint: disable=protected-access 68 [_is_gpu_device(d) for d in strategy.extended.worker_devices]): 69 return _call_for_each_replica(strategy, fn, args, kwargs) 70 71 if strategy not in _cfer_fn_cache: 72 _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() 73 wrapped = _cfer_fn_cache[strategy].get(fn) 74 if wrapped is None: 75 # We need to wrap fn such that it triggers _call_for_each_replica inside 76 # the tf.function. We use _clone() instead of @tf.function wrapped 77 # call_for_each_replica() because we would like to retain the arguments to 78 # the @tf.function decorator of fn. 79 wrapped = fn._clone( # pylint: disable=protected-access 80 python_function=functools.partial(call_for_each_replica, strategy, 81 fn.python_function)) 82 _cfer_fn_cache[strategy][fn] = wrapped 83 return wrapped(args, kwargs) 84 85 if context.executing_eagerly(): 86 logging.log_first_n( 87 logging.WARN, "Using %s eagerly has significant " 88 "overhead currently. We will be working on improving " 89 "this in the future, but for now please wrap " 90 "`call_for_each_replica` or `experimental_run` or " 91 "`run` inside a tf.function to get " 92 "the best performance." % strategy.__class__.__name__, 5) 93 else: 94 # When a tf.function is wrapped to trigger _call_for_each_replica (see 95 # the other branch above), AutoGraph stops conversion at 96 # _call_for_each_replica itself (TF library functions are allowlisted). 97 # This makes sure that the Python function that originally passed to 98 # the tf.function is still converted. 99 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 100 101 return _call_for_each_replica(strategy, fn, args, kwargs) 102 103 104# Per strategy cache for call_for_each_replica def_function.Function objects. 105_cfer_fn_cache = weakref.WeakKeyDictionary() 106 107 108@contextlib.contextmanager 109def _enter_graph(g, eager, creator_stack=None): 110 """Context manager for selecting a graph and maybe eager mode.""" 111 if eager: 112 with g.as_default(), context.eager_mode(): 113 if creator_stack is not None: 114 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 115 yield 116 else: 117 with g.as_default(): 118 if creator_stack is not None: 119 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 120 yield 121 122 123@contextlib.contextmanager 124def _maybe_enter_eager_mode(eager): 125 if eager: 126 with context.eager_mode(): 127 yield 128 else: 129 yield 130 131 132def _cpu_device(device): 133 cpu_device = tf_device.DeviceSpec.from_string(device) 134 cpu_device = cpu_device.replace(device_type="CPU", device_index=0) 135 return cpu_device.to_string() 136 137 138class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name 139 pass 140 141 142def _get_thread_local_configuration_callable(): 143 if traceback_utils.is_traceback_filtering_enabled(): 144 thread_local_callables = {traceback_utils.enable_traceback_filtering} 145 else: 146 thread_local_callables = {traceback_utils.disable_traceback_filtering} 147 return thread_local_callables 148 149 150def _call_for_each_replica(distribution, fn, args, kwargs): 151 """Run `fn` in separate threads, once per replica/worker device. 152 153 Args: 154 distribution: the DistributionStrategy object. 155 fn: function to run (will be run once per replica, each in its own thread). 156 args: positional arguments for `fn` 157 kwargs: keyword arguments for `fn`. 158 159 Returns: 160 Merged return value of `fn` across all replicas. 161 162 Raises: 163 RuntimeError: If fn() calls get_replica_context().merge_call() a different 164 number of times from the available devices. 165 """ 166 # TODO(josh11b): Add this option once we add synchronization to variable 167 # creation. Until then, this is pretty unsafe to use. 168 run_concurrently = False 169 if not context.executing_eagerly(): 170 # Needed for per-thread device, etc. contexts in graph mode. 171 ops.get_default_graph().switch_to_thread_local() 172 173 coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) 174 175 shared_variable_store = {} 176 devices = distribution.extended.worker_devices 177 178 thread_local_callables = _get_thread_local_configuration_callable() 179 180 # TODO(isaprykin): Create these threads once instead of during every call. 181 threads = [] 182 for index in range(len(devices)): 183 variable_creator_fn = shared_variable_creator.make_fn( 184 shared_variable_store, index) 185 t = _MirroredReplicaThread(distribution, coord, index, devices, 186 variable_creator_fn, fn, 187 distribute_utils.caching_scope_local, 188 distribute_utils.select_replica(index, args), 189 distribute_utils.select_replica(index, kwargs), 190 thread_local_callables) 191 threads.append(t) 192 193 for t in threads: 194 t.start() 195 196 # When `fn` starts `should_run` event is set on _MirroredReplicaThread 197 # (`MRT`) threads. The execution waits until 198 # `MRT.has_paused` is set, which indicates that either `fn` is 199 # complete or a `get_replica_context().merge_call()` is called. If `fn` is 200 # complete, then `MRT.done` is set to True. Otherwise, arguments 201 # of `get_replica_context().merge_call` from all paused threads are grouped 202 # and the `merge_fn` is performed. Results of the 203 # `get_replica_context().merge_call` are then set to `MRT.merge_result`. 204 # Each such `get_replica_context().merge_call` call returns the 205 # `MRT.merge_result` for that thread when `MRT.should_run` event 206 # is reset again. Execution of `fn` resumes. 207 208 try: 209 with coord.stop_on_exception(): 210 all_done = False 211 while not all_done and not coord.should_stop(): 212 done = [] 213 if run_concurrently: 214 for t in threads: 215 t.should_run.set() 216 for t in threads: 217 t.has_paused.wait() 218 t.has_paused.clear() 219 if coord.should_stop(): 220 return None 221 done.append(t.done) 222 else: 223 for t in threads: 224 t.should_run.set() 225 t.has_paused.wait() 226 t.has_paused.clear() 227 if coord.should_stop(): 228 return None 229 done.append(t.done) 230 if coord.should_stop(): 231 return None 232 all_done = all(done) 233 if not all_done: 234 if any(done): 235 raise RuntimeError("Some replicas made a different number of " 236 "replica_context().merge_call() calls.") 237 # get_replica_context().merge_call() case 238 merge_args = distribute_utils.regroup( 239 tuple(t.merge_args for t in threads)) 240 merge_kwargs = distribute_utils.regroup( 241 tuple(t.merge_kwargs for t in threads)) 242 # We capture the name_scope of the MRT when we call merge_fn 243 # to ensure that if we have opened a name scope in the MRT, 244 # it will be respected when executing the merge function. We only 245 # capture the name_scope from the first MRT and assume it is 246 # the same for all other MRTs. 247 mtt_captured_name_scope = threads[0].captured_name_scope 248 mtt_captured_var_scope = threads[0].captured_var_scope 249 # Capture and merge the control dependencies from all the threads. 250 mtt_captured_control_deps = set() 251 for t in threads: 252 mtt_captured_control_deps.update(t.captured_control_deps) 253 254 # Control is transfered from _MirroredReplicaThread (MRT) to the main 255 # thread, i.e., here, to perform `merge_fn`, and thus we preserve the 256 # name scope, control dependencies, etc. from MRT at the time 257 # `merge_call` is made. 258 # One special case is that the `merge_call` is made under an 259 # `tf.init_scope` in the MRT. `tf.init_scope` will clear control 260 # dependencies, pause gradient tape, and enter the lowest context on 261 # the `context_stack` that is not building a graph function. Entering 262 # the lowest context could be one of the two things: installation of a 263 # graph as the default graph or switch into eager mode. If the former 264 # is done and causes `merge_call` to be called in a different graph 265 # from the one in which `call_for_each_replica` is called, we do not 266 # allow this case (see comment in `_merge_call`) and we would not have 267 # arrived here due to the assertion in `_merge_call`. However, if the 268 # latter is done, we want to make sure the main thread enter an eager 269 # mode scope as well so that `merge_fn` does not have trouble 270 # accessing resources defined in MRT under the same context. 271 with ops.name_scope( 272 mtt_captured_name_scope), ops.control_dependencies( 273 mtt_captured_control_deps), variable_scope.variable_scope( 274 mtt_captured_var_scope), _maybe_enter_eager_mode( 275 threads[0].merge_call_entered_in_eager): 276 merge_result = threads[0].merge_fn(distribution, *merge_args, 277 **merge_kwargs) 278 for r, t in enumerate(threads): 279 t.merge_result = distribute_utils.select_replica(r, merge_result) 280 finally: 281 for t in threads: 282 t.should_run.set() 283 coord.join(threads) 284 285 return distribute_utils.regroup(tuple(t.main_result for t in threads)) 286 287 288class _MirroredReplicaThread(threading.Thread): 289 """A thread that runs() a function on a device.""" 290 291 def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn, 292 caching_scope, args, kwargs, thread_local_callables=None): 293 super(_MirroredReplicaThread, self).__init__() 294 self.coord = coord 295 self.distribution = dist 296 self.devices = devices 297 self.replica_id = replica_id 298 self.replica_id_in_sync_group = ( 299 dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access 300 301 self.variable_creator_fn = variable_creator_fn 302 # State needed to run and return the results of `fn`. 303 self.main_fn = fn 304 self.main_args = args 305 self.main_kwargs = kwargs 306 self.main_result = None 307 self.done = False 308 # State needed to run the next merge_call() (if any) requested via 309 # ReplicaContext. 310 self.merge_fn = None 311 self.merge_args = None 312 self.merge_kwargs = None 313 self.merge_result = None 314 self.captured_name_scope = None 315 self.captured_var_scope = None 316 try: 317 self.caching_scope_entered = caching_scope.new_cache_scope_count 318 self.caching_scope_exited = caching_scope.cache_scope_exited_count 319 except AttributeError: 320 self.caching_scope_entered = None 321 self.caching_scope_exited = None 322 323 # We use a thread.Event for the main thread to signal when this 324 # thread should start running (`should_run`), and another for 325 # this thread to transfer control back to the main thread 326 # (`has_paused`, either when it gets to a 327 # `get_replica_context().merge_call` or when `fn` returns). In 328 # either case the event starts cleared, is signaled by calling 329 # set(). The receiving thread waits for the signal by calling 330 # wait() and then immediately clearing the event using clear(). 331 self.should_run = threading.Event() 332 self.has_paused = threading.Event() 333 # These fields have to do with inheriting various contexts from the 334 # parent thread: 335 context.ensure_initialized() 336 ctx = context.context() 337 self.in_eager = ctx.executing_eagerly() 338 self.record_thread_local_summary_state() 339 self.record_thread_local_eager_context_state() 340 self.context_device_policy = ( 341 pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( 342 ctx._context_handle)) # pylint: disable=protected-access 343 self.graph = ops.get_default_graph() 344 with ops.init_scope(): 345 self._init_in_eager = context.executing_eagerly() 346 self._init_graph = ops.get_default_graph() 347 self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access 348 self._var_scope = variable_scope.get_variable_scope() 349 # Adding a "/" at end lets us re-enter this scope later. 350 self._name_scope = self.graph.get_name_scope() 351 if self._name_scope: 352 self._name_scope += "/" 353 if self.replica_id > 0: 354 if not self._name_scope: 355 self._name_scope = "" 356 self._name_scope += "replica_%d/" % self.replica_id 357 358 self._thread_local_callables = thread_local_callables 359 360 def run(self): 361 self.should_run.wait() 362 self.should_run.clear() 363 try: 364 if self.coord.should_stop(): 365 return 366 self.restore_thread_local_summary_state() 367 self.restore_thread_local_callable() 368 self.restore_thread_local_eager_context_state() 369 if (self.caching_scope_entered is not None and 370 self.caching_scope_exited is not None): 371 distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered 372 distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited 373 # TODO(josh11b): Use current logical device instead of 0 here. 374 with self.coord.stop_on_exception(), \ 375 _enter_graph(self._init_graph, self._init_in_eager), \ 376 _enter_graph(self.graph, self.in_eager, 377 self._variable_creator_stack), \ 378 context.device_policy(self.context_device_policy), \ 379 _MirroredReplicaContext(self.distribution, 380 self.replica_id_in_sync_group), \ 381 ops.device(self.devices[self.replica_id]), \ 382 ops.name_scope(self._name_scope), \ 383 variable_scope.variable_scope( 384 self._var_scope, reuse=self.replica_id > 0), \ 385 variable_scope.variable_creator_scope(self.variable_creator_fn): 386 self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) 387 self.done = True 388 finally: 389 self.has_paused.set() 390 391 def record_thread_local_summary_state(self): 392 """Record the thread local summary state in self.""" 393 # TODO(slebedev): is this still relevant? the referenced bug is closed. 394 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 395 self._summary_step = summary_state.step 396 self._summary_writer = summary_state.writer 397 self._summary_recording = summary_state.is_recording 398 self._summary_recording_distribution_strategy = ( 399 summary_state.is_recording_distribution_strategy) 400 401 def restore_thread_local_summary_state(self): 402 """Restore thread local summary state from self.""" 403 # TODO(slebedev): is this still relevant? the referenced bug is closed. 404 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 405 summary_state.step = self._summary_step 406 summary_state.writer = self._summary_writer 407 summary_state.is_recording = self._summary_recording 408 summary_state.is_recording_distribution_strategy = ( 409 self._summary_recording_distribution_strategy) 410 411 def record_thread_local_eager_context_state(self): 412 ctx = context.context() 413 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 414 self._eager_context_op_callbacks = eager_context_state.op_callbacks 415 # TODO(b/125892694): record other fields in EagerContext. 416 417 def restore_thread_local_eager_context_state(self): 418 ctx = context.context() 419 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 420 eager_context_state.op_callbacks = self._eager_context_op_callbacks 421 # TODO(b/125892694): record other fields in EagerContext. 422 423 def restore_thread_local_callable(self): 424 if self._thread_local_callables: 425 for fn in self._thread_local_callables: 426 fn() 427 428 429class _MirroredReplicaContext(distribute_lib.ReplicaContext): 430 """ReplicaContext for synchronized replica.""" 431 432 def _merge_call(self, fn, args, kwargs): 433 """`merge_call()` implementation for synchronized replica. 434 435 This pauses the current replica thread and passes `fn` and its arguments to 436 the main thread. The main thread will wait until all replicas pause, then 437 invoke `fn` with grouped arguments. The current replica thread will continue 438 after `fn` completes. 439 440 See `_call_for_each_replica` for the logic in the main thread. 441 442 Args: 443 fn: a function that is called in cross replica context with grouped 444 arguments from each replica. `fn` should returns grouped values. 445 args: positional arguments to `fn`. 446 kwargs: keyward arguments to `fn`. 447 448 Returns: 449 Return value of `fn` for the current replica. 450 451 Raises: 452 RuntimeError: when merge_call happens in a different graph, e.g. in a 453 different tf.function, which is not supported now. 454 _RequestedStop: when stop is requested. 455 456 """ 457 t = threading.current_thread() 458 assert isinstance(t, _MirroredReplicaThread) 459 t.merge_fn = fn 460 t.merge_args = args 461 t.merge_kwargs = kwargs 462 t.captured_name_scope = t.graph.get_name_scope() 463 # Adding a "/" at end lets us re-enter this scope later. 464 if t.captured_name_scope: 465 t.captured_name_scope += "/" 466 467 t.captured_var_scope = variable_scope.get_variable_scope() 468 t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access 469 470 t.merge_call_entered_in_eager = context.context().executing_eagerly() 471 472 # It is problematic if `merge_call` is called under a different graph other 473 # than the one that `_call_for_each_replica` is called under, there are 474 # 3 cases this can happen: 475 # 476 # 1. The `fn` passed to `_call_for_each_replica` is decorated with 477 # `tf.function` and there is a `merge_call` in `fn`. Since 478 # MirroredStrategy traces a separate function per thread (per device), 479 # and each trace takes a shared lock, the lock is never released by the 480 # first thread and subsequent replica threads cannot proceed to trace 481 # their own functions. This issue is addressed by always converting 482 # `_call_for_each_replica(tf.function(f))` to 483 # ``tf.function(_call_for_each_replica(f))`.` in 484 # `MirroredStrategy._call_for_each_replica`. 485 # 486 # 2. The `fn` passed to `_call_for_each_replica` contains a nested 487 # `tf.function`, and there is a `merge_call` in the nested `tf.function`. 488 # In this case each thread can successfully trace its own function, but 489 # since the `merge_fn` passed to `merge_call` is executed in the main 490 # thread (where `_call_for_each_replica` is executed), it can't access 491 # the tensors that come from different graphs. 492 # 493 # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow 494 # statement, and there is a `merge_call` inside the control-flow body, 495 # `fn` or `_call_for_each_replica` is decorated with `tf.function`. 496 # Control flow statement creates a separate graph for its body, similar 497 # to #2, `merge_fn` executed in the main thread can't access the 498 # tensors that come from different graphs. 499 # 500 # We raise an error for #2 and #3. 501 if ops.get_default_graph() != t.graph: 502 raise RuntimeError( 503 "`merge_call` called while defining a new graph or a tf.function." 504 " This can often happen if the function `fn` passed to" 505 " `strategy.run()` contains a nested `@tf.function`, and the nested " 506 "`@tf.function` contains a synchronization point, such as aggregating" 507 " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" 508 " uses a control flow statement which contains a synchronization" 509 " point in the body. Such behaviors are not yet supported. Instead," 510 " please avoid nested `tf.function`s or control flow statements that" 511 " may potentially cross a synchronization boundary, for example," 512 " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" 513 " inside a `tf.function` or move the control flow out of `fn`. If" 514 " you are subclassing a `tf.keras.Model`, please avoid decorating" 515 " overridden methods `test_step` and `train_step` in `tf.function`.") 516 517 t.has_paused.set() 518 t.should_run.wait() 519 t.should_run.clear() 520 if t.coord.should_stop(): 521 raise _RequestedStop() 522 t.merge_call_entered_in_eager = None 523 return t.merge_result 524 525 @property 526 def devices(self): 527 distribute_lib.require_replica_context(self) 528 return [ 529 self._strategy.extended.worker_devices_by_replica[ 530 self._replica_id_in_sync_group] 531 ] 532