1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility methods for the trackable dependencies.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21 22 23def pretty_print_node_path(path): 24 if not path: 25 return "root object" 26 else: 27 return "root." + ".".join([p.name for p in path]) 28 29 30class CyclicDependencyError(Exception): 31 32 def __init__(self, leftover_dependency_map): 33 """Creates a CyclicDependencyException.""" 34 # Leftover edges that were not able to be topologically sorted. 35 self.leftover_dependency_map = leftover_dependency_map 36 super(CyclicDependencyError, self).__init__() 37 38 39def order_by_dependency(dependency_map): 40 """Topologically sorts the keys of a map so that dependencies appear first. 41 42 Uses Kahn's algorithm: 43 https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm 44 45 Args: 46 dependency_map: a dict mapping values to a list of dependencies (other keys 47 in the map). All keys and dependencies must be hashable types. 48 49 Returns: 50 A sorted array of keys from dependency_map. 51 52 Raises: 53 CyclicDependencyError: if there is a cycle in the graph. 54 ValueError: If there are values in the dependency map that are not keys in 55 the map. 56 """ 57 # Maps trackables -> trackables that depend on them. These are the edges used 58 # in Kahn's algorithm. 59 reverse_dependency_map = collections.defaultdict(set) 60 for x, deps in dependency_map.items(): 61 for dep in deps: 62 reverse_dependency_map[dep].add(x) 63 64 # Validate that all values in the dependency map are also keys. 65 unknown_keys = reverse_dependency_map.keys() - dependency_map.keys() 66 if unknown_keys: 67 raise ValueError("Found values in the dependency map which are not keys: " 68 f"{unknown_keys}") 69 70 # Generate the list sorted by objects without dependencies -> dependencies. 71 # The returned list will reverse this. 72 reversed_dependency_arr = [] 73 74 # Prefill `to_visit` with all nodes that do not have other objects depending 75 # on them. 76 to_visit = [x for x in dependency_map if x not in reverse_dependency_map] 77 78 while to_visit: 79 x = to_visit.pop(0) 80 reversed_dependency_arr.append(x) 81 for dep in set(dependency_map[x]): 82 edges = reverse_dependency_map[dep] 83 edges.remove(x) 84 if not edges: 85 to_visit.append(dep) 86 reverse_dependency_map.pop(dep) 87 88 if reverse_dependency_map: 89 leftover_dependency_map = collections.defaultdict(list) 90 for dep, xs in reverse_dependency_map.items(): 91 for x in xs: 92 leftover_dependency_map[x].append(dep) 93 raise CyclicDependencyError(leftover_dependency_map) 94 95 return reversed(reversed_dependency_arr) 96 97 98_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. 99 100# Keyword for identifying that the next bit of a checkpoint variable name is a 101# slot name. Checkpoint names for slot variables look like: 102# 103# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name> 104# 105# Where <path to variable> is a full path from the checkpoint root to the 106# variable being slotted for. 107_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" 108# Keyword for separating the path to an object from the name of an 109# attribute in checkpoint names. Used like: 110# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute> 111OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" 112 113# A constant string that is used to reference the save and restore functions of 114# Trackable objects that define `_serialize_to_tensors` and 115# `_restore_from_tensors`. This is written as the key in the 116# `SavedObject.saveable_objects<string, SaveableObject>` map in the SavedModel. 117SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS" 118 119 120def escape_local_name(name): 121 # We need to support slashes in local names for compatibility, since this 122 # naming scheme is being patched in to things like Layer.add_variable where 123 # slashes were previously accepted. We also want to use slashes to indicate 124 # edges traversed to reach the variable, so we escape forward slashes in 125 # names. 126 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).replace( 127 r"/", _ESCAPE_CHAR + "S")) 128 129 130def object_path_to_string(node_path_arr): 131 """Converts a list of nodes to a string.""" 132 return "/".join( 133 (escape_local_name(trackable.name) for trackable in node_path_arr)) 134 135 136def checkpoint_key(object_path, local_name): 137 """Returns the checkpoint key for a local attribute of an object.""" 138 key_suffix = escape_local_name(local_name) 139 if local_name == SERIALIZE_TO_TENSORS_NAME: 140 # In the case that Trackable uses the _serialize_to_tensor API for defining 141 # tensors to save to the checkpoint, the suffix should be the key(s) 142 # returned by `_serialize_to_tensor`. The suffix used here is empty. 143 key_suffix = "" 144 145 return f"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}" 146 147 148def slot_variable_key(variable_path, optimizer_path, slot_name): 149 """Returns checkpoint key for a slot variable.""" 150 # Name slot variables: 151 # 152 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name> 153 # 154 # where <variable name> is exactly the checkpoint name used for the original 155 # variable, including the path from the checkpoint root and the local name in 156 # the object which owns it. Note that we only save slot variables if the 157 # variable it's slotting for is also being saved. 158 159 return (f"{variable_path}/{_OPTIMIZER_SLOTS_NAME}/{optimizer_path}/" 160 f"{escape_local_name(slot_name)}") 161