xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/ps_values.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values for PS."""
16
17import contextlib
18import copy
19import threading
20import weakref
21
22import numpy as np
23
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import distribution_strategy_context as ds_context
27from tensorflow.python.distribute import values
28from tensorflow.python.distribute import values_util
29from tensorflow.python.distribute.coordinator import coordinator_context
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import lookup_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variable_scope as vs
38from tensorflow.python.saved_model import save_context
39from tensorflow.python.trackable import base as trackable
40from tensorflow.python.types import core
41from tensorflow.python.util.lazy_loader import LazyLoader
42
43load_context = LazyLoader(
44    "load_context", globals(),
45    "tensorflow.python.keras.saving.saved_model.load_context"
46)
47
48TRACKABLE_RESOURCE_METHODS = [
49    "_create_resource", "_initialize", "_destroy_resource"
50]
51
52
53# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
54class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
55                          core.Tensor):
56  """A wrapper around a variable that aggregates updates across replicas."""
57
58  def __init__(self, strategy, v, aggregation):
59    self._distribute_strategy = strategy
60    self._v = v
61    # NOTE: We don't use "_distributed_container" here because we don't want
62    # to trigger that code path in regroup().
63    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
64    self._aggregation = aggregation
65
66  def __deepcopy__(self, memo):
67    """Perform a deepcopy of the `AggregatingVariable`.
68
69    Unlike the deepcopy of a regular tf.Variable, this keeps the original
70    strategy and devices of the `AggregatingVariable`.  To avoid confusion
71    with the behavior of deepcopy on a regular `Variable` (which does
72    copy into new devices), we only allow a deepcopy of a `AggregatingVariable`
73    within its originating strategy scope.
74
75    Args:
76      memo: The memoization object for `deepcopy`.
77
78    Returns:
79      A deep copy of the current `AggregatingVariable`.
80
81    Raises:
82      RuntimeError: If trying to deepcopy into a different strategy.
83    """
84    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
85      v = copy.deepcopy(self._v, memo)
86
87    copied_variable = type(self)(
88        strategy=self._distribute_strategy,
89        v=v,
90        aggregation=self._aggregation)
91
92    memo[id(self)] = copied_variable
93
94    return copied_variable
95
96  def get(self):
97    return self._v
98
99  @property
100  def distribute_strategy(self):
101    return self._distribute_strategy
102
103  def __getattr__(self, name):
104    return getattr(self._v, name)
105
106  def _assign_func(self, *args, **kwargs):
107    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
108      f = kwargs.pop("f")
109      if ds_context.in_cross_replica_context():
110        if distribute_lib.get_update_replica_id() is not None:
111          # We are calling an assign function in an update context.
112          return f(self._v, *args, **kwargs)
113
114        # We are calling an assign function in cross replica context, wrap it in
115        # an update call.
116        return self._distribute_strategy.extended.update(
117            self, f, args=args, kwargs=kwargs)
118      else:
119        replica_context = ds_context.get_replica_context()
120        assert replica_context
121        # We are calling an assign function in replica context.
122        # We reduce the value we want to assign/add/sub. More details about how
123        # we handle the different use cases can be found in the _reduce method.
124        # We call the function with the reduced value.
125        if self._aggregation == vs.VariableAggregation.NONE:
126          raise ValueError(
127              values_util.aggregation_error_msg.format(
128                  variable_type="AggregatingVariable"))
129
130        def merge_fn(strategy,
131                     value,
132                     use_locking=False,
133                     name=None,
134                     read_value=True):
135          v = values_util.apply_aggregation(strategy, value, self._aggregation,
136                                            self)
137          if name and isinstance(name, values.PerReplica):
138            name = name.values[0]
139          return strategy.extended.update(
140              self,
141              f,
142              args=(v,),
143              kwargs={
144                  "use_locking": use_locking,
145                  "name": name,
146                  "read_value": read_value
147              })
148        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
149
150  def assign_sub(self, *args, **kwargs):
151    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
152    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
153
154  def assign_add(self, *args, **kwargs):
155    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
156    return self._assign_func(f=assign_add_fn, *args, **kwargs)
157
158  def assign(self, *args, **kwargs):
159    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
160    return self._assign_func(f=assign_fn, *args, **kwargs)
161
162  @property
163  def initializer(self):
164    return self._v.initializer
165
166  def initialized_value(self):
167    return self._v.initialized_value()
168
169  @property
170  def initial_value(self):
171    return self._v.initial_value
172
173  @property
174  def op(self):
175    return self._v.op
176
177  def value(self):
178    return self._v.value()
179
180  def read_value(self):
181    return self._v.read_value()
182
183  def sparse_read(self, indices, name=None):
184    return self._v.sparse_read(indices, name=name)
185
186  def eval(self, session=None):
187    return self._v.eval(session)
188
189  @property
190  def graph(self):
191    return self._v.graph
192
193  @property
194  def device(self):
195    return self._v.device
196
197  @property
198  def shape(self):
199    return self._v.shape
200
201  @property
202  def aggregation(self):
203    return self._aggregation
204
205  @property
206  def synchronization(self):
207    return self._v.synchronization
208
209  @property
210  def name(self):
211    return self._v.name
212
213  @property
214  def trainable(self):
215    return self._v.trainable
216
217  @property
218  def dtype(self):
219    return self._v.dtype
220
221  # TODO(josh11b): Test saving & restoring.
222  def _gather_saveables_for_checkpoint(self):
223    if isinstance(self._v, CachingVariable):
224      return self._v._gather_saveables_for_checkpoint()  # pylint:disable=protected-access
225    return {trackable.VARIABLE_VALUE_KEY: self._v}
226
227  def _map_resources(self, save_options):
228    """For implementing `Trackable`."""
229    # By delegating this method to the wrapped variable, SavedModel with
230    # AggregatingVariable are identical to SavedModel with normal variables.
231    obj_map, resource_map = self._v._map_resources(save_options)  # pylint:disable=protected-access
232    obj_map[self] = obj_map[self._v]
233    return obj_map, resource_map
234
235  # pylint: disable=multiple-statements
236  def __add__(self, o):
237    return self._v + o
238
239  def __radd__(self, o):
240    return o + self._v
241
242  def __sub__(self, o):
243    return self._v - o
244
245  def __rsub__(self, o):
246    return o - self._v
247
248  def __mul__(self, o):
249    return self._v * o
250
251  def __rmul__(self, o):
252    return o * self._v
253
254  def __truediv__(self, o):
255    return self._v / o
256
257  def __rtruediv__(self, o):
258    return o / self._v
259
260  def __floordiv__(self, o):
261    return self._v // o
262
263  def __rfloordiv__(self, o):
264    return o // self._v
265
266  def __mod__(self, o):
267    return self._v % o
268
269  def __rmod__(self, o):
270    return o % self._v
271
272  def __lt__(self, o):
273    return self._v < o
274
275  def __le__(self, o):
276    return self._v <= o
277
278  def __gt__(self, o):
279    return self._v > o
280
281  def __ge__(self, o):
282    return self._v >= o
283
284  def __and__(self, o):
285    return self._v & o
286
287  def __rand__(self, o):
288    return o & self._v
289
290  def __or__(self, o):
291    return self._v | o
292
293  def __ror__(self, o):
294    return o | self._v
295
296  def __xor__(self, o):
297    return self._v ^ o
298
299  def __rxor__(self, o):
300    return o ^ self._v
301
302  def __getitem__(self, o):
303    return self._v[o]
304
305  def __pow__(self, o, modulo=None):
306    return pow(self._v, o, modulo)
307
308  def __rpow__(self, o):
309    return pow(o, self._v)
310
311  def __invert__(self):
312    return ~self._v
313
314  def __neg__(self):
315    return -self._v
316
317  def __abs__(self):
318    return abs(self._v)
319
320  def __div__(self, o):
321    try:
322      return self._v.__div__(o)
323    except AttributeError:
324      # See https://docs.python.org/3/library/constants.html#NotImplemented
325      return NotImplemented
326
327  def __rdiv__(self, o):
328    try:
329      return self._v.__rdiv__(o)
330    except AttributeError:
331      # See https://docs.python.org/3/library/constants.html#NotImplemented
332      return NotImplemented
333
334  def __matmul__(self, o):
335    try:
336      return self._v.__matmul__(o)
337    except AttributeError:
338      # See https://docs.python.org/3/library/constants.html#NotImplemented
339      return NotImplemented
340
341  def __rmatmul__(self, o):
342    try:
343      return self._v.__rmatmul__(o)
344    except AttributeError:
345      # See https://docs.python.org/3/library/constants.html#NotImplemented
346      return NotImplemented
347
348  def __str__(self):
349    return str(self._v)
350
351  def __repr__(self):
352    return repr(self._v)
353
354  def _should_act_as_resource_variable(self):
355    """Pass resource_variable_ops.is_resource_variable check."""
356    pass
357
358  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
359    return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
360
361
362class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor):
363  """A wrapper around a variable that caches read value locally."""
364
365  def __init__(self, v):
366    self._v = v
367    self._cache = None
368    self._current_new_cache_scope_count = 0
369
370  def get(self):
371    return self._v
372
373  def __getattr__(self, name):
374    return getattr(self._v, name)
375
376  def read_value(self):
377    if distribute_utils.caching_scope_local.in_caching_scope():
378      return self.cached_read_value()
379    return self._v.read_value()
380
381  def sparse_read(self, indices, name=None):
382    return self._v.sparse_read(indices, name=name)
383
384  def cached_read_value(self):
385    if (distribute_utils.caching_scope_local.new_cache_scope_count >
386        self._current_new_cache_scope_count):
387      self._current_new_cache_scope_count += 1
388      self._cache = None
389
390    with ops.device("CPU:0"):
391      if self._cache is not None:
392        return self._cache
393      else:
394        self._cache = array_ops.identity(self._v)
395        return self._cache
396
397  def assign_sub(self, *args, **kwargs):
398    return self._v.assign_sub(*args, **kwargs)
399
400  def assign_add(self, *args, **kwargs):
401    return self._v.assign_add(*args, **kwargs)
402
403  def assign(self, *args, **kwargs):
404    return self._v.assign(*args, **kwargs)
405
406  @property
407  def initializer(self):
408    return self._v.initializer
409
410  def initialized_value(self):
411    return self._v.initialized_value()
412
413  @property
414  def initial_value(self):
415    return self._v.initial_value
416
417  @property
418  def op(self):
419    return self._v.op
420
421  def value(self):
422    if distribute_utils.caching_scope_local.in_caching_scope():
423      return self.cached_read_value()
424    return self._v.value()
425
426  def eval(self, session=None):
427    return self._v.eval(session)
428
429  @property
430  def graph(self):
431    return self._v.graph
432
433  @property
434  def device(self):
435    return self._v.device
436
437  @property
438  def shape(self):
439    return self._v.shape
440
441  @property
442  def synchronization(self):
443    return self._v.synchronization
444
445  @property
446  def name(self):
447    return self._v.name
448
449  @property
450  def trainable(self):
451    return self._v.trainable
452
453  @property
454  def dtype(self):
455    return self._v.dtype
456
457  @property
458  def constraint(self):
459    return self._v.constraint
460
461  def __array__(self, dtype=None):
462    return np.asarray(self.numpy(), dtype=dtype)
463
464  def __complex__(self):
465    return complex(self.value().numpy())
466
467  def __int__(self):
468    return int(self.value().numpy())
469
470  def __float__(self):
471    return float(self.value().numpy())
472
473  def numpy(self):
474    if context.executing_eagerly():
475      return self.read_value().numpy()
476    else:
477      raise NotImplementedError(
478          "numpy() is only available when eager execution is enabled.")
479
480  def __str__(self):
481    return str(self._v)
482
483  def __repr__(self):
484    return repr(self._v)
485
486  def _should_act_as_resource_variable(self):
487    """Pass resource_variable_ops.is_resource_variable check."""
488    pass
489
490  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
491    if distribute_utils.caching_scope_local.in_caching_scope():
492      return self.cached_read_value()
493    return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False)  # pylint: disable=protected-access
494
495  @classmethod
496  def _overload_overloadable_operators(cls):
497    """Register overloads for all operators."""
498    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
499      # Overloading __eq__ or __ne__ does not work as expected.
500      if operator == "__eq__" or operator == "__ne__":
501        continue
502      cls._tensor_overload_operator(operator)
503
504  @classmethod
505  def _tensor_overload_operator(cls, operator):
506    """Delegate an operator overload to `ops.Tensor`."""
507    tensor_operator = getattr(ops.Tensor, operator)
508
509    def _operator(v, *args, **kwargs):
510      return tensor_operator(v.value(), *args, **kwargs)  # pylint: disable=protected-access
511    setattr(cls, operator, _operator)
512
513  def _gather_saveables_for_checkpoint(self):
514    return {trackable.VARIABLE_VALUE_KEY: self._v}
515
516  def _map_resources(self, save_options):
517    """For implementing `Trackable`."""
518    # By delegating this method to the wrapped variable, SavedModel with
519    # AggregatingVariable are identical to SavedModel with normal variables.
520    obj_map, resource_map = self._v._map_resources(save_options)  # pylint:disable=protected-access
521    obj_map[self] = obj_map[self._v]
522    return obj_map, resource_map
523
524
525# Register a conversion function which reads the value of the variable,
526# allowing instances of the class to be used as tensors.
527def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
528  return var._dense_var_to_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
529
530
531ops.register_tensor_conversion_function(AggregatingVariable,
532                                        _tensor_conversion_aggregate)
533
534
535# Register a conversion function which reads the value of the variable,
536# allowing instances of the class to be used as tensors.
537def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False):
538  return var._dense_var_to_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
539
540
541ops.register_tensor_conversion_function(CachingVariable,
542                                        _tensor_conversion_caching)
543
544CachingVariable._overload_overloadable_operators()  # pylint: disable=protected-access
545
546
547class DistributedTable(lookup_ops.StaticHashTable):
548  """A distributed StaticHashTable for ParameterServerStrategy.
549
550  An instance of DistributedTable has copies of a StaticHashTable and its
551  resource handle on the coordinator of each worker, created at the
552  DistributedTable instance initialization time with initializers on each
553  worker. Users can call methods on a DistributedTable as if it were a
554  StaticHashTable, which leads to execution with the resource local to the
555  consumer worker (or the coordinator, if calling from the coordinator). This
556  implementation relies on the fact that the methods of StaticHashTable are
557  queried with the resource handle (instead of the python object).
558
559  Currently, at saving time, a DistributedTable is saved as a StaticHashTable on
560  the coordinator, and restoring a DistributedTable from SavedModel is not
561  supported.
562  """
563
564  def __init__(self, strategy, wrapped_creator):
565    distribute_lib.distribution_strategy_input_api_counter.get_cell(
566        self.__class__.__name__, "PSSDistributedLookupTable").increase_by(1)
567    self._coordinator_instance = wrapped_creator()
568    self._wrapped_creator = wrapped_creator
569    self._coordinator = strategy._cluster_coordinator
570    # self._distributed_table is a RemoteValue mapping worker_index to
571    # RemoteValue that wraps a resource handle on the worker
572    self._distributed_table = None
573    self._distributed_table_creation_lock = threading.Lock()
574
575    if not save_context.in_save_context():
576      self._maybe_build_distributed_table()
577
578  def __getattr__(self, attr):
579    # This allows copy.copy(DistributedTable), e.g. at saving time.
580    # (DistributedVariable uses the same fix.) When copying an object, copy.copy
581    # doesn't invoke its __init__ method, instead it makes a new empty object,
582    # then copies the attributes over. copy.copy looks for attributes like
583    # "__setstate__" in case the object implements its custom unpickling. Since
584    # DistributedTable doesn't have those attributes defined, __getattr__ will
585    # be invoked, which tries to access the `_coordinator_instance` attribute.
586    # But that doesn't exist either because this is an empty object, and again
587    # __getattr__ is invoked, leading to an infinite recursion.
588    if attr == "_coordinator_instance":
589      raise AttributeError()
590
591    if attr in self._coordinator_instance.__dict__:
592      attr_value = self._coordinator_instance.__dict__[attr]
593      if callable(attr_value):
594
595        def wrapper(*args, **kwargs):
596          return attr_value(self, *args, **kwargs)
597
598        return wrapper
599      elif isinstance(attr_value, property):
600        return attr_value
601      else:
602        return getattr(self._coordinator_instance, attr)
603    else:
604      return getattr(self._coordinator_instance, attr)
605
606  def resource_handle_call_time_value(self):
607    """Returns a closure to run for a resource handle at call time and its spec.
608
609    This function is called in self.resource_handle to create a placeholder
610    which returns a resource handle on some worker or on the coordinator.
611    """
612
613    def closure():
614      # function to be evaluated at function call time, returning a nest of
615      # tensors compatible with `spec`.
616      dispatch_context = coordinator_context.get_current_dispatch_context()
617      if dispatch_context:
618        remote_value = self._distributed_table._values[  # pylint: disable=protected-access
619            dispatch_context.worker_index]
620        ret = dispatch_context.maybe_get_remote_value(remote_value)
621        return ret
622
623      else:
624        return self._coordinator_instance.resource_handle
625
626    return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource)
627
628  def _maybe_build_distributed_table(self):
629    """Create table objects and resources on each worker if hasn't been created."""
630    with self._distributed_table_creation_lock:
631      if not self._distributed_table:
632
633        def create_copy():
634          new_table = self._wrapped_creator()
635          ret = new_table.resource_handle
636          return ret
637
638        self._distributed_table = (
639            self._coordinator._create_per_worker_resources(create_copy))  # pylint: disable=protected-access
640
641  @property
642  def resource_handle(self):
643    if context.executing_eagerly() or save_context.in_save_context():
644      return self._coordinator_instance.resource_handle
645    else:
646      self._maybe_build_distributed_table()
647      closure, spec = self.resource_handle_call_time_value()
648      return ops.get_default_graph().capture_call_time_value(
649          closure,
650          spec,
651          default_value=self._coordinator_instance.resource_handle)
652
653  @property
654  def is_distributed_table(self):
655    return True
656
657  def __tf_experimental_restore_capture__(
658      self, concrete_function, internal_capture):
659    closure, spec = self.resource_handle_call_time_value()
660    concrete_function.graph.replace_capture_with_deferred_capture(
661        self._coordinator_instance.resource_handle,
662        closure,
663        spec,
664        default_value=self._coordinator_instance.resource_handle,
665        placeholder=internal_capture)
666    return concrete_function.graph.deferred_external_captures[-1]
667
668
669_local_resource_restore_context = threading.local()
670
671
672def get_current_local_resource_restore_context():
673  try:
674    return _local_resource_restore_context.current
675  except AttributeError:
676    return None
677
678
679@contextlib.contextmanager
680def with_local_resource_restore_context(instance):
681  previous_context = getattr(_local_resource_restore_context, "current", None)
682  _local_resource_restore_context.current = LocalResourceRestoreContext(
683      instance)
684  yield
685  _local_resource_restore_context.current = previous_context
686
687
688class LocalResourceRestoreContext(object):
689  """Class holding information of a distributed instance, e.g. StaticHashTable.
690
691  Pairing use with context manager `with_local_resource_restore_context` allows
692  operations under this context manager to conveniently gets information of a
693  component of the `RestoredDistributedTable` (and other restored distributed
694  `CapturableResource` if we're supporting their distribution in the future),
695  instead of looking it up from the mapping of the worker-to-resource handle.
696  This is especially useful when we know which instance the operations should
697  execute with and the mapping is not available yet.
698  """
699
700  def __init__(self, instance):
701    self.instance = instance
702
703
704class RestoredDistributedTable(DistributedTable):
705  """A restored and distributed StaticHashTable for ParameterServerStrategy."""
706
707  def __init__(self, strategy, wrapped_creator):
708    # Wait for all resource functions to have been set before building the table
709    self._has_resource_functions = threading.Condition()
710    super().__init__(strategy, wrapped_creator)
711
712  def resource_handle_call_time_value(self):
713    """Returns a closure to run for a resource handle at call time and its spec.
714
715    This function is called in self.resource_handle to create a placeholder
716    which returns a resource handle on some worker or on the coordinator.
717    """
718
719    def closure():
720      # function to be evaluated at function call time, returning a nest of
721      # tensors compatible with `spec`.
722      dispatch_context = coordinator_context.get_current_dispatch_context()
723      if dispatch_context:
724        local_resource_restore_context = (
725            get_current_local_resource_restore_context())
726
727        # A LocalResourceRestoreContext is entered in the process of remote
728        # table creation and initialization if we're in the process of loading
729        # from a SavedModel. A LocalResourceRestoreContext carries the
730        # information regarding which table is being created and initialized. In
731        # order to initialize a table, we need the restored `_initialize`
732        # function, which captures this closure as table resource. And when this
733        # closure is executed, we will read the table info from the
734        # LocalResourceRestoreContext and return its handle, rather than
735        # following the normal procedure of fetching from
736        # `self._distributed_table`, because we're still in the middle of
737        # building `self._distributed_table`.
738        if local_resource_restore_context:
739          remote_value = local_resource_restore_context.instance.resource_handle
740
741        else:
742          remote_value = self._distributed_table._values[  # pylint: disable=protected-access
743              dispatch_context.worker_index]
744
745        ret = dispatch_context.maybe_get_remote_value(remote_value)
746        return ret
747
748      else:
749
750        return self._coordinator_instance.resource_handle
751
752    return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource)
753
754  def __setattr__(self, name, value):
755    if name in TRACKABLE_RESOURCE_METHODS:
756      # When a StaticHashTable is loaded with `tf.saved_model.load`, it becomes
757      # a RestoredResource with dummy `_create_resource`, `_initialize`, and
758      # `_destroy_resource" methods. Similarly, when loaded with
759      # `tf.keras.models.load_model`, its initializer becomes a dummy one. In
760      # both cases, these methods needs to be set to some RestoredFunctions
761      # through `__setattr__`. Thus we need to store and set these methods for
762      # the distributed tables (a.k.a. `self._distributed_table`) on the
763      # workers too, besides setting for the coordinator instance. However, we
764      # cannot set them at this point, since the distributed tables have not
765      # been created. We store them in '_restored_function' and set them to the
766      # distributed tables when they're created in
767      # `self._maybe_build_distributed_table.create_copy`.
768      if not hasattr(self, "_restored_function"):
769        self._restored_function = {}
770      self._restored_function[name] = value
771      if all(method in self._restored_function
772             for method in TRACKABLE_RESOURCE_METHODS):
773        with self._has_resource_functions:
774          self._has_resource_functions.notify_all()
775      return self._coordinator_instance.__setattr__(name, value)
776    else:
777      return super(RestoredDistributedTable, self).__setattr__(name, value)
778
779  def _create_resource(self):
780    """A function that creates a resource handle for a table on coordinator."""
781    return self._coordinator_instance._create_resource()  # pylint: disable=protected-access
782
783  def _initialize(self):
784    """A function that initializes the resource."""
785    return self._coordinator_instance._initialize()  # pylint: disable=protected-access
786
787  def _destroy_resource(self):
788    """A function that destroys the resource."""
789    return self._coordinator_instance._destroy_resource()  # pylint: disable=protected-access
790
791  def _maybe_build_distributed_table(self):
792    """Create table objects and resources on each worker if hasn't been created."""
793    with self._distributed_table_creation_lock:
794      if not self._distributed_table:
795
796        def create_copy():
797          new_table = self._wrapped_creator()
798          # Wait until all resource functions are available before setting them
799          # on new_table.
800          with self._has_resource_functions:
801            while not hasattr(self, "_restored_function") or any(
802                method not in self._restored_function
803                for method in TRACKABLE_RESOURCE_METHODS):
804              self._has_resource_functions.wait()
805
806          if hasattr(self, "_restored_function"):
807            with with_local_resource_restore_context(new_table):
808              for name, tf_function in self._restored_function.items():
809                setattr(new_table, name, tf_function)
810              init_op = new_table._initialize()  # pylint: disable=protected-access
811              if not context.executing_eagerly():
812                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
813
814          ret = new_table.resource_handle
815          return ret
816
817        self._distributed_table = (
818            self._coordinator._create_per_worker_resources(create_copy))  # pylint: disable=protected-access
819