xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/autotrackable.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dependency tracking for trackable objects."""
16
17import warnings
18
19from absl import logging
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.eager import function as defun
23from tensorflow.python.trackable import base
24from tensorflow.python.trackable import data_structures
25from tensorflow.python.types import core as core_types
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("__internal__.tracking.AutoTrackable", v1=[])
30class AutoTrackable(base.Trackable):
31  """Manages dependencies on other objects.
32
33  `Trackable` objects may have dependencies: other `Trackable` objects
34  which should be saved if the object declaring the dependency is saved. A
35  correctly saveable program has a dependency graph such that if changing a
36  global variable affects an object (e.g. changes the behavior of any of its
37  methods) then there is a chain of dependencies from the influenced object to
38  the variable.
39
40  Dependency edges have names, and are created implicitly when a
41  `Trackable` object is assigned to an attribute of another
42  `Trackable` object. For example:
43
44  ```
45  obj = Trackable()
46  obj.v = ResourceVariable(0.)
47  ```
48
49  The `Trackable` object `obj` now has a dependency named "v" on a
50  variable.
51
52  `Trackable` objects may specify `Tensor`s to be saved and restored
53  directly (e.g. a `Variable` indicating how to save itself) rather than through
54  dependencies on other objects. See
55  `Trackable._gather_saveables_for_checkpoint` for details.
56  """
57
58  def __setattr__(self, name, value):
59    """Support self.foo = trackable syntax."""
60    try:
61      if getattr(self, name) is value:
62        # Short circuit for `self.$x = self.$x`.
63        return
64    except AttributeError:
65      pass
66
67    if getattr(self, "_self_setattr_tracking", True):
68      value = data_structures.sticky_attribute_assignment(
69          trackable=self, value=value, name=name)
70    super(AutoTrackable, self).__setattr__(name, value)
71
72  def __delattr__(self, name):
73    self._delete_tracking(name)
74    super(AutoTrackable, self).__delattr__(name)
75
76  def _no_dependency(self, value):
77    """Override to allow TrackableBase to disable dependency tracking."""
78    return data_structures.NoDependency(value)
79
80  def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
81    """Returns all children of a trackable, including functions."""
82    if save_type != base.SaveType.SAVEDMODEL:
83      return super(AutoTrackable, self)._trackable_children(
84          save_type, **kwargs)
85
86    functions = {}
87    try:
88      # We get the attributes, suppressing warnings and exceptions.
89      logging_verbosity = logging.get_verbosity()
90      logging.set_verbosity(logging.FATAL)
91      for attribute_name in dir(self):
92        try:
93          with warnings.catch_warnings():
94            warnings.simplefilter("ignore")
95            attribute_value = getattr(self, attribute_name, None)
96        except Exception:  # pylint: disable=broad-except
97          # NOTE: If we make the exception catching here less broad, we might
98          # need to revisit `finally` block below.
99          # We really don't want to throw an exception just because some
100          # object's attribute accessor is broken.
101          attribute_value = None
102        if isinstance(attribute_value, (def_function.Function,
103                                        defun.ConcreteFunction)):
104          functions[attribute_name] = attribute_value
105    finally:
106      logging.set_verbosity(logging_verbosity)
107
108    # Trace concrete functions to force side-effects:
109    #   1. populate the cache for functions that have an input_signature
110    #      and have not been called
111    #   2. force side effects of creation of concrete functions, e.g. create
112    #      variables on first run.
113    for fn in functions.values():
114      if isinstance(fn, core_types.GenericFunction):
115        fn._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
116
117    # Additional dependencies may have been generated during function tracing
118    # (e.g. captured variables). Make sure we return those too.
119    children = {}
120    for name, child in self._checkpoint_dependencies:
121      if isinstance(child, (core_types.GenericFunction,
122                            core_types.ConcreteFunction)):
123        # Skip "tracked" functions for now since there may be objects that
124        # automatically track functions that should not be saved.
125        # TODO(kathywu): remove once `_list_functions_for_serialization` has
126        # been fully deprecated.
127        continue
128
129      if name in functions and child is not functions[name]:
130        raise ValueError(
131            "Can't save object because it has multiple children with the same "
132            f"name. Object: {self}, attribute name: {name}, child 1: "
133            f"{child}, child 2: {functions[name]}")
134
135      children[name] = child
136
137    children.update(functions)
138    return children
139
140  def _delete_tracking(self, name):
141    """Removes the tracking of name."""
142    self._maybe_initialize_trackable()
143    if name in self._unconditional_dependency_names:
144      del self._unconditional_dependency_names[name]
145      for index, (dep_name, _) in enumerate(
146          self._unconditional_checkpoint_dependencies):
147        if dep_name == name:
148          del self._unconditional_checkpoint_dependencies[index]
149          break
150
151  def _add_trackable_child(self, name, value):
152    self.__setattr__(name, value)
153