xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/resource.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definitions for resource-type trackable object classes."""
16
17import contextlib
18import copy
19import weakref
20
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import ops
24from tensorflow.python.trackable import base
25from tensorflow.python.util import tf_contextlib
26from tensorflow.python.util.tf_export import tf_export
27
28# global _RESOURCE_TRACKER_STACK
29_RESOURCE_TRACKER_STACK = []
30
31
32class ResourceTracker:
33  """An object that tracks a list of resources."""
34
35  __slots__ = ["_resources"]
36
37  def __init__(self):
38    self._resources = []
39
40  @property
41  def resources(self):
42    return self._resources
43
44  def add_resource(self, resource):
45    self._resources.append(resource)
46
47
48@tf_contextlib.contextmanager
49def resource_tracker_scope(resource_tracker):
50  """A context to manage resource trackers.
51
52  Use this in order to collect up all resources created within a block of code.
53  Example usage:
54
55  ```python
56  resource_tracker = ResourceTracker()
57  with resource_tracker_scope(resource_tracker):
58    resource = TrackableResource()
59
60  assert resource_tracker.resources == [resource]
61
62  Args:
63    resource_tracker: The passed in ResourceTracker object
64
65  Yields:
66    A scope in which the resource_tracker is active.
67  """
68  global _RESOURCE_TRACKER_STACK
69  old = list(_RESOURCE_TRACKER_STACK)
70  _RESOURCE_TRACKER_STACK.append(resource_tracker)
71  try:
72    yield
73  finally:
74    _RESOURCE_TRACKER_STACK = old
75
76
77def _make_getter(captured_getter, captured_previous):
78  """To avoid capturing loop variables."""
79
80  def getter(*args, **kwargs):
81    return captured_getter(captured_previous, *args, **kwargs)
82
83  return getter
84
85
86class _ResourceMetaclass(type):
87  """Metaclass for CapturableResource."""
88
89  def __call__(cls, *args, **kwargs):
90
91    def default_resource_creator(next_creator, *a, **kw):
92      assert next_creator is None
93      obj = cls.__new__(cls, *a, **kw)
94      obj.__init__(*a, **kw)
95      return obj
96
97    previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw)
98    resource_creator_stack = ops.get_default_graph()._resource_creator_stack
99    for getter in resource_creator_stack[cls._resource_type()]:
100      previous_getter = _make_getter(getter, previous_getter)
101
102    return previous_getter(*args, **kwargs)
103
104
105class CapturableResource(base.Trackable, metaclass=_ResourceMetaclass):
106  """Holds a Tensor which a tf.function can capture.
107
108  `CapturableResource`s are discovered by traversing the graph of object
109  attributes, e.g. during `tf.saved_model.save`. They are excluded from the
110  scope-based tracking of `TrackableResource`; generally things that require
111  initialization should inherit from `TrackableResource` instead of
112  `CapturableResource` directly.
113  """
114
115  def __init__(self, device=""):
116    """Initialize the `CapturableResource`.
117
118    Args:
119      device: A string indicating a required placement for this resource,
120        e.g. "CPU" if this resource must be created on a CPU device. A blank
121        device allows the user to place resource creation, so generally this
122        should be blank unless the resource only makes sense on one device.
123    """
124    self._resource_handle_value = None
125    self._resource_device = device
126    self._self_destruction_context = (
127        context.eager_mode if context.executing_eagerly()
128        else ops.get_default_graph().as_default)
129
130  @classmethod
131  def _resource_type(cls):
132    return cls.__name__
133
134  @property
135  def _destruction_context(self):
136    return getattr(self, "_self_destruction_context",
137                   # no-op context
138                   contextlib.suppress)
139
140  @_destruction_context.setter
141  def _destruction_context(self, destruction_context):
142    self._self_destruction_context = destruction_context
143
144  def _create_resource(self):
145    """A function that creates a resource handle."""
146    raise NotImplementedError("TrackableResource._create_resource not "
147                              "implemented.")
148
149  @property
150  def _resource_handle(self):
151    return self._resource_handle_value
152
153  @_resource_handle.setter
154  def _resource_handle(self, value):
155    if isinstance(value, (ops.Tensor, ops.EagerTensor)):
156      value._parent_trackable = weakref.ref(self)  # pylint: disable=protected-access
157    self._resource_handle_value = value
158
159  def _initialize(self):
160    """A function that initializes the resource. Optional."""
161    pass
162
163  def _destroy_resource(self):
164    """A function that destroys the resource. Optional."""
165    pass
166
167  @property
168  def resource_handle(self):
169    """Returns the resource handle associated with this Resource."""
170    if self._resource_handle is None:
171      with ops.device(self._resource_device):
172        self._resource_handle = self._create_resource()
173    return self._resource_handle
174
175  def _map_resources(self, _):
176    """For implementing `Trackable`."""
177    new_obj = copy.copy(self)
178    # pylint: disable=protected-access
179    with ops.device(self._resource_device):
180      new_resource = new_obj._create_resource()
181    new_obj._resource_handle = new_resource
182    # pylint: enable=protected-access
183    obj_map = {self: new_obj}
184    resource_map = {self.resource_handle: new_resource}
185    return obj_map, resource_map
186
187  def _trackable_children(self, save_type, **kwargs):
188    children = super()._trackable_children(save_type, **kwargs)
189    if save_type == "savedmodel":
190      @def_function.function(input_signature=[], autograph=False)
191      def _creator():
192        resource = self._create_resource()
193        return resource
194
195      @def_function.function(input_signature=[], autograph=False)
196      def _initializer():
197        self._initialize()
198        return 1  # Dummy return
199
200      @def_function.function(input_signature=[], autograph=False)
201      def _destroyer():
202        self._destroy_resource()
203        return 1  # Dummy return
204
205      children.update({
206          "_create_resource": _creator,
207          "_initialize": _initializer,
208          "_destroy_resource": _destroyer,
209      })
210    return children
211
212  def __del__(self):
213    try:
214      # Outer race condition: on program exit, the destruction context may be
215      # deleted before this __del__ is called. At this point we can safely
216      # exit without calling _destroy_resource() and let Python handle things.
217      with self._destruction_context():
218        # Inner race condition: possible between this and `ScopedTFFunction`
219        # whereby if an entire garbage collection chain containing both
220        # objects is moved to unreachable during the same garbage collection
221        # cycle, the __del__ for `ScopedTFFunction` can be collected before
222        # this method is called. In that case, we can't do much but
223        # continue.
224        self._destroy_resource()
225    except Exception:  # pylint: disable=broad-except
226      # Silence all error logs that occur when attempting to destroy this
227      # resource.
228      pass
229
230
231@tf_export("saved_model.experimental.TrackableResource")
232class TrackableResource(CapturableResource):
233  """Holds a Tensor which a tf.function can capture.
234
235  A TrackableResource is most useful for stateful Tensors that require
236  initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s
237  are discovered by traversing the graph of object attributes, e.g. during
238  `tf.saved_model.save`.
239
240  A TrackableResource has three methods to override:
241
242  * `_create_resource` should create the resource tensor handle.
243  * `_initialize` should initialize the resource held at `self.resource_handle`.
244  * `_destroy_resource` is called upon a `TrackableResource`'s destruction
245    and should decrement the resource's ref count. For most resources, this
246    should be done with a call to `tf.raw_ops.DestroyResourceOp`.
247
248  Example usage:
249
250  >>> class DemoResource(tf.saved_model.experimental.TrackableResource):
251  ...   def __init__(self):
252  ...     super().__init__()
253  ...     self._initialize()
254  ...   def _create_resource(self):
255  ...     return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
256  ...   def _initialize(self):
257  ...     tf.raw_ops.AssignVariableOp(
258  ...         resource=self.resource_handle, value=tf.ones([2]))
259  ...   def _destroy_resource(self):
260  ...     tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
261  >>> class DemoModule(tf.Module):
262  ...   def __init__(self):
263  ...     self.resource = DemoResource()
264  ...   def increment(self, tensor):
265  ...     return tensor + tf.raw_ops.ReadVariableOp(
266  ...         resource=self.resource.resource_handle, dtype=tf.float32)
267  >>> demo = DemoModule()
268  >>> demo.increment([5, 1])
269  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>
270  """
271
272  def __init__(self, device=""):
273    """Initialize the `TrackableResource`.
274
275    Args:
276      device: A string indicating a required placement for this resource,
277        e.g. "CPU" if this resource must be created on a CPU device. A blank
278        device allows the user to place resource creation, so generally this
279        should be blank unless the resource only makes sense on one device.
280    """
281    global _RESOURCE_TRACKER_STACK
282    for resource_tracker in _RESOURCE_TRACKER_STACK:
283      resource_tracker.add_resource(self)
284    super().__init__(device=device)
285
286
287# TODO(b/124205571,b/124092991): Solve destruction of resources.
288class RestoredResource(TrackableResource):
289  """Restored SavedResource."""
290
291  def __init__(self, device=""):
292    super().__init__(device=device)
293
294  @classmethod
295  def _deserialize_from_proto(cls, object_proto, dependencies, **unused_kwargs):
296    obj = cls(device=object_proto.resource.device)
297    resource_creator = dependencies.get("_create_resource")
298    if resource_creator is not None:
299      obj._create_resource = resource_creator  # pylint: disable=protected-access
300    return obj
301
302  def _add_trackable_child(self, name, value):
303    setattr(self, name, value)
304    if (isinstance(value, base.Trackable) and
305        not isinstance(value, def_function.Function)):
306      self._track_trackable(value, name)
307