xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/data_structures.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
17import copy
18import operator
19import sys
20
21try:
22  import wrapt
23except ImportError:
24  # Fall back to the build-time dependency if the system package is not available.
25  from .....third_party import wrapt  # pylint: disable=relative-beyond-top-level
26
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as defun
29from tensorflow.python.ops import variables
30from tensorflow.python.saved_model import revived_types
31from tensorflow.python.trackable import base
32from tensorflow.python.trackable import layer_utils
33from tensorflow.python.util import lazy_loader
34from tensorflow.python.util.compat import collections_abc
35from tensorflow.python.util.tf_export import tf_export
36
37
38module = lazy_loader.LazyLoader(
39    "module", globals(), "tensorflow.python.module.module")
40
41
42class NoDependency:
43  """Allows attribute assignment to `Trackable` objects with no dependency.
44
45  Example usage:
46  ```python
47  obj = Trackable()
48  obj.has_dependency = tf.Variable(0., name="dep")
49  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
50  assert obj.no_dependency.name == "nodep:0"
51  ```
52
53  `obj` in this example has a dependency on the variable "dep", and both
54  attributes contain un-wrapped `Variable` objects.
55
56  `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
57  dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
58  `Layer` to the attribute without a checkpoint dependency, but the `Model` will
59  still track the `Layer` (so it will appear in `Model.layers`, and its
60  variables will appear in `Model.variables`).
61  """
62
63  __slots__ = ["value"]
64
65  def __init__(self, value):
66    self.value = value
67
68
69def _should_wrap_tuple(t):
70  """Determine if a tuple has any trackable components."""
71  # pylint: disable=unidiomatic-typecheck
72  # Exact type checking to avoid mucking up custom logic in list/dict
73  # subclasses, e.g. collections.Counter.
74  for element in t:
75    if isinstance(element, NoDependency):
76      return True  # We should remove the NoDependency object from the tuple.
77    if isinstance(element, base.Trackable):
78      return True
79    if type(element) == dict:
80      return True
81    if type(element) == collections.OrderedDict:
82      return True
83    if type(element) == list:
84      return True
85    if isinstance(element, tuple) and _should_wrap_tuple(element):
86      return True
87  # There are no trackable elements or data structures. Tuples are immutable, so
88  # mutation isn't a concern. Don't wrap.
89  return False
90  # pylint: enable=unidiomatic-typecheck
91
92
93@tf_export("__internal__.tracking.wrap", v1=[])
94def wrap_or_unwrap(value):
95  """Wraps input value into trackable data structures.
96
97  This is mostly useful for containers like list, dict, etc, which could contain
98  trackable objects in it. Wrapped data structure will be tracked when
99  associated with a `tf.Module`, so that save model/checkpoint can properly
100  track the dependency.
101
102  It will also unwrap NoDependency objects.
103
104  Args:
105    value: the input object to be wrapped.
106
107  Returns:
108    Wrapped trackable data structure.
109  """
110  # pylint: disable=unidiomatic-typecheck
111  # Exact type checking to avoid mucking up custom logic in list/dict
112  # subclasses, e.g. collections.Counter.
113  if isinstance(value, NoDependency):
114    return value.value
115  if isinstance(value, base.Trackable):
116    return value  # Skip conversion for already trackable objects.
117  elif type(value) == dict:
118    return _DictWrapper(value)
119  elif type(value) == collections.OrderedDict:
120    return _DictWrapper(value)
121  elif type(value) == list:
122    return ListWrapper(value)
123  elif isinstance(value, tuple) and _should_wrap_tuple(value):
124    # There are trackable elements or data structures. Wrap the tuple.
125    return _TupleWrapper(value)
126  else:
127    return value
128  # pylint: enable=unidiomatic-typecheck
129
130
131@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[])
132def sticky_attribute_assignment(trackable, name, value):
133  """Adds dependencies, generally called from __setattr__.
134
135  This behavior is shared between Trackable and Model.
136
137  Respects NoDependency indicators, but otherwise makes trackable objects
138  out of common data structures and tracks objects by their attribute names.
139
140  Args:
141    trackable: The object to add dependencies to (generally the one having
142      an attribute assigned).
143    name: The attribute name being assigned.
144    value: The value being assigned. Not necessarily a trackable object.
145
146  Returns:
147    The value which should be stored in the attribute (unwrapped from a
148    NoDependency object if necessary).
149  """
150  if isinstance(value, NoDependency):
151    add_dependency = False
152  else:
153    add_dependency = True
154  value = wrap_or_unwrap(value)
155  if not add_dependency:
156    return value
157  if isinstance(value, base.Trackable):
158    trackable._track_trackable(  # pylint: disable=protected-access
159        value, name=name,
160        # Allow the user to switch the Trackable which is tracked by this
161        # name, since assigning a new variable to an attribute has
162        # historically been fine (e.g. Adam did this).
163        overwrite=True)
164  return value
165
166
167class _UntrackableError(ValueError):
168
169  def __init__(self, value):  # pylint: disable=super-init-not-called
170    self._value = value
171
172  def __str__(self):
173    return ("Only trackable objects (such as Layers or Optimizers) may be "
174            f"stored in a List object. Got {self._value}, which does not "
175            "inherit from Trackable.")
176
177
178@tf_export("__internal__.tracking.TrackableDataStructure", v1=[])
179class TrackableDataStructure(base.Trackable):
180  """Base class for data structures which contain trackable objects."""
181
182  def __init__(self):
183    # Attributes prefixed with "_self_" for compatibility with
184    # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as
185    # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of
186    # ways.
187    self._self_trainable = True
188    self._self_extra_variables = []
189    self._self_attribute_sentinel = layer_utils.AttributeSentinel(True)
190
191  @property
192  def _attribute_sentinel(self):
193    return self._self_attribute_sentinel
194
195  @property
196  def trainable(self):
197    return self._self_trainable
198
199  @trainable.setter
200  def trainable(self, value):
201    self._self_trainable = value
202
203  def _track_value(self, value, name):
204    """Add a dependency on `value`."""
205    value = sticky_attribute_assignment(
206        trackable=self, value=value, name=name)
207    if isinstance(value, variables.Variable):
208      self._self_extra_variables.append(value)
209    if not isinstance(value, base.Trackable):
210      raise _UntrackableError(value)
211    if hasattr(value, "_use_resource_variables"):
212      # In subclassed models, legacy layers (tf.layers) must always use
213      # resource variables.
214      value._use_resource_variables = True  # pylint: disable=protected-access
215    value_attribute_sentinel = getattr(value, "_attribute_sentinel", None)
216    if value_attribute_sentinel:
217      value_attribute_sentinel.add_parent(self._attribute_sentinel)
218    return value
219
220  @property
221  def _values(self):
222    """An iterable/sequence which may contain trackable objects."""
223    raise NotImplementedError("Abstract method")
224
225  @property
226  def _layers(self):
227    """All Layers and Layer containers, including empty containers."""
228    # Filter objects on demand so that wrapper objects use values from the thing
229    # they're wrapping if out of sync.
230    collected = []
231    for obj in self._values:
232      if (isinstance(obj, TrackableDataStructure)
233          or layer_utils.is_layer(obj)
234          or layer_utils.has_weights(obj)):
235        collected.append(obj)
236    return collected
237
238  @property
239  def layers(self):
240    return list(layer_utils.filter_empty_layer_containers(self._layers))
241
242  @property
243  def trainable_weights(self):
244    if not self._self_trainable:
245      return []
246    trainable_variables = []
247    for obj in self._values:
248      if isinstance(obj, (TrackableDataStructure, module.Module)):
249        trainable_variables += obj.trainable_variables
250    trainable_extra_variables = [
251        v for v in self._self_extra_variables if v.trainable
252    ]
253    return trainable_variables + trainable_extra_variables
254
255  @property
256  def non_trainable_weights(self):
257    trainable_extra_variables = [
258        v for v in self._self_extra_variables if v.trainable
259    ]
260    non_trainable_extra_variables = [
261        v for v in self._self_extra_variables if not v.trainable
262    ]
263    non_trainable_variables = []
264    for obj in self._values:
265      if isinstance(obj, (TrackableDataStructure, module.Module)):
266        non_trainable_variables += obj.non_trainable_variables
267
268    if not self._self_trainable:
269      # Return order is all trainable vars, then all non-trainable vars.
270      trainable_variables = []
271      for obj in self._values:
272        if isinstance(obj, (TrackableDataStructure, module.Module)):
273          trainable_variables += obj.trainable_variables
274
275      non_trainable_variables = (
276          trainable_variables + trainable_extra_variables +
277          non_trainable_variables + non_trainable_extra_variables)
278    else:
279      non_trainable_variables = (
280          non_trainable_variables + non_trainable_extra_variables)
281
282    return non_trainable_variables
283
284  @property
285  def weights(self):
286    return self.trainable_weights + self.non_trainable_weights
287
288  @property
289  def trainable_variables(self):
290    return self.trainable_weights
291
292  @property
293  def non_trainable_variables(self):
294    return self.non_trainable_weights
295
296  @property
297  def variables(self):
298    return self.weights
299
300  @property
301  def updates(self):
302    """Aggregate updates from any `Layer` instances."""
303    # Updates and conditional losses are forwarded as-is rather than being
304    # filtered based on inputs, since this is just a container and won't ever
305    # have any inputs.
306    aggregated = []
307    for layer in self.layers:
308      if hasattr(layer, "updates"):
309        aggregated += layer.updates
310    return aggregated
311
312  @property
313  def losses(self):
314    """Aggregate losses from any `Layer` instances."""
315    aggregated = []
316    for layer in self.layers:
317      if hasattr(layer, "losses"):
318        aggregated += layer.losses
319    return aggregated
320
321  def __hash__(self):
322    # Support object-identity hashing, so these structures can be used as keys
323    # in sets/dicts.
324    return id(self)
325
326  def __eq__(self, other):
327    # Similar to Tensors, trackable data structures use object-identity
328    # equality to support set/dict membership.
329    return self is other
330
331
332class List(TrackableDataStructure, collections_abc.Sequence):
333  """An append-only sequence type which is trackable.
334
335  Maintains checkpoint dependencies on its contents (which must also be
336  trackable), and forwards any `Layer` metadata such as updates and losses.
337
338  Note that `List` is purely a container. It lets a `tf.keras.Model` or
339  other trackable object know about its contents, but does not call any
340  `Layer` instances which are added to it. To indicate a sequence of `Layer`
341  instances which should be called sequentially, use `tf.keras.Sequential`.
342
343  Example usage:
344  ```python
345  class HasList(tf.keras.Model):
346
347    def __init__(self):
348      super().__init__()
349      self.layer_list = List([layers.Dense(3)])
350      self.layer_list.append(layers.Dense(4))
351
352    def call(self, x):
353      aggregation = 0.
354      for l in self.layer_list:
355        x = l(x)
356        aggregation += tf.reduce_sum(x)
357      return aggregation
358  ```
359
360  This kind of wrapping is necessary because `Trackable` objects do not
361  (yet) deeply inspect regular Python data structures, so for example assigning
362  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
363  checkpoint dependency and does not add the `Layer` instance's weights to its
364  parent `Model`.
365  """
366
367  def __init__(self, *args, **kwargs):
368    """Construct a new sequence. Arguments are passed to `list()`."""
369    super().__init__()
370    self._storage = self._make_storage(*args, **kwargs)
371    for index, element in enumerate(self._storage):
372      self._storage[index] = self._track_value(
373          element, name=self._name_element(index))
374
375  def copy(self):
376    return type(self)(copy.copy(self._storage))
377
378  def __copy__(self):
379    return self.copy()
380
381  def __deepcopy__(self, memo):
382    return type(self)(copy.deepcopy(self._storage, memo))
383
384  def _make_storage(self, *args, **kwargs):
385    """Determines the backing storage (overridden in subclasses)."""
386    return list(*args, **kwargs)
387
388  def _name_element(self, index):
389    return "%d" % (index,)
390
391  @property
392  def _values(self):
393    """Collect values for TrackableDataStructure."""
394    return self
395
396  def append(self, value):
397    """Add a new trackable value."""
398    value = self._track_value(value, self._name_element(len(self._storage)))
399    self._storage.append(value)
400
401  def extend(self, values):
402    """Add a sequence of trackable values."""
403    for value in values:
404      self.append(value)
405
406  def __iadd__(self, values):
407    self.extend(values)
408    return self
409
410  def __add__(self, other):
411    return self._storage + getattr(other, "_storage", other)
412
413  def __imul__(self, y):
414    if y <= 0:
415      raise ValueError(
416          f"List only supports append, multiplying in place by {y} removes "
417          "elements.")
418
419    n = len(self._storage)
420    for _ in range(y - 1):
421      for i in range(n):
422        self.append(self._storage[i])
423
424    return self
425
426  def __mul__(self, n):
427    return self._storage * n
428
429  def __rmul__(self, n):
430    return self * n
431
432  def __radd__(self, other):
433    return other + self._storage
434
435  def __getitem__(self, key):
436    return self._storage[key]
437
438  def __getslice__(self, i, j):
439    return self._storage[slice(i, j)]
440
441  def __len__(self):
442    return len(self._storage)
443
444  def __repr__(self):
445    return "List(%s)" % (repr(self._storage),)
446
447  def __sizeof__(self):
448    return super().__sizeof__() + sys.getsizeof(self._storage)
449
450
451# TODO(tomhennigan) Update to collections.UserList?
452# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
453# Python 3.4 support (may still be tricky).
454class ListWrapper(
455    List,
456    collections_abc.MutableSequence,
457    # Shadowed, but there for isinstance checks.
458    list):
459  """Wraps the built-in `list` to support restore-on-create for variables.
460
461  Unlike `List`, this sequence type is mutable in the same ways built-in lists
462  are. Instead of throwing an error immediately like `List`, it records
463  problematic mutations (e.g. assigning a new element to a position already
464  occupied, meaning both elements get the same names at different times) and
465  refuses to save.
466
467  On assignment to an attribute of a Model or Trackable object, Python
468  lists are replaced with ListWrapper. Wrapping a list in a
469  `NoDependency` object prevents this.
470  """
471
472  def __init__(self, wrapped_list):
473    """Construct a new list wrapper.
474
475    Args:
476      wrapped_list: The initial value of the data structure. A shallow copy may
477        be maintained for error checking. `wrapped_list` itself should not be
478        modified directly after constructing the `ListWrapper`, and if changes
479        are detected the `ListWrapper` will throw an exception on save.
480    """
481    # Monotonic flags which indicate this object would not be restored properly,
482    # and therefore should throw an error on save to avoid giving the impression
483    # that restoring it will work.
484    self._non_append_mutation_value = False
485    self._external_modification_value = False
486    super().__init__(wrapped_list)
487    self._last_wrapped_list_snapshot = list(self._storage)
488
489  @property
490  def _non_append_mutation(self):
491    return self._non_append_mutation_value
492
493  @_non_append_mutation.setter
494  def _non_append_mutation(self, value):
495    # Trackable only cares that a mutation occurred at some point; when
496    # attempting to save it checks whether a mutation occurred and the object is
497    # in a "dirty" state but otherwise the specifics of how it got to that state
498    # are ignored. By contrast, the attribute cache needs to signal the mutation
499    # immediately since a caller could query the value of an attribute (And
500    # should not hit the cached value since the mutation may have affected the
501    # result.)
502    self._attribute_sentinel.invalidate_all()
503    self._non_append_mutation_value = value
504
505  @property
506  def _external_modification(self):
507    return self._external_modification_value
508
509  @_external_modification.setter
510  def _external_modification(self, value):
511    # Invalidate for the same reason as `_non_append_mutation`
512    self._attribute_sentinel.invalidate_all()
513    self._external_modification_value = value
514
515  # pylint: disable=protected-access
516  def __copy__(self):
517    copied = super().__copy__()
518    copied._non_append_mutation = self._non_append_mutation
519    copied._external_modification = self._external_modification
520    return copied
521
522  def __deepcopy__(self, memo):
523    copied = super().__deepcopy__(memo)
524    copied._non_append_mutation = self._non_append_mutation
525    copied._external_modification = self._external_modification
526    return copied
527  # pylint: enable=protected-access
528
529  def __reduce_ex__(self, protocol):
530    return (self.__class__,
531            (self._storage,))
532
533  def _make_storage(self, wrapped_list):
534    """Use the user's original list for storage."""
535    return wrapped_list
536
537  def _check_external_modification(self):
538    """Checks for any changes to the wrapped list not through the wrapper."""
539    if self._external_modification or self._non_append_mutation:
540      return
541    if self._storage != self._last_wrapped_list_snapshot:
542      self._external_modification = True
543      self._last_wrapped_list_snapshot = None
544
545  def _update_snapshot(self):
546    """Acknowledges tracked changes to the wrapped list."""
547
548    # Mutation tracking for attributes reuses the same infrastructure as
549    # Trackable mutation tracking.
550    self._attribute_sentinel.invalidate_all()
551    if self._external_modification or self._non_append_mutation:
552      return
553    self._last_wrapped_list_snapshot = list(self._storage)
554
555  def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
556    self._check_external_modification()
557    if self._non_append_mutation:
558      raise ValueError(
559          f"Unable to save the object {self} (a list wrapper constructed to "
560          "track trackable TensorFlow objects). A list element was replaced "
561          "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
562          "or moved (sort). In order to support restoration on object "
563          "creation, tracking is exclusively for append-only data structures."
564          "\n\nIf you don't need this list checkpointed, wrap it in a "
565          "non-trackable object; it will be subsequently ignored.")
566    if self._external_modification:
567      raise ValueError(
568          f"Unable to save the object {self} (a list wrapper constructed to "
569          "track trackable TensorFlow objects). The wrapped list was modified "
570          f"outside the wrapper (its final value was {self._storage}, its value"
571          " when a checkpoint dependency was added was "
572          f"{self._last_wrapped_list_snapshot}), which breaks "
573          "restoration on object creation.\n\nIf you don't need this list "
574          "checkpointed, wrap it in a NoDependency object; it will be "
575          "subsequently ignored.")
576    children = super()._trackable_children(save_type, **kwargs)
577
578    if save_type == base.SaveType.SAVEDMODEL:
579      # Add functions to be serialized.
580      children.update({
581          str(key): value
582          for key, value in enumerate(self)
583          if _is_function(value)
584      })
585
586    return children
587
588  def _has_mutation_or_trackable(self):
589    """Short-circuits a check for trackables if there's already a mutation."""
590    if self._non_append_mutation:
591      return True
592    return any(isinstance(element, base.Trackable) for element in self._storage)
593
594  def __delitem__(self, key):
595    self._check_external_modification()
596    if self._has_mutation_or_trackable():
597      self._non_append_mutation = True
598    del self._storage[key]
599    self._update_snapshot()
600
601  def __setitem__(self, key, value):
602    self._check_external_modification()
603
604    if isinstance(key, slice):
605      # Note: this is quite inefficient, but the list API supports a broad range
606      # of slice setters (e.g. truncate, extend, replace) and imitating this
607      # for a range of Python versions is non-trivial.
608      storage_copy = list(self._storage)
609      self._storage[key] = value
610
611      len_before = len(storage_copy)
612      len_now = len(self._storage)
613      for i in range(max(len_before, len_now)):
614        value_now = self._storage[i] if i < len_now else None
615        value_before = storage_copy[i] if i < len_before else None
616
617        if isinstance(value_before, base.Trackable):
618          self._non_append_mutation = True
619
620        if value_now is not None and value_now != value_before:
621          self._storage[i] = self._track_value(self._storage[i],
622                                               self._name_element(i))
623
624    else:
625      if isinstance(self._storage[key], base.Trackable):
626        self._non_append_mutation = True
627      self._storage[key] = self._track_value(value, self._name_element(key))
628
629    self._update_snapshot()
630
631  def append(self, value):
632    """Add a new trackable value."""
633    self._check_external_modification()
634    super().append(value)
635    self._update_snapshot()
636
637  def extend(self, values):
638    """Add a sequence of trackable values."""
639    self._check_external_modification()
640    super().extend(values)
641    self._update_snapshot()
642
643  def __imul__(self, y):
644    if y <= 0:
645      self._check_external_modification()
646      if self._has_mutation_or_trackable():
647        self._non_append_mutation = True
648      self._storage *= y
649      self._update_snapshot()
650      return self
651
652    # Relies on super() calling append, which updates the snapshot.
653    return super().__imul__(y)
654
655  def __eq__(self, other):
656    return self._storage == getattr(other, "_storage", other)
657
658  def __ne__(self, other):
659    return self._storage != getattr(other, "_storage", other)
660
661  def __lt__(self, other):
662    return self._storage < getattr(other, "_storage", other)
663
664  def __le__(self, other):
665    return self._storage <= getattr(other, "_storage", other)
666
667  def __gt__(self, other):
668    return self._storage > getattr(other, "_storage", other)
669
670  def __ge__(self, other):
671    return self._storage >= getattr(other, "_storage", other)
672
673  def __hash__(self):
674    # List wrappers need to compare like regular lists, and so like regular
675    # lists they don't belong in hash tables.
676    raise TypeError("unhashable type: 'ListWrapper'")
677
678  def insert(self, index, obj):
679    self._check_external_modification()
680    if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)):
681      self._non_append_mutation = True
682    self._storage.insert(index, obj)
683    self._update_snapshot()
684
685  def sort(self):
686    self._check_external_modification()
687    if self._has_mutation_or_trackable():
688      self._non_append_mutation = True
689    self._storage.sort()
690    self._update_snapshot()
691
692  def __setslice__(self, i, j, y):
693    self.__setitem__(slice(i, j), y)
694
695  def __delslice__(self, i, j):
696    self._check_external_modification()
697    if self._has_mutation_or_trackable():
698      self._non_append_mutation = True
699    del self._storage[slice(i, j)]
700    self._update_snapshot()
701
702  def _track_value(self, value, name):
703    """Allows storage of non-trackable objects."""
704    try:
705      value = super()._track_value(value=value, name=name)
706    except ValueError:
707      # Even if this value isn't trackable, we need to make sure
708      # NoDependency objects get unwrapped.
709      value = sticky_attribute_assignment(
710          trackable=self, value=value, name=name)
711    return value
712
713  def __repr__(self):
714    return "ListWrapper(%s)" % (repr(self._storage),)
715
716
717class Mapping(TrackableDataStructure, collections_abc.Mapping):
718  """An append-only trackable mapping data structure with string keys.
719
720  Maintains checkpoint dependencies on its contents (which must also be
721  trackable), named based on its keys.
722
723  Note that once a key has been added, it may not be deleted or replaced.
724  """
725
726  def __init__(self, *args, **kwargs):
727    """Construct a new sequence. Arguments are passed to `dict()`."""
728    super().__init__()
729    self._storage = self._make_storage(*args, **kwargs)
730    self._storage.update(
731        {key: self._track_value(
732            value, name=self._name_element(key))
733         for key, value in self._storage.items()})
734
735  def __copy__(self):
736    return type(self)(copy.copy(self._storage))
737
738  def __deepcopy__(self, memo):
739    return type(self)(copy.deepcopy(self._storage, memo))
740
741  def _make_storage(self, *args, **kwargs):
742    return dict(*args, **kwargs)
743
744  @property
745  def _values(self):
746    """Collect values for TrackableDataStructure."""
747    # Sort items deterministically by key
748    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
749    if ordered:
750      return ordered[1]
751    return []
752
753  def _name_element(self, key):
754    if not isinstance(key, str):
755      raise TypeError(
756          f"Mapping accepts only string keys, but got a key {repr(key)}.")
757    return str(key)
758
759  def __setitem__(self, key, value):
760    name = self._name_element(key)
761    value = self._track_value(value, name=name)
762    current_value = self._storage.setdefault(key, value)
763    if current_value is not value:
764      raise ValueError(
765          "Mappings are an append-only data structure. Tried to overwrite the "
766          f"key '{key}' with value {value}, but it already contains "
767          f"{current_value}")
768
769  def update(self, *args, **kwargs):
770    for key, value in dict(*args, **kwargs).items():
771      self[key] = value
772
773  def __getitem__(self, key):
774    return self._storage[key]
775
776  def __len__(self):
777    return len(self._storage)
778
779  def __repr__(self):
780    return "Mapping(%s)" % (repr(self._storage),)
781
782  def __iter__(self):
783    return iter(self._storage)
784
785
786class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
787  """Wraps built-in dicts to support restore-on-create for variables.
788
789  _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
790  _DictWrapper allows non-string keys and values and arbitrary mutations (delete
791  keys, reassign values). Like ListWrapper, these mutations mean that
792  _DictWrapper will raise an exception on save.
793  """
794
795  def __init__(self, wrapped_dict=None):
796    if wrapped_dict is None:
797      # Allow zero-argument construction, e.g. from session.run's re-wrapping.
798      wrapped_dict = {}
799    if not isinstance(wrapped_dict, collections_abc.Mapping):
800      # Allow construction from a sequence, e.g. from nest.pack_sequence_as.
801      wrapped_dict = dict(wrapped_dict)
802    wrapt.ObjectProxy.__init__(self, wrapped_dict)
803    TrackableDataStructure.__init__(self)
804    self._self_non_string_key = False
805    self._self_external_modification = False
806    self.__wrapped__.update(
807        {key: self._track_value(
808            value, name=self._name_element(key))
809         for key, value in self.__wrapped__.items()})
810    self._update_snapshot()
811
812  def __reduce_ex__(self, protocol):
813    return (self.__class__,
814            (self.__wrapped__,))
815
816  def __getattribute__(self, name):
817    if (hasattr(type(self), name)
818        and isinstance(getattr(type(self), name), property)):
819      # Bypass ObjectProxy for properties. Whether this workaround is necessary
820      # appears to depend on the Python version but not the wrapt version: 3.4
821      # in particular seems to look up properties on the wrapped object instead
822      # of the wrapper without this logic.
823      return object.__getattribute__(self, name)
824    else:
825      return super().__getattribute__(name)
826
827  def copy(self):
828    return copy.copy(self)
829
830  # pylint: disable=protected-access
831  def __copy__(self):
832    copied = _DictWrapper(copy.copy(self.__wrapped__))
833    copied._self_external_modification = self._self_external_modification
834    copied._self_non_string_key = self._self_non_string_key
835    return copied
836
837  def __deepcopy__(self, memo):
838    copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
839    copied._self_external_modification = self._self_external_modification
840    copied._self_non_string_key = self._self_non_string_key
841    return copied
842  # pylint: enable=protected-access
843
844  @property
845  def _values(self):
846    """Collect values for TrackableDataStructure."""
847    # Sort items deterministically by key
848    ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
849    if ordered:
850      return ordered[1]
851    return []
852
853  def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
854    """Check that the object is saveable before listing its dependencies."""
855    self._check_self_external_modification()
856    if self._self_non_string_key:
857      raise ValueError(
858          f"Unable to save the object {self} (a dictionary wrapper constructed "
859          "automatically on attribute assignment). The wrapped dictionary "
860          "contains a non-string key which maps to a trackable object or "
861          "mutable data structure.\n\nIf you don't need this dictionary "
862          "checkpointed, wrap it in a non-trackable "
863          "object; it will be subsequently ignored.")
864    if self._self_external_modification:
865      raise ValueError(
866          f"Unable to save the object {self} (a dictionary wrapper constructed "
867          "automatically on attribute assignment). The wrapped dictionary was "
868          f"modified outside the wrapper (its final value was {self}, its value"
869          " when a checkpoint dependency was added was "
870          f"{self._self_last_wrapped_dict_snapshot}), which breaks "
871          "restoration on object creation.\n\nIf you don't need this "
872          "dictionary checkpointed, wrap it in a "
873          "non-trackable object; it will be subsequently ignored.")
874    assert not self._dirty  # Any reason for dirtiness should have an exception.
875    children = super()._trackable_children(save_type, **kwargs)
876
877    if save_type == base.SaveType.SAVEDMODEL:
878      # Add functions to be serialized.
879      children.update(
880          {key: value for key, value in self.items() if _is_function(value)})
881
882    return children
883
884  @property
885  def _dirty(self):
886    """Check if there has already been a mutation which prevents saving."""
887    return (self._self_external_modification
888            or self._self_non_string_key)
889
890  def _check_self_external_modification(self):
891    """Checks for any changes to the wrapped dict not through the wrapper."""
892    if self._dirty:
893      return
894    if self != self._self_last_wrapped_dict_snapshot:
895      self._self_external_modification = True
896      self._self_last_wrapped_dict_snapshot = None
897
898  def _update_snapshot(self):
899    """Acknowledges tracked changes to the wrapped dict."""
900    self._attribute_sentinel.invalidate_all()
901    if self._dirty:
902      return
903    self._self_last_wrapped_dict_snapshot = dict(self)
904
905  def _track_value(self, value, name):
906    """Allows storage of non-trackable objects."""
907    if isinstance(name, str):
908      string_key = True
909    else:
910      name = "-non_string_key"
911      string_key = False
912    try:
913      no_dependency = isinstance(value, NoDependency)
914      value = super()._track_value(value=value, name=name)
915      if not (string_key or no_dependency):
916        # A non-string key maps to a trackable value. This data structure
917        # is not saveable.
918        self._self_non_string_key = True
919      return value
920    except ValueError:
921      # Even if this value isn't trackable, we need to make sure
922      # NoDependency objects get unwrapped.
923      return sticky_attribute_assignment(
924          trackable=self, value=value, name=name)
925
926  def _name_element(self, key):
927    """Tells TrackableDataStructure to use keys as names as-is."""
928    return key
929
930  def __setitem__(self, key, value):
931    """Allow any modifications, but possibly mark the wrapper as unsaveable."""
932    self._check_self_external_modification()
933    self._maybe_initialize_trackable()
934    no_dep = isinstance(value, NoDependency)
935    if isinstance(key, str):
936      value = self._track_value(value, name=key)
937    else:
938      value = wrap_or_unwrap(value)
939      if not no_dep and isinstance(value, base.Trackable):
940        # Non-string keys are OK as long as we have no reason to add a
941        # dependency on the value (either because the value is not
942        # trackable, or because it was wrapped in a NoDependency object).
943        self._self_non_string_key = True
944    self.__wrapped__[key] = value
945
946    self._update_snapshot()
947
948  def __delitem__(self, key):
949    self._check_self_external_modification()
950    del self.__wrapped__[key]
951    self._update_snapshot()
952
953  def __repr__(self):
954    return "DictWrapper(%s)" % (repr(self.__wrapped__),)
955
956  def __hash__(self):
957    raise TypeError("unhashable type: 'DictWrapper'")
958
959  def __eq__(self, other):
960    # Override the TrackableDataStructure "== -> is" forwarding and go back to
961    # the wrapt implementation.
962    return self.__wrapped__ == other
963
964  def update(self, *args, **kwargs):
965    for key, value in dict(*args, **kwargs).items():
966      self[key] = value
967
968
969class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy):
970  """Trackable wrapper for tuples and namedtuples."""
971
972  def __init__(self, original_wrapped_tuple=()):
973    add_dependency = []
974    substituted_wrapped_tuple = []
975    for element in original_wrapped_tuple:
976      if isinstance(element, NoDependency):
977        add_dependency.append(False)
978      else:
979        add_dependency.append(True)
980      substituted_wrapped_tuple.append(wrap_or_unwrap(element))
981    try:
982      fields = original_wrapped_tuple._fields
983    except AttributeError:
984      # Not a namedtuple
985      is_namedtuple = False
986    else:
987      is_namedtuple = True
988    original_type = type(original_wrapped_tuple)
989    # Flag to poison saving if we can't re-construct a namedtupled because its
990    # __new__ takes different keyword arguments than its _fields.
991    self._self_tuple_is_constructable = True
992    if is_namedtuple:
993      try:
994        # NamedTuples take N arguments, unlike tuple which takes a sequence.
995        substituted_wrapped_tuple = original_type(
996            **dict(zip(fields, substituted_wrapped_tuple)))
997      except TypeError:
998        wrapt.ObjectProxy.__init__(self, original_wrapped_tuple)
999        TrackableDataStructure.__init__(self)
1000        self._self_tuple_is_constructable = False
1001        return
1002    else:
1003      substituted_wrapped_tuple = original_type(substituted_wrapped_tuple)
1004    wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple)
1005    TrackableDataStructure.__init__(self)
1006
1007    if is_namedtuple:
1008      # For namedtuples, also track by names for compatibility with
1009      # dictionaries.
1010      for name, should_depend, element in zip(
1011          fields, add_dependency, substituted_wrapped_tuple):
1012        if should_depend:
1013          self._track_value(element, name=name)
1014
1015    # Track by index as well, for compatibility with lists.
1016    for index, (should_depend, element) in enumerate(
1017        zip(add_dependency, substituted_wrapped_tuple)):
1018      if should_depend:
1019        self._track_value(element, name="%d" % (index,))
1020
1021  @property
1022  def _values(self):
1023    """Collect values for TrackableDataStructure."""
1024    return self
1025
1026  def _track_value(self, value, name):
1027    """Allows storage of non-trackable objects."""
1028    try:
1029      value = super()._track_value(value=value, name=name)
1030    except ValueError:
1031      # Even if this value isn't trackable, we need to make sure
1032      # NoDependency objects get unwrapped.
1033      value = sticky_attribute_assignment(
1034          trackable=self, value=value, name=name)
1035    return value
1036
1037  def __repr__(self):
1038    return "_TupleWrapper(%s)" % (repr(self.__wrapped__),)
1039
1040  def __hash__(self):
1041    # Override the TrackableDataStructure hash forwarding and go back to
1042    # the wrapt implementation.
1043    return hash(self.__wrapped__)
1044
1045  def __eq__(self, other):
1046    # Override the TrackableDataStructure "== -> is" forwarding and go back to
1047    # the wrapt implementation.
1048    return self.__wrapped__ == other
1049
1050  def __copy__(self):
1051    return _TupleWrapper(copy.copy(self.__wrapped__))
1052
1053  def __deepcopy__(self, memo):
1054    return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo))
1055
1056  def __reduce_ex__(self, protocol):
1057    return (self.__class__,
1058            (self.__wrapped__,))
1059
1060  # imul and iadd are the only tuple-relevant in-place operators. They need to
1061  # be special-cased to avoid mutating the original proxy object.
1062  def __imul__(self, y):
1063    """Avoid running self.__wrapped__ *= y, which mutates `self`."""
1064    return self.__wrapped__ * y
1065
1066  def __iadd__(self, y):
1067    """Avoid running self.__wrapped__ += y, which mutates `self`."""
1068    return self.__wrapped__ + y
1069
1070  def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
1071    if not self._self_tuple_is_constructable:
1072      raise ValueError(
1073          f"Unable to save because the namedtuple {self.__wrapped__} is not "
1074          "constructable from its _fields (i.e. __new__ is overridden). "
1075          f"Expected keyword arguments {self.__wrapped__._fields}. If you do "
1076          "not need to save this object, consider wrapping it in a custom "
1077          "object that does not inherit from tuple.")
1078    return super()._trackable_children(save_type, **kwargs)
1079
1080  def __getattribute__(self, name):
1081    if name != "__wrapped__" and hasattr(self.__wrapped__, name):
1082      # Prefer attributes on the wrapped object when they conflict with
1083      # attributes on the wrapper object.
1084      return getattr(self.__wrapped__, name)
1085
1086    if (hasattr(type(self), name)
1087        and isinstance(getattr(type(self), name), property)):
1088      # Bypass ObjectProxy for properties. Whether this workaround is necessary
1089      # appears to depend on the Python version but not the wrapt version: 3.4
1090      # in particular seems to look up properties on the wrapped object instead
1091      # of the wrapper without this logic.
1092      return object.__getattribute__(self, name)
1093    else:
1094      return super().__getattribute__(name)
1095
1096
1097def _is_function(x):
1098  return isinstance(x, (def_function.Function, defun.ConcreteFunction))
1099
1100
1101revived_types.register_revived_type(
1102    "trackable_dict_wrapper",
1103    lambda obj: isinstance(obj, _DictWrapper),
1104    versions=[revived_types.VersionedTypeRegistration(
1105        # Standard dependencies are enough to reconstruct the trackable
1106        # items in dictionaries, so we don't need to save any extra information.
1107        object_factory=lambda proto: _DictWrapper({}),
1108        version=1,
1109        min_producer_version=1,
1110        min_consumer_version=1,
1111        setter=operator.setitem)])
1112
1113
1114def _set_list_item(list_object, index_string, value):
1115  item_index = int(index_string)
1116  if len(list_object) <= item_index:
1117    list_object.extend([None] * (1 + item_index - len(list_object)))
1118  list_object[item_index] = value
1119
1120
1121revived_types.register_revived_type(
1122    "trackable_list_wrapper",
1123    lambda obj: isinstance(obj, ListWrapper),
1124    versions=[revived_types.VersionedTypeRegistration(
1125        object_factory=lambda proto: ListWrapper([]),
1126        version=1,
1127        min_producer_version=1,
1128        min_consumer_version=1,
1129        setter=_set_list_item)])
1130
1131
1132def _set_tuple_item(list_object, index_string, value):
1133  try:
1134    item_index = int(index_string)
1135  except ValueError:
1136    # Ignore namedtuple fields.
1137    return
1138  if len(list_object) <= item_index:
1139    list_object.extend([None] * (1 + item_index - len(list_object)))
1140  list_object[item_index] = value
1141
1142
1143# Revive tuples as lists so we can append any dependencies during loading.
1144revived_types.register_revived_type(
1145    "trackable_tuple_wrapper",
1146    lambda obj: isinstance(obj, _TupleWrapper),
1147    versions=[revived_types.VersionedTypeRegistration(
1148        object_factory=lambda proto: ListWrapper([]),
1149        version=1,
1150        min_producer_version=1,
1151        min_consumer_version=1,
1152        setter=_set_tuple_item)])
1153