xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/layer_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to layer/model functionality."""
16
17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
18# once __init__ files no longer require all of tf.keras to be imported together.
19
20import collections
21import functools
22import weakref
23
24from tensorflow.python.util import object_identity
25
26try:
27  # typing module is only used for comment type annotations.
28  import typing  # pylint: disable=g-import-not-at-top, unused-import
29except ImportError:
30  pass
31
32
33def is_layer(obj):
34  """Implicit check for Layer-like objects."""
35  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
36  return hasattr(obj, "_is_layer") and not isinstance(obj, type)
37
38
39def has_weights(obj):
40  """Implicit check for Layer-like objects."""
41  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
42  has_weight = (hasattr(type(obj), "trainable_weights")
43                and hasattr(type(obj), "non_trainable_weights"))
44
45  return has_weight and not isinstance(obj, type)
46
47
48def invalidate_recursive_cache(key):
49  """Convenience decorator to invalidate the cache when setting attributes."""
50  def outer(f):
51    @functools.wraps(f)
52    def wrapped(self, value):
53      sentinel = getattr(self, "_attribute_sentinel")  # type: AttributeSentinel
54      sentinel.invalidate(key)
55      return f(self, value)
56    return wrapped
57  return outer
58
59
60class MutationSentinel(object):
61  """Container for tracking whether a property is in a cached state."""
62  _in_cached_state = False
63
64  def mark_as(self, value):  # type: (MutationSentinel, bool) -> bool
65    may_affect_upstream = (value != self._in_cached_state)
66    self._in_cached_state = value
67    return may_affect_upstream
68
69  @property
70  def in_cached_state(self):
71    return self._in_cached_state
72
73
74class AttributeSentinel(object):
75  """Container for managing attribute cache state within a Layer.
76
77  The cache can be invalidated either on an individual basis (for instance when
78  an attribute is mutated) or a layer-wide basis (such as when a new dependency
79  is added).
80  """
81
82  def __init__(self, always_propagate=False):
83    self._parents = weakref.WeakSet()
84    self.attributes = collections.defaultdict(MutationSentinel)
85
86    # The trackable data structure containers are simple pass throughs. They
87    # don't know or care about particular attributes. As a result, they will
88    # consider themselves to be in a cached state, so it's up to the Layer
89    # which contains them to terminate propagation.
90    self.always_propagate = always_propagate
91
92  def __repr__(self):
93    return "{}\n  {}".format(
94        super(AttributeSentinel, self).__repr__(),
95        {k: v.in_cached_state for k, v in self.attributes.items()})
96
97  def add_parent(self, node):
98    # type: (AttributeSentinel, AttributeSentinel) -> None
99
100    # Properly tracking removal is quite challenging; however since this is only
101    # used to invalidate a cache it's alright to be overly conservative. We need
102    # to invalidate the cache of `node` (since it has implicitly gained a child)
103    # but we don't need to invalidate self since attributes should not depend on
104    # parent Layers.
105    self._parents.add(node)
106    node.invalidate_all()
107
108  def get(self, key):
109    # type: (AttributeSentinel, str) -> bool
110    return self.attributes[key].in_cached_state
111
112  def _set(self, key, value):
113    # type: (AttributeSentinel, str, bool) -> None
114    may_affect_upstream = self.attributes[key].mark_as(value)
115    if may_affect_upstream or self.always_propagate:
116      for node in self._parents:  # type: AttributeSentinel
117        node.invalidate(key)
118
119  def mark_cached(self, key):
120    # type: (AttributeSentinel, str) -> None
121    self._set(key, True)
122
123  def invalidate(self, key):
124    # type: (AttributeSentinel, str) -> None
125    self._set(key, False)
126
127  def invalidate_all(self):
128    # Parents may have different keys than their children, so we locally
129    # invalidate but use the `invalidate_all` method of parents.
130    for key in self.attributes.keys():
131      self.attributes[key].mark_as(False)
132
133    for node in self._parents:
134      node.invalidate_all()
135
136
137def filter_empty_layer_containers(layer_list):
138  """Filter out empty Layer-like containers and uniquify."""
139  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
140  existing = object_identity.ObjectIdentitySet()
141  to_visit = layer_list[::-1]
142  while to_visit:
143    obj = to_visit.pop()
144    if obj in existing:
145      continue
146    existing.add(obj)
147    if is_layer(obj):
148      yield obj
149    else:
150      sub_layers = getattr(obj, "layers", None) or []
151
152      # Trackable data structures will not show up in ".layers" lists, but
153      # the layers they contain will.
154      to_visit.extend(sub_layers[::-1])
155