1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions shared between SavedModel saving/loading implementations.""" 16 17import itertools 18import threading 19import types 20 21from tensorflow.python.eager import context 22from tensorflow.python.keras import backend as K 23from tensorflow.python.keras.engine import base_layer_utils 24from tensorflow.python.keras.utils import control_flow_util 25from tensorflow.python.keras.utils import tf_contextlib 26from tensorflow.python.keras.utils import tf_inspect 27from tensorflow.python.keras.utils.generic_utils import LazyLoader 28from tensorflow.python.util import tf_decorator 29 30 31# pylint:disable=g-inconsistent-quotes 32training_lib = LazyLoader( 33 "training_lib", globals(), 34 "tensorflow.python.keras.engine.training") 35# pylint:enable=g-inconsistent-quotes 36 37 38def use_wrapped_call(layer, call_fn, default_training_value=None, 39 return_method=False): 40 """Creates fn that adds the losses returned by call_fn & returns the outputs. 41 42 Args: 43 layer: A Keras layer object 44 call_fn: tf.function that takes layer inputs (and possibly a training arg), 45 and returns a tuple of (outputs, list of losses). 46 default_training_value: Default value of the training kwarg. If `None`, the 47 default is `K.learning_phase()`. 48 return_method: Whether to return a method bound to the layer. 49 50 Returns: 51 function that calls call_fn and returns the outputs. Losses returned by 52 call_fn are added to the layer losses. 53 """ 54 expects_training_arg = layer_uses_training_bool(layer) 55 if hasattr(call_fn, 'original_layer_call'): # call_fn is a LayerCall object 56 original_call = call_fn.original_layer_call 57 # In Python 3, callable objects are not compatible with inspect.getargspec 58 call_fn = call_fn.__call__ 59 else: 60 original_call = call_fn 61 fn, arg_spec = maybe_add_training_arg( 62 original_call, call_fn, expects_training_arg, default_training_value) 63 64 def return_outputs_and_add_losses(*args, **kwargs): 65 """Returns the outputs from the layer call function, and adds the losses.""" 66 if return_method: 67 args = args[1:] 68 69 outputs, losses = fn(*args, **kwargs) 70 layer.add_loss(losses, inputs=True) 71 72 # TODO(kathywu): This is a temporary hack. When a network of layers is 73 # revived from SavedModel, only the top-level layer will have losses. This 74 # causes issues in eager mode because the child layers may have graph losses 75 # (thus model.losses returns a mix of Eager and graph tensors). To fix this, 76 # whenever eager losses are added to one layer, add eager losses to all 77 # child layers. This causes `.losses` to only return eager losses. 78 # pylint: disable=protected-access 79 if context.executing_eagerly(): 80 for i in layer._flatten_layers(): 81 if i is not layer: 82 i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER] 83 # pylint: enable=protected-access 84 return outputs 85 86 decorated = tf_decorator.make_decorator( 87 target=call_fn, 88 decorator_func=return_outputs_and_add_losses, 89 decorator_argspec=arg_spec) 90 91 if return_method: 92 return types.MethodType(decorated, layer) 93 else: 94 return decorated 95 96 97def layer_uses_training_bool(layer): 98 """Returns whether this layer or any of its children uses the training arg.""" 99 if layer._expects_training_arg: # pylint: disable=protected-access 100 return True 101 visited = {layer} 102 to_visit = list_all_layers(layer) 103 while to_visit: 104 layer = to_visit.pop() 105 if layer in visited: 106 continue 107 if getattr(layer, '_expects_training_arg', True): 108 return True 109 visited.add(layer) 110 to_visit.extend(list_all_layers(layer)) 111 return False 112 113 114def list_all_layers(obj): 115 if isinstance(obj, training_lib.Model): 116 # Handle special case of Sequential, which doesn't return 117 # the `Input` layer. 118 return obj.layers 119 else: 120 return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access 121 122 123def list_all_layers_and_sublayers(obj): 124 s = set([obj]) 125 s.update(itertools.chain.from_iterable( 126 list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj))) 127 return s 128 129 130def maybe_add_training_arg( 131 original_call, wrapped_call, expects_training_arg, default_training_value): 132 """Decorate call and optionally adds training argument. 133 134 If a layer expects a training argument, this function ensures that 'training' 135 is present in the layer args or kwonly args, with the default training value. 136 137 Args: 138 original_call: Original call function. 139 wrapped_call: Wrapped call function. 140 expects_training_arg: Whether to include 'training' argument. 141 default_training_value: Default value of the training kwarg to include in 142 the arg spec. If `None`, the default is `K.learning_phase()`. 143 144 Returns: 145 Tuple of ( 146 function that calls `wrapped_call` and sets the training arg, 147 Argspec of returned function or `None` if the argspec is unchanged) 148 """ 149 if not expects_training_arg: 150 return wrapped_call, None 151 def wrap_with_training_arg(*args, **kwargs): 152 """Wrap the `wrapped_call` function, and set training argument.""" 153 training_arg_index = get_training_arg_index(original_call) 154 training = get_training_arg(training_arg_index, args, kwargs) 155 if training is None: 156 training = default_training_value or K.learning_phase() 157 158 args = list(args) 159 kwargs = kwargs.copy() 160 161 def replace_training_and_call(training): 162 set_training_arg(training, training_arg_index, args, kwargs) 163 return wrapped_call(*args, **kwargs) 164 165 return control_flow_util.smart_cond( 166 training, lambda: replace_training_and_call(True), 167 lambda: replace_training_and_call(False)) 168 169 # Create arg spec for decorated function. If 'training' is not defined in the 170 # args of the original arg spec, then add it to kwonlyargs. 171 arg_spec = tf_inspect.getfullargspec(original_call) 172 defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else [] 173 174 kwonlyargs = arg_spec.kwonlyargs 175 kwonlydefaults = arg_spec.kwonlydefaults or {} 176 # Add training arg if it does not exist, or set the default training value. 177 if 'training' not in arg_spec.args: 178 kwonlyargs.append('training') 179 kwonlydefaults['training'] = default_training_value 180 else: 181 index = arg_spec.args.index('training') 182 training_default_index = len(arg_spec.args) - index 183 if (arg_spec.defaults and 184 len(arg_spec.defaults) >= training_default_index and 185 defaults[-training_default_index] is None): 186 defaults[-training_default_index] = default_training_value 187 188 decorator_argspec = tf_inspect.FullArgSpec( 189 args=arg_spec.args, 190 varargs=arg_spec.varargs, 191 varkw=arg_spec.varkw, 192 defaults=defaults, 193 kwonlyargs=kwonlyargs, 194 kwonlydefaults=kwonlydefaults, 195 annotations=arg_spec.annotations) 196 return wrap_with_training_arg, decorator_argspec 197 198 199def get_training_arg_index(call_fn): 200 """Returns the index of 'training' in the layer call function arguments. 201 202 Args: 203 call_fn: Call function. 204 205 Returns: 206 - n: index of 'training' in the call function arguments. 207 - -1: if 'training' is not found in the arguments, but layer.call accepts 208 variable keyword arguments 209 - None: if layer doesn't expect a training argument. 210 """ 211 argspec = tf_inspect.getfullargspec(call_fn) 212 if argspec.varargs: 213 # When there are variable args, training must be a keyword arg. 214 if 'training' in argspec.kwonlyargs or argspec.varkw: 215 return -1 216 return None 217 else: 218 # Try to find 'training' in the list of args or kwargs. 219 arg_list = argspec.args 220 if tf_inspect.ismethod(call_fn): 221 arg_list = arg_list[1:] 222 223 if 'training' in arg_list: 224 return arg_list.index('training') 225 elif 'training' in argspec.kwonlyargs or argspec.varkw: 226 return -1 227 return None 228 229 230def set_training_arg(training, index, args, kwargs): 231 if index is None or index < 0 or len(args) <= index: # index is invalid 232 kwargs['training'] = training 233 else: 234 args[index] = training 235 return args, kwargs 236 237 238def get_training_arg(index, args, kwargs): 239 if index is None or index < 0 or len(args) <= index: # index is invalid 240 return kwargs.get('training', None) 241 else: 242 return args[index] 243 244 245def remove_training_arg(index, args, kwargs): 246 if index is None or index < 0 or len(args) <= index: # index is invalid 247 kwargs.pop('training', None) 248 else: 249 args.pop(index) 250 251 252class SaveOptionsContext(threading.local): 253 254 def __init__(self): 255 super(SaveOptionsContext, self).__init__() 256 self.save_traces = True 257 258 259_save_options_context = SaveOptionsContext() 260 261 262@tf_contextlib.contextmanager 263def keras_option_scope(save_traces): 264 previous_value = _save_options_context.save_traces 265 try: 266 _save_options_context.save_traces = save_traces 267 yield 268 finally: 269 _save_options_context.save_traces = previous_value 270 271 272def should_save_traces(): 273 """Whether to trace layer functions-can be disabled in the save_traces arg.""" 274 return _save_options_context.save_traces 275 276 277@tf_contextlib.contextmanager 278def no_automatic_dependency_tracking_scope(obj): 279 """A context that disables automatic dependency tracking when assigning attrs. 280 281 Objects that inherit from Autotrackable automatically creates dependencies 282 to trackable objects through attribute assignments, and wraps data structures 283 (lists or dicts) with trackable classes. This scope may be used to temporarily 284 disable this behavior. This works similar to the decorator 285 `no_automatic_dependency_tracking`. 286 287 Example usage: 288 ``` 289 model = tf.keras.Model() 290 model.arr1 = [] # Creates a ListWrapper object 291 with no_automatic_dependency_tracking_scope(model): 292 model.arr2 = [] # Creates a regular, untracked python list 293 ``` 294 295 Args: 296 obj: A trackable object. 297 298 Yields: 299 a scope in which the object doesn't track dependencies. 300 """ 301 previous_value = getattr(obj, '_setattr_tracking', True) 302 obj._setattr_tracking = False # pylint: disable=protected-access 303 try: 304 yield 305 finally: 306 obj._setattr_tracking = previous_value # pylint: disable=protected-access 307