xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/coordinator/coordinator_context.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The execution context for ClusterCoordinator."""
16
17import contextlib
18import threading
19
20from tensorflow.python.util.lazy_loader import LazyLoader
21
22# There is a circular dependency between this and the `cluster_coordinator`
23# module. So we load it lazily to work around this.
24cluster_coordinator = LazyLoader(
25    "cluster_coordinator", globals(),
26    "tensorflow.python.distribute.coordinator.cluster_coordinator"
27)
28
29_dispatch_context = threading.local()
30
31
32def get_current_dispatch_context():
33  try:
34    return _dispatch_context.current
35  except AttributeError:
36    return None
37
38
39@contextlib.contextmanager
40def with_dispatch_context(worker_obj):
41  previous_context = getattr(_dispatch_context, "current", None)
42  _dispatch_context.current = DispatchContext(worker_obj)
43  yield
44  _dispatch_context.current = previous_context
45
46
47class DispatchContext(object):
48  """Context entered when executing a closure on a given worker."""
49
50  def __init__(self, worker_obj):
51    self._worker = worker_obj
52    self._worker_index = worker_obj.worker_index
53
54  @property
55  def worker(self):
56    return self._worker
57
58  @property
59  def worker_index(self):
60    return self._worker_index
61
62  def maybe_rebuild_remote_values(self, remote_value):
63    e = (
64        cluster_coordinator._maybe_rebuild_remote_values(  # pylint: disable=protected-access
65            self._worker, remote_value))
66    if e:
67      if not isinstance(e, cluster_coordinator.ClosureInputError):
68        e = cluster_coordinator.ClosureInputError(e)
69      raise e
70
71  def maybe_get_remote_value(self, ret):
72    return cluster_coordinator._maybe_get_remote_value(ret)  # pylint: disable=protected-access
73