1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values for PS.""" 16 17import contextlib 18import copy 19import threading 20import weakref 21 22import numpy as np 23 24from tensorflow.python.distribute import distribute_lib 25from tensorflow.python.distribute import distribute_utils 26from tensorflow.python.distribute import distribution_strategy_context as ds_context 27from tensorflow.python.distribute import values 28from tensorflow.python.distribute import values_util 29from tensorflow.python.distribute.coordinator import coordinator_context 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import lookup_ops 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import variable_scope as vs 38from tensorflow.python.saved_model import save_context 39from tensorflow.python.trackable import base as trackable 40from tensorflow.python.types import core 41from tensorflow.python.util.lazy_loader import LazyLoader 42 43load_context = LazyLoader( 44 "load_context", globals(), 45 "tensorflow.python.keras.saving.saved_model.load_context" 46) 47 48TRACKABLE_RESOURCE_METHODS = [ 49 "_create_resource", "_initialize", "_destroy_resource" 50] 51 52 53# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy. 54class AggregatingVariable(resource_variable_ops.BaseResourceVariable, 55 core.Tensor): 56 """A wrapper around a variable that aggregates updates across replicas.""" 57 58 def __init__(self, strategy, v, aggregation): 59 self._distribute_strategy = strategy 60 self._v = v 61 # NOTE: We don't use "_distributed_container" here because we don't want 62 # to trigger that code path in regroup(). 63 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 64 self._aggregation = aggregation 65 66 def __deepcopy__(self, memo): 67 """Perform a deepcopy of the `AggregatingVariable`. 68 69 Unlike the deepcopy of a regular tf.Variable, this keeps the original 70 strategy and devices of the `AggregatingVariable`. To avoid confusion 71 with the behavior of deepcopy on a regular `Variable` (which does 72 copy into new devices), we only allow a deepcopy of a `AggregatingVariable` 73 within its originating strategy scope. 74 75 Args: 76 memo: The memoization object for `deepcopy`. 77 78 Returns: 79 A deep copy of the current `AggregatingVariable`. 80 81 Raises: 82 RuntimeError: If trying to deepcopy into a different strategy. 83 """ 84 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 85 v = copy.deepcopy(self._v, memo) 86 87 copied_variable = type(self)( 88 strategy=self._distribute_strategy, 89 v=v, 90 aggregation=self._aggregation) 91 92 memo[id(self)] = copied_variable 93 94 return copied_variable 95 96 def get(self): 97 return self._v 98 99 @property 100 def distribute_strategy(self): 101 return self._distribute_strategy 102 103 def __getattr__(self, name): 104 return getattr(self._v, name) 105 106 def _assign_func(self, *args, **kwargs): 107 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 108 f = kwargs.pop("f") 109 if ds_context.in_cross_replica_context(): 110 if distribute_lib.get_update_replica_id() is not None: 111 # We are calling an assign function in an update context. 112 return f(self._v, *args, **kwargs) 113 114 # We are calling an assign function in cross replica context, wrap it in 115 # an update call. 116 return self._distribute_strategy.extended.update( 117 self, f, args=args, kwargs=kwargs) 118 else: 119 replica_context = ds_context.get_replica_context() 120 assert replica_context 121 # We are calling an assign function in replica context. 122 # We reduce the value we want to assign/add/sub. More details about how 123 # we handle the different use cases can be found in the _reduce method. 124 # We call the function with the reduced value. 125 if self._aggregation == vs.VariableAggregation.NONE: 126 raise ValueError( 127 values_util.aggregation_error_msg.format( 128 variable_type="AggregatingVariable")) 129 130 def merge_fn(strategy, 131 value, 132 use_locking=False, 133 name=None, 134 read_value=True): 135 v = values_util.apply_aggregation(strategy, value, self._aggregation, 136 self) 137 if name and isinstance(name, values.PerReplica): 138 name = name.values[0] 139 return strategy.extended.update( 140 self, 141 f, 142 args=(v,), 143 kwargs={ 144 "use_locking": use_locking, 145 "name": name, 146 "read_value": read_value 147 }) 148 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 149 150 def assign_sub(self, *args, **kwargs): 151 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 152 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 153 154 def assign_add(self, *args, **kwargs): 155 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 156 return self._assign_func(f=assign_add_fn, *args, **kwargs) 157 158 def assign(self, *args, **kwargs): 159 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 160 return self._assign_func(f=assign_fn, *args, **kwargs) 161 162 @property 163 def initializer(self): 164 return self._v.initializer 165 166 def initialized_value(self): 167 return self._v.initialized_value() 168 169 @property 170 def initial_value(self): 171 return self._v.initial_value 172 173 @property 174 def op(self): 175 return self._v.op 176 177 def value(self): 178 return self._v.value() 179 180 def read_value(self): 181 return self._v.read_value() 182 183 def sparse_read(self, indices, name=None): 184 return self._v.sparse_read(indices, name=name) 185 186 def eval(self, session=None): 187 return self._v.eval(session) 188 189 @property 190 def graph(self): 191 return self._v.graph 192 193 @property 194 def device(self): 195 return self._v.device 196 197 @property 198 def shape(self): 199 return self._v.shape 200 201 @property 202 def aggregation(self): 203 return self._aggregation 204 205 @property 206 def synchronization(self): 207 return self._v.synchronization 208 209 @property 210 def name(self): 211 return self._v.name 212 213 @property 214 def trainable(self): 215 return self._v.trainable 216 217 @property 218 def dtype(self): 219 return self._v.dtype 220 221 # TODO(josh11b): Test saving & restoring. 222 def _gather_saveables_for_checkpoint(self): 223 if isinstance(self._v, CachingVariable): 224 return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access 225 return {trackable.VARIABLE_VALUE_KEY: self._v} 226 227 def _map_resources(self, save_options): 228 """For implementing `Trackable`.""" 229 # By delegating this method to the wrapped variable, SavedModel with 230 # AggregatingVariable are identical to SavedModel with normal variables. 231 obj_map, resource_map = self._v._map_resources(save_options) # pylint:disable=protected-access 232 obj_map[self] = obj_map[self._v] 233 return obj_map, resource_map 234 235 # pylint: disable=multiple-statements 236 def __add__(self, o): 237 return self._v + o 238 239 def __radd__(self, o): 240 return o + self._v 241 242 def __sub__(self, o): 243 return self._v - o 244 245 def __rsub__(self, o): 246 return o - self._v 247 248 def __mul__(self, o): 249 return self._v * o 250 251 def __rmul__(self, o): 252 return o * self._v 253 254 def __truediv__(self, o): 255 return self._v / o 256 257 def __rtruediv__(self, o): 258 return o / self._v 259 260 def __floordiv__(self, o): 261 return self._v // o 262 263 def __rfloordiv__(self, o): 264 return o // self._v 265 266 def __mod__(self, o): 267 return self._v % o 268 269 def __rmod__(self, o): 270 return o % self._v 271 272 def __lt__(self, o): 273 return self._v < o 274 275 def __le__(self, o): 276 return self._v <= o 277 278 def __gt__(self, o): 279 return self._v > o 280 281 def __ge__(self, o): 282 return self._v >= o 283 284 def __and__(self, o): 285 return self._v & o 286 287 def __rand__(self, o): 288 return o & self._v 289 290 def __or__(self, o): 291 return self._v | o 292 293 def __ror__(self, o): 294 return o | self._v 295 296 def __xor__(self, o): 297 return self._v ^ o 298 299 def __rxor__(self, o): 300 return o ^ self._v 301 302 def __getitem__(self, o): 303 return self._v[o] 304 305 def __pow__(self, o, modulo=None): 306 return pow(self._v, o, modulo) 307 308 def __rpow__(self, o): 309 return pow(o, self._v) 310 311 def __invert__(self): 312 return ~self._v 313 314 def __neg__(self): 315 return -self._v 316 317 def __abs__(self): 318 return abs(self._v) 319 320 def __div__(self, o): 321 try: 322 return self._v.__div__(o) 323 except AttributeError: 324 # See https://docs.python.org/3/library/constants.html#NotImplemented 325 return NotImplemented 326 327 def __rdiv__(self, o): 328 try: 329 return self._v.__rdiv__(o) 330 except AttributeError: 331 # See https://docs.python.org/3/library/constants.html#NotImplemented 332 return NotImplemented 333 334 def __matmul__(self, o): 335 try: 336 return self._v.__matmul__(o) 337 except AttributeError: 338 # See https://docs.python.org/3/library/constants.html#NotImplemented 339 return NotImplemented 340 341 def __rmatmul__(self, o): 342 try: 343 return self._v.__rmatmul__(o) 344 except AttributeError: 345 # See https://docs.python.org/3/library/constants.html#NotImplemented 346 return NotImplemented 347 348 def __str__(self): 349 return str(self._v) 350 351 def __repr__(self): 352 return repr(self._v) 353 354 def _should_act_as_resource_variable(self): 355 """Pass resource_variable_ops.is_resource_variable check.""" 356 pass 357 358 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 359 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 360 361 362class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor): 363 """A wrapper around a variable that caches read value locally.""" 364 365 def __init__(self, v): 366 self._v = v 367 self._cache = None 368 self._current_new_cache_scope_count = 0 369 370 def get(self): 371 return self._v 372 373 def __getattr__(self, name): 374 return getattr(self._v, name) 375 376 def read_value(self): 377 if distribute_utils.caching_scope_local.in_caching_scope(): 378 return self.cached_read_value() 379 return self._v.read_value() 380 381 def sparse_read(self, indices, name=None): 382 return self._v.sparse_read(indices, name=name) 383 384 def cached_read_value(self): 385 if (distribute_utils.caching_scope_local.new_cache_scope_count > 386 self._current_new_cache_scope_count): 387 self._current_new_cache_scope_count += 1 388 self._cache = None 389 390 with ops.device("CPU:0"): 391 if self._cache is not None: 392 return self._cache 393 else: 394 self._cache = array_ops.identity(self._v) 395 return self._cache 396 397 def assign_sub(self, *args, **kwargs): 398 return self._v.assign_sub(*args, **kwargs) 399 400 def assign_add(self, *args, **kwargs): 401 return self._v.assign_add(*args, **kwargs) 402 403 def assign(self, *args, **kwargs): 404 return self._v.assign(*args, **kwargs) 405 406 @property 407 def initializer(self): 408 return self._v.initializer 409 410 def initialized_value(self): 411 return self._v.initialized_value() 412 413 @property 414 def initial_value(self): 415 return self._v.initial_value 416 417 @property 418 def op(self): 419 return self._v.op 420 421 def value(self): 422 if distribute_utils.caching_scope_local.in_caching_scope(): 423 return self.cached_read_value() 424 return self._v.value() 425 426 def eval(self, session=None): 427 return self._v.eval(session) 428 429 @property 430 def graph(self): 431 return self._v.graph 432 433 @property 434 def device(self): 435 return self._v.device 436 437 @property 438 def shape(self): 439 return self._v.shape 440 441 @property 442 def synchronization(self): 443 return self._v.synchronization 444 445 @property 446 def name(self): 447 return self._v.name 448 449 @property 450 def trainable(self): 451 return self._v.trainable 452 453 @property 454 def dtype(self): 455 return self._v.dtype 456 457 @property 458 def constraint(self): 459 return self._v.constraint 460 461 def __array__(self, dtype=None): 462 return np.asarray(self.numpy(), dtype=dtype) 463 464 def __complex__(self): 465 return complex(self.value().numpy()) 466 467 def __int__(self): 468 return int(self.value().numpy()) 469 470 def __float__(self): 471 return float(self.value().numpy()) 472 473 def numpy(self): 474 if context.executing_eagerly(): 475 return self.read_value().numpy() 476 else: 477 raise NotImplementedError( 478 "numpy() is only available when eager execution is enabled.") 479 480 def __str__(self): 481 return str(self._v) 482 483 def __repr__(self): 484 return repr(self._v) 485 486 def _should_act_as_resource_variable(self): 487 """Pass resource_variable_ops.is_resource_variable check.""" 488 pass 489 490 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 491 if distribute_utils.caching_scope_local.in_caching_scope(): 492 return self.cached_read_value() 493 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False) # pylint: disable=protected-access 494 495 @classmethod 496 def _overload_overloadable_operators(cls): 497 """Register overloads for all operators.""" 498 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 499 # Overloading __eq__ or __ne__ does not work as expected. 500 if operator == "__eq__" or operator == "__ne__": 501 continue 502 cls._tensor_overload_operator(operator) 503 504 @classmethod 505 def _tensor_overload_operator(cls, operator): 506 """Delegate an operator overload to `ops.Tensor`.""" 507 tensor_operator = getattr(ops.Tensor, operator) 508 509 def _operator(v, *args, **kwargs): 510 return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access 511 setattr(cls, operator, _operator) 512 513 def _gather_saveables_for_checkpoint(self): 514 return {trackable.VARIABLE_VALUE_KEY: self._v} 515 516 def _map_resources(self, save_options): 517 """For implementing `Trackable`.""" 518 # By delegating this method to the wrapped variable, SavedModel with 519 # AggregatingVariable are identical to SavedModel with normal variables. 520 obj_map, resource_map = self._v._map_resources(save_options) # pylint:disable=protected-access 521 obj_map[self] = obj_map[self._v] 522 return obj_map, resource_map 523 524 525# Register a conversion function which reads the value of the variable, 526# allowing instances of the class to be used as tensors. 527def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 528 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 529 530 531ops.register_tensor_conversion_function(AggregatingVariable, 532 _tensor_conversion_aggregate) 533 534 535# Register a conversion function which reads the value of the variable, 536# allowing instances of the class to be used as tensors. 537def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False): 538 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 539 540 541ops.register_tensor_conversion_function(CachingVariable, 542 _tensor_conversion_caching) 543 544CachingVariable._overload_overloadable_operators() # pylint: disable=protected-access 545 546 547class DistributedTable(lookup_ops.StaticHashTable): 548 """A distributed StaticHashTable for ParameterServerStrategy. 549 550 An instance of DistributedTable has copies of a StaticHashTable and its 551 resource handle on the coordinator of each worker, created at the 552 DistributedTable instance initialization time with initializers on each 553 worker. Users can call methods on a DistributedTable as if it were a 554 StaticHashTable, which leads to execution with the resource local to the 555 consumer worker (or the coordinator, if calling from the coordinator). This 556 implementation relies on the fact that the methods of StaticHashTable are 557 queried with the resource handle (instead of the python object). 558 559 Currently, at saving time, a DistributedTable is saved as a StaticHashTable on 560 the coordinator, and restoring a DistributedTable from SavedModel is not 561 supported. 562 """ 563 564 def __init__(self, strategy, wrapped_creator): 565 distribute_lib.distribution_strategy_input_api_counter.get_cell( 566 self.__class__.__name__, "PSSDistributedLookupTable").increase_by(1) 567 self._coordinator_instance = wrapped_creator() 568 self._wrapped_creator = wrapped_creator 569 self._coordinator = strategy._cluster_coordinator 570 # self._distributed_table is a RemoteValue mapping worker_index to 571 # RemoteValue that wraps a resource handle on the worker 572 self._distributed_table = None 573 self._distributed_table_creation_lock = threading.Lock() 574 575 if not save_context.in_save_context(): 576 self._maybe_build_distributed_table() 577 578 def __getattr__(self, attr): 579 # This allows copy.copy(DistributedTable), e.g. at saving time. 580 # (DistributedVariable uses the same fix.) When copying an object, copy.copy 581 # doesn't invoke its __init__ method, instead it makes a new empty object, 582 # then copies the attributes over. copy.copy looks for attributes like 583 # "__setstate__" in case the object implements its custom unpickling. Since 584 # DistributedTable doesn't have those attributes defined, __getattr__ will 585 # be invoked, which tries to access the `_coordinator_instance` attribute. 586 # But that doesn't exist either because this is an empty object, and again 587 # __getattr__ is invoked, leading to an infinite recursion. 588 if attr == "_coordinator_instance": 589 raise AttributeError() 590 591 if attr in self._coordinator_instance.__dict__: 592 attr_value = self._coordinator_instance.__dict__[attr] 593 if callable(attr_value): 594 595 def wrapper(*args, **kwargs): 596 return attr_value(self, *args, **kwargs) 597 598 return wrapper 599 elif isinstance(attr_value, property): 600 return attr_value 601 else: 602 return getattr(self._coordinator_instance, attr) 603 else: 604 return getattr(self._coordinator_instance, attr) 605 606 def resource_handle_call_time_value(self): 607 """Returns a closure to run for a resource handle at call time and its spec. 608 609 This function is called in self.resource_handle to create a placeholder 610 which returns a resource handle on some worker or on the coordinator. 611 """ 612 613 def closure(): 614 # function to be evaluated at function call time, returning a nest of 615 # tensors compatible with `spec`. 616 dispatch_context = coordinator_context.get_current_dispatch_context() 617 if dispatch_context: 618 remote_value = self._distributed_table._values[ # pylint: disable=protected-access 619 dispatch_context.worker_index] 620 ret = dispatch_context.maybe_get_remote_value(remote_value) 621 return ret 622 623 else: 624 return self._coordinator_instance.resource_handle 625 626 return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource) 627 628 def _maybe_build_distributed_table(self): 629 """Create table objects and resources on each worker if hasn't been created.""" 630 with self._distributed_table_creation_lock: 631 if not self._distributed_table: 632 633 def create_copy(): 634 new_table = self._wrapped_creator() 635 ret = new_table.resource_handle 636 return ret 637 638 self._distributed_table = ( 639 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access 640 641 @property 642 def resource_handle(self): 643 if context.executing_eagerly() or save_context.in_save_context(): 644 return self._coordinator_instance.resource_handle 645 else: 646 self._maybe_build_distributed_table() 647 closure, spec = self.resource_handle_call_time_value() 648 return ops.get_default_graph().capture_call_time_value( 649 closure, 650 spec, 651 default_value=self._coordinator_instance.resource_handle) 652 653 @property 654 def is_distributed_table(self): 655 return True 656 657 def __tf_experimental_restore_capture__( 658 self, concrete_function, internal_capture): 659 closure, spec = self.resource_handle_call_time_value() 660 concrete_function.graph.replace_capture_with_deferred_capture( 661 self._coordinator_instance.resource_handle, 662 closure, 663 spec, 664 default_value=self._coordinator_instance.resource_handle, 665 placeholder=internal_capture) 666 return concrete_function.graph.deferred_external_captures[-1] 667 668 669_local_resource_restore_context = threading.local() 670 671 672def get_current_local_resource_restore_context(): 673 try: 674 return _local_resource_restore_context.current 675 except AttributeError: 676 return None 677 678 679@contextlib.contextmanager 680def with_local_resource_restore_context(instance): 681 previous_context = getattr(_local_resource_restore_context, "current", None) 682 _local_resource_restore_context.current = LocalResourceRestoreContext( 683 instance) 684 yield 685 _local_resource_restore_context.current = previous_context 686 687 688class LocalResourceRestoreContext(object): 689 """Class holding information of a distributed instance, e.g. StaticHashTable. 690 691 Pairing use with context manager `with_local_resource_restore_context` allows 692 operations under this context manager to conveniently gets information of a 693 component of the `RestoredDistributedTable` (and other restored distributed 694 `CapturableResource` if we're supporting their distribution in the future), 695 instead of looking it up from the mapping of the worker-to-resource handle. 696 This is especially useful when we know which instance the operations should 697 execute with and the mapping is not available yet. 698 """ 699 700 def __init__(self, instance): 701 self.instance = instance 702 703 704class RestoredDistributedTable(DistributedTable): 705 """A restored and distributed StaticHashTable for ParameterServerStrategy.""" 706 707 def __init__(self, strategy, wrapped_creator): 708 # Wait for all resource functions to have been set before building the table 709 self._has_resource_functions = threading.Condition() 710 super().__init__(strategy, wrapped_creator) 711 712 def resource_handle_call_time_value(self): 713 """Returns a closure to run for a resource handle at call time and its spec. 714 715 This function is called in self.resource_handle to create a placeholder 716 which returns a resource handle on some worker or on the coordinator. 717 """ 718 719 def closure(): 720 # function to be evaluated at function call time, returning a nest of 721 # tensors compatible with `spec`. 722 dispatch_context = coordinator_context.get_current_dispatch_context() 723 if dispatch_context: 724 local_resource_restore_context = ( 725 get_current_local_resource_restore_context()) 726 727 # A LocalResourceRestoreContext is entered in the process of remote 728 # table creation and initialization if we're in the process of loading 729 # from a SavedModel. A LocalResourceRestoreContext carries the 730 # information regarding which table is being created and initialized. In 731 # order to initialize a table, we need the restored `_initialize` 732 # function, which captures this closure as table resource. And when this 733 # closure is executed, we will read the table info from the 734 # LocalResourceRestoreContext and return its handle, rather than 735 # following the normal procedure of fetching from 736 # `self._distributed_table`, because we're still in the middle of 737 # building `self._distributed_table`. 738 if local_resource_restore_context: 739 remote_value = local_resource_restore_context.instance.resource_handle 740 741 else: 742 remote_value = self._distributed_table._values[ # pylint: disable=protected-access 743 dispatch_context.worker_index] 744 745 ret = dispatch_context.maybe_get_remote_value(remote_value) 746 return ret 747 748 else: 749 750 return self._coordinator_instance.resource_handle 751 752 return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource) 753 754 def __setattr__(self, name, value): 755 if name in TRACKABLE_RESOURCE_METHODS: 756 # When a StaticHashTable is loaded with `tf.saved_model.load`, it becomes 757 # a RestoredResource with dummy `_create_resource`, `_initialize`, and 758 # `_destroy_resource" methods. Similarly, when loaded with 759 # `tf.keras.models.load_model`, its initializer becomes a dummy one. In 760 # both cases, these methods needs to be set to some RestoredFunctions 761 # through `__setattr__`. Thus we need to store and set these methods for 762 # the distributed tables (a.k.a. `self._distributed_table`) on the 763 # workers too, besides setting for the coordinator instance. However, we 764 # cannot set them at this point, since the distributed tables have not 765 # been created. We store them in '_restored_function' and set them to the 766 # distributed tables when they're created in 767 # `self._maybe_build_distributed_table.create_copy`. 768 if not hasattr(self, "_restored_function"): 769 self._restored_function = {} 770 self._restored_function[name] = value 771 if all(method in self._restored_function 772 for method in TRACKABLE_RESOURCE_METHODS): 773 with self._has_resource_functions: 774 self._has_resource_functions.notify_all() 775 return self._coordinator_instance.__setattr__(name, value) 776 else: 777 return super(RestoredDistributedTable, self).__setattr__(name, value) 778 779 def _create_resource(self): 780 """A function that creates a resource handle for a table on coordinator.""" 781 return self._coordinator_instance._create_resource() # pylint: disable=protected-access 782 783 def _initialize(self): 784 """A function that initializes the resource.""" 785 return self._coordinator_instance._initialize() # pylint: disable=protected-access 786 787 def _destroy_resource(self): 788 """A function that destroys the resource.""" 789 return self._coordinator_instance._destroy_resource() # pylint: disable=protected-access 790 791 def _maybe_build_distributed_table(self): 792 """Create table objects and resources on each worker if hasn't been created.""" 793 with self._distributed_table_creation_lock: 794 if not self._distributed_table: 795 796 def create_copy(): 797 new_table = self._wrapped_creator() 798 # Wait until all resource functions are available before setting them 799 # on new_table. 800 with self._has_resource_functions: 801 while not hasattr(self, "_restored_function") or any( 802 method not in self._restored_function 803 for method in TRACKABLE_RESOURCE_METHODS): 804 self._has_resource_functions.wait() 805 806 if hasattr(self, "_restored_function"): 807 with with_local_resource_restore_context(new_table): 808 for name, tf_function in self._restored_function.items(): 809 setattr(new_table, name, tf_function) 810 init_op = new_table._initialize() # pylint: disable=protected-access 811 if not context.executing_eagerly(): 812 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 813 814 ret = new_table.resource_handle 815 return ret 816 817 self._distributed_table = ( 818 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access 819