1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The execution context for ClusterCoordinator.""" 16 17import contextlib 18import threading 19 20from tensorflow.python.util.lazy_loader import LazyLoader 21 22# There is a circular dependency between this and the `cluster_coordinator` 23# module. So we load it lazily to work around this. 24cluster_coordinator = LazyLoader( 25 "cluster_coordinator", globals(), 26 "tensorflow.python.distribute.coordinator.cluster_coordinator" 27) 28 29_dispatch_context = threading.local() 30 31 32def get_current_dispatch_context(): 33 try: 34 return _dispatch_context.current 35 except AttributeError: 36 return None 37 38 39@contextlib.contextmanager 40def with_dispatch_context(worker_obj): 41 previous_context = getattr(_dispatch_context, "current", None) 42 _dispatch_context.current = DispatchContext(worker_obj) 43 yield 44 _dispatch_context.current = previous_context 45 46 47class DispatchContext(object): 48 """Context entered when executing a closure on a given worker.""" 49 50 def __init__(self, worker_obj): 51 self._worker = worker_obj 52 self._worker_index = worker_obj.worker_index 53 54 @property 55 def worker(self): 56 return self._worker 57 58 @property 59 def worker_index(self): 60 return self._worker_index 61 62 def maybe_rebuild_remote_values(self, remote_value): 63 e = ( 64 cluster_coordinator._maybe_rebuild_remote_values( # pylint: disable=protected-access 65 self._worker, remote_value)) 66 if e: 67 if not isinstance(e, cluster_coordinator.ClosureInputError): 68 e = cluster_coordinator.ClosureInputError(e) 69 raise e 70 71 def maybe_get_remote_value(self, ret): 72 return cluster_coordinator._maybe_get_remote_value(ret) # pylint: disable=protected-access 73