xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/trackable_view.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Manages a Trackable object graph."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
17import weakref
18
19from tensorflow.python.trackable import base
20from tensorflow.python.trackable import converter
21from tensorflow.python.util import object_identity
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("train.TrackableView", v1=[])
26class TrackableView(object):
27  """Gathers and serializes a trackable view.
28
29  Example usage:
30
31  >>> class SimpleModule(tf.Module):
32  ...   def __init__(self, name=None):
33  ...     super().__init__(name=name)
34  ...     self.a_var = tf.Variable(5.0)
35  ...     self.b_var = tf.Variable(4.0)
36  ...     self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
37
38  >>> root = SimpleModule(name="root")
39  >>> root.leaf = SimpleModule(name="leaf")
40  >>> trackable_view = tf.train.TrackableView(root)
41
42  Pass root to tf.train.TrackableView.children() to get the dictionary of all
43  children directly linked to root by name.
44  >>> trackable_view_children = trackable_view.children(root)
45  >>> for item in trackable_view_children.items():
46  ...   print(item)
47  ('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
48  ('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
49  ('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
50  numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
51  ('leaf', ...)
52
53  """
54
55  def __init__(self, root):
56    """Configure the trackable view.
57
58    Args:
59      root: A `Trackable` object whose variables (including the variables of
60        dependencies, recursively) should be saved. May be a weak reference.
61    """
62    # TrackableView should never contain a strong reference to root, since it
63    # may result in a cycle:
64    #   root -> deferred dependencies -> CheckpointPosition
65    #   -> CheckpointRestoreCoordinator -> TrackableView -> root
66    self._root_ref = (root if isinstance(root, weakref.ref)
67                      else weakref.ref(root))
68
69  @classmethod
70  def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
71    """Returns all child trackables attached to obj.
72
73    Args:
74      obj: A `Trackable` object.
75      save_type: A string, can be 'savedmodel' or 'checkpoint'.
76      **kwargs: kwargs to use when retrieving the object's children.
77
78    Returns:
79      Dictionary of all children attached to the object with name to trackable.
80    """
81    # pylint: disable=protected-access
82    obj._maybe_initialize_trackable()
83    children = {}
84    for name, ref in obj._trackable_children(save_type, **kwargs).items():
85      ref = converter.convert_to_trackable(ref, parent=obj)
86      children[name] = ref
87    return children
88
89  @property
90  def root(self):
91    if isinstance(self._root_ref, weakref.ref):
92      derefed = self._root_ref()
93      assert derefed is not None
94      return derefed
95    else:
96      return self._root_ref
97
98  def descendants(self):
99    """Returns a list of all nodes from self.root using a breadth first traversal."""
100    return self._descendants_with_paths()[0]
101
102  def _descendants_with_paths(self):
103    """Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
104    bfs_sorted = []
105    to_visit = collections.deque([self.root])
106    node_paths = object_identity.ObjectIdentityDictionary()
107    node_paths[self.root] = ()
108    while to_visit:
109      current_trackable = to_visit.popleft()
110      bfs_sorted.append(current_trackable)
111      for name, dependency in self.children(current_trackable).items():
112        if dependency not in node_paths:
113          node_paths[dependency] = (
114              node_paths[current_trackable] +
115              (base.TrackableReference(name, dependency),))
116          to_visit.append(dependency)
117    return bfs_sorted, node_paths
118