1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Definitions for resource-type trackable object classes.""" 16 17import contextlib 18import copy 19import weakref 20 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import ops 24from tensorflow.python.trackable import base 25from tensorflow.python.util import tf_contextlib 26from tensorflow.python.util.tf_export import tf_export 27 28# global _RESOURCE_TRACKER_STACK 29_RESOURCE_TRACKER_STACK = [] 30 31 32class ResourceTracker: 33 """An object that tracks a list of resources.""" 34 35 __slots__ = ["_resources"] 36 37 def __init__(self): 38 self._resources = [] 39 40 @property 41 def resources(self): 42 return self._resources 43 44 def add_resource(self, resource): 45 self._resources.append(resource) 46 47 48@tf_contextlib.contextmanager 49def resource_tracker_scope(resource_tracker): 50 """A context to manage resource trackers. 51 52 Use this in order to collect up all resources created within a block of code. 53 Example usage: 54 55 ```python 56 resource_tracker = ResourceTracker() 57 with resource_tracker_scope(resource_tracker): 58 resource = TrackableResource() 59 60 assert resource_tracker.resources == [resource] 61 62 Args: 63 resource_tracker: The passed in ResourceTracker object 64 65 Yields: 66 A scope in which the resource_tracker is active. 67 """ 68 global _RESOURCE_TRACKER_STACK 69 old = list(_RESOURCE_TRACKER_STACK) 70 _RESOURCE_TRACKER_STACK.append(resource_tracker) 71 try: 72 yield 73 finally: 74 _RESOURCE_TRACKER_STACK = old 75 76 77def _make_getter(captured_getter, captured_previous): 78 """To avoid capturing loop variables.""" 79 80 def getter(*args, **kwargs): 81 return captured_getter(captured_previous, *args, **kwargs) 82 83 return getter 84 85 86class _ResourceMetaclass(type): 87 """Metaclass for CapturableResource.""" 88 89 def __call__(cls, *args, **kwargs): 90 91 def default_resource_creator(next_creator, *a, **kw): 92 assert next_creator is None 93 obj = cls.__new__(cls, *a, **kw) 94 obj.__init__(*a, **kw) 95 return obj 96 97 previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw) 98 resource_creator_stack = ops.get_default_graph()._resource_creator_stack 99 for getter in resource_creator_stack[cls._resource_type()]: 100 previous_getter = _make_getter(getter, previous_getter) 101 102 return previous_getter(*args, **kwargs) 103 104 105class CapturableResource(base.Trackable, metaclass=_ResourceMetaclass): 106 """Holds a Tensor which a tf.function can capture. 107 108 `CapturableResource`s are discovered by traversing the graph of object 109 attributes, e.g. during `tf.saved_model.save`. They are excluded from the 110 scope-based tracking of `TrackableResource`; generally things that require 111 initialization should inherit from `TrackableResource` instead of 112 `CapturableResource` directly. 113 """ 114 115 def __init__(self, device=""): 116 """Initialize the `CapturableResource`. 117 118 Args: 119 device: A string indicating a required placement for this resource, 120 e.g. "CPU" if this resource must be created on a CPU device. A blank 121 device allows the user to place resource creation, so generally this 122 should be blank unless the resource only makes sense on one device. 123 """ 124 self._resource_handle_value = None 125 self._resource_device = device 126 self._self_destruction_context = ( 127 context.eager_mode if context.executing_eagerly() 128 else ops.get_default_graph().as_default) 129 130 @classmethod 131 def _resource_type(cls): 132 return cls.__name__ 133 134 @property 135 def _destruction_context(self): 136 return getattr(self, "_self_destruction_context", 137 # no-op context 138 contextlib.suppress) 139 140 @_destruction_context.setter 141 def _destruction_context(self, destruction_context): 142 self._self_destruction_context = destruction_context 143 144 def _create_resource(self): 145 """A function that creates a resource handle.""" 146 raise NotImplementedError("TrackableResource._create_resource not " 147 "implemented.") 148 149 @property 150 def _resource_handle(self): 151 return self._resource_handle_value 152 153 @_resource_handle.setter 154 def _resource_handle(self, value): 155 if isinstance(value, (ops.Tensor, ops.EagerTensor)): 156 value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access 157 self._resource_handle_value = value 158 159 def _initialize(self): 160 """A function that initializes the resource. Optional.""" 161 pass 162 163 def _destroy_resource(self): 164 """A function that destroys the resource. Optional.""" 165 pass 166 167 @property 168 def resource_handle(self): 169 """Returns the resource handle associated with this Resource.""" 170 if self._resource_handle is None: 171 with ops.device(self._resource_device): 172 self._resource_handle = self._create_resource() 173 return self._resource_handle 174 175 def _map_resources(self, _): 176 """For implementing `Trackable`.""" 177 new_obj = copy.copy(self) 178 # pylint: disable=protected-access 179 with ops.device(self._resource_device): 180 new_resource = new_obj._create_resource() 181 new_obj._resource_handle = new_resource 182 # pylint: enable=protected-access 183 obj_map = {self: new_obj} 184 resource_map = {self.resource_handle: new_resource} 185 return obj_map, resource_map 186 187 def _trackable_children(self, save_type, **kwargs): 188 children = super()._trackable_children(save_type, **kwargs) 189 if save_type == "savedmodel": 190 @def_function.function(input_signature=[], autograph=False) 191 def _creator(): 192 resource = self._create_resource() 193 return resource 194 195 @def_function.function(input_signature=[], autograph=False) 196 def _initializer(): 197 self._initialize() 198 return 1 # Dummy return 199 200 @def_function.function(input_signature=[], autograph=False) 201 def _destroyer(): 202 self._destroy_resource() 203 return 1 # Dummy return 204 205 children.update({ 206 "_create_resource": _creator, 207 "_initialize": _initializer, 208 "_destroy_resource": _destroyer, 209 }) 210 return children 211 212 def __del__(self): 213 try: 214 # Outer race condition: on program exit, the destruction context may be 215 # deleted before this __del__ is called. At this point we can safely 216 # exit without calling _destroy_resource() and let Python handle things. 217 with self._destruction_context(): 218 # Inner race condition: possible between this and `ScopedTFFunction` 219 # whereby if an entire garbage collection chain containing both 220 # objects is moved to unreachable during the same garbage collection 221 # cycle, the __del__ for `ScopedTFFunction` can be collected before 222 # this method is called. In that case, we can't do much but 223 # continue. 224 self._destroy_resource() 225 except Exception: # pylint: disable=broad-except 226 # Silence all error logs that occur when attempting to destroy this 227 # resource. 228 pass 229 230 231@tf_export("saved_model.experimental.TrackableResource") 232class TrackableResource(CapturableResource): 233 """Holds a Tensor which a tf.function can capture. 234 235 A TrackableResource is most useful for stateful Tensors that require 236 initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s 237 are discovered by traversing the graph of object attributes, e.g. during 238 `tf.saved_model.save`. 239 240 A TrackableResource has three methods to override: 241 242 * `_create_resource` should create the resource tensor handle. 243 * `_initialize` should initialize the resource held at `self.resource_handle`. 244 * `_destroy_resource` is called upon a `TrackableResource`'s destruction 245 and should decrement the resource's ref count. For most resources, this 246 should be done with a call to `tf.raw_ops.DestroyResourceOp`. 247 248 Example usage: 249 250 >>> class DemoResource(tf.saved_model.experimental.TrackableResource): 251 ... def __init__(self): 252 ... super().__init__() 253 ... self._initialize() 254 ... def _create_resource(self): 255 ... return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2]) 256 ... def _initialize(self): 257 ... tf.raw_ops.AssignVariableOp( 258 ... resource=self.resource_handle, value=tf.ones([2])) 259 ... def _destroy_resource(self): 260 ... tf.raw_ops.DestroyResourceOp(resource=self.resource_handle) 261 >>> class DemoModule(tf.Module): 262 ... def __init__(self): 263 ... self.resource = DemoResource() 264 ... def increment(self, tensor): 265 ... return tensor + tf.raw_ops.ReadVariableOp( 266 ... resource=self.resource.resource_handle, dtype=tf.float32) 267 >>> demo = DemoModule() 268 >>> demo.increment([5, 1]) 269 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)> 270 """ 271 272 def __init__(self, device=""): 273 """Initialize the `TrackableResource`. 274 275 Args: 276 device: A string indicating a required placement for this resource, 277 e.g. "CPU" if this resource must be created on a CPU device. A blank 278 device allows the user to place resource creation, so generally this 279 should be blank unless the resource only makes sense on one device. 280 """ 281 global _RESOURCE_TRACKER_STACK 282 for resource_tracker in _RESOURCE_TRACKER_STACK: 283 resource_tracker.add_resource(self) 284 super().__init__(device=device) 285 286 287# TODO(b/124205571,b/124092991): Solve destruction of resources. 288class RestoredResource(TrackableResource): 289 """Restored SavedResource.""" 290 291 def __init__(self, device=""): 292 super().__init__(device=device) 293 294 @classmethod 295 def _deserialize_from_proto(cls, object_proto, dependencies, **unused_kwargs): 296 obj = cls(device=object_proto.resource.device) 297 resource_creator = dependencies.get("_create_resource") 298 if resource_creator is not None: 299 obj._create_resource = resource_creator # pylint: disable=protected-access 300 return obj 301 302 def _add_trackable_child(self, name, value): 303 setattr(self, name, value) 304 if (isinstance(value, base.Trackable) and 305 not isinstance(value, def_function.Function)): 306 self._track_trackable(value, name) 307