xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for extracting and writing checkpoint info`."""
16
17from tensorflow.core.protobuf import trackable_object_graph_pb2
18from tensorflow.python.ops import resource_variable_ops
19from tensorflow.python.ops import variables
20from tensorflow.python.trackable import trackable_utils
21from tensorflow.python.training import optimizer as optimizer_v1
22from tensorflow.python.util import object_identity
23
24
25def serialize_slot_variables(trackable_objects, node_ids, object_names):
26  """Gather and name slot variables."""
27  non_slot_objects = list(trackable_objects)
28  slot_variables = object_identity.ObjectIdentityDictionary()
29  for trackable in non_slot_objects:
30    if (isinstance(trackable, optimizer_v1.Optimizer)
31        # TODO(b/110718070): Fix Keras imports.
32        # Note: dir() is used rather than hasattr() here to avoid triggering
33        # custom __getattr__ code, see b/152031870 for context.
34        or "get_slot_names" in dir(trackable)):
35      slot_names = trackable.get_slot_names()
36      for slot_name in slot_names:
37        for original_variable_node_id, original_variable in enumerate(
38            non_slot_objects):
39          try:
40            slot_variable = trackable.get_slot(original_variable, slot_name)
41          except (AttributeError, KeyError):
42            slot_variable = None
43          if slot_variable is None:
44            continue
45          slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
46          if slot_variable._trackable_children():  # pylint: disable=protected-access
47            # TODO(allenl): Gather dependencies of slot variables.
48            raise NotImplementedError(
49                "Currently only variables with no dependencies can be saved as "
50                "slot variables. File a feature request if this limitation "
51                "bothers you.")
52          if slot_variable in node_ids:
53            raise NotImplementedError(
54                "A slot variable was re-used as a dependency of a Trackable "
55                f"object: {slot_variable}. This is not currently allowed. "
56                "File a feature request if this limitation bothers you.")
57          checkpoint_name = trackable_utils.slot_variable_key(
58              variable_path=object_names[original_variable],
59              optimizer_path=object_names[trackable],
60              slot_name=slot_name)
61          object_names[slot_variable] = checkpoint_name
62          slot_variable_node_id = len(trackable_objects)
63          node_ids[slot_variable] = slot_variable_node_id
64          trackable_objects.append(slot_variable)
65          slot_variable_proto = (
66              trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
67              .SlotVariableReference(
68                  slot_name=slot_name,
69                  original_variable_node_id=original_variable_node_id,
70                  slot_variable_node_id=slot_variable_node_id))
71          slot_variables.setdefault(trackable, []).append(slot_variable_proto)
72  return slot_variables
73
74
75def get_mapped_trackable(trackable, object_map):
76  """Returns the mapped trackable if possible, otherwise returns trackable."""
77  if object_map is None:
78    return trackable
79  else:
80    return object_map.get(trackable, trackable)
81
82
83def get_full_name(var):
84  """Gets the full name of variable for name-based checkpoint compatiblity."""
85  # pylint: disable=protected-access
86  if (not (isinstance(var, variables.Variable) or
87           # Some objects do not subclass Variable but still act as one.
88           resource_variable_ops.is_resource_variable(var))):
89    return ""
90
91  if getattr(var, "_save_slice_info", None) is not None:
92    # Use getattr because `var._save_slice_info` may be set as `None`.
93    return var._save_slice_info.full_name
94  else:
95    return var._shared_name
96  # pylint: enable=protected-access
97
98
99def add_checkpoint_values_check(object_graph_proto):
100  """Determines which objects have checkpoint values and save this to the proto.
101
102  Args:
103    object_graph_proto: A `TrackableObjectGraph` proto.
104  """
105  # Trackable -> set of all trackables that depend on it (the "parents").
106  # If a trackable has checkpoint values, then all of the parents can be
107  # marked as having checkpoint values.
108  parents = {}
109  checkpointed_trackables = object_identity.ObjectIdentitySet()
110
111  # First pass: build dictionary of parent objects and initial set of
112  # checkpointed trackables.
113  checkpointed_trackables = set()
114  for node_id, object_proto in enumerate(object_graph_proto.nodes):
115    if (object_proto.attributes or object_proto.slot_variables or
116        object_proto.HasField("registered_saver")):
117      checkpointed_trackables.add(node_id)
118    for child_proto in object_proto.children:
119      child = child_proto.node_id
120      if child not in parents:
121        parents[child] = set()
122      parents[child].add(node_id)
123
124  # Second pass: add all connected parents to set of checkpointed trackables.
125  to_visit = set()
126  to_visit.update(checkpointed_trackables)
127
128  while to_visit:
129    trackable = to_visit.pop()
130    if trackable not in parents:
131      # Some trackables may not have parents (e.g. slot variables).
132      continue
133    current_parents = parents.pop(trackable)
134    checkpointed_trackables.update(current_parents)
135    for parent in current_parents:
136      if parent in parents:
137        to_visit.add(parent)
138
139  for node_id, object_proto in enumerate(object_graph_proto.nodes):
140    object_proto.has_checkpoint_values.value = bool(
141        node_id in checkpointed_trackables)
142
143
144def objects_ids_and_slot_variables_and_paths(graph_view):
145  """Traverse the object graph and list all accessible objects.
146
147  Looks for `Trackable` objects which are dependencies of
148  `root_trackable`. Includes slot variables only if the variable they are
149  slotting for and the optimizer are dependencies of `root_trackable`
150  (i.e. if they would be saved with a checkpoint).
151
152  Args:
153    graph_view: A GraphView object.
154
155  Returns:
156    A tuple of (trackable objects, paths from root for each object,
157                object -> node id, slot variables, object_names)
158  """
159  trackable_objects, node_paths = graph_view.breadth_first_traversal()
160  object_names = object_identity.ObjectIdentityDictionary()
161  for obj, path in node_paths.items():
162    object_names[obj] = trackable_utils.object_path_to_string(path)
163  node_ids = object_identity.ObjectIdentityDictionary()
164  for node_id, node in enumerate(trackable_objects):
165    node_ids[node] = node_id
166  slot_variables = serialize_slot_variables(
167      trackable_objects=trackable_objects,
168      node_ids=node_ids,
169      object_names=object_names)
170  return (trackable_objects, node_paths, node_ids, slot_variables, object_names)
171
172
173def list_objects(graph_view):
174  """Traverse the object graph and list all accessible objects."""
175  trackable_objects = objects_ids_and_slot_variables_and_paths(graph_view)[0]
176  return trackable_objects
177