xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related part of the Keras engine."""
16
17import copy
18import itertools
19import json
20import os
21import warnings
22import weakref
23
24from tensorflow.python.autograph.lang import directives
25from tensorflow.python.checkpoint import checkpoint as trackable_utils
26from tensorflow.python.checkpoint import checkpoint_management
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import options as options_lib
29from tensorflow.python.distribute import collective_all_reduce_strategy
30from tensorflow.python.distribute import distribution_strategy_context as ds_context
31from tensorflow.python.distribute import values as ds_values
32from tensorflow.python.distribute.coordinator import cluster_coordinator
33from tensorflow.python.eager import backprop
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.framework import composite_tensor
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import errors_impl
39from tensorflow.python.framework import func_graph
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.keras import backend
44from tensorflow.python.keras import callbacks as callbacks_module
45from tensorflow.python.keras import optimizer_v1
46from tensorflow.python.keras import optimizers
47from tensorflow.python.keras.engine import base_layer
48from tensorflow.python.keras.engine import base_layer_utils
49from tensorflow.python.keras.engine import compile_utils
50from tensorflow.python.keras.engine import data_adapter
51from tensorflow.python.keras.engine import training_utils
52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
53from tensorflow.python.keras.mixed_precision import policy
54from tensorflow.python.keras.saving import hdf5_format
55from tensorflow.python.keras.saving import save
56from tensorflow.python.keras.saving import saving_utils
57from tensorflow.python.keras.saving.saved_model import json_utils
58from tensorflow.python.keras.saving.saved_model import model_serialization
59from tensorflow.python.keras.utils import generic_utils
60from tensorflow.python.keras.utils import layer_utils
61from tensorflow.python.keras.utils import tf_utils
62from tensorflow.python.keras.utils import version_utils
63from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
64from tensorflow.python.keras.utils.io_utils import path_to_string
65from tensorflow.python.keras.utils.mode_keys import ModeKeys
66from tensorflow.python.ops import array_ops
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import sparse_ops
69from tensorflow.python.ops import summary_ops_v2
70from tensorflow.python.ops import variables
71from tensorflow.python.platform import tf_logging as logging
72from tensorflow.python.profiler import trace
73from tensorflow.python.saved_model import constants as sm_constants
74from tensorflow.python.saved_model import loader_impl as sm_loader
75from tensorflow.python.trackable import base as trackable
76from tensorflow.python.training import py_checkpoint_reader
77from tensorflow.python.util import nest
78from tensorflow.python.util import tf_decorator
79from tensorflow.python.util.tf_export import keras_export
80from tensorflow.tools.docs import doc_controls
81
82
83# pylint: disable=g-import-not-at-top
84try:
85  import h5py
86except ImportError:
87  h5py = None
88# pylint: enable=g-import-not-at-top
89
90
91def disable_multi_worker(method):
92  """Decorator that disallows multi-worker use of `method`."""
93
94  def _method_wrapper(self, *args, **kwargs):
95    if self._in_multi_worker_mode():  # pylint: disable=protected-access
96      raise ValueError('{} is not supported in multi-worker mode.'.format(
97          method.__name__))
98    return method(self, *args, **kwargs)
99
100  return tf_decorator.make_decorator(
101      target=method, decorator_func=_method_wrapper)
102
103
104def inject_functional_model_class(cls):
105  """Inject `Functional` into the hierarchy of this class if needed."""
106  from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
107  from tensorflow.python.keras.engine import training_v1  # pylint: disable=g-import-not-at-top
108  if cls == Model or cls == training_v1.Model:
109    return functional.Functional
110  # In case there is any multiple inheritance, we stop injecting the
111  # class if keras model is not in its class hierarchy.
112  if cls == object:
113    return object
114
115  cls.__bases__ = tuple(inject_functional_model_class(base)
116                        for base in cls.__bases__)
117  # Trigger any `__new__` class swapping that needed to happen on `Functional`
118  # but did not because functional was not in the class hierarchy.
119  cls.__new__(cls)
120
121  return cls
122
123
124def is_functional_model_init_params(args, kwargs):
125  return (len(args) == 2 or
126          len(args) == 1 and 'outputs' in kwargs or
127          'inputs' in kwargs and 'outputs' in kwargs)
128
129
130@keras_export('keras.Model', 'keras.models.Model')
131class Model(base_layer.Layer, version_utils.ModelVersionSelector):
132  """`Model` groups layers into an object with training and inference features.
133
134  Args:
135      inputs: The input(s) of the model: a `keras.Input` object or list of
136          `keras.Input` objects.
137      outputs: The output(s) of the model. See Functional API example below.
138      name: String, the name of the model.
139
140  There are two ways to instantiate a `Model`:
141
142  1 - With the "Functional API", where you start from `Input`,
143  you chain layer calls to specify the model's forward pass,
144  and finally you create your model from inputs and outputs:
145
146  ```python
147  import tensorflow as tf
148
149  inputs = tf.keras.Input(shape=(3,))
150  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
151  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
152  model = tf.keras.Model(inputs=inputs, outputs=outputs)
153  ```
154
155  Note: Only dicts, lists, and tuples of input tensors are supported. Nested
156  inputs are not supported (e.g. lists of list or dicts of dict).
157
158  2 - By subclassing the `Model` class: in that case, you should define your
159  layers in `__init__` and you should implement the model's forward pass
160  in `call`.
161
162  ```python
163  import tensorflow as tf
164
165  class MyModel(tf.keras.Model):
166
167    def __init__(self):
168      super(MyModel, self).__init__()
169      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
170      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
171
172    def call(self, inputs):
173      x = self.dense1(inputs)
174      return self.dense2(x)
175
176  model = MyModel()
177  ```
178
179  If you subclass `Model`, you can optionally have
180  a `training` argument (boolean) in `call`, which you can use to specify
181  a different behavior in training and inference:
182
183  ```python
184  import tensorflow as tf
185
186  class MyModel(tf.keras.Model):
187
188    def __init__(self):
189      super(MyModel, self).__init__()
190      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
191      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
192      self.dropout = tf.keras.layers.Dropout(0.5)
193
194    def call(self, inputs, training=False):
195      x = self.dense1(inputs)
196      if training:
197        x = self.dropout(x, training=training)
198      return self.dense2(x)
199
200  model = MyModel()
201  ```
202
203  Once the model is created, you can config the model with losses and metrics
204  with `model.compile()`, train the model with `model.fit()`, or use the model
205  to do prediction with `model.predict()`.
206  """
207  _TF_MODULE_IGNORED_PROPERTIES = frozenset(
208      itertools.chain(('_train_counter', '_test_counter', '_predict_counter',
209                       '_steps_per_execution'),
210                      base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES))  # pylint: disable=protected-access
211
212  def __new__(cls, *args, **kwargs):
213    # Signature detection
214    if is_functional_model_init_params(args, kwargs) and cls == Model:
215      # Functional model
216      from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
217      return functional.Functional(skip_init=True, *args, **kwargs)
218    else:
219      return super(Model, cls).__new__(cls, *args, **kwargs)
220
221  @trackable.no_automatic_dependency_tracking
222  def __init__(self, *args, **kwargs):
223    self._is_model_for_instrumentation = True
224
225    # Special case for Subclassed Functional Model, which we couldn't detect
226    # when __new__ is called. We only realize it is a functional model when it
227    # calls super.__init__ with input and output tensor.
228    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
229    if (is_functional_model_init_params(args, kwargs) and
230        not isinstance(self, functional.Functional)):
231      # Filter the kwargs for multiple inheritance.
232      supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
233      model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
234      other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
235      inject_functional_model_class(self.__class__)
236      functional.Functional.__init__(self, *args, **model_kwargs)
237
238      # In case there is any multiple inheritance here, we need to call the
239      # __init__ for any class that appears after the Functional class.
240      clz_to_init = []
241      found_functional_class = False
242      for clz in self.__class__.__bases__:
243        if issubclass(clz, functional.Functional):
244          found_functional_class = True
245          continue
246        if found_functional_class:
247          clz_to_init.append(clz)
248
249      if clz_to_init:
250        for clz in clz_to_init:
251          clz.__init__(self, *args, **other_kwargs)
252      elif other_kwargs:
253        # In case there are unused kwargs, we should raise an error to user, in
254        # case they have a typo in the param name.
255        raise TypeError(
256            'The following keyword arguments aren\'t supported: {}'.format(
257                other_kwargs))
258      return
259
260    # The following are implemented as property functions:
261    # self.trainable_weights
262    # self.non_trainable_weights
263    # `inputs` / `outputs` will only appear in kwargs if either are misspelled.
264    generic_utils.validate_kwargs(kwargs, {
265        'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs'
266    })
267    super(Model, self).__init__(**kwargs)
268    # By default, Model is a subclass model, which is not in graph network.
269    self._is_graph_network = False
270
271    self.inputs = None
272    self.outputs = None
273    self.input_names = None
274    self.output_names = None
275    # stop_training is used by callback to stop training when error happens
276    self.stop_training = False
277    self.history = None
278    # These objects are used in the default `Model.compile`. They are not
279    # guaranteed to be set after `Model.compile` is called, as users can
280    # override compile with custom logic.
281    self.compiled_loss = None
282    self.compiled_metrics = None
283
284    # This is True for Sequential networks and Functional networks.
285    self._compute_output_and_mask_jointly = False
286
287    # Don't reset compilation if already done. This may occur if calling
288    # `__init__` (or `_init_graph_network`) on an already-compiled model
289    # such as a Sequential model. Sequential models may need to rebuild
290    # themselves after compilation.
291    self._maybe_create_attribute('_is_compiled', False)
292    self._maybe_create_attribute('optimizer', None)
293
294    # Model must be created under scope of DistStrat it will be trained with.
295    if ds_context.has_strategy():
296      self._distribution_strategy = ds_context.get_strategy()
297    else:
298      self._distribution_strategy = None
299
300    self._cluster_coordinator = None
301
302    # Defaults to value of `tf.config.experimental_functions_run_eagerly`.
303    self._run_eagerly = None
304    # Initialize cache attrs.
305    self._reset_compile_cache()
306
307    # Fault-tolerance handler. Set in `ModelCheckpoint`.
308    self._training_state = None
309    self._saved_model_inputs_spec = None
310    self._checkpoint = trackable_utils.Checkpoint(root=weakref.ref(self))
311
312    self._steps_per_execution = None
313
314    self._init_batch_counters()
315    self._base_model_initialized = True
316
317  @trackable.no_automatic_dependency_tracking
318  def _init_batch_counters(self):
319    # Untracked Variables, used to keep track of mini-batches seen in `fit`,
320    # `evaluate`, and `predict`.
321    agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
322    self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
323    self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
324    self._predict_counter = variables.Variable(
325        0, dtype='int64', aggregation=agg)
326
327  def __setattr__(self, name, value):
328    if not getattr(self, '_self_setattr_tracking', True):
329      super(Model, self).__setattr__(name, value)
330      return
331
332    if all(
333        isinstance(v, (base_layer.Layer, variables.Variable)) or
334        base_layer_utils.has_weights(v) for v in nest.flatten(value)):
335      try:
336        self._base_model_initialized
337      except AttributeError:
338        raise RuntimeError(
339            'It looks like you are subclassing `Model` and you '
340            'forgot to call `super().__init__()`.'
341            ' Always start with this line.')
342
343    super(Model, self).__setattr__(name, value)
344
345  @generic_utils.default
346  def build(self, input_shape):
347    """Builds the model based on input shapes received.
348
349    This is to be used for subclassed models, which do not know at instantiation
350    time what their inputs look like.
351
352    This method only exists for users who want to call `model.build()` in a
353    standalone way (as a substitute for calling the model on real data to
354    build it). It will never be called by the framework (and thus it will
355    never throw unexpected errors in an unrelated workflow).
356
357    Args:
358     input_shape: Single tuple, TensorShape, or list/dict of shapes, where
359         shapes are tuples, integers, or TensorShapes.
360
361    Raises:
362      ValueError:
363        1. In case of invalid user-provided data (not of type tuple,
364           list, TensorShape, or dict).
365        2. If the model requires call arguments that are agnostic
366           to the input shapes (positional or kwarg in call signature).
367        3. If not all layers were properly built.
368        4. If float type inputs are not supported within the layers.
369
370      In each of these cases, the user should build their model by calling it
371      on real tensor data.
372    """
373    if self._is_graph_network:
374      super(Model, self).build(input_shape)
375      return
376
377    if input_shape is None:
378      raise ValueError('Input shape must be defined when calling build on a '
379                       'model subclass network.')
380    valid_types = (tuple, list, tensor_shape.TensorShape, dict)
381    if not isinstance(input_shape, valid_types):
382      raise ValueError('Specified input shape is not one of the valid types. '
383                       'Please specify a batch input shape of type tuple or '
384                       'list of input shapes. User provided '
385                       'input type: {}'.format(type(input_shape)))
386
387    if input_shape and not self.inputs:
388      # We create placeholders for the `None`s in the shape and build the model
389      # in a Graph. Since tf.Variable is compatible with both eager execution
390      # and graph building, the variables created after building the model in
391      # a Graph are still valid when executing eagerly.
392      if context.executing_eagerly():
393        graph = func_graph.FuncGraph('build_graph')
394      else:
395        graph = backend.get_graph()
396      with graph.as_default():
397        if (isinstance(input_shape, list) and
398            all(d is None or isinstance(d, int) for d in input_shape)):
399          input_shape = tuple(input_shape)
400        if isinstance(input_shape, list):
401          x = [base_layer_utils.generate_placeholders_from_shape(shape)
402               for shape in input_shape]
403        elif isinstance(input_shape, dict):
404          x = {
405              k: base_layer_utils.generate_placeholders_from_shape(shape)
406              for k, shape in input_shape.items()
407          }
408        else:
409          x = base_layer_utils.generate_placeholders_from_shape(input_shape)
410
411        kwargs = {}
412        call_signature = self._call_full_argspec
413        call_args = call_signature.args
414        # Exclude `self`, `inputs`, and any argument with a default value.
415        if len(call_args) > 2:
416          if call_signature.defaults:
417            call_args = call_args[2:-len(call_signature.defaults)]
418          else:
419            call_args = call_args[2:]
420          for arg in call_args:
421            if arg == 'training':
422              # Case where `training` is a positional arg with no default.
423              kwargs['training'] = False
424            else:
425              # Has invalid call signature with unknown positional arguments.
426              raise ValueError(
427                  'Currently, you cannot build your model if it has '
428                  'positional or keyword arguments that are not '
429                  'inputs to the model, but are required for its '
430                  '`call` method. Instead, in order to instantiate '
431                  'and build your model, `call` your model on real '
432                  'tensor data with all expected call arguments.')
433        elif len(call_args) < 2:
434          # Signature without `inputs`.
435          raise ValueError('You can only call `build` on a model if its `call` '
436                           'method accepts an `inputs` argument.')
437        try:
438          self.call(x, **kwargs)
439        except (errors.InvalidArgumentError, TypeError):
440          raise ValueError('You cannot build your model by calling `build` '
441                           'if your layers do not support float type inputs. '
442                           'Instead, in order to instantiate and build your '
443                           'model, `call` your model on real tensor data (of '
444                           'the correct dtype).')
445    super(Model, self).build(input_shape)
446
447  @doc_controls.doc_in_current_and_subclasses
448  def call(self, inputs, training=None, mask=None):
449    """Calls the model on new inputs.
450
451    In this case `call` just reapplies
452    all ops in the graph to the new inputs
453    (e.g. build a new computational graph from the provided inputs).
454
455    Note: This method should not be called directly. It is only meant to be
456    overridden when subclassing `tf.keras.Model`.
457    To call a model on an input, always use the `__call__` method,
458    i.e. `model(inputs)`, which relies on the underlying `call` method.
459
460    Args:
461        inputs: Input tensor, or dict/list/tuple of input tensors.
462        training: Boolean or boolean scalar tensor, indicating whether to run
463          the `Network` in training mode or inference mode.
464        mask: A mask or list of masks. A mask can be
465            either a tensor or None (no mask).
466
467    Returns:
468        A tensor if there is a single output, or
469        a list of tensors if there are more than one outputs.
470    """
471    raise NotImplementedError('When subclassing the `Model` class, you should '
472                              'implement a `call` method.')
473
474  def compile(self,
475              optimizer='rmsprop',
476              loss=None,
477              metrics=None,
478              loss_weights=None,
479              weighted_metrics=None,
480              run_eagerly=None,
481              steps_per_execution=None,
482              **kwargs):
483    """Configures the model for training.
484
485    Args:
486        optimizer: String (name of optimizer) or optimizer instance. See
487          `tf.keras.optimizers`.
488        loss: String (name of objective function), objective function or
489          `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
490          function is any callable with the signature `loss = fn(y_true,
491          y_pred)`, where y_true = ground truth values with shape =
492          `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse
493          categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`.
494          y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It
495          returns a weighted loss float tensor. If a custom `Loss` instance is
496          used and reduction is set to `None`, return value has the shape
497          `[batch_size, d0, .. dN-1]` i.e. per-sample or per-timestep loss
498          values; otherwise, it is a scalar. If the model has multiple outputs,
499          you can use a different loss on each output by passing a dictionary
500          or a list of losses. The loss value that will be minimized by the
501          model will then be the sum of all individual losses, unless
502          `loss_weights` is specified.
503        metrics: List of metrics to be evaluated by the model during training
504          and testing. Each of this can be a string (name of a built-in
505          function), function or a `tf.keras.metrics.Metric` instance. See
506          `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A
507          function is any callable with the signature `result = fn(y_true,
508          y_pred)`. To specify different metrics for different outputs of a
509          multi-output model, you could also pass a dictionary, such as
510          `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
511          You can also pass a list to specify a metric or a list of metrics
512          for each output, such as `metrics=[['accuracy'], ['accuracy', 'mse']]`
513          or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the
514          strings 'accuracy' or 'acc', we convert this to one of
515          `tf.keras.metrics.BinaryAccuracy`,
516          `tf.keras.metrics.CategoricalAccuracy`,
517          `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss
518          function used and the model output shape. We do a similar
519          conversion for the strings 'crossentropy' and 'ce' as well.
520        loss_weights: Optional list or dictionary specifying scalar coefficients
521          (Python floats) to weight the loss contributions of different model
522          outputs. The loss value that will be minimized by the model will then
523          be the *weighted sum* of all individual losses, weighted by the
524          `loss_weights` coefficients.
525            If a list, it is expected to have a 1:1 mapping to the model's
526              outputs. If a dict, it is expected to map output names (strings)
527              to scalar coefficients.
528        weighted_metrics: List of metrics to be evaluated and weighted by
529          `sample_weight` or `class_weight` during training and testing.
530        run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s
531          logic will not be wrapped in a `tf.function`. Recommended to leave
532          this as `None` unless your `Model` cannot be run inside a
533          `tf.function`. `run_eagerly=True` is not supported when using
534          `tf.distribute.experimental.ParameterServerStrategy`.
535        steps_per_execution: Int. Defaults to 1. The number of batches to
536          run during each `tf.function` call. Running multiple batches
537          inside a single `tf.function` call can greatly improve performance
538          on TPUs or small models with a large Python overhead.
539          At most, one full epoch will be run each
540          execution. If a number larger than the size of the epoch is passed,
541          the execution will be truncated to the size of the epoch.
542          Note that if `steps_per_execution` is set to `N`,
543          `Callback.on_batch_begin` and `Callback.on_batch_end` methods
544          will only be called every `N` batches
545          (i.e. before/after each `tf.function` execution).
546        **kwargs: Arguments supported for backwards compatibility only.
547
548    Raises:
549        ValueError: In case of invalid arguments for
550            `optimizer`, `loss` or `metrics`.
551    """
552    with self.distribute_strategy.scope():
553      if 'experimental_steps_per_execution' in kwargs:
554        logging.warning('The argument `steps_per_execution` is no longer '
555                        'experimental. Pass `steps_per_execution` instead of '
556                        '`experimental_steps_per_execution`.')
557        if not steps_per_execution:
558          steps_per_execution = kwargs.pop('experimental_steps_per_execution')
559
560      # When compiling from an already-serialized model, we do not want to
561      # reapply some processing steps (e.g. metric renaming for multi-output
562      # models, which have prefixes added for each corresponding output name).
563      from_serialized = kwargs.pop('from_serialized', False)
564
565      self._validate_compile(optimizer, metrics, **kwargs)
566      self._run_eagerly = run_eagerly
567
568      self.optimizer = self._get_optimizer(optimizer)
569      self.compiled_loss = compile_utils.LossesContainer(
570          loss, loss_weights, output_names=self.output_names)
571      self.compiled_metrics = compile_utils.MetricsContainer(
572          metrics, weighted_metrics, output_names=self.output_names,
573          from_serialized=from_serialized)
574
575      self._configure_steps_per_execution(steps_per_execution or 1)
576
577      # Initializes attrs that are reset each time `compile` is called.
578      self._reset_compile_cache()
579      self._is_compiled = True
580
581      self.loss = loss or {}  # Backwards compat.
582
583  def _get_optimizer(self, optimizer):
584    """Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
585    # The deprecated PolicyV1 has a loss_scale, which we use for backwards
586    # compatibility to match TF 2.3 behavior. The new Policy does not have a
587    # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is
588    # used.
589    if isinstance(self._dtype_policy, policy.PolicyV1):
590      loss_scale = self._dtype_policy.loss_scale
591    elif self._dtype_policy.name == 'mixed_float16':
592      loss_scale = 'dynamic'
593    else:
594      loss_scale = None
595
596    def _get_single_optimizer(opt):
597      opt = optimizers.get(opt)
598      if (loss_scale is not None and
599          not isinstance(opt, lso.LossScaleOptimizer)):
600        if loss_scale == 'dynamic':
601          opt = lso.LossScaleOptimizer(opt)
602        else:
603          opt = lso.LossScaleOptimizerV1(opt, loss_scale)
604      return opt
605
606    return nest.map_structure(_get_single_optimizer, optimizer)
607
608  @trackable.no_automatic_dependency_tracking
609  def _reset_compile_cache(self):
610    self.train_function = None
611    self.test_function = None
612    self.predict_function = None
613    # Used to cache the `tf.function`'ed `train_function` to be logged in
614    # TensorBoard, since the original `train_function` is not necessarily
615    # a `tf.function` (e.g., with ParameterServerStrategy, the `train_function`
616    # is a scheduling of the actual training function to a remote worker).
617    self.train_tf_function = None
618
619    # Used to cache `trainable` attr of `Layer`s for `fit`.
620    self._compiled_trainable_state = self._get_trainable_state()
621
622  @trackable.no_automatic_dependency_tracking
623  def _configure_steps_per_execution(self, steps_per_execution):
624    self._steps_per_execution = variables.Variable(
625        steps_per_execution,
626        dtype='int64',
627        aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
628
629  @property
630  def _should_compute_mask(self):
631    return False
632
633  @property
634  def metrics(self):
635    """Returns the model's metrics added using `compile`, `add_metric` APIs.
636
637    Note: Metrics passed to `compile()` are available only after a `keras.Model`
638    has been trained/evaluated on actual data.
639
640    Examples:
641
642    >>> inputs = tf.keras.layers.Input(shape=(3,))
643    >>> outputs = tf.keras.layers.Dense(2)(inputs)
644    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
645    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
646    >>> [m.name for m in model.metrics]
647    []
648
649    >>> x = np.random.random((2, 3))
650    >>> y = np.random.randint(0, 2, (2, 2))
651    >>> model.fit(x, y)
652    >>> [m.name for m in model.metrics]
653    ['loss', 'mae']
654
655    >>> inputs = tf.keras.layers.Input(shape=(3,))
656    >>> d = tf.keras.layers.Dense(2, name='out')
657    >>> output_1 = d(inputs)
658    >>> output_2 = d(inputs)
659    >>> model = tf.keras.models.Model(
660    ...    inputs=inputs, outputs=[output_1, output_2])
661    >>> model.add_metric(
662    ...    tf.reduce_sum(output_2), name='mean', aggregation='mean')
663    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
664    >>> model.fit(x, (y, y))
665    >>> [m.name for m in model.metrics]
666    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
667    'out_1_acc', 'mean']
668
669    """
670    metrics = []
671    if self._is_compiled:
672      # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
673      # so that attr names are not load-bearing.
674      if self.compiled_loss is not None:
675        metrics += self.compiled_loss.metrics
676      if self.compiled_metrics is not None:
677        metrics += self.compiled_metrics.metrics
678
679    for l in self._flatten_layers():
680      metrics.extend(l._metrics)  # pylint: disable=protected-access
681    return metrics
682
683  @property
684  def metrics_names(self):
685    """Returns the model's display labels for all outputs.
686
687    Note: `metrics_names` are available only after a `keras.Model` has been
688    trained/evaluated on actual data.
689
690    Examples:
691
692    >>> inputs = tf.keras.layers.Input(shape=(3,))
693    >>> outputs = tf.keras.layers.Dense(2)(inputs)
694    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
695    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
696    >>> model.metrics_names
697    []
698
699    >>> x = np.random.random((2, 3))
700    >>> y = np.random.randint(0, 2, (2, 2))
701    >>> model.fit(x, y)
702    >>> model.metrics_names
703    ['loss', 'mae']
704
705    >>> inputs = tf.keras.layers.Input(shape=(3,))
706    >>> d = tf.keras.layers.Dense(2, name='out')
707    >>> output_1 = d(inputs)
708    >>> output_2 = d(inputs)
709    >>> model = tf.keras.models.Model(
710    ...    inputs=inputs, outputs=[output_1, output_2])
711    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
712    >>> model.fit(x, (y, y))
713    >>> model.metrics_names
714    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
715    'out_1_acc']
716
717    """
718
719    # This property includes all output names including `loss` and per-output
720    # losses for backward compatibility.
721    return [m.name for m in self.metrics]
722
723  @property
724  def distribute_strategy(self):
725    """The `tf.distribute.Strategy` this model was created under."""
726    return self._distribution_strategy or ds_context.get_strategy()
727
728  @property
729  def run_eagerly(self):
730    """Settable attribute indicating whether the model should run eagerly.
731
732    Running eagerly means that your model will be run step by step,
733    like Python code. Your model might run slower, but it should become easier
734    for you to debug it by stepping into individual layer calls.
735
736    By default, we will attempt to compile your model to a static graph to
737    deliver the best execution performance.
738
739    Returns:
740      Boolean, whether the model should run eagerly.
741    """
742    if self.dynamic and self._run_eagerly is False:  # pylint:disable=g-bool-id-comparison
743      # TODO(fchollet): consider using py_func to enable this.
744      raise ValueError('Your model contains layers that can only be '
745                       'successfully run in eager execution (layers '
746                       'constructed with `dynamic=True`). '
747                       'You cannot set `run_eagerly=False`.')
748
749    if self._cluster_coordinator and self._run_eagerly:
750      raise ValueError('When using `Model` with `ParameterServerStrategy`, '
751                       '`run_eagerly` is not supported.')
752
753    # Run eagerly logic, by priority:
754    # (1) Dynamic models must be run eagerly.
755    # (2) Explicitly setting run_eagerly causes a Model to be run eagerly.
756    # (3) Not explicitly setting run_eagerly defaults to TF's global setting.
757    return (self.dynamic or self._run_eagerly or
758            (def_function.functions_run_eagerly() and
759             self._run_eagerly is None))
760
761  @run_eagerly.setter
762  def run_eagerly(self, value):
763    self._run_eagerly = value
764
765  def train_step(self, data):
766    """The logic for one training step.
767
768    This method can be overridden to support custom training logic.
769    For concrete examples of how to override this method see
770    [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit).
771    This method is called by `Model.make_train_function`.
772
773    This method should contain the mathematical logic for one step of training.
774    This typically includes the forward pass, loss calculation, backpropagation,
775    and metric updates.
776
777    Configuration details for *how* this logic is run (e.g. `tf.function` and
778    `tf.distribute.Strategy` settings), should be left to
779    `Model.make_train_function`, which can also be overridden.
780
781    Args:
782      data: A nested structure of `Tensor`s.
783
784    Returns:
785      A `dict` containing values that will be passed to
786      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
787      values of the `Model`'s metrics are returned. Example:
788      `{'loss': 0.2, 'accuracy': 0.7}`.
789
790    """
791    # These are the only transformations `Model.fit` applies to user-input
792    # data when a `tf.data.Dataset` is provided.
793    data = data_adapter.expand_1d(data)
794    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
795    # Run forward pass.
796    with backprop.GradientTape() as tape:
797      y_pred = self(x, training=True)
798      loss = self.compiled_loss(
799          y, y_pred, sample_weight, regularization_losses=self.losses)
800    # Run backwards pass.
801    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
802    self.compiled_metrics.update_state(y, y_pred, sample_weight)
803    # Collect metrics to return
804    return_metrics = {}
805    for metric in self.metrics:
806      result = metric.result()
807      if isinstance(result, dict):
808        return_metrics.update(result)
809      else:
810        return_metrics[metric.name] = result
811    return return_metrics
812
813  def make_train_function(self):
814    """Creates a function that executes one step of training.
815
816    This method can be overridden to support custom training logic.
817    This method is called by `Model.fit` and `Model.train_on_batch`.
818
819    Typically, this method directly controls `tf.function` and
820    `tf.distribute.Strategy` settings, and delegates the actual training
821    logic to `Model.train_step`.
822
823    This function is cached the first time `Model.fit` or
824    `Model.train_on_batch` is called. The cache is cleared whenever
825    `Model.compile` is called.
826
827    Returns:
828      Function. The function created by this method should accept a
829      `tf.data.Iterator`, and return a `dict` containing values that will
830      be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
831      `{'loss': 0.2, 'accuracy': 0.7}`.
832    """
833    if self.train_function is not None:
834      return self.train_function
835
836    def step_function(model, iterator):
837      """Runs a single training step."""
838
839      def run_step(data):
840        outputs = model.train_step(data)
841        # Ensure counter is updated only if `train_step` succeeds.
842        with ops.control_dependencies(_minimum_control_deps(outputs)):
843          model._train_counter.assign_add(1)  # pylint: disable=protected-access
844        return outputs
845
846      data = next(iterator)
847      outputs = model.distribute_strategy.run(run_step, args=(data,))
848      outputs = reduce_per_replica(
849          outputs, self.distribute_strategy, reduction='first')
850      write_scalar_summaries(outputs, step=model._train_counter)  # pylint: disable=protected-access
851      return outputs
852
853    if self._steps_per_execution.numpy().item() == 1:
854
855      def train_function(iterator):
856        """Runs a training execution with one step."""
857        return step_function(self, iterator)
858
859    else:
860
861      def train_function(iterator):
862        """Runs a training execution with multiple steps."""
863        for _ in math_ops.range(self._steps_per_execution):
864          outputs = step_function(self, iterator)
865        return outputs
866
867    if not self.run_eagerly:
868      train_function = def_function.function(
869          train_function, experimental_relax_shapes=True)
870      self.train_tf_function = train_function
871
872    self.train_function = train_function
873
874    if self._cluster_coordinator:
875      self.train_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
876          train_function, args=(iterator,))
877
878    return self.train_function
879
880  def fit(self,
881          x=None,
882          y=None,
883          batch_size=None,
884          epochs=1,
885          verbose='auto',
886          callbacks=None,
887          validation_split=0.,
888          validation_data=None,
889          shuffle=True,
890          class_weight=None,
891          sample_weight=None,
892          initial_epoch=0,
893          steps_per_epoch=None,
894          validation_steps=None,
895          validation_batch_size=None,
896          validation_freq=1,
897          max_queue_size=10,
898          workers=1,
899          use_multiprocessing=False):
900    """Trains the model for a fixed number of epochs (iterations on a dataset).
901
902    Args:
903        x: Input data. It could be:
904          - A Numpy array (or array-like), or a list of arrays
905            (in case the model has multiple inputs).
906          - A TensorFlow tensor, or a list of tensors
907            (in case the model has multiple inputs).
908          - A dict mapping input names to the corresponding array/tensors,
909            if the model has named inputs.
910          - A `tf.data` dataset. Should return a tuple
911            of either `(inputs, targets)` or
912            `(inputs, targets, sample_weights)`.
913          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
914            or `(inputs, targets, sample_weights)`.
915          - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
916            callable that takes a single argument of type
917            `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
918            `DatasetCreator` should be used when users prefer to specify the
919            per-replica batching and sharding logic for the `Dataset`.
920            See `tf.keras.utils.experimental.DatasetCreator` doc for more
921            information.
922          A more detailed description of unpacking behavior for iterator types
923          (Dataset, generator, Sequence) is given below. If using
924          `tf.distribute.experimental.ParameterServerStrategy`, only
925          `DatasetCreator` type is supported for `x`.
926        y: Target data. Like the input data `x`,
927          it could be either Numpy array(s) or TensorFlow tensor(s).
928          It should be consistent with `x` (you cannot have Numpy inputs and
929          tensor targets, or inversely). If `x` is a dataset, generator,
930          or `keras.utils.Sequence` instance, `y` should
931          not be specified (since targets will be obtained from `x`).
932        batch_size: Integer or `None`.
933            Number of samples per gradient update.
934            If unspecified, `batch_size` will default to 32.
935            Do not specify the `batch_size` if your data is in the
936            form of datasets, generators, or `keras.utils.Sequence` instances
937            (since they generate batches).
938        epochs: Integer. Number of epochs to train the model.
939            An epoch is an iteration over the entire `x` and `y`
940            data provided.
941            Note that in conjunction with `initial_epoch`,
942            `epochs` is to be understood as "final epoch".
943            The model is not trained for a number of iterations
944            given by `epochs`, but merely until the epoch
945            of index `epochs` is reached.
946        verbose: 'auto', 0, 1, or 2. Verbosity mode.
947            0 = silent, 1 = progress bar, 2 = one line per epoch.
948            'auto' defaults to 1 for most cases, but 2 when used with
949            `ParameterServerStrategy`. Note that the progress bar is not
950            particularly useful when logged to a file, so verbose=2 is
951            recommended when not running interactively (eg, in a production
952            environment).
953        callbacks: List of `keras.callbacks.Callback` instances.
954            List of callbacks to apply during training.
955            See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
956            and `tf.keras.callbacks.History` callbacks are created automatically
957            and need not be passed into `model.fit`.
958            `tf.keras.callbacks.ProgbarLogger` is created or not based on
959            `verbose` argument to `model.fit`.
960            Callbacks with batch-level calls are currently unsupported with
961            `tf.distribute.experimental.ParameterServerStrategy`, and users are
962            advised to implement epoch-level calls instead with an appropriate
963            `steps_per_epoch` value.
964        validation_split: Float between 0 and 1.
965            Fraction of the training data to be used as validation data.
966            The model will set apart this fraction of the training data,
967            will not train on it, and will evaluate
968            the loss and any model metrics
969            on this data at the end of each epoch.
970            The validation data is selected from the last samples
971            in the `x` and `y` data provided, before shuffling. This argument is
972            not supported when `x` is a dataset, generator or
973           `keras.utils.Sequence` instance.
974            `validation_split` is not yet supported with
975            `tf.distribute.experimental.ParameterServerStrategy`.
976        validation_data: Data on which to evaluate
977            the loss and any model metrics at the end of each epoch.
978            The model will not be trained on this data. Thus, note the fact
979            that the validation loss of data provided using `validation_split`
980            or `validation_data` is not affected by regularization layers like
981            noise and dropout.
982            `validation_data` will override `validation_split`.
983            `validation_data` could be:
984              - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
985              - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
986              - A `tf.data.Dataset`.
987              - A Python generator or `keras.utils.Sequence` returning
988              `(inputs, targets)` or `(inputs, targets, sample_weights)`.
989            `validation_data` is not yet supported with
990            `tf.distribute.experimental.ParameterServerStrategy`.
991        shuffle: Boolean (whether to shuffle the training data
992            before each epoch) or str (for 'batch'). This argument is ignored
993            when `x` is a generator or an object of tf.data.Dataset.
994            'batch' is a special option for dealing
995            with the limitations of HDF5 data; it shuffles in batch-sized
996            chunks. Has no effect when `steps_per_epoch` is not `None`.
997        class_weight: Optional dictionary mapping class indices (integers)
998            to a weight (float) value, used for weighting the loss function
999            (during training only).
1000            This can be useful to tell the model to
1001            "pay more attention" to samples from
1002            an under-represented class.
1003        sample_weight: Optional Numpy array of weights for
1004            the training samples, used for weighting the loss function
1005            (during training only). You can either pass a flat (1D)
1006            Numpy array with the same length as the input samples
1007            (1:1 mapping between weights and samples),
1008            or in the case of temporal data,
1009            you can pass a 2D array with shape
1010            `(samples, sequence_length)`,
1011            to apply a different weight to every timestep of every sample. This
1012            argument is not supported when `x` is a dataset, generator, or
1013           `keras.utils.Sequence` instance, instead provide the sample_weights
1014            as the third element of `x`.
1015        initial_epoch: Integer.
1016            Epoch at which to start training
1017            (useful for resuming a previous training run).
1018        steps_per_epoch: Integer or `None`.
1019            Total number of steps (batches of samples)
1020            before declaring one epoch finished and starting the
1021            next epoch. When training with input tensors such as
1022            TensorFlow data tensors, the default `None` is equal to
1023            the number of samples in your dataset divided by
1024            the batch size, or 1 if that cannot be determined. If x is a
1025            `tf.data` dataset, and 'steps_per_epoch'
1026            is None, the epoch will run until the input dataset is exhausted.
1027            When passing an infinitely repeating dataset, you must specify the
1028            `steps_per_epoch` argument. If `steps_per_epoch=-1` the training
1029            will run indefinitely with an infinitely repeating dataset.
1030            This argument is not supported with array inputs.
1031            When using `tf.distribute.experimental.ParameterServerStrategy`:
1032              * `steps_per_epoch=None` is not supported.
1033        validation_steps: Only relevant if `validation_data` is provided and
1034            is a `tf.data` dataset. Total number of steps (batches of
1035            samples) to draw before stopping when performing validation
1036            at the end of every epoch. If 'validation_steps' is None, validation
1037            will run until the `validation_data` dataset is exhausted. In the
1038            case of an infinitely repeated dataset, it will run into an
1039            infinite loop. If 'validation_steps' is specified and only part of
1040            the dataset will be consumed, the evaluation will start from the
1041            beginning of the dataset at each epoch. This ensures that the same
1042            validation samples are used every time.
1043        validation_batch_size: Integer or `None`.
1044            Number of samples per validation batch.
1045            If unspecified, will default to `batch_size`.
1046            Do not specify the `validation_batch_size` if your data is in the
1047            form of datasets, generators, or `keras.utils.Sequence` instances
1048            (since they generate batches).
1049        validation_freq: Only relevant if validation data is provided. Integer
1050            or `collections.abc.Container` instance (e.g. list, tuple, etc.).
1051            If an integer, specifies how many training epochs to run before a
1052            new validation run is performed, e.g. `validation_freq=2` runs
1053            validation every 2 epochs. If a Container, specifies the epochs on
1054            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
1055            validation at the end of the 1st, 2nd, and 10th epochs.
1056        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1057            input only. Maximum size for the generator queue.
1058            If unspecified, `max_queue_size` will default to 10.
1059        workers: Integer. Used for generator or `keras.utils.Sequence` input
1060            only. Maximum number of processes to spin up
1061            when using process-based threading. If unspecified, `workers`
1062            will default to 1.
1063        use_multiprocessing: Boolean. Used for generator or
1064            `keras.utils.Sequence` input only. If `True`, use process-based
1065            threading. If unspecified, `use_multiprocessing` will default to
1066            `False`. Note that because this implementation relies on
1067            multiprocessing, you should not pass non-picklable arguments to
1068            the generator as they can't be passed easily to children processes.
1069
1070    Unpacking behavior for iterator-like inputs:
1071        A common pattern is to pass a tf.data.Dataset, generator, or
1072      tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
1073      yield not only features (x) but optionally targets (y) and sample weights.
1074      Keras requires that the output of such iterator-likes be unambiguous. The
1075      iterator should return a tuple of length 1, 2, or 3, where the optional
1076      second and third elements will be used for y and sample_weight
1077      respectively. Any other type provided will be wrapped in a length one
1078      tuple, effectively treating everything as 'x'. When yielding dicts, they
1079      should still adhere to the top-level tuple structure.
1080      e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
1081      features, targets, and weights from the keys of a single dict.
1082        A notable unsupported data type is the namedtuple. The reason is that
1083      it behaves like both an ordered datatype (tuple) and a mapping
1084      datatype (dict). So given a namedtuple of the form:
1085          `namedtuple("example_tuple", ["y", "x"])`
1086      it is ambiguous whether to reverse the order of the elements when
1087      interpreting the value. Even worse is a tuple of the form:
1088          `namedtuple("other_tuple", ["x", "y", "z"])`
1089      where it is unclear if the tuple was intended to be unpacked into x, y,
1090      and sample_weight or passed through as a single element to `x`. As a
1091      result the data processing code will simply raise a ValueError if it
1092      encounters a namedtuple. (Along with instructions to remedy the issue.)
1093
1094    Returns:
1095        A `History` object. Its `History.history` attribute is
1096        a record of training loss values and metrics values
1097        at successive epochs, as well as validation loss values
1098        and validation metrics values (if applicable).
1099
1100    Raises:
1101        RuntimeError: 1. If the model was never compiled or,
1102        2. If `model.fit` is  wrapped in `tf.function`.
1103
1104        ValueError: In case of mismatch between the provided input data
1105            and what the model expects or when the input data is empty.
1106    """
1107    # Legacy graph support is contained in `training_v1.Model`.
1108    version_utils.disallow_legacy_graph('Model', 'fit')
1109    self._assert_compile_was_called()
1110    self._check_call_args('fit')
1111    _disallow_inside_tf_function('fit')
1112
1113    if verbose == 'auto':
1114      if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1115        verbose = 2  # Default to epoch-level logging for PSStrategy.
1116      else:
1117        verbose = 1  # Default to batch-level logging otherwise.
1118
1119    if validation_split:
1120      # Create the validation data using the training data. Only supported for
1121      # `Tensor` and `NumPy` input.
1122      (x, y, sample_weight), validation_data = (
1123          data_adapter.train_validation_split(
1124              (x, y, sample_weight), validation_split=validation_split))
1125
1126    if validation_data:
1127      val_x, val_y, val_sample_weight = (
1128          data_adapter.unpack_x_y_sample_weight(validation_data))
1129
1130    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1131      self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
1132          self.distribute_strategy)
1133
1134    with self.distribute_strategy.scope(), \
1135         training_utils.RespectCompiledTrainableState(self):
1136      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1137      data_handler = data_adapter.get_data_handler(
1138          x=x,
1139          y=y,
1140          sample_weight=sample_weight,
1141          batch_size=batch_size,
1142          steps_per_epoch=steps_per_epoch,
1143          initial_epoch=initial_epoch,
1144          epochs=epochs,
1145          shuffle=shuffle,
1146          class_weight=class_weight,
1147          max_queue_size=max_queue_size,
1148          workers=workers,
1149          use_multiprocessing=use_multiprocessing,
1150          model=self,
1151          steps_per_execution=self._steps_per_execution)
1152
1153      # Container that configures and calls `tf.keras.Callback`s.
1154      if not isinstance(callbacks, callbacks_module.CallbackList):
1155        callbacks = callbacks_module.CallbackList(
1156            callbacks,
1157            add_history=True,
1158            add_progbar=verbose != 0,
1159            model=self,
1160            verbose=verbose,
1161            epochs=epochs,
1162            steps=data_handler.inferred_steps)
1163
1164      self.stop_training = False
1165      self.train_function = self.make_train_function()
1166      self._train_counter.assign(0)
1167      callbacks.on_train_begin()
1168      training_logs = None
1169      # Handle fault-tolerance for multi-worker.
1170      # TODO(omalleyt): Fix the ordering issues that mean this has to
1171      # happen after `callbacks.on_train_begin`.
1172      data_handler._initial_epoch = (  # pylint: disable=protected-access
1173          self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
1174      logs = None
1175      for epoch, iterator in data_handler.enumerate_epochs():
1176        self.reset_metrics()
1177        callbacks.on_epoch_begin(epoch)
1178        with data_handler.catch_stop_iteration():
1179          for step in data_handler.steps():
1180            with trace.Trace(
1181                'train',
1182                epoch_num=epoch,
1183                step_num=step,
1184                batch_size=batch_size,
1185                _r=1):
1186              callbacks.on_train_batch_begin(step)
1187              tmp_logs = self.train_function(iterator)
1188              if data_handler.should_sync:
1189                context.async_wait()
1190              logs = tmp_logs  # No error, now safe to assign to logs.
1191              end_step = step + data_handler.step_increment
1192              callbacks.on_train_batch_end(end_step, logs)
1193              if self.stop_training:
1194                break
1195
1196        logs = tf_utils.sync_to_numpy_or_python_type(logs)
1197        if logs is None:
1198          raise ValueError('Expect x to be a non-empty array or dataset.')
1199        epoch_logs = copy.copy(logs)
1200
1201        # Run validation.
1202        if validation_data and self._should_eval(epoch, validation_freq):
1203          # Create data_handler for evaluation and cache it.
1204          if getattr(self, '_eval_data_handler', None) is None:
1205            self._eval_data_handler = data_adapter.get_data_handler(
1206                x=val_x,
1207                y=val_y,
1208                sample_weight=val_sample_weight,
1209                batch_size=validation_batch_size or batch_size,
1210                steps_per_epoch=validation_steps,
1211                initial_epoch=0,
1212                epochs=1,
1213                max_queue_size=max_queue_size,
1214                workers=workers,
1215                use_multiprocessing=use_multiprocessing,
1216                model=self,
1217                steps_per_execution=self._steps_per_execution)
1218          val_logs = self.evaluate(
1219              x=val_x,
1220              y=val_y,
1221              sample_weight=val_sample_weight,
1222              batch_size=validation_batch_size or batch_size,
1223              steps=validation_steps,
1224              callbacks=callbacks,
1225              max_queue_size=max_queue_size,
1226              workers=workers,
1227              use_multiprocessing=use_multiprocessing,
1228              return_dict=True,
1229              _use_cached_eval_dataset=True)
1230          val_logs = {'val_' + name: val for name, val in val_logs.items()}
1231          epoch_logs.update(val_logs)
1232
1233        callbacks.on_epoch_end(epoch, epoch_logs)
1234        training_logs = epoch_logs
1235        if self.stop_training:
1236          break
1237
1238      # If eval data_hanlder exists, delete it after all epochs are done.
1239      if getattr(self, '_eval_data_handler', None) is not None:
1240        del self._eval_data_handler
1241      callbacks.on_train_end(logs=training_logs)
1242      return self.history
1243
1244  def test_step(self, data):
1245    """The logic for one evaluation step.
1246
1247    This method can be overridden to support custom evaluation logic.
1248    This method is called by `Model.make_test_function`.
1249
1250    This function should contain the mathematical logic for one step of
1251    evaluation.
1252    This typically includes the forward pass, loss calculation, and metrics
1253    updates.
1254
1255    Configuration details for *how* this logic is run (e.g. `tf.function` and
1256    `tf.distribute.Strategy` settings), should be left to
1257    `Model.make_test_function`, which can also be overridden.
1258
1259    Args:
1260      data: A nested structure of `Tensor`s.
1261
1262    Returns:
1263      A `dict` containing values that will be passed to
1264      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
1265      values of the `Model`'s metrics are returned.
1266    """
1267    data = data_adapter.expand_1d(data)
1268    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
1269
1270    y_pred = self(x, training=False)
1271    # Updates stateful loss metrics.
1272    self.compiled_loss(
1273        y, y_pred, sample_weight, regularization_losses=self.losses)
1274    self.compiled_metrics.update_state(y, y_pred, sample_weight)
1275    # Collect metrics to return
1276    return_metrics = {}
1277    for metric in self.metrics:
1278      result = metric.result()
1279      if isinstance(result, dict):
1280        return_metrics.update(result)
1281      else:
1282        return_metrics[metric.name] = result
1283    return return_metrics
1284
1285  def make_test_function(self):
1286    """Creates a function that executes one step of evaluation.
1287
1288    This method can be overridden to support custom evaluation logic.
1289    This method is called by `Model.evaluate` and `Model.test_on_batch`.
1290
1291    Typically, this method directly controls `tf.function` and
1292    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1293    logic to `Model.test_step`.
1294
1295    This function is cached the first time `Model.evaluate` or
1296    `Model.test_on_batch` is called. The cache is cleared whenever
1297    `Model.compile` is called.
1298
1299    Returns:
1300      Function. The function created by this method should accept a
1301      `tf.data.Iterator`, and return a `dict` containing values that will
1302      be passed to `tf.keras.Callbacks.on_test_batch_end`.
1303    """
1304    if self.test_function is not None:
1305      return self.test_function
1306
1307    def step_function(model, iterator):
1308      """Runs a single evaluation step."""
1309
1310      def run_step(data):
1311        outputs = model.test_step(data)
1312        # Ensure counter is updated only if `test_step` succeeds.
1313        with ops.control_dependencies(_minimum_control_deps(outputs)):
1314          model._test_counter.assign_add(1)  # pylint: disable=protected-access
1315        return outputs
1316
1317      data = next(iterator)
1318      outputs = model.distribute_strategy.run(run_step, args=(data,))
1319      outputs = reduce_per_replica(
1320          outputs, self.distribute_strategy, reduction='first')
1321      return outputs
1322
1323    if self._steps_per_execution.numpy().item() == 1:
1324
1325      def test_function(iterator):
1326        """Runs an evaluation execution with one step."""
1327        return step_function(self, iterator)
1328
1329    else:
1330
1331      def test_function(iterator):
1332        """Runs an evaluation execution with multiple steps."""
1333        for _ in math_ops.range(self._steps_per_execution):
1334          outputs = step_function(self, iterator)
1335        return outputs
1336
1337    if not self.run_eagerly:
1338      test_function = def_function.function(
1339          test_function, experimental_relax_shapes=True)
1340
1341    self.test_function = test_function
1342
1343    if self._cluster_coordinator:
1344      self.test_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
1345          test_function, args=(iterator,))
1346
1347    return self.test_function
1348
1349  def evaluate(self,
1350               x=None,
1351               y=None,
1352               batch_size=None,
1353               verbose=1,
1354               sample_weight=None,
1355               steps=None,
1356               callbacks=None,
1357               max_queue_size=10,
1358               workers=1,
1359               use_multiprocessing=False,
1360               return_dict=False,
1361               **kwargs):
1362    """Returns the loss value & metrics values for the model in test mode.
1363
1364    Computation is done in batches (see the `batch_size` arg.)
1365
1366    Args:
1367        x: Input data. It could be:
1368          - A Numpy array (or array-like), or a list of arrays
1369            (in case the model has multiple inputs).
1370          - A TensorFlow tensor, or a list of tensors
1371            (in case the model has multiple inputs).
1372          - A dict mapping input names to the corresponding array/tensors,
1373            if the model has named inputs.
1374          - A `tf.data` dataset. Should return a tuple
1375            of either `(inputs, targets)` or
1376            `(inputs, targets, sample_weights)`.
1377          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
1378            or `(inputs, targets, sample_weights)`.
1379          A more detailed description of unpacking behavior for iterator types
1380          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1381          for iterator-like inputs` section of `Model.fit`.
1382        y: Target data. Like the input data `x`, it could be either Numpy
1383          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1384          (you cannot have Numpy inputs and tensor targets, or inversely). If
1385          `x` is a dataset, generator or `keras.utils.Sequence` instance, `y`
1386          should not be specified (since targets will be obtained from the
1387          iterator/dataset).
1388        batch_size: Integer or `None`. Number of samples per batch of
1389          computation. If unspecified, `batch_size` will default to 32. Do not
1390          specify the `batch_size` if your data is in the form of a dataset,
1391          generators, or `keras.utils.Sequence` instances (since they generate
1392          batches).
1393        verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar.
1394        sample_weight: Optional Numpy array of weights for the test samples,
1395          used for weighting the loss function. You can either pass a flat (1D)
1396          Numpy array with the same length as the input samples
1397            (1:1 mapping between weights and samples), or in the case of
1398              temporal data, you can pass a 2D array with shape `(samples,
1399              sequence_length)`, to apply a different weight to every timestep
1400              of every sample. This argument is not supported when `x` is a
1401              dataset, instead pass sample weights as the third element of `x`.
1402        steps: Integer or `None`. Total number of steps (batches of samples)
1403          before declaring the evaluation round finished. Ignored with the
1404          default value of `None`. If x is a `tf.data` dataset and `steps` is
1405          None, 'evaluate' will run until the dataset is exhausted. This
1406          argument is not supported with array inputs.
1407        callbacks: List of `keras.callbacks.Callback` instances. List of
1408          callbacks to apply during evaluation. See
1409          [callbacks](/api_docs/python/tf/keras/callbacks).
1410        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1411          input only. Maximum size for the generator queue. If unspecified,
1412          `max_queue_size` will default to 10.
1413        workers: Integer. Used for generator or `keras.utils.Sequence` input
1414          only. Maximum number of processes to spin up when using process-based
1415          threading. If unspecified, `workers` will default to 1.
1416        use_multiprocessing: Boolean. Used for generator or
1417          `keras.utils.Sequence` input only. If `True`, use process-based
1418          threading. If unspecified, `use_multiprocessing` will default to
1419          `False`. Note that because this implementation relies on
1420          multiprocessing, you should not pass non-picklable arguments to the
1421          generator as they can't be passed easily to children processes.
1422        return_dict: If `True`, loss and metric results are returned as a dict,
1423          with each key being the name of the metric. If `False`, they are
1424          returned as a list.
1425        **kwargs: Unused at this time.
1426
1427    See the discussion of `Unpacking behavior for iterator-like inputs` for
1428    `Model.fit`.
1429
1430    `Model.evaluate` is not yet supported with
1431    `tf.distribute.experimental.ParameterServerStrategy`.
1432
1433    Returns:
1434        Scalar test loss (if the model has a single output and no metrics)
1435        or list of scalars (if the model has multiple outputs
1436        and/or metrics). The attribute `model.metrics_names` will give you
1437        the display labels for the scalar outputs.
1438
1439    Raises:
1440        RuntimeError: If `model.evaluate` is wrapped in `tf.function`.
1441        ValueError: in case of invalid arguments.
1442    """
1443    version_utils.disallow_legacy_graph('Model', 'evaluate')
1444    self._assert_compile_was_called()
1445    self._check_call_args('evaluate')
1446    _disallow_inside_tf_function('evaluate')
1447    use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False)
1448    if kwargs:
1449      raise TypeError('Invalid keyword arguments: %s' % (kwargs,))
1450
1451    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1452      self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
1453          self.distribute_strategy)
1454
1455    with self.distribute_strategy.scope():
1456      # Use cached evaluation data only when it's called in `Model.fit`
1457      if (use_cached_eval_dataset
1458          and getattr(self, '_eval_data_handler', None) is not None):
1459        data_handler = self._eval_data_handler
1460      else:
1461        # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1462        data_handler = data_adapter.get_data_handler(
1463            x=x,
1464            y=y,
1465            sample_weight=sample_weight,
1466            batch_size=batch_size,
1467            steps_per_epoch=steps,
1468            initial_epoch=0,
1469            epochs=1,
1470            max_queue_size=max_queue_size,
1471            workers=workers,
1472            use_multiprocessing=use_multiprocessing,
1473            model=self,
1474            steps_per_execution=self._steps_per_execution)
1475
1476      # Container that configures and calls `tf.keras.Callback`s.
1477      if not isinstance(callbacks, callbacks_module.CallbackList):
1478        callbacks = callbacks_module.CallbackList(
1479            callbacks,
1480            add_history=True,
1481            add_progbar=verbose != 0,
1482            model=self,
1483            verbose=verbose,
1484            epochs=1,
1485            steps=data_handler.inferred_steps)
1486
1487      logs = {}
1488      self.test_function = self.make_test_function()
1489      self._test_counter.assign(0)
1490      callbacks.on_test_begin()
1491      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1492        self.reset_metrics()
1493        with data_handler.catch_stop_iteration():
1494          for step in data_handler.steps():
1495            with trace.Trace('test', step_num=step, _r=1):
1496              callbacks.on_test_batch_begin(step)
1497              tmp_logs = self.test_function(iterator)
1498              if data_handler.should_sync:
1499                context.async_wait()
1500              logs = tmp_logs  # No error, now safe to assign to logs.
1501              end_step = step + data_handler.step_increment
1502              callbacks.on_test_batch_end(end_step, logs)
1503      logs = tf_utils.sync_to_numpy_or_python_type(logs)
1504      callbacks.on_test_end(logs=logs)
1505
1506      if return_dict:
1507        return logs
1508      else:
1509        return flatten_metrics_in_order(logs, self.metrics_names)
1510
1511  def predict_step(self, data):
1512    """The logic for one inference step.
1513
1514    This method can be overridden to support custom inference logic.
1515    This method is called by `Model.make_predict_function`.
1516
1517    This method should contain the mathematical logic for one step of inference.
1518    This typically includes the forward pass.
1519
1520    Configuration details for *how* this logic is run (e.g. `tf.function` and
1521    `tf.distribute.Strategy` settings), should be left to
1522    `Model.make_predict_function`, which can also be overridden.
1523
1524    Args:
1525      data: A nested structure of `Tensor`s.
1526
1527    Returns:
1528      The result of one inference step, typically the output of calling the
1529      `Model` on data.
1530    """
1531    data = data_adapter.expand_1d(data)
1532    x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
1533    return self(x, training=False)
1534
1535  def make_predict_function(self):
1536    """Creates a function that executes one step of inference.
1537
1538    This method can be overridden to support custom inference logic.
1539    This method is called by `Model.predict` and `Model.predict_on_batch`.
1540
1541    Typically, this method directly controls `tf.function` and
1542    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1543    logic to `Model.predict_step`.
1544
1545    This function is cached the first time `Model.predict` or
1546    `Model.predict_on_batch` is called. The cache is cleared whenever
1547    `Model.compile` is called.
1548
1549    Returns:
1550      Function. The function created by this method should accept a
1551      `tf.data.Iterator`, and return the outputs of the `Model`.
1552    """
1553    if self.predict_function is not None:
1554      return self.predict_function
1555
1556    def step_function(model, iterator):
1557      """Runs a single evaluation step."""
1558
1559      def run_step(data):
1560        outputs = model.predict_step(data)
1561        # Ensure counter is updated only if `test_step` succeeds.
1562        with ops.control_dependencies(_minimum_control_deps(outputs)):
1563          model._predict_counter.assign_add(1)  # pylint: disable=protected-access
1564        return outputs
1565
1566      data = next(iterator)
1567      outputs = model.distribute_strategy.run(run_step, args=(data,))
1568      outputs = reduce_per_replica(
1569          outputs, self.distribute_strategy, reduction='concat')
1570      return outputs
1571
1572    if (self._steps_per_execution is None or
1573        self._steps_per_execution.numpy().item() == 1):
1574
1575      def predict_function(iterator):
1576        """Runs an evaluation execution with one step."""
1577        return step_function(self, iterator)
1578
1579    else:
1580
1581      def predict_function(iterator):
1582        """Runs an evaluation execution with multiple steps."""
1583        outputs = step_function(self, iterator)
1584        for _ in math_ops.range(self._steps_per_execution - 1):
1585          directives.set_loop_options(
1586              shape_invariants=[(
1587                  t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
1588                                for t in nest.flatten(outputs)])
1589          step_outputs = step_function(self, iterator)
1590          outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
1591                                       step_outputs)
1592        return outputs
1593
1594    if not self.run_eagerly:
1595      predict_function = def_function.function(
1596          predict_function, experimental_relax_shapes=True)
1597
1598    self.predict_function = predict_function
1599    return self.predict_function
1600
1601  def predict(self,
1602              x,
1603              batch_size=None,
1604              verbose=0,
1605              steps=None,
1606              callbacks=None,
1607              max_queue_size=10,
1608              workers=1,
1609              use_multiprocessing=False):
1610    """Generates output predictions for the input samples.
1611
1612    Computation is done in batches. This method is designed for performance in
1613    large scale inputs. For small amount of inputs that fit in one batch,
1614    directly using `__call__` is recommended for faster execution, e.g.,
1615    `model(x)`, or `model(x, training=False)` if you have layers such as
1616    `tf.keras.layers.BatchNormalization` that behaves differently during
1617    inference. Also, note the fact that test loss is not affected by
1618    regularization layers like noise and dropout.
1619
1620    Args:
1621        x: Input samples. It could be:
1622          - A Numpy array (or array-like), or a list of arrays
1623            (in case the model has multiple inputs).
1624          - A TensorFlow tensor, or a list of tensors
1625            (in case the model has multiple inputs).
1626          - A `tf.data` dataset.
1627          - A generator or `keras.utils.Sequence` instance.
1628          A more detailed description of unpacking behavior for iterator types
1629          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1630          for iterator-like inputs` section of `Model.fit`.
1631        batch_size: Integer or `None`.
1632            Number of samples per batch.
1633            If unspecified, `batch_size` will default to 32.
1634            Do not specify the `batch_size` if your data is in the
1635            form of dataset, generators, or `keras.utils.Sequence` instances
1636            (since they generate batches).
1637        verbose: Verbosity mode, 0 or 1.
1638        steps: Total number of steps (batches of samples)
1639            before declaring the prediction round finished.
1640            Ignored with the default value of `None`. If x is a `tf.data`
1641            dataset and `steps` is None, `predict` will
1642            run until the input dataset is exhausted.
1643        callbacks: List of `keras.callbacks.Callback` instances.
1644            List of callbacks to apply during prediction.
1645            See [callbacks](/api_docs/python/tf/keras/callbacks).
1646        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1647            input only. Maximum size for the generator queue.
1648            If unspecified, `max_queue_size` will default to 10.
1649        workers: Integer. Used for generator or `keras.utils.Sequence` input
1650            only. Maximum number of processes to spin up when using
1651            process-based threading. If unspecified, `workers` will default
1652            to 1.
1653        use_multiprocessing: Boolean. Used for generator or
1654            `keras.utils.Sequence` input only. If `True`, use process-based
1655            threading. If unspecified, `use_multiprocessing` will default to
1656            `False`. Note that because this implementation relies on
1657            multiprocessing, you should not pass non-picklable arguments to
1658            the generator as they can't be passed easily to children processes.
1659
1660    See the discussion of `Unpacking behavior for iterator-like inputs` for
1661    `Model.fit`. Note that Model.predict uses the same interpretation rules as
1662    `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all
1663    three methods.
1664
1665    Returns:
1666        Numpy array(s) of predictions.
1667
1668    Raises:
1669        RuntimeError: If `model.predict` is wrapped in `tf.function`.
1670        ValueError: In case of mismatch between the provided
1671            input data and the model's expectations,
1672            or in case a stateful model receives a number of samples
1673            that is not a multiple of the batch size.
1674    """
1675    version_utils.disallow_legacy_graph('Model', 'predict')
1676    self._check_call_args('predict')
1677    _disallow_inside_tf_function('predict')
1678
1679    # TODO(yashkatariya): Cache model on the coordinator for faster prediction.
1680    # If running under PSS, then swap it with OneDeviceStrategy so that
1681    # execution will run on the coordinator.
1682    original_pss_strategy = None
1683    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1684      original_pss_strategy = self.distribute_strategy
1685      self._distribution_strategy = None
1686
1687    # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not
1688    # needed in `.predict()` because all the predictions happen on the
1689    # coordinator/locally.
1690    if self._cluster_coordinator:
1691      self._cluster_coordinator = None
1692
1693    outputs = None
1694    with self.distribute_strategy.scope():
1695      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1696      dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2)
1697      if (self._in_multi_worker_mode() or _is_tpu_multi_host(
1698          self.distribute_strategy)) and isinstance(x, dataset_types):
1699        try:
1700          options = options_lib.Options()
1701          data_option = options_lib.AutoShardPolicy.DATA
1702          options.experimental_distribute.auto_shard_policy = data_option
1703          x = x.with_options(options)
1704        except ValueError:
1705          warnings.warn('Using Model.predict with '
1706                        'MultiWorkerDistributionStrategy or TPUStrategy and '
1707                        'AutoShardPolicy.FILE might lead to out-of-order result'
1708                        '. Consider setting it to AutoShardPolicy.DATA.')
1709
1710      data_handler = data_adapter.get_data_handler(
1711          x=x,
1712          batch_size=batch_size,
1713          steps_per_epoch=steps,
1714          initial_epoch=0,
1715          epochs=1,
1716          max_queue_size=max_queue_size,
1717          workers=workers,
1718          use_multiprocessing=use_multiprocessing,
1719          model=self,
1720          steps_per_execution=self._steps_per_execution)
1721
1722      # Container that configures and calls `tf.keras.Callback`s.
1723      if not isinstance(callbacks, callbacks_module.CallbackList):
1724        callbacks = callbacks_module.CallbackList(
1725            callbacks,
1726            add_history=True,
1727            add_progbar=verbose != 0,
1728            model=self,
1729            verbose=verbose,
1730            epochs=1,
1731            steps=data_handler.inferred_steps)
1732
1733      self.predict_function = self.make_predict_function()
1734      self._predict_counter.assign(0)
1735      callbacks.on_predict_begin()
1736      batch_outputs = None
1737      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1738        with data_handler.catch_stop_iteration():
1739          for step in data_handler.steps():
1740            callbacks.on_predict_batch_begin(step)
1741            tmp_batch_outputs = self.predict_function(iterator)
1742            if data_handler.should_sync:
1743              context.async_wait()
1744            batch_outputs = tmp_batch_outputs  # No error, now safe to assign.
1745            if outputs is None:
1746              outputs = nest.map_structure(lambda batch_output: [batch_output],
1747                                           batch_outputs)
1748            else:
1749              nest.map_structure_up_to(
1750                  batch_outputs,
1751                  lambda output, batch_output: output.append(batch_output),
1752                  outputs, batch_outputs)
1753            end_step = step + data_handler.step_increment
1754            callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
1755      if batch_outputs is None:
1756        raise ValueError('Expect x to be a non-empty array or dataset.')
1757      callbacks.on_predict_end()
1758    all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
1759
1760    # If originally PSS strategy was used, then replace it back since predict
1761    # is running under `OneDeviceStrategy` after the swap and once its done
1762    # we need to replace it back to PSS again.
1763    if original_pss_strategy is not None:
1764      self._distribution_strategy = original_pss_strategy
1765
1766    return tf_utils.sync_to_numpy_or_python_type(all_outputs)
1767
1768  def reset_metrics(self):
1769    """Resets the state of all the metrics in the model.
1770
1771    Examples:
1772
1773    >>> inputs = tf.keras.layers.Input(shape=(3,))
1774    >>> outputs = tf.keras.layers.Dense(2)(inputs)
1775    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
1776    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
1777
1778    >>> x = np.random.random((2, 3))
1779    >>> y = np.random.randint(0, 2, (2, 2))
1780    >>> _ = model.fit(x, y, verbose=0)
1781    >>> assert all(float(m.result()) for m in model.metrics)
1782
1783    >>> model.reset_metrics()
1784    >>> assert all(float(m.result()) == 0 for m in model.metrics)
1785
1786    """
1787    for m in self.metrics:
1788      m.reset_state()
1789
1790  def train_on_batch(self,
1791                     x,
1792                     y=None,
1793                     sample_weight=None,
1794                     class_weight=None,
1795                     reset_metrics=True,
1796                     return_dict=False):
1797    """Runs a single gradient update on a single batch of data.
1798
1799    Args:
1800        x: Input data. It could be:
1801          - A Numpy array (or array-like), or a list of arrays
1802              (in case the model has multiple inputs).
1803          - A TensorFlow tensor, or a list of tensors
1804              (in case the model has multiple inputs).
1805          - A dict mapping input names to the corresponding array/tensors,
1806              if the model has named inputs.
1807        y: Target data. Like the input data `x`, it could be either Numpy
1808          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1809          (you cannot have Numpy inputs and tensor targets, or inversely).
1810        sample_weight: Optional array of the same length as x, containing
1811          weights to apply to the model's loss for each sample. In the case of
1812          temporal data, you can pass a 2D array with shape (samples,
1813          sequence_length), to apply a different weight to every timestep of
1814          every sample.
1815        class_weight: Optional dictionary mapping class indices (integers) to a
1816          weight (float) to apply to the model's loss for the samples from this
1817          class during training. This can be useful to tell the model to "pay
1818          more attention" to samples from an under-represented class.
1819        reset_metrics: If `True`, the metrics returned will be only for this
1820          batch. If `False`, the metrics will be statefully accumulated across
1821          batches.
1822        return_dict: If `True`, loss and metric results are returned as a dict,
1823          with each key being the name of the metric. If `False`, they are
1824          returned as a list.
1825
1826    Returns:
1827        Scalar training loss
1828        (if the model has a single output and no metrics)
1829        or list of scalars (if the model has multiple outputs
1830        and/or metrics). The attribute `model.metrics_names` will give you
1831        the display labels for the scalar outputs.
1832
1833    Raises:
1834      RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`.
1835      ValueError: In case of invalid user-provided arguments.
1836    """
1837    self._assert_compile_was_called()
1838    self._check_call_args('train_on_batch')
1839    _disallow_inside_tf_function('train_on_batch')
1840    with self.distribute_strategy.scope(), \
1841         training_utils.RespectCompiledTrainableState(self):
1842      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1843                                                    y, sample_weight,
1844                                                    class_weight)
1845      self.train_function = self.make_train_function()
1846      logs = self.train_function(iterator)
1847
1848    if reset_metrics:
1849      self.reset_metrics()
1850    logs = tf_utils.sync_to_numpy_or_python_type(logs)
1851    if return_dict:
1852      return logs
1853    else:
1854      return flatten_metrics_in_order(logs, self.metrics_names)
1855
1856  def test_on_batch(self,
1857                    x,
1858                    y=None,
1859                    sample_weight=None,
1860                    reset_metrics=True,
1861                    return_dict=False):
1862    """Test the model on a single batch of samples.
1863
1864    Args:
1865        x: Input data. It could be:
1866          - A Numpy array (or array-like), or a list of arrays (in case the
1867              model has multiple inputs).
1868          - A TensorFlow tensor, or a list of tensors (in case the model has
1869              multiple inputs).
1870          - A dict mapping input names to the corresponding array/tensors, if
1871              the model has named inputs.
1872        y: Target data. Like the input data `x`, it could be either Numpy
1873          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1874          (you cannot have Numpy inputs and tensor targets, or inversely).
1875        sample_weight: Optional array of the same length as x, containing
1876          weights to apply to the model's loss for each sample. In the case of
1877          temporal data, you can pass a 2D array with shape (samples,
1878          sequence_length), to apply a different weight to every timestep of
1879          every sample.
1880        reset_metrics: If `True`, the metrics returned will be only for this
1881          batch. If `False`, the metrics will be statefully accumulated across
1882          batches.
1883        return_dict: If `True`, loss and metric results are returned as a dict,
1884          with each key being the name of the metric. If `False`, they are
1885          returned as a list.
1886
1887    Returns:
1888        Scalar test loss (if the model has a single output and no metrics)
1889        or list of scalars (if the model has multiple outputs
1890        and/or metrics). The attribute `model.metrics_names` will give you
1891        the display labels for the scalar outputs.
1892
1893    Raises:
1894        RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`.
1895        ValueError: In case of invalid user-provided arguments.
1896    """
1897    self._assert_compile_was_called()
1898    self._check_call_args('test_on_batch')
1899    _disallow_inside_tf_function('test_on_batch')
1900    with self.distribute_strategy.scope():
1901      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1902                                                    y, sample_weight)
1903      self.test_function = self.make_test_function()
1904      logs = self.test_function(iterator)
1905
1906    if reset_metrics:
1907      self.reset_metrics()
1908    logs = tf_utils.sync_to_numpy_or_python_type(logs)
1909    if return_dict:
1910      return logs
1911    else:
1912      return flatten_metrics_in_order(logs, self.metrics_names)
1913
1914  def predict_on_batch(self, x):
1915    """Returns predictions for a single batch of samples.
1916
1917    Args:
1918        x: Input data. It could be:
1919          - A Numpy array (or array-like), or a list of arrays (in case the
1920              model has multiple inputs).
1921          - A TensorFlow tensor, or a list of tensors (in case the model has
1922              multiple inputs).
1923
1924    Returns:
1925        Numpy array(s) of predictions.
1926
1927    Raises:
1928        RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`.
1929        ValueError: In case of mismatch between given number of inputs and
1930          expectations of the model.
1931    """
1932    self._check_call_args('predict_on_batch')
1933    _disallow_inside_tf_function('predict_on_batch')
1934    with self.distribute_strategy.scope():
1935      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x)
1936      self.predict_function = self.make_predict_function()
1937      outputs = self.predict_function(iterator)
1938    return tf_utils.sync_to_numpy_or_python_type(outputs)
1939
1940  def fit_generator(self,
1941                    generator,
1942                    steps_per_epoch=None,
1943                    epochs=1,
1944                    verbose=1,
1945                    callbacks=None,
1946                    validation_data=None,
1947                    validation_steps=None,
1948                    validation_freq=1,
1949                    class_weight=None,
1950                    max_queue_size=10,
1951                    workers=1,
1952                    use_multiprocessing=False,
1953                    shuffle=True,
1954                    initial_epoch=0):
1955    """Fits the model on data yielded batch-by-batch by a Python generator.
1956
1957    DEPRECATED:
1958      `Model.fit` now supports generators, so there is no longer any need to use
1959      this endpoint.
1960    """
1961    warnings.warn('`Model.fit_generator` is deprecated and '
1962                  'will be removed in a future version. '
1963                  'Please use `Model.fit`, which supports generators.')
1964    return self.fit(
1965        generator,
1966        steps_per_epoch=steps_per_epoch,
1967        epochs=epochs,
1968        verbose=verbose,
1969        callbacks=callbacks,
1970        validation_data=validation_data,
1971        validation_steps=validation_steps,
1972        validation_freq=validation_freq,
1973        class_weight=class_weight,
1974        max_queue_size=max_queue_size,
1975        workers=workers,
1976        use_multiprocessing=use_multiprocessing,
1977        shuffle=shuffle,
1978        initial_epoch=initial_epoch)
1979
1980  def evaluate_generator(self,
1981                         generator,
1982                         steps=None,
1983                         callbacks=None,
1984                         max_queue_size=10,
1985                         workers=1,
1986                         use_multiprocessing=False,
1987                         verbose=0):
1988    """Evaluates the model on a data generator.
1989
1990    DEPRECATED:
1991      `Model.evaluate` now supports generators, so there is no longer any need
1992      to use this endpoint.
1993    """
1994    warnings.warn('`Model.evaluate_generator` is deprecated and '
1995                  'will be removed in a future version. '
1996                  'Please use `Model.evaluate`, which supports generators.')
1997    self._check_call_args('evaluate_generator')
1998
1999    return self.evaluate(
2000        generator,
2001        steps=steps,
2002        max_queue_size=max_queue_size,
2003        workers=workers,
2004        use_multiprocessing=use_multiprocessing,
2005        verbose=verbose,
2006        callbacks=callbacks)
2007
2008  def predict_generator(self,
2009                        generator,
2010                        steps=None,
2011                        callbacks=None,
2012                        max_queue_size=10,
2013                        workers=1,
2014                        use_multiprocessing=False,
2015                        verbose=0):
2016    """Generates predictions for the input samples from a data generator.
2017
2018    DEPRECATED:
2019      `Model.predict` now supports generators, so there is no longer any need
2020      to use this endpoint.
2021    """
2022    warnings.warn('`Model.predict_generator` is deprecated and '
2023                  'will be removed in a future version. '
2024                  'Please use `Model.predict`, which supports generators.')
2025    return self.predict(
2026        generator,
2027        steps=steps,
2028        max_queue_size=max_queue_size,
2029        workers=workers,
2030        use_multiprocessing=use_multiprocessing,
2031        verbose=verbose,
2032        callbacks=callbacks)
2033
2034  ######################################################################
2035  # Functions below are not training related. They are for model weights
2036  # tracking, save/load, serialization, etc.
2037  ######################################################################
2038
2039  @property
2040  def trainable_weights(self):
2041    self._assert_weights_created()
2042    if not self._trainable:
2043      return []
2044    trainable_variables = []
2045    for trackable_obj in self._self_tracked_trackables:
2046      trainable_variables += trackable_obj.trainable_variables
2047    trainable_variables += self._trainable_weights
2048    return self._dedup_weights(trainable_variables)
2049
2050  @property
2051  def non_trainable_weights(self):
2052    self._assert_weights_created()
2053    non_trainable_variables = []
2054    for trackable_obj in self._self_tracked_trackables:
2055      non_trainable_variables += trackable_obj.non_trainable_variables
2056
2057    if not self._trainable:
2058      # Return order is all trainable vars, then all non-trainable vars.
2059      trainable_variables = []
2060      for trackable_obj in self._self_tracked_trackables:
2061        trainable_variables += trackable_obj.trainable_variables
2062
2063      non_trainable_variables = (
2064          trainable_variables + self._trainable_weights +
2065          non_trainable_variables + self._non_trainable_weights)
2066    else:
2067      non_trainable_variables = (
2068          non_trainable_variables + self._non_trainable_weights)
2069
2070    return self._dedup_weights(non_trainable_variables)
2071
2072  def get_weights(self):
2073    """Retrieves the weights of the model.
2074
2075    Returns:
2076        A flat list of Numpy arrays.
2077    """
2078    with self.distribute_strategy.scope():
2079      return super(Model, self).get_weights()
2080
2081  def save(self,
2082           filepath,
2083           overwrite=True,
2084           include_optimizer=True,
2085           save_format=None,
2086           signatures=None,
2087           options=None,
2088           save_traces=True):
2089    # pylint: disable=line-too-long
2090    """Saves the model to Tensorflow SavedModel or a single HDF5 file.
2091
2092    Please see `tf.keras.models.save_model` or the
2093    [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
2094    for details.
2095
2096    Args:
2097        filepath: String, PathLike, path to SavedModel or H5 file to save the
2098            model.
2099        overwrite: Whether to silently overwrite any existing file at the
2100            target location, or provide the user with a manual prompt.
2101        include_optimizer: If True, save optimizer's state together.
2102        save_format: Either `'tf'` or `'h5'`, indicating whether to save the
2103            model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
2104            and 'h5' in TF 1.X.
2105        signatures: Signatures to save with the SavedModel. Applicable to the
2106            'tf' format only. Please see the `signatures` argument in
2107            `tf.saved_model.save` for details.
2108        options: (only applies to SavedModel format)
2109            `tf.saved_model.SaveOptions` object that specifies options for
2110            saving to SavedModel.
2111        save_traces: (only applies to SavedModel format) When enabled, the
2112            SavedModel will store the function traces for each layer. This
2113            can be disabled, so that only the configs of each layer are stored.
2114            Defaults to `True`. Disabling this will decrease serialization time
2115            and reduce file size, but it requires that all custom layers/models
2116            implement a `get_config()` method.
2117
2118    Example:
2119
2120    ```python
2121    from keras.models import load_model
2122
2123    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
2124    del model  # deletes the existing model
2125
2126    # returns a compiled model
2127    # identical to the previous one
2128    model = load_model('my_model.h5')
2129    ```
2130    """
2131    # pylint: enable=line-too-long
2132    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2133                    signatures, options, save_traces)
2134
2135  def save_weights(self,
2136                   filepath,
2137                   overwrite=True,
2138                   save_format=None,
2139                   options=None):
2140    """Saves all layer weights.
2141
2142    Either saves in HDF5 or in TensorFlow format based on the `save_format`
2143    argument.
2144
2145    When saving in HDF5 format, the weight file has:
2146      - `layer_names` (attribute), a list of strings
2147          (ordered names of model layers).
2148      - For every layer, a `group` named `layer.name`
2149          - For every such layer group, a group attribute `weight_names`,
2150              a list of strings
2151              (ordered names of weights tensor of the layer).
2152          - For every weight in the layer, a dataset
2153              storing the weight value, named after the weight tensor.
2154
2155    When saving in TensorFlow format, all objects referenced by the network are
2156    saved in the same format as `tf.train.Checkpoint`, including any `Layer`
2157    instances or `Optimizer` instances assigned to object attributes. For
2158    networks constructed from inputs and outputs using `tf.keras.Model(inputs,
2159    outputs)`, `Layer` instances used by the network are tracked/saved
2160    automatically. For user-defined classes which inherit from `tf.keras.Model`,
2161    `Layer` instances must be assigned to object attributes, typically in the
2162    constructor. See the documentation of `tf.train.Checkpoint` and
2163    `tf.keras.Model` for details.
2164
2165    While the formats are the same, do not mix `save_weights` and
2166    `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
2167    loaded using `Model.load_weights`. Checkpoints saved using
2168    `tf.train.Checkpoint.save` should be restored using the corresponding
2169    `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
2170    `save_weights` for training checkpoints.
2171
2172    The TensorFlow format matches objects and variables by starting at a root
2173    object, `self` for `save_weights`, and greedily matching attribute
2174    names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
2175    is the `Checkpoint` even if the `Checkpoint` has a model attached. This
2176    means saving a `tf.keras.Model` using `save_weights` and loading into a
2177    `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
2178    the `Model`'s variables. See the [guide to training
2179    checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
2180    on the TensorFlow format.
2181
2182    Args:
2183        filepath: String or PathLike, path to the file to save the weights to.
2184            When saving in TensorFlow format, this is the prefix used for
2185            checkpoint files (multiple files are generated). Note that the '.h5'
2186            suffix causes weights to be saved in HDF5 format.
2187        overwrite: Whether to silently overwrite any existing file at the
2188            target location, or provide the user with a manual prompt.
2189        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
2190            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
2191            `None` defaults to 'tf'.
2192        options: Optional `tf.train.CheckpointOptions` object that specifies
2193            options for saving weights.
2194
2195    Raises:
2196        ImportError: If h5py is not available when attempting to save in HDF5
2197            format.
2198        ValueError: For invalid/unknown format arguments.
2199    """
2200    self._assert_weights_created()
2201    filepath = path_to_string(filepath)
2202    filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
2203    if save_format is None:
2204      if filepath_is_h5:
2205        save_format = 'h5'
2206      else:
2207        save_format = 'tf'
2208    else:
2209      user_format = save_format.lower().strip()
2210      if user_format in ('tensorflow', 'tf'):
2211        save_format = 'tf'
2212      elif user_format in ('hdf5', 'h5', 'keras'):
2213        save_format = 'h5'
2214      else:
2215        raise ValueError(
2216            'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
2217                save_format,))
2218    if save_format == 'tf' and filepath_is_h5:
2219      raise ValueError(
2220          ('save_weights got save_format="tf"/"tensorflow", but the '
2221           'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
2222           'when saving in TensorFlow format.')
2223          % filepath)
2224
2225    if save_format == 'h5' and h5py is None:
2226      raise ImportError(
2227          '`save_weights` requires h5py when saving in hdf5.')
2228    if save_format == 'tf':
2229      check_filepath = filepath + '.index'
2230    else:
2231      check_filepath = filepath
2232    # If file exists and should not be overwritten:
2233    if not overwrite and os.path.isfile(check_filepath):
2234      proceed = ask_to_proceed_with_overwrite(check_filepath)
2235      if not proceed:
2236        return
2237    if save_format == 'h5':
2238      with h5py.File(filepath, 'w') as f:
2239        hdf5_format.save_weights_to_hdf5_group(f, self.layers)
2240    else:
2241      if not context.executing_eagerly():
2242        # Call `get_session` to initialize any uninitialized variables.
2243        backend.get_session()
2244      self._checkpoint.write(filepath, options=options)
2245      # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
2246      checkpoint_management.update_checkpoint_state_internal(
2247          save_dir=os.path.dirname(filepath),
2248          model_checkpoint_path=filepath,
2249          save_relative_paths=True,
2250          all_model_checkpoint_paths=[filepath])
2251
2252  def load_weights(self,
2253                   filepath,
2254                   by_name=False,
2255                   skip_mismatch=False,
2256                   options=None):
2257    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
2258
2259    If `by_name` is False weights are loaded based on the network's
2260    topology. This means the architecture should be the same as when the weights
2261    were saved.  Note that layers that don't have weights are not taken into
2262    account in the topological ordering, so adding or removing layers is fine as
2263    long as they don't have weights.
2264
2265    If `by_name` is True, weights are loaded into layers only if they share the
2266    same name. This is useful for fine-tuning or transfer-learning models where
2267    some of the layers have changed.
2268
2269    Only topological loading (`by_name=False`) is supported when loading weights
2270    from the TensorFlow format. Note that topological loading differs slightly
2271    between TensorFlow and HDF5 formats for user-defined classes inheriting from
2272    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
2273    TensorFlow format loads based on the object-local names of attributes to
2274    which layers are assigned in the `Model`'s constructor.
2275
2276    Args:
2277        filepath: String, path to the weights file to load. For weight files in
2278            TensorFlow format, this is the file prefix (the same as was passed
2279            to `save_weights`). This can also be a path to a SavedModel
2280            saved from `model.save`.
2281        by_name: Boolean, whether to load weights by name or by topological
2282            order. Only topological loading is supported for weight files in
2283            TensorFlow format.
2284        skip_mismatch: Boolean, whether to skip loading of layers where there is
2285            a mismatch in the number of weights, or a mismatch in the shape of
2286            the weight (only valid when `by_name=True`).
2287        options: Optional `tf.train.CheckpointOptions` object that specifies
2288            options for loading weights.
2289
2290    Returns:
2291        When loading a weight file in TensorFlow format, returns the same status
2292        object as `tf.train.Checkpoint.restore`. When graph building, restore
2293        ops are run automatically as soon as the network is built (on first call
2294        for user-defined classes inheriting from `Model`, immediately if it is
2295        already built).
2296
2297        When loading weights in HDF5 format, returns `None`.
2298
2299    Raises:
2300        ImportError: If h5py is not available and the weight file is in HDF5
2301            format.
2302        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
2303          `False`.
2304    """
2305    if backend.is_tpu_strategy(self._distribution_strategy):
2306      if (self._distribution_strategy.extended.steps_per_run > 1 and
2307          (not saving_utils.is_hdf5_filepath(filepath))):
2308        raise ValueError('Load weights is not yet supported with TPUStrategy '
2309                         'with steps_per_run greater than 1.')
2310    if skip_mismatch and not by_name:
2311      raise ValueError(
2312          'When calling model.load_weights, skip_mismatch can only be set to '
2313          'True when by_name is True.')
2314
2315    filepath, save_format = _detect_save_format(filepath)
2316    if save_format == 'tf':
2317      status = self._checkpoint.read(filepath, options)
2318      if by_name:
2319        raise NotImplementedError(
2320            'Weights may only be loaded based on topology into Models when '
2321            'loading TensorFlow-formatted weights (got by_name=True to '
2322            'load_weights).')
2323      if not context.executing_eagerly():
2324        session = backend.get_session()
2325        # Restore existing variables (if any) immediately, and set up a
2326        # streaming restore for any variables created in the future.
2327        trackable_utils.streaming_restore(status=status, session=session)
2328      status.assert_nontrivial_match()
2329    else:
2330      status = None
2331      if h5py is None:
2332        raise ImportError(
2333            '`load_weights` requires h5py when loading weights from HDF5.')
2334      if not self._is_graph_network and not self.built:
2335        raise ValueError(
2336            'Unable to load weights saved in HDF5 format into a subclassed '
2337            'Model which has not created its variables yet. Call the Model '
2338            'first, then load the weights.')
2339      self._assert_weights_created()
2340      with h5py.File(filepath, 'r') as f:
2341        if 'layer_names' not in f.attrs and 'model_weights' in f:
2342          f = f['model_weights']
2343        if by_name:
2344          hdf5_format.load_weights_from_hdf5_group_by_name(
2345              f, self.layers, skip_mismatch=skip_mismatch)
2346        else:
2347          hdf5_format.load_weights_from_hdf5_group(f, self.layers)
2348
2349    # Perform any layer defined finalization of the layer state.
2350    for layer in self.layers:
2351      layer.finalize_state()
2352    return status
2353
2354  def _updated_config(self):
2355    """Util shared between different serialization methods.
2356
2357    Returns:
2358        Model config with Keras version information added.
2359    """
2360    from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
2361
2362    config = self.get_config()
2363    model_config = {
2364        'class_name': self.__class__.__name__,
2365        'config': config,
2366        'keras_version': keras_version,
2367        'backend': backend.backend()
2368    }
2369    return model_config
2370
2371  def get_config(self):
2372    raise NotImplementedError
2373
2374  @classmethod
2375  def from_config(cls, config, custom_objects=None):
2376    # `from_config` assumes `cls` is either `Functional` or a child class of
2377    # `Functional`. In the case that `cls` is meant to behave like a child class
2378    # of `Functional` but only inherits from the `Model` class, we have to call
2379    # `cls(...)` instead of `Functional.from_config`.
2380    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
2381    with generic_utils.SharedObjectLoadingScope():
2382      input_tensors, output_tensors, created_layers = (
2383          functional.reconstruct_from_config(config, custom_objects))
2384      # Initialize a model belonging to `cls`, which can be user-defined or
2385      # `Functional`.
2386      model = cls(inputs=input_tensors, outputs=output_tensors,
2387                  name=config.get('name'))
2388      functional.connect_ancillary_layers(model, created_layers)
2389      return model
2390
2391  def to_json(self, **kwargs):
2392    """Returns a JSON string containing the network configuration.
2393
2394    To load a network from a JSON save file, use
2395    `keras.models.model_from_json(json_string, custom_objects={})`.
2396
2397    Args:
2398        **kwargs: Additional keyword arguments
2399            to be passed to `json.dumps()`.
2400
2401    Returns:
2402        A JSON string.
2403    """
2404    model_config = self._updated_config()
2405    return json.dumps(
2406        model_config, default=json_utils.get_json_type, **kwargs)
2407
2408  def to_yaml(self, **kwargs):
2409    """Returns a yaml string containing the network configuration.
2410
2411    Note: Since TF 2.6, this method is no longer supported and will raise a
2412    RuntimeError.
2413
2414    To load a network from a yaml save file, use
2415    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
2416
2417    `custom_objects` should be a dictionary mapping
2418    the names of custom losses / layers / etc to the corresponding
2419    functions / classes.
2420
2421    Args:
2422        **kwargs: Additional keyword arguments
2423            to be passed to `yaml.dump()`.
2424
2425    Returns:
2426        A YAML string.
2427
2428    Raises:
2429        RuntimeError: announces that the method poses a security risk
2430    """
2431    raise RuntimeError(
2432        'Method `model.to_yaml()` has been removed due to security risk of '
2433        'arbitrary code execution. Please use `model.to_json()` instead.'
2434    )
2435
2436  def reset_states(self):
2437    for layer in self.layers:
2438      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
2439        layer.reset_states()
2440
2441  @property
2442  @doc_controls.do_not_generate_docs
2443  def state_updates(self):
2444    """Deprecated, do NOT use!
2445
2446    Returns the `updates` from all layers that are stateful.
2447
2448    This is useful for separating training updates and
2449    state updates, e.g. when we need to update a layer's internal state
2450    during prediction.
2451
2452    Returns:
2453        A list of update ops.
2454    """
2455    warnings.warn('`Model.state_updates` will be removed in a future version. '
2456                  'This property should not be used in TensorFlow 2.0, '
2457                  'as `updates` are applied automatically.')
2458    state_updates = []
2459    for layer in self.layers:
2460      if getattr(layer, 'stateful', False):
2461        if hasattr(layer, 'updates'):
2462          state_updates += layer.updates
2463    return state_updates
2464
2465  @property
2466  def weights(self):
2467    """Returns the list of all layer variables/weights.
2468
2469    Note: This will not track the weights of nested `tf.Modules` that are not
2470    themselves Keras layers.
2471
2472    Returns:
2473      A list of variables.
2474    """
2475    return self._dedup_weights(self._undeduplicated_weights)
2476
2477  @property
2478  def _undeduplicated_weights(self):
2479    """Returns the undeduplicated list of all layer variables/weights."""
2480    self._assert_weights_created()
2481    weights = []
2482    for layer in self._self_tracked_trackables:
2483      weights += layer.variables
2484    weights += (self._trainable_weights + self._non_trainable_weights)
2485    return weights
2486
2487  def summary(self, line_length=None, positions=None, print_fn=None):
2488    """Prints a string summary of the network.
2489
2490    Args:
2491        line_length: Total length of printed lines
2492            (e.g. set this to adapt the display to different
2493            terminal window sizes).
2494        positions: Relative or absolute positions of log elements
2495            in each line. If not provided,
2496            defaults to `[.33, .55, .67, 1.]`.
2497        print_fn: Print function to use. Defaults to `print`.
2498            It will be called on each line of the summary.
2499            You can set it to a custom function
2500            in order to capture the string summary.
2501
2502    Raises:
2503        ValueError: if `summary()` is called before the model is built.
2504    """
2505    if not self.built:
2506      raise ValueError('This model has not yet been built. '
2507                       'Build the model first by calling `build()` or calling '
2508                       '`fit()` with some data, or specify '
2509                       'an `input_shape` argument in the first layer(s) for '
2510                       'automatic build.')
2511    layer_utils.print_summary(self,
2512                              line_length=line_length,
2513                              positions=positions,
2514                              print_fn=print_fn)
2515
2516  @property
2517  def layers(self):
2518    return list(self._flatten_layers(include_self=False, recursive=False))
2519
2520  def get_layer(self, name=None, index=None):
2521    """Retrieves a layer based on either its name (unique) or index.
2522
2523    If `name` and `index` are both provided, `index` will take precedence.
2524    Indices are based on order of horizontal graph traversal (bottom-up).
2525
2526    Args:
2527        name: String, name of layer.
2528        index: Integer, index of layer.
2529
2530    Returns:
2531        A layer instance.
2532
2533    Raises:
2534        ValueError: In case of invalid layer name or index.
2535    """
2536    # TODO(fchollet): We could build a dictionary based on layer names
2537    # since they are constant, but we have not done that yet.
2538    if index is not None and name is not None:
2539      raise ValueError('Provide only a layer name or a layer index.')
2540
2541    if index is not None:
2542      if len(self.layers) <= index:
2543        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
2544                         ' but model only has ' + str(len(self.layers)) +
2545                         ' layers.')
2546      else:
2547        return self.layers[index]
2548
2549    if name is not None:
2550      for layer in self.layers:
2551        if layer.name == name:
2552          return layer
2553      raise ValueError('No such layer: ' + name + '.')
2554    raise ValueError('Provide either a layer name or layer index.')
2555
2556  @trackable.no_automatic_dependency_tracking
2557  def _set_save_spec(self, inputs):
2558    if self._saved_model_inputs_spec is not None:
2559      return  # Already set.
2560
2561    input_names = self.input_names
2562    if not input_names:
2563      input_names = compile_utils.create_pseudo_input_names(inputs)
2564
2565    flat_inputs = nest.flatten(inputs)
2566    specs = []
2567    for name, tensor in zip(input_names, flat_inputs):
2568      specs.append(
2569          tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
2570    specs = nest.pack_sequence_as(inputs, specs)
2571
2572    self._saved_model_inputs_spec = specs
2573
2574    # Store the input shapes
2575    if (self.__class__.__name__ == 'Sequential' and
2576        self._build_input_shape is None):
2577      self._build_input_shape = nest.map_structure(
2578          lambda x: None if x is None else x.shape, specs)
2579
2580  def _assert_weights_created(self):
2581    """Asserts that all the weights for the model have been created.
2582
2583    For a non-dynamic model, the weights must already be created after the
2584    layer has been called. For a dynamic model, the exact list of weights can
2585    never be known for certain since it may change at any time during execution.
2586
2587    We run this check right before accessing weights or getting the Numpy value
2588    for the current weights. Otherwise, if the layer has never been called,
2589    the user would just get an empty list, which is misleading.
2590
2591    Raises:
2592      ValueError: if the weights of the network has not yet been created.
2593    """
2594    if self.dynamic:
2595      return
2596
2597    if ('build' in self.__class__.__dict__ and
2598        self.__class__ != Model and
2599        not self.built):
2600      # For any model that has customized build() method but hasn't
2601      # been invoked yet, this will cover both sequential and subclass model.
2602      # Also make sure to exclude Model class itself which has build() defined.
2603      raise ValueError('Weights for model %s have not yet been created. '
2604                       'Weights are created when the Model is first called on '
2605                       'inputs or `build()` is called with an `input_shape`.' %
2606                       self.name)
2607
2608  def _check_call_args(self, method_name):
2609    """Check that `call` has only one positional arg."""
2610    # Always allow first arg, regardless of arg name.
2611    fullargspec = self._call_full_argspec
2612    if fullargspec.defaults:
2613      positional_args = fullargspec.args[:-len(fullargspec.defaults)]
2614    else:
2615      positional_args = fullargspec.args
2616    if 'training' in positional_args:
2617      positional_args.remove('training')
2618
2619    # self and first arg can be positional.
2620    if len(positional_args) > 2:
2621      extra_args = positional_args[2:]
2622      raise ValueError(
2623          'Models passed to `' + method_name + '` can only have `training` '
2624          'and the first argument in `call` as positional arguments, '
2625          'found: ' + str(extra_args) + '.')
2626
2627  def _validate_compile(self, optimizer, metrics, **kwargs):
2628    """Performs validation checks for the default `compile`."""
2629    if any(
2630        isinstance(opt, optimizer_v1.Optimizer)
2631        for opt in nest.flatten(optimizer)):
2632      raise ValueError(
2633          '`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
2634          'not supported when eager execution is enabled. Use a '
2635          '`tf.keras` Optimizer instead, or disable eager '
2636          'execution.')
2637
2638    kwargs.pop('cloning', None)  # Legacy DistStrat argument, never used.
2639    kwargs.pop('experimental_run_tf_function', None)  # Always `True`.
2640    if kwargs.pop('distribute', None) is not None:
2641      raise ValueError(
2642          'Distribute argument in compile is not available in TF 2.0 please '
2643          'create the model under the distribution strategy scope.')
2644    if kwargs.pop('target_tensors', None) is not None:
2645      raise ValueError(
2646          'target_tensors argument is not supported when executing eagerly.')
2647    invalid_kwargs = set(kwargs) - {'sample_weight_mode'}
2648    if invalid_kwargs:
2649      raise TypeError('Invalid keyword argument(s) in `compile`: %s' %
2650                      (invalid_kwargs,))
2651
2652    # Model must be created and compiled with the same DistStrat.
2653    if self.built and ds_context.has_strategy():
2654      strategy = ds_context.get_strategy()
2655      for v in self.variables:
2656        if not strategy.extended.variable_created_in_scope(v):
2657          raise ValueError(
2658              'Variable (%s) was not created in the distribution strategy '
2659              'scope of (%s). It is most likely due to not all layers or '
2660              'the model or optimizer being created outside the distribution '
2661              'strategy scope. Try to make sure your code looks similar '
2662              'to the following.\n'
2663              'with strategy.scope():\n'
2664              '  model=_create_model()\n'
2665              '  model.compile(...)' % (v, strategy))
2666
2667    # Model metrics must be created in the same distribution strategy scope
2668    # as the model.
2669    strategy = self.distribute_strategy
2670    for metric in nest.flatten(metrics):
2671      for v in getattr(metric, 'variables', []):
2672        if not strategy.extended.variable_created_in_scope(v):
2673          raise ValueError(
2674              'Metric (%s) passed to model.compile was created inside of a '
2675              'different distribution strategy scope than the model. All '
2676              'metrics must be created in the same distribution strategy '
2677              'scope as the model (in this case %s). If you pass in a string '
2678              'identifier for a metric to compile the metric will '
2679              'automatically be created in the correct distribution '
2680              'strategy scope.' % (metric, strategy)
2681          )
2682
2683    # Model metrics must be created in the same distribution strategy scope
2684    # as the model.
2685    for opt in nest.flatten(optimizer):
2686      for v in getattr(opt, '_weights', []):
2687        if not strategy.extended.variable_created_in_scope(v):
2688          raise ValueError(
2689              'Optimizer (%s) passed to model.compile was created inside of a '
2690              'different distribution strategy scope than the model. All '
2691              'optimizers must be created in the same distribution strategy '
2692              'scope as the model (in this case %s). If you pass in a string '
2693              'identifier for an optimizer to compile the optimizer will '
2694              'automatically be created in the correct distribution '
2695              'strategy scope.' % (opt, strategy))
2696
2697  def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
2698    """Maybe load initial epoch from ckpt considering possible worker recovery.
2699
2700    Refer to tensorflow/python/keras/distribute/worker_training_state.py
2701    for more information.
2702
2703    Args:
2704      initial_epoch: The original initial_epoch user passes in in `fit()`.
2705
2706    Returns:
2707      If the training is recovering from previous failure under multi-worker
2708      training setting, return the epoch the training is supposed to continue
2709      at. Otherwise, return the `initial_epoch` the user passes in.
2710    """
2711    if self._training_state is not None:
2712      return self._training_state.maybe_load_initial_epoch_from_ckpt(
2713          initial_epoch, mode=ModeKeys.TRAIN)
2714    return initial_epoch
2715
2716  def _assert_compile_was_called(self):
2717    # Checks whether `compile` has been called. If it has been called,
2718    # then the optimizer is set. This is different from whether the
2719    # model is compiled
2720    # (i.e. whether the model is built and its inputs/outputs are set).
2721    if not self._is_compiled:
2722      raise RuntimeError('You must compile your model before '
2723                         'training/testing. '
2724                         'Use `model.compile(optimizer, loss)`.')
2725
2726  def _set_inputs(self, inputs, outputs=None, training=None):
2727    """This method is for compat with Modelv1. Only inputs are needed here."""
2728    self._set_save_spec(inputs)
2729
2730  @property
2731  def _trackable_saved_model_saver(self):
2732    return model_serialization.ModelSavedModelSaver(self)
2733
2734  def _trackable_children(self, save_type='checkpoint', **kwargs):
2735    if save_type == 'savedmodel':
2736      # SavedModel needs to ignore the execution functions.
2737      train_function = self.train_function
2738      test_function = self.test_function
2739      predict_function = self.predict_function
2740      train_tf_function = self.train_tf_function
2741      self.train_function = None
2742      self.test_function = None
2743      self.predict_function = None
2744      self.train_tf_function = None
2745
2746    children = super(Model, self)._trackable_children(save_type, **kwargs)
2747
2748    if save_type == 'savedmodel':
2749      self.train_function = train_function
2750      self.test_function = test_function
2751      self.predict_function = predict_function
2752      self.train_tf_function = train_tf_function
2753
2754    return children
2755
2756  def _should_eval(self, epoch, validation_freq):
2757    epoch = epoch + 1  # one-index the user-facing epoch.
2758    if isinstance(validation_freq, int):
2759      return epoch % validation_freq == 0
2760    elif isinstance(validation_freq, list):
2761      return epoch in validation_freq
2762    else:
2763      raise ValueError('Expected `validation_freq` to be a list or int.')
2764
2765  ######################################################################
2766  # Functions below exist only as v1 / v2 compatibility shims.
2767  ######################################################################
2768
2769  def _get_compile_args(self, user_metrics=True):
2770    """Used for saving or cloning a Model.
2771
2772    Args:
2773      user_metrics: Whether to return user-supplied metrics or `Metric` objects.
2774        Defaults to returning the user-supplied metrics.
2775
2776    Returns:
2777      Dictionary of arguments that were used when compiling the model.
2778    """
2779    self._assert_compile_was_called()
2780    # pylint: disable=protected-access
2781
2782    saved_metrics = self.compiled_metrics._user_metrics
2783    saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics
2784
2785    if not user_metrics:
2786      if saved_metrics is not None:
2787        saved_metrics = self.compiled_metrics._metrics
2788      if saved_weighted_metrics is not None:
2789        saved_weighted_metrics = self.compiled_metrics._weighted_metrics
2790
2791    compile_args = {
2792        'optimizer': self.optimizer,
2793        'loss': self.compiled_loss._user_losses,
2794        'metrics': saved_metrics,
2795        'weighted_metrics': saved_weighted_metrics,
2796        'loss_weights': self.compiled_loss._user_loss_weights,
2797    }
2798    # pylint: enable=protected-access
2799    return compile_args
2800
2801  def _get_callback_model(self):
2802    return self
2803
2804  def _in_multi_worker_mode(self):
2805    return self.distribute_strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2806
2807  @property
2808  def _compile_was_called(self):
2809    return self._is_compiled
2810
2811
2812def reduce_per_replica(values, strategy, reduction='first'):
2813  """Reduce PerReplica objects.
2814
2815  Args:
2816    values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are
2817      returned as-is.
2818    strategy: `tf.distribute.Strategy` object.
2819    reduction: One of 'first', 'concat'.
2820
2821  Returns:
2822    Structure of `Tensor`s.
2823  """
2824
2825  def _reduce(v):
2826    """Reduce a single `PerReplica` object."""
2827    if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
2828      return _multi_worker_concat(v, strategy)
2829    if not _is_per_replica_instance(v):
2830      return v
2831    elif reduction == 'first':
2832      return strategy.unwrap(v)[0]
2833    elif reduction == 'concat':
2834      if _is_tpu_multi_host(strategy):
2835        return _tpu_multi_host_concat(v, strategy)
2836      else:
2837        return concat(strategy.unwrap(v))
2838    else:
2839      raise ValueError('`reduction` must be "first" or "concat".')
2840
2841  return nest.map_structure(_reduce, values)
2842
2843
2844def concat(tensors, axis=0):
2845  """Concats `tensor`s along `axis`."""
2846  if isinstance(tensors[0], sparse_tensor.SparseTensor):
2847    return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
2848  elif _is_scalar(tensors[0]):
2849    return array_ops.stack(tensors, axis=axis)
2850  else:
2851    return array_ops.concat(tensors, axis=axis)
2852
2853
2854def _is_tpu_multi_host(strategy):
2855  return (backend.is_tpu_strategy(strategy) and
2856          strategy.extended.num_hosts > 1)
2857
2858
2859def _tpu_multi_host_concat(v, strategy):
2860  """Correctly order TPU PerReplica objects."""
2861  replicas = strategy.unwrap(v)
2862  # When distributed datasets are created from Tensors / NumPy,
2863  # TPUStrategy.experimental_distribute_dataset shards data in
2864  # (Replica, Host) order, and TPUStrategy.unwrap returns it in
2865  # (Host, Replica) order.
2866  # TODO(b/150317897): Figure out long-term plan here.
2867  num_replicas_per_host = strategy.extended.num_replicas_per_host
2868  ordered_replicas = []
2869  for replica_id in range(num_replicas_per_host):
2870    ordered_replicas += replicas[replica_id::num_replicas_per_host]
2871  return concat(ordered_replicas)
2872
2873
2874def _collective_all_reduce_multi_worker(strategy):
2875  return (isinstance(strategy,
2876                     collective_all_reduce_strategy.CollectiveAllReduceStrategy)
2877         ) and strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2878
2879
2880# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
2881# for all strategies
2882def _multi_worker_concat(v, strategy):
2883  """Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
2884  replicas = strategy.gather(v, axis=0)
2885  # v might not have the same shape on different replicas
2886  if _is_per_replica_instance(v):
2887    shapes = array_ops.concat([
2888        array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
2889        for single_value in v.values
2890    ],
2891                              axis=0)
2892    all_shapes = strategy.gather(shapes, axis=0)
2893  else:
2894    # v is a tensor. This may happen when, say, we have 2x1 multi-worker.
2895    all_shapes = strategy.gather(
2896        array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0)
2897
2898  replicas = array_ops.split(
2899      replicas,
2900      num_or_size_splits=all_shapes,
2901      num=strategy.num_replicas_in_sync)
2902  ordered_replicas = []
2903  num_replicas_per_worker = len(strategy.extended.worker_devices)
2904  for replica_id in range(num_replicas_per_worker):
2905    ordered_replicas += replicas[replica_id::num_replicas_per_worker]
2906  return concat(ordered_replicas)
2907
2908
2909def _is_scalar(x):
2910  return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
2911
2912
2913def write_scalar_summaries(logs, step):
2914  for name, value in logs.items():
2915    if _is_scalar(value):
2916      summary_ops_v2.scalar('batch_' + name, value, step=step)
2917
2918
2919def _minimum_control_deps(outputs):
2920  """Returns the minimum control dependencies to ensure step succeeded."""
2921  if context.executing_eagerly():
2922    return []  # Control dependencies not needed.
2923  outputs = nest.flatten(outputs, expand_composites=True)
2924  for out in outputs:
2925    # Variables can't be control dependencies.
2926    if not isinstance(out, variables.Variable):
2927      return [out]  # Return first Tensor or Op from outputs.
2928  return []  # No viable Tensor or Op to use for control deps.
2929
2930
2931def _disallow_inside_tf_function(method_name):
2932  if ops.inside_function():
2933    error_msg = (
2934        'Detected a call to `Model.{method_name}` inside a `tf.function`. '
2935        '`Model.{method_name} is a high-level endpoint that manages its own '
2936        '`tf.function`. Please move the call to `Model.{method_name}` outside '
2937        'of all enclosing `tf.function`s. Note that you can call a `Model` '
2938        'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
2939    ).format(method_name=method_name)
2940    raise RuntimeError(error_msg)
2941
2942
2943def _detect_save_format(filepath):
2944  """Returns path to weights file and save format."""
2945
2946  filepath = path_to_string(filepath)
2947  if saving_utils.is_hdf5_filepath(filepath):
2948    return filepath, 'h5'
2949
2950  # Filepath could be a TensorFlow checkpoint file prefix or SavedModel
2951  # directory. It's possible for filepath to be both a prefix and directory.
2952  # Prioritize checkpoint over SavedModel.
2953  if _is_readable_tf_checkpoint(filepath):
2954    save_format = 'tf'
2955  elif sm_loader.contains_saved_model(filepath):
2956    ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
2957                             sm_constants.VARIABLES_FILENAME)
2958    if _is_readable_tf_checkpoint(ckpt_path):
2959      filepath = ckpt_path
2960      save_format = 'tf'
2961    else:
2962      raise ValueError('Unable to load weights. filepath {} appears to be a '
2963                       'SavedModel directory, but checkpoint either doesn\'t '
2964                       'exist, or is incorrectly formatted.'.format(filepath))
2965  else:
2966    # Not a TensorFlow checkpoint. This filepath is likely an H5 file that
2967    # doesn't have the hdf5/keras extensions.
2968    save_format = 'h5'
2969  return filepath, save_format
2970
2971
2972def _is_readable_tf_checkpoint(filepath):
2973  try:
2974    py_checkpoint_reader.NewCheckpointReader(filepath)
2975    return True
2976  except errors_impl.DataLossError:
2977    # The checkpoint is not readable in TensorFlow format.
2978    return False
2979
2980
2981def flatten_metrics_in_order(logs, metrics_names):
2982  """Turns the `logs` dict into a list as per key order of `metrics_names`."""
2983  results = []
2984  for name in metrics_names:
2985    if name in logs:
2986      results.append(logs[name])
2987  for key in sorted(logs.keys()):
2988    if key not in metrics_names:
2989      results.append(logs[key])
2990  if len(results) == 1:
2991    return results[0]
2992  return results
2993
2994
2995def _is_per_replica_instance(obj):
2996  return (isinstance(obj, ds_values.DistributedValues) and
2997          isinstance(obj, composite_tensor.CompositeTensor))
2998