1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Dependency tracking for trackable objects.""" 16 17import warnings 18 19from absl import logging 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.eager import function as defun 23from tensorflow.python.trackable import base 24from tensorflow.python.trackable import data_structures 25from tensorflow.python.types import core as core_types 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("__internal__.tracking.AutoTrackable", v1=[]) 30class AutoTrackable(base.Trackable): 31 """Manages dependencies on other objects. 32 33 `Trackable` objects may have dependencies: other `Trackable` objects 34 which should be saved if the object declaring the dependency is saved. A 35 correctly saveable program has a dependency graph such that if changing a 36 global variable affects an object (e.g. changes the behavior of any of its 37 methods) then there is a chain of dependencies from the influenced object to 38 the variable. 39 40 Dependency edges have names, and are created implicitly when a 41 `Trackable` object is assigned to an attribute of another 42 `Trackable` object. For example: 43 44 ``` 45 obj = Trackable() 46 obj.v = ResourceVariable(0.) 47 ``` 48 49 The `Trackable` object `obj` now has a dependency named "v" on a 50 variable. 51 52 `Trackable` objects may specify `Tensor`s to be saved and restored 53 directly (e.g. a `Variable` indicating how to save itself) rather than through 54 dependencies on other objects. See 55 `Trackable._gather_saveables_for_checkpoint` for details. 56 """ 57 58 def __setattr__(self, name, value): 59 """Support self.foo = trackable syntax.""" 60 try: 61 if getattr(self, name) is value: 62 # Short circuit for `self.$x = self.$x`. 63 return 64 except AttributeError: 65 pass 66 67 if getattr(self, "_self_setattr_tracking", True): 68 value = data_structures.sticky_attribute_assignment( 69 trackable=self, value=value, name=name) 70 super(AutoTrackable, self).__setattr__(name, value) 71 72 def __delattr__(self, name): 73 self._delete_tracking(name) 74 super(AutoTrackable, self).__delattr__(name) 75 76 def _no_dependency(self, value): 77 """Override to allow TrackableBase to disable dependency tracking.""" 78 return data_structures.NoDependency(value) 79 80 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 81 """Returns all children of a trackable, including functions.""" 82 if save_type != base.SaveType.SAVEDMODEL: 83 return super(AutoTrackable, self)._trackable_children( 84 save_type, **kwargs) 85 86 functions = {} 87 try: 88 # We get the attributes, suppressing warnings and exceptions. 89 logging_verbosity = logging.get_verbosity() 90 logging.set_verbosity(logging.FATAL) 91 for attribute_name in dir(self): 92 try: 93 with warnings.catch_warnings(): 94 warnings.simplefilter("ignore") 95 attribute_value = getattr(self, attribute_name, None) 96 except Exception: # pylint: disable=broad-except 97 # NOTE: If we make the exception catching here less broad, we might 98 # need to revisit `finally` block below. 99 # We really don't want to throw an exception just because some 100 # object's attribute accessor is broken. 101 attribute_value = None 102 if isinstance(attribute_value, (def_function.Function, 103 defun.ConcreteFunction)): 104 functions[attribute_name] = attribute_value 105 finally: 106 logging.set_verbosity(logging_verbosity) 107 108 # Trace concrete functions to force side-effects: 109 # 1. populate the cache for functions that have an input_signature 110 # and have not been called 111 # 2. force side effects of creation of concrete functions, e.g. create 112 # variables on first run. 113 for fn in functions.values(): 114 if isinstance(fn, core_types.GenericFunction): 115 fn._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 116 117 # Additional dependencies may have been generated during function tracing 118 # (e.g. captured variables). Make sure we return those too. 119 children = {} 120 for name, child in self._checkpoint_dependencies: 121 if isinstance(child, (core_types.GenericFunction, 122 core_types.ConcreteFunction)): 123 # Skip "tracked" functions for now since there may be objects that 124 # automatically track functions that should not be saved. 125 # TODO(kathywu): remove once `_list_functions_for_serialization` has 126 # been fully deprecated. 127 continue 128 129 if name in functions and child is not functions[name]: 130 raise ValueError( 131 "Can't save object because it has multiple children with the same " 132 f"name. Object: {self}, attribute name: {name}, child 1: " 133 f"{child}, child 2: {functions[name]}") 134 135 children[name] = child 136 137 children.update(functions) 138 return children 139 140 def _delete_tracking(self, name): 141 """Removes the tracking of name.""" 142 self._maybe_initialize_trackable() 143 if name in self._unconditional_dependency_names: 144 del self._unconditional_dependency_names[name] 145 for index, (dep_name, _) in enumerate( 146 self._unconditional_checkpoint_dependencies): 147 if dep_name == name: 148 del self._unconditional_checkpoint_dependencies[index] 149 break 150 151 def _add_trackable_child(self, name, value): 152 self.__setattr__(name, value) 153