xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/trackable_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility methods for the trackable dependencies."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21
22
23def pretty_print_node_path(path):
24  if not path:
25    return "root object"
26  else:
27    return "root." + ".".join([p.name for p in path])
28
29
30class CyclicDependencyError(Exception):
31
32  def __init__(self, leftover_dependency_map):
33    """Creates a CyclicDependencyException."""
34    # Leftover edges that were not able to be topologically sorted.
35    self.leftover_dependency_map = leftover_dependency_map
36    super(CyclicDependencyError, self).__init__()
37
38
39def order_by_dependency(dependency_map):
40  """Topologically sorts the keys of a map so that dependencies appear first.
41
42  Uses Kahn's algorithm:
43  https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
44
45  Args:
46    dependency_map: a dict mapping values to a list of dependencies (other keys
47      in the map). All keys and dependencies must be hashable types.
48
49  Returns:
50    A sorted array of keys from dependency_map.
51
52  Raises:
53    CyclicDependencyError: if there is a cycle in the graph.
54    ValueError: If there are values in the dependency map that are not keys in
55      the map.
56  """
57  # Maps trackables -> trackables that depend on them. These are the edges used
58  # in Kahn's algorithm.
59  reverse_dependency_map = collections.defaultdict(set)
60  for x, deps in dependency_map.items():
61    for dep in deps:
62      reverse_dependency_map[dep].add(x)
63
64  # Validate that all values in the dependency map are also keys.
65  unknown_keys = reverse_dependency_map.keys() - dependency_map.keys()
66  if unknown_keys:
67    raise ValueError("Found values in the dependency map which are not keys: "
68                     f"{unknown_keys}")
69
70  # Generate the list sorted by objects without dependencies -> dependencies.
71  # The returned list will reverse this.
72  reversed_dependency_arr = []
73
74  # Prefill `to_visit` with all nodes that do not have other objects depending
75  # on them.
76  to_visit = [x for x in dependency_map if x not in reverse_dependency_map]
77
78  while to_visit:
79    x = to_visit.pop(0)
80    reversed_dependency_arr.append(x)
81    for dep in set(dependency_map[x]):
82      edges = reverse_dependency_map[dep]
83      edges.remove(x)
84      if not edges:
85        to_visit.append(dep)
86        reverse_dependency_map.pop(dep)
87
88  if reverse_dependency_map:
89    leftover_dependency_map = collections.defaultdict(list)
90    for dep, xs in reverse_dependency_map.items():
91      for x in xs:
92        leftover_dependency_map[x].append(dep)
93    raise CyclicDependencyError(leftover_dependency_map)
94
95  return reversed(reversed_dependency_arr)
96
97
98_ESCAPE_CHAR = "."  # For avoiding conflicts with user-specified names.
99
100# Keyword for identifying that the next bit of a checkpoint variable name is a
101# slot name. Checkpoint names for slot variables look like:
102#
103#   <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
104#
105# Where <path to variable> is a full path from the checkpoint root to the
106# variable being slotted for.
107_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
108# Keyword for separating the path to an object from the name of an
109# attribute in checkpoint names. Used like:
110#   <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
111OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
112
113# A constant string that is used to reference the save and restore functions of
114#  Trackable objects that define `_serialize_to_tensors` and
115# `_restore_from_tensors`. This is written as the key in the
116# `SavedObject.saveable_objects<string, SaveableObject>` map in the SavedModel.
117SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"
118
119
120def escape_local_name(name):
121  # We need to support slashes in local names for compatibility, since this
122  # naming scheme is being patched in to things like Layer.add_variable where
123  # slashes were previously accepted. We also want to use slashes to indicate
124  # edges traversed to reach the variable, so we escape forward slashes in
125  # names.
126  return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).replace(
127      r"/", _ESCAPE_CHAR + "S"))
128
129
130def object_path_to_string(node_path_arr):
131  """Converts a list of nodes to a string."""
132  return "/".join(
133      (escape_local_name(trackable.name) for trackable in node_path_arr))
134
135
136def checkpoint_key(object_path, local_name):
137  """Returns the checkpoint key for a local attribute of an object."""
138  key_suffix = escape_local_name(local_name)
139  if local_name == SERIALIZE_TO_TENSORS_NAME:
140    # In the case that Trackable uses the _serialize_to_tensor API for defining
141    # tensors to save to the checkpoint, the suffix should be the key(s)
142    # returned by `_serialize_to_tensor`. The suffix used here is empty.
143    key_suffix = ""
144
145  return f"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"
146
147
148def slot_variable_key(variable_path, optimizer_path, slot_name):
149  """Returns checkpoint key for a slot variable."""
150  # Name slot variables:
151  #
152  #   <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
153  #
154  # where <variable name> is exactly the checkpoint name used for the original
155  # variable, including the path from the checkpoint root and the local name in
156  # the object which owns it. Note that we only save slot variables if the
157  # variable it's slotting for is also being saved.
158
159  return (f"{variable_path}/{_OPTIMIZER_SLOTS_NAME}/{optimizer_path}/"
160          f"{escape_local_name(slot_name)}")
161