xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/mirrored_run.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
16
17import contextlib
18import functools
19import threading
20import weakref
21
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
24from tensorflow.python.autograph.impl import api as autograph
25from tensorflow.python.distribute import distribute_lib
26from tensorflow.python.distribute import distribute_utils
27from tensorflow.python.distribute import shared_variable_creator
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import device as tf_device
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import summary_ops_v2
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.training import coordinator
36from tensorflow.python.util import traceback_utils
37
38
39def _is_gpu_device(device):
40  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
41
42
43def call_for_each_replica(strategy, fn, args=None, kwargs=None):
44  """Call `fn` on each worker devices(replica).
45
46  It's highly recommended to wrap the call to this function inside a
47  `tf.function`, otherwise the performance is poor.
48
49  Args:
50    strategy: `tf.distribute.Strategy`.
51    fn: function to call on each worker devices.
52    args: positional arguments to `fn`.
53    kwargs: keyword arguments to `fn`.
54
55  Returns:
56    Wrapped returned value of `fn` from all replicas.
57  """
58  if args is None:
59    args = ()
60  if kwargs is None:
61    kwargs = {}
62
63  if isinstance(fn, def_function.Function):
64    # Don't lift up the tf.function decoration if `fn` is compiled with XLA
65    # and all devices are GPU. In this case we will use collectives to do
66    # cross-device communication, thus no merge_call is in the path.
67    if fn._jit_compile and all(  # pylint: disable=protected-access
68        [_is_gpu_device(d) for d in strategy.extended.worker_devices]):
69      return _call_for_each_replica(strategy, fn, args, kwargs)
70
71    if strategy not in _cfer_fn_cache:
72      _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
73    wrapped = _cfer_fn_cache[strategy].get(fn)
74    if wrapped is None:
75      # We need to wrap fn such that it triggers _call_for_each_replica inside
76      # the tf.function. We use _clone() instead of @tf.function wrapped
77      # call_for_each_replica() because we would like to retain the arguments to
78      # the @tf.function decorator of fn.
79      wrapped = fn._clone(  # pylint: disable=protected-access
80          python_function=functools.partial(call_for_each_replica, strategy,
81                                            fn.python_function))
82      _cfer_fn_cache[strategy][fn] = wrapped
83    return wrapped(args, kwargs)
84
85  if context.executing_eagerly():
86    logging.log_first_n(
87        logging.WARN, "Using %s eagerly has significant "
88        "overhead currently. We will be working on improving "
89        "this in the future, but for now please wrap "
90        "`call_for_each_replica` or `experimental_run` or "
91        "`run` inside a tf.function to get "
92        "the best performance." % strategy.__class__.__name__, 5)
93  else:
94    # When a tf.function is wrapped to trigger _call_for_each_replica (see
95    # the other branch above), AutoGraph stops conversion at
96    # _call_for_each_replica itself (TF library functions are allowlisted).
97    # This makes sure that the Python function that originally passed to
98    # the tf.function is still converted.
99    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
100
101  return _call_for_each_replica(strategy, fn, args, kwargs)
102
103
104# Per strategy cache for call_for_each_replica def_function.Function objects.
105_cfer_fn_cache = weakref.WeakKeyDictionary()
106
107
108@contextlib.contextmanager
109def _enter_graph(g, eager, creator_stack=None):
110  """Context manager for selecting a graph and maybe eager mode."""
111  if eager:
112    with g.as_default(), context.eager_mode():
113      if creator_stack is not None:
114        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
115      yield
116  else:
117    with g.as_default():
118      if creator_stack is not None:
119        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
120      yield
121
122
123@contextlib.contextmanager
124def _maybe_enter_eager_mode(eager):
125  if eager:
126    with context.eager_mode():
127      yield
128  else:
129    yield
130
131
132def _cpu_device(device):
133  cpu_device = tf_device.DeviceSpec.from_string(device)
134  cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
135  return cpu_device.to_string()
136
137
138class _RequestedStop(Exception):  # pylint: disable=g-bad-exception-name
139  pass
140
141
142def _get_thread_local_configuration_callable():
143  if traceback_utils.is_traceback_filtering_enabled():
144    thread_local_callables = {traceback_utils.enable_traceback_filtering}
145  else:
146    thread_local_callables = {traceback_utils.disable_traceback_filtering}
147  return thread_local_callables
148
149
150def _call_for_each_replica(distribution, fn, args, kwargs):
151  """Run `fn` in separate threads, once per replica/worker device.
152
153  Args:
154    distribution: the DistributionStrategy object.
155    fn: function to run (will be run once per replica, each in its own thread).
156    args: positional arguments for `fn`
157    kwargs: keyword arguments for `fn`.
158
159  Returns:
160    Merged return value of `fn` across all replicas.
161
162  Raises:
163    RuntimeError: If fn() calls get_replica_context().merge_call() a different
164        number of times from the available devices.
165  """
166  # TODO(josh11b): Add this option once we add synchronization to variable
167  # creation. Until then, this is pretty unsafe to use.
168  run_concurrently = False
169  if not context.executing_eagerly():
170    # Needed for per-thread device, etc. contexts in graph mode.
171    ops.get_default_graph().switch_to_thread_local()
172
173  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
174
175  shared_variable_store = {}
176  devices = distribution.extended.worker_devices
177
178  thread_local_callables = _get_thread_local_configuration_callable()
179
180  # TODO(isaprykin): Create these threads once instead of during every call.
181  threads = []
182  for index in range(len(devices)):
183    variable_creator_fn = shared_variable_creator.make_fn(
184        shared_variable_store, index)
185    t = _MirroredReplicaThread(distribution, coord, index, devices,
186                               variable_creator_fn, fn,
187                               distribute_utils.caching_scope_local,
188                               distribute_utils.select_replica(index, args),
189                               distribute_utils.select_replica(index, kwargs),
190                               thread_local_callables)
191    threads.append(t)
192
193  for t in threads:
194    t.start()
195
196  # When `fn` starts `should_run` event is set on _MirroredReplicaThread
197  # (`MRT`) threads. The execution waits until
198  # `MRT.has_paused` is set, which indicates that either `fn` is
199  # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
200  # complete, then `MRT.done` is set to True.  Otherwise, arguments
201  # of `get_replica_context().merge_call` from all paused threads are grouped
202  # and the `merge_fn` is performed.  Results of the
203  # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
204  # Each such `get_replica_context().merge_call` call returns the
205  # `MRT.merge_result` for that thread when `MRT.should_run` event
206  # is reset again. Execution of `fn` resumes.
207
208  try:
209    with coord.stop_on_exception():
210      all_done = False
211      while not all_done and not coord.should_stop():
212        done = []
213        if run_concurrently:
214          for t in threads:
215            t.should_run.set()
216          for t in threads:
217            t.has_paused.wait()
218            t.has_paused.clear()
219            if coord.should_stop():
220              return None
221            done.append(t.done)
222        else:
223          for t in threads:
224            t.should_run.set()
225            t.has_paused.wait()
226            t.has_paused.clear()
227            if coord.should_stop():
228              return None
229            done.append(t.done)
230        if coord.should_stop():
231          return None
232        all_done = all(done)
233        if not all_done:
234          if any(done):
235            raise RuntimeError("Some replicas made a different number of "
236                               "replica_context().merge_call() calls.")
237          # get_replica_context().merge_call() case
238          merge_args = distribute_utils.regroup(
239              tuple(t.merge_args for t in threads))
240          merge_kwargs = distribute_utils.regroup(
241              tuple(t.merge_kwargs for t in threads))
242          # We capture the name_scope of the MRT when we call merge_fn
243          # to ensure that if we have opened a name scope in the MRT,
244          # it will be respected when executing the merge function. We only
245          # capture the name_scope from the first MRT and assume it is
246          # the same for all other MRTs.
247          mtt_captured_name_scope = threads[0].captured_name_scope
248          mtt_captured_var_scope = threads[0].captured_var_scope
249          # Capture and merge the control dependencies from all the threads.
250          mtt_captured_control_deps = set()
251          for t in threads:
252            mtt_captured_control_deps.update(t.captured_control_deps)
253
254          # Control is transfered from _MirroredReplicaThread (MRT) to the main
255          # thread, i.e., here, to perform `merge_fn`, and thus we preserve the
256          # name scope,  control dependencies, etc. from MRT at the time
257          # `merge_call` is made.
258          # One special case is that the `merge_call` is made under an
259          # `tf.init_scope` in the MRT. `tf.init_scope` will clear control
260          # dependencies, pause gradient tape, and enter the lowest context on
261          # the `context_stack` that is not building a graph function. Entering
262          # the lowest context could be one of the two things: installation of a
263          # graph as the default graph or switch into eager mode. If the former
264          # is done and causes `merge_call` to be called in a different graph
265          # from the one in which `call_for_each_replica` is called, we do not
266          # allow this case (see comment in `_merge_call`) and we would not have
267          # arrived here due to the assertion in `_merge_call`. However, if the
268          # latter is done, we want to make sure the main thread enter an eager
269          # mode scope as well so that `merge_fn` does not have trouble
270          # accessing resources defined in MRT under the same context.
271          with ops.name_scope(
272              mtt_captured_name_scope), ops.control_dependencies(
273                  mtt_captured_control_deps), variable_scope.variable_scope(
274                      mtt_captured_var_scope), _maybe_enter_eager_mode(
275                          threads[0].merge_call_entered_in_eager):
276            merge_result = threads[0].merge_fn(distribution, *merge_args,
277                                               **merge_kwargs)
278          for r, t in enumerate(threads):
279            t.merge_result = distribute_utils.select_replica(r, merge_result)
280  finally:
281    for t in threads:
282      t.should_run.set()
283    coord.join(threads)
284
285  return distribute_utils.regroup(tuple(t.main_result for t in threads))
286
287
288class _MirroredReplicaThread(threading.Thread):
289  """A thread that runs() a function on a device."""
290
291  def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn,
292               caching_scope, args, kwargs, thread_local_callables=None):
293    super(_MirroredReplicaThread, self).__init__()
294    self.coord = coord
295    self.distribution = dist
296    self.devices = devices
297    self.replica_id = replica_id
298    self.replica_id_in_sync_group = (
299        dist.extended._get_replica_id_in_sync_group(replica_id))  # pylint: disable=protected-access
300
301    self.variable_creator_fn = variable_creator_fn
302    # State needed to run and return the results of `fn`.
303    self.main_fn = fn
304    self.main_args = args
305    self.main_kwargs = kwargs
306    self.main_result = None
307    self.done = False
308    # State needed to run the next merge_call() (if any) requested via
309    # ReplicaContext.
310    self.merge_fn = None
311    self.merge_args = None
312    self.merge_kwargs = None
313    self.merge_result = None
314    self.captured_name_scope = None
315    self.captured_var_scope = None
316    try:
317      self.caching_scope_entered = caching_scope.new_cache_scope_count
318      self.caching_scope_exited = caching_scope.cache_scope_exited_count
319    except AttributeError:
320      self.caching_scope_entered = None
321      self.caching_scope_exited = None
322
323    # We use a thread.Event for the main thread to signal when this
324    # thread should start running (`should_run`), and another for
325    # this thread to transfer control back to the main thread
326    # (`has_paused`, either when it gets to a
327    # `get_replica_context().merge_call` or when `fn` returns). In
328    # either case the event starts cleared, is signaled by calling
329    # set(). The receiving thread waits for the signal by calling
330    # wait() and then immediately clearing the event using clear().
331    self.should_run = threading.Event()
332    self.has_paused = threading.Event()
333    # These fields have to do with inheriting various contexts from the
334    # parent thread:
335    context.ensure_initialized()
336    ctx = context.context()
337    self.in_eager = ctx.executing_eagerly()
338    self.record_thread_local_summary_state()
339    self.record_thread_local_eager_context_state()
340    self.context_device_policy = (
341        pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
342            ctx._context_handle))  # pylint: disable=protected-access
343    self.graph = ops.get_default_graph()
344    with ops.init_scope():
345      self._init_in_eager = context.executing_eagerly()
346      self._init_graph = ops.get_default_graph()
347    self._variable_creator_stack = self.graph._variable_creator_stack[:]  # pylint: disable=protected-access
348    self._var_scope = variable_scope.get_variable_scope()
349    # Adding a "/" at end lets us re-enter this scope later.
350    self._name_scope = self.graph.get_name_scope()
351    if self._name_scope:
352      self._name_scope += "/"
353    if self.replica_id > 0:
354      if not self._name_scope:
355        self._name_scope = ""
356      self._name_scope += "replica_%d/" % self.replica_id
357
358    self._thread_local_callables = thread_local_callables
359
360  def run(self):
361    self.should_run.wait()
362    self.should_run.clear()
363    try:
364      if self.coord.should_stop():
365        return
366      self.restore_thread_local_summary_state()
367      self.restore_thread_local_callable()
368      self.restore_thread_local_eager_context_state()
369      if (self.caching_scope_entered is not None and
370          self.caching_scope_exited is not None):
371        distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered
372        distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited
373      # TODO(josh11b): Use current logical device instead of 0 here.
374      with self.coord.stop_on_exception(), \
375          _enter_graph(self._init_graph, self._init_in_eager), \
376          _enter_graph(self.graph, self.in_eager,
377                       self._variable_creator_stack), \
378          context.device_policy(self.context_device_policy), \
379          _MirroredReplicaContext(self.distribution,
380                                  self.replica_id_in_sync_group), \
381          ops.device(self.devices[self.replica_id]), \
382          ops.name_scope(self._name_scope), \
383          variable_scope.variable_scope(
384              self._var_scope, reuse=self.replica_id > 0), \
385          variable_scope.variable_creator_scope(self.variable_creator_fn):
386        self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
387        self.done = True
388    finally:
389      self.has_paused.set()
390
391  def record_thread_local_summary_state(self):
392    """Record the thread local summary state in self."""
393    # TODO(slebedev): is this still relevant? the referenced bug is closed.
394    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
395    self._summary_step = summary_state.step
396    self._summary_writer = summary_state.writer
397    self._summary_recording = summary_state.is_recording
398    self._summary_recording_distribution_strategy = (
399        summary_state.is_recording_distribution_strategy)
400
401  def restore_thread_local_summary_state(self):
402    """Restore thread local summary state from self."""
403    # TODO(slebedev): is this still relevant? the referenced bug is closed.
404    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
405    summary_state.step = self._summary_step
406    summary_state.writer = self._summary_writer
407    summary_state.is_recording = self._summary_recording
408    summary_state.is_recording_distribution_strategy = (
409        self._summary_recording_distribution_strategy)
410
411  def record_thread_local_eager_context_state(self):
412    ctx = context.context()
413    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
414    self._eager_context_op_callbacks = eager_context_state.op_callbacks
415    # TODO(b/125892694): record other fields in EagerContext.
416
417  def restore_thread_local_eager_context_state(self):
418    ctx = context.context()
419    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
420    eager_context_state.op_callbacks = self._eager_context_op_callbacks
421    # TODO(b/125892694): record other fields in EagerContext.
422
423  def restore_thread_local_callable(self):
424    if self._thread_local_callables:
425      for fn in self._thread_local_callables:
426        fn()
427
428
429class _MirroredReplicaContext(distribute_lib.ReplicaContext):
430  """ReplicaContext for synchronized replica."""
431
432  def _merge_call(self, fn, args, kwargs):
433    """`merge_call()` implementation for synchronized replica.
434
435    This pauses the current replica thread and passes `fn` and its arguments to
436    the main thread. The main thread will wait until all replicas pause, then
437    invoke `fn` with grouped arguments. The current replica thread will continue
438    after `fn` completes.
439
440    See `_call_for_each_replica` for the logic in the main thread.
441
442    Args:
443      fn: a function that is called in cross replica context with grouped
444        arguments from each replica. `fn` should returns grouped values.
445      args: positional arguments to `fn`.
446      kwargs: keyward arguments to `fn`.
447
448    Returns:
449      Return value of `fn` for the current replica.
450
451    Raises:
452      RuntimeError: when merge_call happens in a different graph, e.g. in a
453        different tf.function, which is not supported now.
454      _RequestedStop: when stop is requested.
455
456    """
457    t = threading.current_thread()
458    assert isinstance(t, _MirroredReplicaThread)
459    t.merge_fn = fn
460    t.merge_args = args
461    t.merge_kwargs = kwargs
462    t.captured_name_scope = t.graph.get_name_scope()
463    # Adding a "/" at end lets us re-enter this scope later.
464    if t.captured_name_scope:
465      t.captured_name_scope += "/"
466
467    t.captured_var_scope = variable_scope.get_variable_scope()
468    t.captured_control_deps = t.graph._current_control_dependencies()  # pylint: disable=protected-access
469
470    t.merge_call_entered_in_eager = context.context().executing_eagerly()
471
472    # It is problematic if `merge_call` is called under a different graph other
473    # than the one that `_call_for_each_replica` is called under, there are
474    # 3 cases this can happen:
475    #
476    #   1. The `fn` passed to `_call_for_each_replica` is decorated with
477    #   `tf.function` and there is a `merge_call` in `fn`. Since
478    #   MirroredStrategy traces a separate function per thread (per device),
479    #   and each trace takes a shared lock, the lock is never released by the
480    #   first thread and subsequent replica threads cannot proceed to trace
481    #   their own functions. This issue is addressed by always converting
482    #   `_call_for_each_replica(tf.function(f))` to
483    #   ``tf.function(_call_for_each_replica(f))`.` in
484    #   `MirroredStrategy._call_for_each_replica`.
485    #
486    #   2. The `fn` passed to `_call_for_each_replica` contains a nested
487    #   `tf.function`, and there is a `merge_call` in the nested `tf.function`.
488    #   In this case each thread can successfully trace its own function, but
489    #   since the `merge_fn` passed to `merge_call` is executed in the main
490    #   thread (where `_call_for_each_replica` is executed), it can't access
491    #   the tensors that come from different graphs.
492    #
493    #   3. The `fn` passed to `_call_for_each_replica` contains a control-flow
494    #   statement, and there is a `merge_call` inside the control-flow body,
495    #   `fn` or `_call_for_each_replica` is decorated with `tf.function`.
496    #   Control flow statement creates a separate graph for its body, similar
497    #   to #2, `merge_fn` executed in the main thread can't access the
498    #   tensors that come from different graphs.
499    #
500    #   We raise an error for #2 and #3.
501    if ops.get_default_graph() != t.graph:
502      raise RuntimeError(
503          "`merge_call` called while defining a new graph or a tf.function."
504          " This can often happen if the function `fn` passed to"
505          " `strategy.run()` contains a nested `@tf.function`, and the nested "
506          "`@tf.function` contains a synchronization point, such as aggregating"
507          " gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
508          " uses a control flow statement which contains a synchronization"
509          " point in the body. Such behaviors are not yet supported. Instead,"
510          " please avoid nested `tf.function`s or control flow statements that"
511          " may potentially cross a synchronization boundary, for example,"
512          " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
513          " inside a `tf.function` or move the control flow out of `fn`. If"
514          " you are subclassing a `tf.keras.Model`, please avoid decorating"
515          " overridden methods `test_step` and `train_step` in `tf.function`.")
516
517    t.has_paused.set()
518    t.should_run.wait()
519    t.should_run.clear()
520    if t.coord.should_stop():
521      raise _RequestedStop()
522    t.merge_call_entered_in_eager = None
523    return t.merge_result
524
525  @property
526  def devices(self):
527    distribute_lib.require_replica_context(self)
528    return [
529        self._strategy.extended.worker_devices_by_replica[
530            self._replica_id_in_sync_group]
531    ]
532