xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saved_model/utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions shared between SavedModel saving/loading implementations."""
16
17import itertools
18import threading
19import types
20
21from tensorflow.python.eager import context
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras.engine import base_layer_utils
24from tensorflow.python.keras.utils import control_flow_util
25from tensorflow.python.keras.utils import tf_contextlib
26from tensorflow.python.keras.utils import tf_inspect
27from tensorflow.python.keras.utils.generic_utils import LazyLoader
28from tensorflow.python.util import tf_decorator
29
30
31# pylint:disable=g-inconsistent-quotes
32training_lib = LazyLoader(
33    "training_lib", globals(),
34    "tensorflow.python.keras.engine.training")
35# pylint:enable=g-inconsistent-quotes
36
37
38def use_wrapped_call(layer, call_fn, default_training_value=None,
39                     return_method=False):
40  """Creates fn that adds the losses returned by call_fn & returns the outputs.
41
42  Args:
43    layer: A Keras layer object
44    call_fn: tf.function that takes layer inputs (and possibly a training arg),
45      and returns a tuple of (outputs, list of losses).
46    default_training_value: Default value of the training kwarg. If `None`, the
47      default is `K.learning_phase()`.
48    return_method: Whether to return a method bound to the layer.
49
50  Returns:
51    function that calls call_fn and returns the outputs. Losses returned by
52    call_fn are added to the layer losses.
53  """
54  expects_training_arg = layer_uses_training_bool(layer)
55  if hasattr(call_fn, 'original_layer_call'):  # call_fn is a LayerCall object
56    original_call = call_fn.original_layer_call
57    # In Python 3, callable objects are not compatible with inspect.getargspec
58    call_fn = call_fn.__call__
59  else:
60    original_call = call_fn
61  fn, arg_spec = maybe_add_training_arg(
62      original_call, call_fn, expects_training_arg, default_training_value)
63
64  def return_outputs_and_add_losses(*args, **kwargs):
65    """Returns the outputs from the layer call function, and adds the losses."""
66    if return_method:
67      args = args[1:]
68
69    outputs, losses = fn(*args, **kwargs)
70    layer.add_loss(losses, inputs=True)
71
72    # TODO(kathywu): This is a temporary hack. When a network of layers is
73    # revived from SavedModel, only the top-level layer will have losses. This
74    # causes issues in eager mode because the child layers may have graph losses
75    # (thus model.losses returns a mix of Eager and graph tensors). To fix this,
76    # whenever eager losses are added to one layer, add eager losses to all
77    # child layers. This causes `.losses` to only return eager losses.
78    # pylint: disable=protected-access
79    if context.executing_eagerly():
80      for i in layer._flatten_layers():
81        if i is not layer:
82          i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
83    # pylint: enable=protected-access
84    return outputs
85
86  decorated = tf_decorator.make_decorator(
87      target=call_fn,
88      decorator_func=return_outputs_and_add_losses,
89      decorator_argspec=arg_spec)
90
91  if return_method:
92    return types.MethodType(decorated, layer)
93  else:
94    return decorated
95
96
97def layer_uses_training_bool(layer):
98  """Returns whether this layer or any of its children uses the training arg."""
99  if layer._expects_training_arg:  # pylint: disable=protected-access
100    return True
101  visited = {layer}
102  to_visit = list_all_layers(layer)
103  while to_visit:
104    layer = to_visit.pop()
105    if layer in visited:
106      continue
107    if getattr(layer, '_expects_training_arg', True):
108      return True
109    visited.add(layer)
110    to_visit.extend(list_all_layers(layer))
111  return False
112
113
114def list_all_layers(obj):
115  if isinstance(obj, training_lib.Model):
116    # Handle special case of Sequential, which doesn't return
117    # the `Input` layer.
118    return obj.layers
119  else:
120    return list(obj._flatten_layers(include_self=False, recursive=False))  # pylint: disable=protected-access
121
122
123def list_all_layers_and_sublayers(obj):
124  s = set([obj])
125  s.update(itertools.chain.from_iterable(
126      list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj)))
127  return s
128
129
130def maybe_add_training_arg(
131    original_call, wrapped_call, expects_training_arg, default_training_value):
132  """Decorate call and optionally adds training argument.
133
134  If a layer expects a training argument, this function ensures that 'training'
135  is present in the layer args or kwonly args, with the default training value.
136
137  Args:
138    original_call: Original call function.
139    wrapped_call: Wrapped call function.
140    expects_training_arg: Whether to include 'training' argument.
141    default_training_value: Default value of the training kwarg to include in
142      the arg spec. If `None`, the default is `K.learning_phase()`.
143
144  Returns:
145    Tuple of (
146      function that calls `wrapped_call` and sets the training arg,
147      Argspec of returned function or `None` if the argspec is unchanged)
148  """
149  if not expects_training_arg:
150    return wrapped_call, None
151  def wrap_with_training_arg(*args, **kwargs):
152    """Wrap the `wrapped_call` function, and set training argument."""
153    training_arg_index = get_training_arg_index(original_call)
154    training = get_training_arg(training_arg_index, args, kwargs)
155    if training is None:
156      training = default_training_value or K.learning_phase()
157
158    args = list(args)
159    kwargs = kwargs.copy()
160
161    def replace_training_and_call(training):
162      set_training_arg(training, training_arg_index, args, kwargs)
163      return wrapped_call(*args, **kwargs)
164
165    return control_flow_util.smart_cond(
166        training, lambda: replace_training_and_call(True),
167        lambda: replace_training_and_call(False))
168
169  # Create arg spec for decorated function. If 'training' is not defined in the
170  # args of the original arg spec, then add it to kwonlyargs.
171  arg_spec = tf_inspect.getfullargspec(original_call)
172  defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
173
174  kwonlyargs = arg_spec.kwonlyargs
175  kwonlydefaults = arg_spec.kwonlydefaults or {}
176  # Add training arg if it does not exist, or set the default training value.
177  if 'training' not in arg_spec.args:
178    kwonlyargs.append('training')
179    kwonlydefaults['training'] = default_training_value
180  else:
181    index = arg_spec.args.index('training')
182    training_default_index = len(arg_spec.args) - index
183    if (arg_spec.defaults and
184        len(arg_spec.defaults) >= training_default_index and
185        defaults[-training_default_index] is None):
186      defaults[-training_default_index] = default_training_value
187
188  decorator_argspec = tf_inspect.FullArgSpec(
189      args=arg_spec.args,
190      varargs=arg_spec.varargs,
191      varkw=arg_spec.varkw,
192      defaults=defaults,
193      kwonlyargs=kwonlyargs,
194      kwonlydefaults=kwonlydefaults,
195      annotations=arg_spec.annotations)
196  return wrap_with_training_arg, decorator_argspec
197
198
199def get_training_arg_index(call_fn):
200  """Returns the index of 'training' in the layer call function arguments.
201
202  Args:
203    call_fn: Call function.
204
205  Returns:
206    - n: index of 'training' in the call function arguments.
207    - -1: if 'training' is not found in the arguments, but layer.call accepts
208          variable keyword arguments
209    - None: if layer doesn't expect a training argument.
210  """
211  argspec = tf_inspect.getfullargspec(call_fn)
212  if argspec.varargs:
213    # When there are variable args, training must be a keyword arg.
214    if 'training' in argspec.kwonlyargs or argspec.varkw:
215      return -1
216    return None
217  else:
218    # Try to find 'training' in the list of args or kwargs.
219    arg_list = argspec.args
220    if tf_inspect.ismethod(call_fn):
221      arg_list = arg_list[1:]
222
223    if 'training' in arg_list:
224      return arg_list.index('training')
225    elif 'training' in argspec.kwonlyargs or argspec.varkw:
226      return -1
227    return None
228
229
230def set_training_arg(training, index, args, kwargs):
231  if index is None or index < 0 or len(args) <= index:  # index is invalid
232    kwargs['training'] = training
233  else:
234    args[index] = training
235  return args, kwargs
236
237
238def get_training_arg(index, args, kwargs):
239  if index is None or index < 0 or len(args) <= index:  # index is invalid
240    return kwargs.get('training', None)
241  else:
242    return args[index]
243
244
245def remove_training_arg(index, args, kwargs):
246  if index is None or index < 0 or len(args) <= index:  # index is invalid
247    kwargs.pop('training', None)
248  else:
249    args.pop(index)
250
251
252class SaveOptionsContext(threading.local):
253
254  def __init__(self):
255    super(SaveOptionsContext, self).__init__()
256    self.save_traces = True
257
258
259_save_options_context = SaveOptionsContext()
260
261
262@tf_contextlib.contextmanager
263def keras_option_scope(save_traces):
264  previous_value = _save_options_context.save_traces
265  try:
266    _save_options_context.save_traces = save_traces
267    yield
268  finally:
269    _save_options_context.save_traces = previous_value
270
271
272def should_save_traces():
273  """Whether to trace layer functions-can be disabled in the save_traces arg."""
274  return _save_options_context.save_traces
275
276
277@tf_contextlib.contextmanager
278def no_automatic_dependency_tracking_scope(obj):
279  """A context that disables automatic dependency tracking when assigning attrs.
280
281  Objects that inherit from Autotrackable automatically creates dependencies
282  to trackable objects through attribute assignments, and wraps data structures
283  (lists or dicts) with trackable classes. This scope may be used to temporarily
284  disable this behavior. This works similar to the decorator
285  `no_automatic_dependency_tracking`.
286
287  Example usage:
288  ```
289  model = tf.keras.Model()
290  model.arr1 = []  # Creates a ListWrapper object
291  with no_automatic_dependency_tracking_scope(model):
292    model.arr2 = []  # Creates a regular, untracked python list
293  ```
294
295  Args:
296    obj: A trackable object.
297
298  Yields:
299    a scope in which the object doesn't track dependencies.
300  """
301  previous_value = getattr(obj, '_setattr_tracking', True)
302  obj._setattr_tracking = False  # pylint: disable=protected-access
303  try:
304    yield
305  finally:
306    obj._setattr_tracking = previous_value  # pylint: disable=protected-access
307