xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
16
17import numpy as np
18
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.keras.utils import generic_utils
22from tensorflow.python.ops import array_ops
23from tensorflow.python.util import nest
24
25
26def slice_arrays(arrays, indices, contiguous=True):
27  """Slices batches out of provided arrays (workaround for eager tensors).
28
29  Unfortunately eager tensors don't have the same slicing behavior as
30  Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
31  hence we cannot use `generic_utils.slice_arrays` directly
32  and we have to implement this workaround based on `concat`. This has a
33  performance cost.
34
35  Args:
36    arrays: Single array or list of arrays.
37    indices: List of indices in the array that should be included in the output
38      batch.
39    contiguous: Boolean flag indicating whether the indices are contiguous.
40
41  Returns:
42    Slice of data (either single array or list of arrays).
43  """
44  converted_to_list = False
45  if not isinstance(arrays, list):
46    converted_to_list = True
47    arrays = [arrays]
48  if any(tensor_util.is_tf_type(x) for x in arrays):
49    if not contiguous:
50      entries = [[x[i:i + 1] for i in indices] for x in arrays]
51      slices = [array_ops.concat(x, axis=0) for x in entries]
52    else:
53      slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
54  else:
55    slices = generic_utils.slice_arrays(arrays, indices)
56
57  if converted_to_list:
58    slices = slices[0]
59  return slices
60
61
62def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes,
63                                  check_all_flat=False):
64  """Adds 1.0 as sample weights for the outputs for which there is no weight.
65
66  Args:
67    outputs: List of model outputs.
68    sample_weights: List of sample weight inputs.
69    sample_weight_modes: List of sample weight modes or None.
70    check_all_flat: Ensure that inputs are not nested structures. This is not
71      a free check, so we may not want to run it eagerly every iteration.
72
73  Returns:
74    Tuple of sample weights, one sample weight for every output, and booleans
75    describing the raw sample weights.
76  """
77  any_sample_weight = sample_weights is not None and any(
78      w is not None for w in sample_weights)
79  partial_sample_weight = any_sample_weight and any(
80      w is None for w in sample_weights)
81
82  if not any_sample_weight:
83    return None, any_sample_weight, partial_sample_weight
84
85  if not partial_sample_weight:
86    return sample_weights, any_sample_weight, partial_sample_weight
87
88  if check_all_flat:
89    nest.assert_same_structure(
90        list_to_tuple(sample_weights),
91        list_to_tuple(nest.flatten(sample_weights)))
92    nest.assert_same_structure(
93        list_to_tuple(outputs),
94        list_to_tuple(nest.flatten(outputs)))
95    if sample_weight_modes is not None:
96      nest.assert_same_structure(
97          sample_weight_modes, nest.flatten(sample_weight_modes))
98
99  new_sample_weights = []
100  for i, sw in enumerate(sample_weights):
101    if sw is None:
102      as_numpy = isinstance(outputs[i], np.ndarray)
103      output = outputs[i]
104      output_shape = output.shape if as_numpy else array_ops.shape(output)
105
106      is_temporal = (
107          sample_weight_modes is not None and
108          sample_weight_modes[i] == 'temporal')
109      sw_shape = (output_shape[0],
110                  output_shape[1]) if is_temporal else (output_shape[0],)
111
112      new_sample_weights.append(
113          np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape))
114
115    else:
116      new_sample_weights.append(sw)
117  return (list_to_tuple(new_sample_weights),
118          any_sample_weight, partial_sample_weight)
119
120
121class RespectCompiledTrainableState(object):
122  """Set and restore trainable state if it has changed since compile.
123
124  The keras API guarantees that the value of each Layer's `trainable` property
125  at `Model.compile` time will be used when training that model. In order to
126  respect this requirement, it may be necessary to set the trainable value of
127  layers to their compile time values before beginning a training endpoint and
128  restore the values before returing from said endpoint. This scope checks if
129  any layer's trainable state has changed since Model compile, and performs this
130  set and un-set bookkeeping.
131
132  However, the trainable state of a layer changes quite infrequently, if ever,
133  for many kinds of workflows. Moreover, updating every layer in a model is an
134  expensive operation. As a result, we will only explicitly set and unset the
135  trainable state of a model if a trainable value has changed since compile.
136  """
137
138  def __init__(self, model):
139    self._model = model
140    self._current_trainable_state = None
141    self._compiled_trainable_state = None
142    self._should_set_trainable = False
143
144  def __enter__(self):
145    self._current_trainable_state = self._model._get_trainable_state()  # pylint: disable=protected-access
146    self._compiled_trainable_state = self._model._compiled_trainable_state  # pylint: disable=protected-access
147
148    # Check to see if any layer's trainable state has changed since `compile`.
149    for layer, trainable in self._compiled_trainable_state.items():
150      if (layer in self._current_trainable_state and
151          trainable != self._current_trainable_state[layer]):
152        self._should_set_trainable = True
153        break
154
155    # If so, restore the model to its compiled state.
156    if self._should_set_trainable:
157      self._model._set_trainable_state(self._compiled_trainable_state)  # pylint: disable=protected-access
158
159  def __exit__(self, type_arg, value_arg, traceback_arg):
160    # If we set the values to their compiled state in __enter__, we need to
161    # restore the original values before leaving the scope.
162    if self._should_set_trainable:
163      self._model._set_trainable_state(self._current_trainable_state)  # pylint: disable=protected-access
164    return False  # False values do not suppress exceptions
165
166
167# Allow use of methods not exposed to the user.
168# pylint: disable=protected-access
169def get_input_shape_and_dtype(layer):
170  """Retrieves input shape and input dtype of layer if applicable.
171
172  Args:
173    layer: Layer (or model) instance.
174
175  Returns:
176    Tuple (input_shape, input_dtype). Both could be None if the layer
177      does not have a defined input shape.
178
179  Raises:
180    ValueError: in case an empty Sequential or Functional model is passed.
181  """
182
183  def _is_graph_model(layer):
184    return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or
185            layer.__class__.__name__ == 'Sequential')
186
187  # In case of nested models: recover the first layer
188  # of the deepest model to infer input shape and dtype.
189  # Subclassed Models may not have been built so can't be checked.
190  while _is_graph_model(layer):
191    if not layer.layers:
192      raise ValueError('An empty Model cannot be used as a Layer.')
193    layer = layer.layers[0]
194
195  if getattr(layer, '_batch_input_shape', None):
196    return layer._batch_input_shape, layer.dtype
197  return None, None
198
199
200# pylint: enable=protected-access
201
202
203def get_static_batch_size(layer):
204  """Gets the static batch size of a Layer.
205
206  Args:
207    layer: a `Layer` instance.
208
209  Returns:
210    The static batch size of a Layer.
211  """
212  batch_input_shape, _ = get_input_shape_and_dtype(layer)
213  if batch_input_shape is not None:
214    return tensor_shape.Dimension(batch_input_shape[0]).value
215  return None
216
217
218def list_to_tuple(maybe_list):
219  """Datasets will stack the list of tensor, so switch them to tuples."""
220  if isinstance(maybe_list, list):
221    return tuple(maybe_list)
222  return maybe_list
223