1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to layer/model functionality.""" 16 17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils 18# once __init__ files no longer require all of tf.keras to be imported together. 19 20import collections 21import functools 22import weakref 23 24from tensorflow.python.util import object_identity 25 26try: 27 # typing module is only used for comment type annotations. 28 import typing # pylint: disable=g-import-not-at-top, unused-import 29except ImportError: 30 pass 31 32 33def is_layer(obj): 34 """Implicit check for Layer-like objects.""" 35 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 36 return hasattr(obj, "_is_layer") and not isinstance(obj, type) 37 38 39def has_weights(obj): 40 """Implicit check for Layer-like objects.""" 41 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 42 has_weight = (hasattr(type(obj), "trainable_weights") 43 and hasattr(type(obj), "non_trainable_weights")) 44 45 return has_weight and not isinstance(obj, type) 46 47 48def invalidate_recursive_cache(key): 49 """Convenience decorator to invalidate the cache when setting attributes.""" 50 def outer(f): 51 @functools.wraps(f) 52 def wrapped(self, value): 53 sentinel = getattr(self, "_attribute_sentinel") # type: AttributeSentinel 54 sentinel.invalidate(key) 55 return f(self, value) 56 return wrapped 57 return outer 58 59 60class MutationSentinel(object): 61 """Container for tracking whether a property is in a cached state.""" 62 _in_cached_state = False 63 64 def mark_as(self, value): # type: (MutationSentinel, bool) -> bool 65 may_affect_upstream = (value != self._in_cached_state) 66 self._in_cached_state = value 67 return may_affect_upstream 68 69 @property 70 def in_cached_state(self): 71 return self._in_cached_state 72 73 74class AttributeSentinel(object): 75 """Container for managing attribute cache state within a Layer. 76 77 The cache can be invalidated either on an individual basis (for instance when 78 an attribute is mutated) or a layer-wide basis (such as when a new dependency 79 is added). 80 """ 81 82 def __init__(self, always_propagate=False): 83 self._parents = weakref.WeakSet() 84 self.attributes = collections.defaultdict(MutationSentinel) 85 86 # The trackable data structure containers are simple pass throughs. They 87 # don't know or care about particular attributes. As a result, they will 88 # consider themselves to be in a cached state, so it's up to the Layer 89 # which contains them to terminate propagation. 90 self.always_propagate = always_propagate 91 92 def __repr__(self): 93 return "{}\n {}".format( 94 super(AttributeSentinel, self).__repr__(), 95 {k: v.in_cached_state for k, v in self.attributes.items()}) 96 97 def add_parent(self, node): 98 # type: (AttributeSentinel, AttributeSentinel) -> None 99 100 # Properly tracking removal is quite challenging; however since this is only 101 # used to invalidate a cache it's alright to be overly conservative. We need 102 # to invalidate the cache of `node` (since it has implicitly gained a child) 103 # but we don't need to invalidate self since attributes should not depend on 104 # parent Layers. 105 self._parents.add(node) 106 node.invalidate_all() 107 108 def get(self, key): 109 # type: (AttributeSentinel, str) -> bool 110 return self.attributes[key].in_cached_state 111 112 def _set(self, key, value): 113 # type: (AttributeSentinel, str, bool) -> None 114 may_affect_upstream = self.attributes[key].mark_as(value) 115 if may_affect_upstream or self.always_propagate: 116 for node in self._parents: # type: AttributeSentinel 117 node.invalidate(key) 118 119 def mark_cached(self, key): 120 # type: (AttributeSentinel, str) -> None 121 self._set(key, True) 122 123 def invalidate(self, key): 124 # type: (AttributeSentinel, str) -> None 125 self._set(key, False) 126 127 def invalidate_all(self): 128 # Parents may have different keys than their children, so we locally 129 # invalidate but use the `invalidate_all` method of parents. 130 for key in self.attributes.keys(): 131 self.attributes[key].mark_as(False) 132 133 for node in self._parents: 134 node.invalidate_all() 135 136 137def filter_empty_layer_containers(layer_list): 138 """Filter out empty Layer-like containers and uniquify.""" 139 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 140 existing = object_identity.ObjectIdentitySet() 141 to_visit = layer_list[::-1] 142 while to_visit: 143 obj = to_visit.pop() 144 if obj in existing: 145 continue 146 existing.add(obj) 147 if is_layer(obj): 148 yield obj 149 else: 150 sub_layers = getattr(obj, "layers", None) or [] 151 152 # Trackable data structures will not show up in ".layers" lists, but 153 # the layers they contain will. 154 to_visit.extend(sub_layers[::-1]) 155