1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for extracting and writing checkpoint info`.""" 16 17from tensorflow.core.protobuf import trackable_object_graph_pb2 18from tensorflow.python.ops import resource_variable_ops 19from tensorflow.python.ops import variables 20from tensorflow.python.trackable import trackable_utils 21from tensorflow.python.training import optimizer as optimizer_v1 22from tensorflow.python.util import object_identity 23 24 25def serialize_slot_variables(trackable_objects, node_ids, object_names): 26 """Gather and name slot variables.""" 27 non_slot_objects = list(trackable_objects) 28 slot_variables = object_identity.ObjectIdentityDictionary() 29 for trackable in non_slot_objects: 30 if (isinstance(trackable, optimizer_v1.Optimizer) 31 # TODO(b/110718070): Fix Keras imports. 32 # Note: dir() is used rather than hasattr() here to avoid triggering 33 # custom __getattr__ code, see b/152031870 for context. 34 or "get_slot_names" in dir(trackable)): 35 slot_names = trackable.get_slot_names() 36 for slot_name in slot_names: 37 for original_variable_node_id, original_variable in enumerate( 38 non_slot_objects): 39 try: 40 slot_variable = trackable.get_slot(original_variable, slot_name) 41 except (AttributeError, KeyError): 42 slot_variable = None 43 if slot_variable is None: 44 continue 45 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access 46 if slot_variable._trackable_children(): # pylint: disable=protected-access 47 # TODO(allenl): Gather dependencies of slot variables. 48 raise NotImplementedError( 49 "Currently only variables with no dependencies can be saved as " 50 "slot variables. File a feature request if this limitation " 51 "bothers you.") 52 if slot_variable in node_ids: 53 raise NotImplementedError( 54 "A slot variable was re-used as a dependency of a Trackable " 55 f"object: {slot_variable}. This is not currently allowed. " 56 "File a feature request if this limitation bothers you.") 57 checkpoint_name = trackable_utils.slot_variable_key( 58 variable_path=object_names[original_variable], 59 optimizer_path=object_names[trackable], 60 slot_name=slot_name) 61 object_names[slot_variable] = checkpoint_name 62 slot_variable_node_id = len(trackable_objects) 63 node_ids[slot_variable] = slot_variable_node_id 64 trackable_objects.append(slot_variable) 65 slot_variable_proto = ( 66 trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject 67 .SlotVariableReference( 68 slot_name=slot_name, 69 original_variable_node_id=original_variable_node_id, 70 slot_variable_node_id=slot_variable_node_id)) 71 slot_variables.setdefault(trackable, []).append(slot_variable_proto) 72 return slot_variables 73 74 75def get_mapped_trackable(trackable, object_map): 76 """Returns the mapped trackable if possible, otherwise returns trackable.""" 77 if object_map is None: 78 return trackable 79 else: 80 return object_map.get(trackable, trackable) 81 82 83def get_full_name(var): 84 """Gets the full name of variable for name-based checkpoint compatiblity.""" 85 # pylint: disable=protected-access 86 if (not (isinstance(var, variables.Variable) or 87 # Some objects do not subclass Variable but still act as one. 88 resource_variable_ops.is_resource_variable(var))): 89 return "" 90 91 if getattr(var, "_save_slice_info", None) is not None: 92 # Use getattr because `var._save_slice_info` may be set as `None`. 93 return var._save_slice_info.full_name 94 else: 95 return var._shared_name 96 # pylint: enable=protected-access 97 98 99def add_checkpoint_values_check(object_graph_proto): 100 """Determines which objects have checkpoint values and save this to the proto. 101 102 Args: 103 object_graph_proto: A `TrackableObjectGraph` proto. 104 """ 105 # Trackable -> set of all trackables that depend on it (the "parents"). 106 # If a trackable has checkpoint values, then all of the parents can be 107 # marked as having checkpoint values. 108 parents = {} 109 checkpointed_trackables = object_identity.ObjectIdentitySet() 110 111 # First pass: build dictionary of parent objects and initial set of 112 # checkpointed trackables. 113 checkpointed_trackables = set() 114 for node_id, object_proto in enumerate(object_graph_proto.nodes): 115 if (object_proto.attributes or object_proto.slot_variables or 116 object_proto.HasField("registered_saver")): 117 checkpointed_trackables.add(node_id) 118 for child_proto in object_proto.children: 119 child = child_proto.node_id 120 if child not in parents: 121 parents[child] = set() 122 parents[child].add(node_id) 123 124 # Second pass: add all connected parents to set of checkpointed trackables. 125 to_visit = set() 126 to_visit.update(checkpointed_trackables) 127 128 while to_visit: 129 trackable = to_visit.pop() 130 if trackable not in parents: 131 # Some trackables may not have parents (e.g. slot variables). 132 continue 133 current_parents = parents.pop(trackable) 134 checkpointed_trackables.update(current_parents) 135 for parent in current_parents: 136 if parent in parents: 137 to_visit.add(parent) 138 139 for node_id, object_proto in enumerate(object_graph_proto.nodes): 140 object_proto.has_checkpoint_values.value = bool( 141 node_id in checkpointed_trackables) 142 143 144def objects_ids_and_slot_variables_and_paths(graph_view): 145 """Traverse the object graph and list all accessible objects. 146 147 Looks for `Trackable` objects which are dependencies of 148 `root_trackable`. Includes slot variables only if the variable they are 149 slotting for and the optimizer are dependencies of `root_trackable` 150 (i.e. if they would be saved with a checkpoint). 151 152 Args: 153 graph_view: A GraphView object. 154 155 Returns: 156 A tuple of (trackable objects, paths from root for each object, 157 object -> node id, slot variables, object_names) 158 """ 159 trackable_objects, node_paths = graph_view.breadth_first_traversal() 160 object_names = object_identity.ObjectIdentityDictionary() 161 for obj, path in node_paths.items(): 162 object_names[obj] = trackable_utils.object_path_to_string(path) 163 node_ids = object_identity.ObjectIdentityDictionary() 164 for node_id, node in enumerate(trackable_objects): 165 node_ids[node] = node_id 166 slot_variables = serialize_slot_variables( 167 trackable_objects=trackable_objects, 168 node_ids=node_ids, 169 object_names=object_names) 170 return (trackable_objects, node_paths, node_ids, slot_variables, object_names) 171 172 173def list_objects(graph_view): 174 """Traverse the object graph and list all accessible objects.""" 175 trackable_objects = objects_ids_and_slot_variables_and_paths(graph_view)[0] 176 return trackable_objects 177