1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related part of the Keras engine.""" 16 17import copy 18import itertools 19import json 20import os 21import warnings 22import weakref 23 24from tensorflow.python.autograph.lang import directives 25from tensorflow.python.checkpoint import checkpoint as trackable_utils 26from tensorflow.python.checkpoint import checkpoint_management 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.data.ops import options as options_lib 29from tensorflow.python.distribute import collective_all_reduce_strategy 30from tensorflow.python.distribute import distribution_strategy_context as ds_context 31from tensorflow.python.distribute import values as ds_values 32from tensorflow.python.distribute.coordinator import cluster_coordinator 33from tensorflow.python.eager import backprop 34from tensorflow.python.eager import context 35from tensorflow.python.eager import def_function 36from tensorflow.python.framework import composite_tensor 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import errors_impl 39from tensorflow.python.framework import func_graph 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import sparse_tensor 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.keras import backend 44from tensorflow.python.keras import callbacks as callbacks_module 45from tensorflow.python.keras import optimizer_v1 46from tensorflow.python.keras import optimizers 47from tensorflow.python.keras.engine import base_layer 48from tensorflow.python.keras.engine import base_layer_utils 49from tensorflow.python.keras.engine import compile_utils 50from tensorflow.python.keras.engine import data_adapter 51from tensorflow.python.keras.engine import training_utils 52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso 53from tensorflow.python.keras.mixed_precision import policy 54from tensorflow.python.keras.saving import hdf5_format 55from tensorflow.python.keras.saving import save 56from tensorflow.python.keras.saving import saving_utils 57from tensorflow.python.keras.saving.saved_model import json_utils 58from tensorflow.python.keras.saving.saved_model import model_serialization 59from tensorflow.python.keras.utils import generic_utils 60from tensorflow.python.keras.utils import layer_utils 61from tensorflow.python.keras.utils import tf_utils 62from tensorflow.python.keras.utils import version_utils 63from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 64from tensorflow.python.keras.utils.io_utils import path_to_string 65from tensorflow.python.keras.utils.mode_keys import ModeKeys 66from tensorflow.python.ops import array_ops 67from tensorflow.python.ops import math_ops 68from tensorflow.python.ops import sparse_ops 69from tensorflow.python.ops import summary_ops_v2 70from tensorflow.python.ops import variables 71from tensorflow.python.platform import tf_logging as logging 72from tensorflow.python.profiler import trace 73from tensorflow.python.saved_model import constants as sm_constants 74from tensorflow.python.saved_model import loader_impl as sm_loader 75from tensorflow.python.trackable import base as trackable 76from tensorflow.python.training import py_checkpoint_reader 77from tensorflow.python.util import nest 78from tensorflow.python.util import tf_decorator 79from tensorflow.python.util.tf_export import keras_export 80from tensorflow.tools.docs import doc_controls 81 82 83# pylint: disable=g-import-not-at-top 84try: 85 import h5py 86except ImportError: 87 h5py = None 88# pylint: enable=g-import-not-at-top 89 90 91def disable_multi_worker(method): 92 """Decorator that disallows multi-worker use of `method`.""" 93 94 def _method_wrapper(self, *args, **kwargs): 95 if self._in_multi_worker_mode(): # pylint: disable=protected-access 96 raise ValueError('{} is not supported in multi-worker mode.'.format( 97 method.__name__)) 98 return method(self, *args, **kwargs) 99 100 return tf_decorator.make_decorator( 101 target=method, decorator_func=_method_wrapper) 102 103 104def inject_functional_model_class(cls): 105 """Inject `Functional` into the hierarchy of this class if needed.""" 106 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 107 from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top 108 if cls == Model or cls == training_v1.Model: 109 return functional.Functional 110 # In case there is any multiple inheritance, we stop injecting the 111 # class if keras model is not in its class hierarchy. 112 if cls == object: 113 return object 114 115 cls.__bases__ = tuple(inject_functional_model_class(base) 116 for base in cls.__bases__) 117 # Trigger any `__new__` class swapping that needed to happen on `Functional` 118 # but did not because functional was not in the class hierarchy. 119 cls.__new__(cls) 120 121 return cls 122 123 124def is_functional_model_init_params(args, kwargs): 125 return (len(args) == 2 or 126 len(args) == 1 and 'outputs' in kwargs or 127 'inputs' in kwargs and 'outputs' in kwargs) 128 129 130@keras_export('keras.Model', 'keras.models.Model') 131class Model(base_layer.Layer, version_utils.ModelVersionSelector): 132 """`Model` groups layers into an object with training and inference features. 133 134 Args: 135 inputs: The input(s) of the model: a `keras.Input` object or list of 136 `keras.Input` objects. 137 outputs: The output(s) of the model. See Functional API example below. 138 name: String, the name of the model. 139 140 There are two ways to instantiate a `Model`: 141 142 1 - With the "Functional API", where you start from `Input`, 143 you chain layer calls to specify the model's forward pass, 144 and finally you create your model from inputs and outputs: 145 146 ```python 147 import tensorflow as tf 148 149 inputs = tf.keras.Input(shape=(3,)) 150 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 151 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 152 model = tf.keras.Model(inputs=inputs, outputs=outputs) 153 ``` 154 155 Note: Only dicts, lists, and tuples of input tensors are supported. Nested 156 inputs are not supported (e.g. lists of list or dicts of dict). 157 158 2 - By subclassing the `Model` class: in that case, you should define your 159 layers in `__init__` and you should implement the model's forward pass 160 in `call`. 161 162 ```python 163 import tensorflow as tf 164 165 class MyModel(tf.keras.Model): 166 167 def __init__(self): 168 super(MyModel, self).__init__() 169 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 170 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 171 172 def call(self, inputs): 173 x = self.dense1(inputs) 174 return self.dense2(x) 175 176 model = MyModel() 177 ``` 178 179 If you subclass `Model`, you can optionally have 180 a `training` argument (boolean) in `call`, which you can use to specify 181 a different behavior in training and inference: 182 183 ```python 184 import tensorflow as tf 185 186 class MyModel(tf.keras.Model): 187 188 def __init__(self): 189 super(MyModel, self).__init__() 190 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 191 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 192 self.dropout = tf.keras.layers.Dropout(0.5) 193 194 def call(self, inputs, training=False): 195 x = self.dense1(inputs) 196 if training: 197 x = self.dropout(x, training=training) 198 return self.dense2(x) 199 200 model = MyModel() 201 ``` 202 203 Once the model is created, you can config the model with losses and metrics 204 with `model.compile()`, train the model with `model.fit()`, or use the model 205 to do prediction with `model.predict()`. 206 """ 207 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 208 itertools.chain(('_train_counter', '_test_counter', '_predict_counter', 209 '_steps_per_execution'), 210 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access 211 212 def __new__(cls, *args, **kwargs): 213 # Signature detection 214 if is_functional_model_init_params(args, kwargs) and cls == Model: 215 # Functional model 216 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 217 return functional.Functional(skip_init=True, *args, **kwargs) 218 else: 219 return super(Model, cls).__new__(cls, *args, **kwargs) 220 221 @trackable.no_automatic_dependency_tracking 222 def __init__(self, *args, **kwargs): 223 self._is_model_for_instrumentation = True 224 225 # Special case for Subclassed Functional Model, which we couldn't detect 226 # when __new__ is called. We only realize it is a functional model when it 227 # calls super.__init__ with input and output tensor. 228 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 229 if (is_functional_model_init_params(args, kwargs) and 230 not isinstance(self, functional.Functional)): 231 # Filter the kwargs for multiple inheritance. 232 supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] 233 model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} 234 other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} 235 inject_functional_model_class(self.__class__) 236 functional.Functional.__init__(self, *args, **model_kwargs) 237 238 # In case there is any multiple inheritance here, we need to call the 239 # __init__ for any class that appears after the Functional class. 240 clz_to_init = [] 241 found_functional_class = False 242 for clz in self.__class__.__bases__: 243 if issubclass(clz, functional.Functional): 244 found_functional_class = True 245 continue 246 if found_functional_class: 247 clz_to_init.append(clz) 248 249 if clz_to_init: 250 for clz in clz_to_init: 251 clz.__init__(self, *args, **other_kwargs) 252 elif other_kwargs: 253 # In case there are unused kwargs, we should raise an error to user, in 254 # case they have a typo in the param name. 255 raise TypeError( 256 'The following keyword arguments aren\'t supported: {}'.format( 257 other_kwargs)) 258 return 259 260 # The following are implemented as property functions: 261 # self.trainable_weights 262 # self.non_trainable_weights 263 # `inputs` / `outputs` will only appear in kwargs if either are misspelled. 264 generic_utils.validate_kwargs(kwargs, { 265 'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs' 266 }) 267 super(Model, self).__init__(**kwargs) 268 # By default, Model is a subclass model, which is not in graph network. 269 self._is_graph_network = False 270 271 self.inputs = None 272 self.outputs = None 273 self.input_names = None 274 self.output_names = None 275 # stop_training is used by callback to stop training when error happens 276 self.stop_training = False 277 self.history = None 278 # These objects are used in the default `Model.compile`. They are not 279 # guaranteed to be set after `Model.compile` is called, as users can 280 # override compile with custom logic. 281 self.compiled_loss = None 282 self.compiled_metrics = None 283 284 # This is True for Sequential networks and Functional networks. 285 self._compute_output_and_mask_jointly = False 286 287 # Don't reset compilation if already done. This may occur if calling 288 # `__init__` (or `_init_graph_network`) on an already-compiled model 289 # such as a Sequential model. Sequential models may need to rebuild 290 # themselves after compilation. 291 self._maybe_create_attribute('_is_compiled', False) 292 self._maybe_create_attribute('optimizer', None) 293 294 # Model must be created under scope of DistStrat it will be trained with. 295 if ds_context.has_strategy(): 296 self._distribution_strategy = ds_context.get_strategy() 297 else: 298 self._distribution_strategy = None 299 300 self._cluster_coordinator = None 301 302 # Defaults to value of `tf.config.experimental_functions_run_eagerly`. 303 self._run_eagerly = None 304 # Initialize cache attrs. 305 self._reset_compile_cache() 306 307 # Fault-tolerance handler. Set in `ModelCheckpoint`. 308 self._training_state = None 309 self._saved_model_inputs_spec = None 310 self._checkpoint = trackable_utils.Checkpoint(root=weakref.ref(self)) 311 312 self._steps_per_execution = None 313 314 self._init_batch_counters() 315 self._base_model_initialized = True 316 317 @trackable.no_automatic_dependency_tracking 318 def _init_batch_counters(self): 319 # Untracked Variables, used to keep track of mini-batches seen in `fit`, 320 # `evaluate`, and `predict`. 321 agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA 322 self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) 323 self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) 324 self._predict_counter = variables.Variable( 325 0, dtype='int64', aggregation=agg) 326 327 def __setattr__(self, name, value): 328 if not getattr(self, '_self_setattr_tracking', True): 329 super(Model, self).__setattr__(name, value) 330 return 331 332 if all( 333 isinstance(v, (base_layer.Layer, variables.Variable)) or 334 base_layer_utils.has_weights(v) for v in nest.flatten(value)): 335 try: 336 self._base_model_initialized 337 except AttributeError: 338 raise RuntimeError( 339 'It looks like you are subclassing `Model` and you ' 340 'forgot to call `super().__init__()`.' 341 ' Always start with this line.') 342 343 super(Model, self).__setattr__(name, value) 344 345 @generic_utils.default 346 def build(self, input_shape): 347 """Builds the model based on input shapes received. 348 349 This is to be used for subclassed models, which do not know at instantiation 350 time what their inputs look like. 351 352 This method only exists for users who want to call `model.build()` in a 353 standalone way (as a substitute for calling the model on real data to 354 build it). It will never be called by the framework (and thus it will 355 never throw unexpected errors in an unrelated workflow). 356 357 Args: 358 input_shape: Single tuple, TensorShape, or list/dict of shapes, where 359 shapes are tuples, integers, or TensorShapes. 360 361 Raises: 362 ValueError: 363 1. In case of invalid user-provided data (not of type tuple, 364 list, TensorShape, or dict). 365 2. If the model requires call arguments that are agnostic 366 to the input shapes (positional or kwarg in call signature). 367 3. If not all layers were properly built. 368 4. If float type inputs are not supported within the layers. 369 370 In each of these cases, the user should build their model by calling it 371 on real tensor data. 372 """ 373 if self._is_graph_network: 374 super(Model, self).build(input_shape) 375 return 376 377 if input_shape is None: 378 raise ValueError('Input shape must be defined when calling build on a ' 379 'model subclass network.') 380 valid_types = (tuple, list, tensor_shape.TensorShape, dict) 381 if not isinstance(input_shape, valid_types): 382 raise ValueError('Specified input shape is not one of the valid types. ' 383 'Please specify a batch input shape of type tuple or ' 384 'list of input shapes. User provided ' 385 'input type: {}'.format(type(input_shape))) 386 387 if input_shape and not self.inputs: 388 # We create placeholders for the `None`s in the shape and build the model 389 # in a Graph. Since tf.Variable is compatible with both eager execution 390 # and graph building, the variables created after building the model in 391 # a Graph are still valid when executing eagerly. 392 if context.executing_eagerly(): 393 graph = func_graph.FuncGraph('build_graph') 394 else: 395 graph = backend.get_graph() 396 with graph.as_default(): 397 if (isinstance(input_shape, list) and 398 all(d is None or isinstance(d, int) for d in input_shape)): 399 input_shape = tuple(input_shape) 400 if isinstance(input_shape, list): 401 x = [base_layer_utils.generate_placeholders_from_shape(shape) 402 for shape in input_shape] 403 elif isinstance(input_shape, dict): 404 x = { 405 k: base_layer_utils.generate_placeholders_from_shape(shape) 406 for k, shape in input_shape.items() 407 } 408 else: 409 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 410 411 kwargs = {} 412 call_signature = self._call_full_argspec 413 call_args = call_signature.args 414 # Exclude `self`, `inputs`, and any argument with a default value. 415 if len(call_args) > 2: 416 if call_signature.defaults: 417 call_args = call_args[2:-len(call_signature.defaults)] 418 else: 419 call_args = call_args[2:] 420 for arg in call_args: 421 if arg == 'training': 422 # Case where `training` is a positional arg with no default. 423 kwargs['training'] = False 424 else: 425 # Has invalid call signature with unknown positional arguments. 426 raise ValueError( 427 'Currently, you cannot build your model if it has ' 428 'positional or keyword arguments that are not ' 429 'inputs to the model, but are required for its ' 430 '`call` method. Instead, in order to instantiate ' 431 'and build your model, `call` your model on real ' 432 'tensor data with all expected call arguments.') 433 elif len(call_args) < 2: 434 # Signature without `inputs`. 435 raise ValueError('You can only call `build` on a model if its `call` ' 436 'method accepts an `inputs` argument.') 437 try: 438 self.call(x, **kwargs) 439 except (errors.InvalidArgumentError, TypeError): 440 raise ValueError('You cannot build your model by calling `build` ' 441 'if your layers do not support float type inputs. ' 442 'Instead, in order to instantiate and build your ' 443 'model, `call` your model on real tensor data (of ' 444 'the correct dtype).') 445 super(Model, self).build(input_shape) 446 447 @doc_controls.doc_in_current_and_subclasses 448 def call(self, inputs, training=None, mask=None): 449 """Calls the model on new inputs. 450 451 In this case `call` just reapplies 452 all ops in the graph to the new inputs 453 (e.g. build a new computational graph from the provided inputs). 454 455 Note: This method should not be called directly. It is only meant to be 456 overridden when subclassing `tf.keras.Model`. 457 To call a model on an input, always use the `__call__` method, 458 i.e. `model(inputs)`, which relies on the underlying `call` method. 459 460 Args: 461 inputs: Input tensor, or dict/list/tuple of input tensors. 462 training: Boolean or boolean scalar tensor, indicating whether to run 463 the `Network` in training mode or inference mode. 464 mask: A mask or list of masks. A mask can be 465 either a tensor or None (no mask). 466 467 Returns: 468 A tensor if there is a single output, or 469 a list of tensors if there are more than one outputs. 470 """ 471 raise NotImplementedError('When subclassing the `Model` class, you should ' 472 'implement a `call` method.') 473 474 def compile(self, 475 optimizer='rmsprop', 476 loss=None, 477 metrics=None, 478 loss_weights=None, 479 weighted_metrics=None, 480 run_eagerly=None, 481 steps_per_execution=None, 482 **kwargs): 483 """Configures the model for training. 484 485 Args: 486 optimizer: String (name of optimizer) or optimizer instance. See 487 `tf.keras.optimizers`. 488 loss: String (name of objective function), objective function or 489 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 490 function is any callable with the signature `loss = fn(y_true, 491 y_pred)`, where y_true = ground truth values with shape = 492 `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse 493 categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`. 494 y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It 495 returns a weighted loss float tensor. If a custom `Loss` instance is 496 used and reduction is set to `None`, return value has the shape 497 `[batch_size, d0, .. dN-1]` i.e. per-sample or per-timestep loss 498 values; otherwise, it is a scalar. If the model has multiple outputs, 499 you can use a different loss on each output by passing a dictionary 500 or a list of losses. The loss value that will be minimized by the 501 model will then be the sum of all individual losses, unless 502 `loss_weights` is specified. 503 metrics: List of metrics to be evaluated by the model during training 504 and testing. Each of this can be a string (name of a built-in 505 function), function or a `tf.keras.metrics.Metric` instance. See 506 `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A 507 function is any callable with the signature `result = fn(y_true, 508 y_pred)`. To specify different metrics for different outputs of a 509 multi-output model, you could also pass a dictionary, such as 510 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 511 You can also pass a list to specify a metric or a list of metrics 512 for each output, such as `metrics=[['accuracy'], ['accuracy', 'mse']]` 513 or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the 514 strings 'accuracy' or 'acc', we convert this to one of 515 `tf.keras.metrics.BinaryAccuracy`, 516 `tf.keras.metrics.CategoricalAccuracy`, 517 `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss 518 function used and the model output shape. We do a similar 519 conversion for the strings 'crossentropy' and 'ce' as well. 520 loss_weights: Optional list or dictionary specifying scalar coefficients 521 (Python floats) to weight the loss contributions of different model 522 outputs. The loss value that will be minimized by the model will then 523 be the *weighted sum* of all individual losses, weighted by the 524 `loss_weights` coefficients. 525 If a list, it is expected to have a 1:1 mapping to the model's 526 outputs. If a dict, it is expected to map output names (strings) 527 to scalar coefficients. 528 weighted_metrics: List of metrics to be evaluated and weighted by 529 `sample_weight` or `class_weight` during training and testing. 530 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s 531 logic will not be wrapped in a `tf.function`. Recommended to leave 532 this as `None` unless your `Model` cannot be run inside a 533 `tf.function`. `run_eagerly=True` is not supported when using 534 `tf.distribute.experimental.ParameterServerStrategy`. 535 steps_per_execution: Int. Defaults to 1. The number of batches to 536 run during each `tf.function` call. Running multiple batches 537 inside a single `tf.function` call can greatly improve performance 538 on TPUs or small models with a large Python overhead. 539 At most, one full epoch will be run each 540 execution. If a number larger than the size of the epoch is passed, 541 the execution will be truncated to the size of the epoch. 542 Note that if `steps_per_execution` is set to `N`, 543 `Callback.on_batch_begin` and `Callback.on_batch_end` methods 544 will only be called every `N` batches 545 (i.e. before/after each `tf.function` execution). 546 **kwargs: Arguments supported for backwards compatibility only. 547 548 Raises: 549 ValueError: In case of invalid arguments for 550 `optimizer`, `loss` or `metrics`. 551 """ 552 with self.distribute_strategy.scope(): 553 if 'experimental_steps_per_execution' in kwargs: 554 logging.warning('The argument `steps_per_execution` is no longer ' 555 'experimental. Pass `steps_per_execution` instead of ' 556 '`experimental_steps_per_execution`.') 557 if not steps_per_execution: 558 steps_per_execution = kwargs.pop('experimental_steps_per_execution') 559 560 # When compiling from an already-serialized model, we do not want to 561 # reapply some processing steps (e.g. metric renaming for multi-output 562 # models, which have prefixes added for each corresponding output name). 563 from_serialized = kwargs.pop('from_serialized', False) 564 565 self._validate_compile(optimizer, metrics, **kwargs) 566 self._run_eagerly = run_eagerly 567 568 self.optimizer = self._get_optimizer(optimizer) 569 self.compiled_loss = compile_utils.LossesContainer( 570 loss, loss_weights, output_names=self.output_names) 571 self.compiled_metrics = compile_utils.MetricsContainer( 572 metrics, weighted_metrics, output_names=self.output_names, 573 from_serialized=from_serialized) 574 575 self._configure_steps_per_execution(steps_per_execution or 1) 576 577 # Initializes attrs that are reset each time `compile` is called. 578 self._reset_compile_cache() 579 self._is_compiled = True 580 581 self.loss = loss or {} # Backwards compat. 582 583 def _get_optimizer(self, optimizer): 584 """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" 585 # The deprecated PolicyV1 has a loss_scale, which we use for backwards 586 # compatibility to match TF 2.3 behavior. The new Policy does not have a 587 # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is 588 # used. 589 if isinstance(self._dtype_policy, policy.PolicyV1): 590 loss_scale = self._dtype_policy.loss_scale 591 elif self._dtype_policy.name == 'mixed_float16': 592 loss_scale = 'dynamic' 593 else: 594 loss_scale = None 595 596 def _get_single_optimizer(opt): 597 opt = optimizers.get(opt) 598 if (loss_scale is not None and 599 not isinstance(opt, lso.LossScaleOptimizer)): 600 if loss_scale == 'dynamic': 601 opt = lso.LossScaleOptimizer(opt) 602 else: 603 opt = lso.LossScaleOptimizerV1(opt, loss_scale) 604 return opt 605 606 return nest.map_structure(_get_single_optimizer, optimizer) 607 608 @trackable.no_automatic_dependency_tracking 609 def _reset_compile_cache(self): 610 self.train_function = None 611 self.test_function = None 612 self.predict_function = None 613 # Used to cache the `tf.function`'ed `train_function` to be logged in 614 # TensorBoard, since the original `train_function` is not necessarily 615 # a `tf.function` (e.g., with ParameterServerStrategy, the `train_function` 616 # is a scheduling of the actual training function to a remote worker). 617 self.train_tf_function = None 618 619 # Used to cache `trainable` attr of `Layer`s for `fit`. 620 self._compiled_trainable_state = self._get_trainable_state() 621 622 @trackable.no_automatic_dependency_tracking 623 def _configure_steps_per_execution(self, steps_per_execution): 624 self._steps_per_execution = variables.Variable( 625 steps_per_execution, 626 dtype='int64', 627 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 628 629 @property 630 def _should_compute_mask(self): 631 return False 632 633 @property 634 def metrics(self): 635 """Returns the model's metrics added using `compile`, `add_metric` APIs. 636 637 Note: Metrics passed to `compile()` are available only after a `keras.Model` 638 has been trained/evaluated on actual data. 639 640 Examples: 641 642 >>> inputs = tf.keras.layers.Input(shape=(3,)) 643 >>> outputs = tf.keras.layers.Dense(2)(inputs) 644 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 645 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 646 >>> [m.name for m in model.metrics] 647 [] 648 649 >>> x = np.random.random((2, 3)) 650 >>> y = np.random.randint(0, 2, (2, 2)) 651 >>> model.fit(x, y) 652 >>> [m.name for m in model.metrics] 653 ['loss', 'mae'] 654 655 >>> inputs = tf.keras.layers.Input(shape=(3,)) 656 >>> d = tf.keras.layers.Dense(2, name='out') 657 >>> output_1 = d(inputs) 658 >>> output_2 = d(inputs) 659 >>> model = tf.keras.models.Model( 660 ... inputs=inputs, outputs=[output_1, output_2]) 661 >>> model.add_metric( 662 ... tf.reduce_sum(output_2), name='mean', aggregation='mean') 663 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 664 >>> model.fit(x, (y, y)) 665 >>> [m.name for m in model.metrics] 666 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 667 'out_1_acc', 'mean'] 668 669 """ 670 metrics = [] 671 if self._is_compiled: 672 # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects 673 # so that attr names are not load-bearing. 674 if self.compiled_loss is not None: 675 metrics += self.compiled_loss.metrics 676 if self.compiled_metrics is not None: 677 metrics += self.compiled_metrics.metrics 678 679 for l in self._flatten_layers(): 680 metrics.extend(l._metrics) # pylint: disable=protected-access 681 return metrics 682 683 @property 684 def metrics_names(self): 685 """Returns the model's display labels for all outputs. 686 687 Note: `metrics_names` are available only after a `keras.Model` has been 688 trained/evaluated on actual data. 689 690 Examples: 691 692 >>> inputs = tf.keras.layers.Input(shape=(3,)) 693 >>> outputs = tf.keras.layers.Dense(2)(inputs) 694 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 695 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 696 >>> model.metrics_names 697 [] 698 699 >>> x = np.random.random((2, 3)) 700 >>> y = np.random.randint(0, 2, (2, 2)) 701 >>> model.fit(x, y) 702 >>> model.metrics_names 703 ['loss', 'mae'] 704 705 >>> inputs = tf.keras.layers.Input(shape=(3,)) 706 >>> d = tf.keras.layers.Dense(2, name='out') 707 >>> output_1 = d(inputs) 708 >>> output_2 = d(inputs) 709 >>> model = tf.keras.models.Model( 710 ... inputs=inputs, outputs=[output_1, output_2]) 711 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 712 >>> model.fit(x, (y, y)) 713 >>> model.metrics_names 714 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 715 'out_1_acc'] 716 717 """ 718 719 # This property includes all output names including `loss` and per-output 720 # losses for backward compatibility. 721 return [m.name for m in self.metrics] 722 723 @property 724 def distribute_strategy(self): 725 """The `tf.distribute.Strategy` this model was created under.""" 726 return self._distribution_strategy or ds_context.get_strategy() 727 728 @property 729 def run_eagerly(self): 730 """Settable attribute indicating whether the model should run eagerly. 731 732 Running eagerly means that your model will be run step by step, 733 like Python code. Your model might run slower, but it should become easier 734 for you to debug it by stepping into individual layer calls. 735 736 By default, we will attempt to compile your model to a static graph to 737 deliver the best execution performance. 738 739 Returns: 740 Boolean, whether the model should run eagerly. 741 """ 742 if self.dynamic and self._run_eagerly is False: # pylint:disable=g-bool-id-comparison 743 # TODO(fchollet): consider using py_func to enable this. 744 raise ValueError('Your model contains layers that can only be ' 745 'successfully run in eager execution (layers ' 746 'constructed with `dynamic=True`). ' 747 'You cannot set `run_eagerly=False`.') 748 749 if self._cluster_coordinator and self._run_eagerly: 750 raise ValueError('When using `Model` with `ParameterServerStrategy`, ' 751 '`run_eagerly` is not supported.') 752 753 # Run eagerly logic, by priority: 754 # (1) Dynamic models must be run eagerly. 755 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly. 756 # (3) Not explicitly setting run_eagerly defaults to TF's global setting. 757 return (self.dynamic or self._run_eagerly or 758 (def_function.functions_run_eagerly() and 759 self._run_eagerly is None)) 760 761 @run_eagerly.setter 762 def run_eagerly(self, value): 763 self._run_eagerly = value 764 765 def train_step(self, data): 766 """The logic for one training step. 767 768 This method can be overridden to support custom training logic. 769 For concrete examples of how to override this method see 770 [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit). 771 This method is called by `Model.make_train_function`. 772 773 This method should contain the mathematical logic for one step of training. 774 This typically includes the forward pass, loss calculation, backpropagation, 775 and metric updates. 776 777 Configuration details for *how* this logic is run (e.g. `tf.function` and 778 `tf.distribute.Strategy` settings), should be left to 779 `Model.make_train_function`, which can also be overridden. 780 781 Args: 782 data: A nested structure of `Tensor`s. 783 784 Returns: 785 A `dict` containing values that will be passed to 786 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 787 values of the `Model`'s metrics are returned. Example: 788 `{'loss': 0.2, 'accuracy': 0.7}`. 789 790 """ 791 # These are the only transformations `Model.fit` applies to user-input 792 # data when a `tf.data.Dataset` is provided. 793 data = data_adapter.expand_1d(data) 794 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 795 # Run forward pass. 796 with backprop.GradientTape() as tape: 797 y_pred = self(x, training=True) 798 loss = self.compiled_loss( 799 y, y_pred, sample_weight, regularization_losses=self.losses) 800 # Run backwards pass. 801 self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 802 self.compiled_metrics.update_state(y, y_pred, sample_weight) 803 # Collect metrics to return 804 return_metrics = {} 805 for metric in self.metrics: 806 result = metric.result() 807 if isinstance(result, dict): 808 return_metrics.update(result) 809 else: 810 return_metrics[metric.name] = result 811 return return_metrics 812 813 def make_train_function(self): 814 """Creates a function that executes one step of training. 815 816 This method can be overridden to support custom training logic. 817 This method is called by `Model.fit` and `Model.train_on_batch`. 818 819 Typically, this method directly controls `tf.function` and 820 `tf.distribute.Strategy` settings, and delegates the actual training 821 logic to `Model.train_step`. 822 823 This function is cached the first time `Model.fit` or 824 `Model.train_on_batch` is called. The cache is cleared whenever 825 `Model.compile` is called. 826 827 Returns: 828 Function. The function created by this method should accept a 829 `tf.data.Iterator`, and return a `dict` containing values that will 830 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as 831 `{'loss': 0.2, 'accuracy': 0.7}`. 832 """ 833 if self.train_function is not None: 834 return self.train_function 835 836 def step_function(model, iterator): 837 """Runs a single training step.""" 838 839 def run_step(data): 840 outputs = model.train_step(data) 841 # Ensure counter is updated only if `train_step` succeeds. 842 with ops.control_dependencies(_minimum_control_deps(outputs)): 843 model._train_counter.assign_add(1) # pylint: disable=protected-access 844 return outputs 845 846 data = next(iterator) 847 outputs = model.distribute_strategy.run(run_step, args=(data,)) 848 outputs = reduce_per_replica( 849 outputs, self.distribute_strategy, reduction='first') 850 write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access 851 return outputs 852 853 if self._steps_per_execution.numpy().item() == 1: 854 855 def train_function(iterator): 856 """Runs a training execution with one step.""" 857 return step_function(self, iterator) 858 859 else: 860 861 def train_function(iterator): 862 """Runs a training execution with multiple steps.""" 863 for _ in math_ops.range(self._steps_per_execution): 864 outputs = step_function(self, iterator) 865 return outputs 866 867 if not self.run_eagerly: 868 train_function = def_function.function( 869 train_function, experimental_relax_shapes=True) 870 self.train_tf_function = train_function 871 872 self.train_function = train_function 873 874 if self._cluster_coordinator: 875 self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 876 train_function, args=(iterator,)) 877 878 return self.train_function 879 880 def fit(self, 881 x=None, 882 y=None, 883 batch_size=None, 884 epochs=1, 885 verbose='auto', 886 callbacks=None, 887 validation_split=0., 888 validation_data=None, 889 shuffle=True, 890 class_weight=None, 891 sample_weight=None, 892 initial_epoch=0, 893 steps_per_epoch=None, 894 validation_steps=None, 895 validation_batch_size=None, 896 validation_freq=1, 897 max_queue_size=10, 898 workers=1, 899 use_multiprocessing=False): 900 """Trains the model for a fixed number of epochs (iterations on a dataset). 901 902 Args: 903 x: Input data. It could be: 904 - A Numpy array (or array-like), or a list of arrays 905 (in case the model has multiple inputs). 906 - A TensorFlow tensor, or a list of tensors 907 (in case the model has multiple inputs). 908 - A dict mapping input names to the corresponding array/tensors, 909 if the model has named inputs. 910 - A `tf.data` dataset. Should return a tuple 911 of either `(inputs, targets)` or 912 `(inputs, targets, sample_weights)`. 913 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 914 or `(inputs, targets, sample_weights)`. 915 - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a 916 callable that takes a single argument of type 917 `tf.distribute.InputContext`, and returns a `tf.data.Dataset`. 918 `DatasetCreator` should be used when users prefer to specify the 919 per-replica batching and sharding logic for the `Dataset`. 920 See `tf.keras.utils.experimental.DatasetCreator` doc for more 921 information. 922 A more detailed description of unpacking behavior for iterator types 923 (Dataset, generator, Sequence) is given below. If using 924 `tf.distribute.experimental.ParameterServerStrategy`, only 925 `DatasetCreator` type is supported for `x`. 926 y: Target data. Like the input data `x`, 927 it could be either Numpy array(s) or TensorFlow tensor(s). 928 It should be consistent with `x` (you cannot have Numpy inputs and 929 tensor targets, or inversely). If `x` is a dataset, generator, 930 or `keras.utils.Sequence` instance, `y` should 931 not be specified (since targets will be obtained from `x`). 932 batch_size: Integer or `None`. 933 Number of samples per gradient update. 934 If unspecified, `batch_size` will default to 32. 935 Do not specify the `batch_size` if your data is in the 936 form of datasets, generators, or `keras.utils.Sequence` instances 937 (since they generate batches). 938 epochs: Integer. Number of epochs to train the model. 939 An epoch is an iteration over the entire `x` and `y` 940 data provided. 941 Note that in conjunction with `initial_epoch`, 942 `epochs` is to be understood as "final epoch". 943 The model is not trained for a number of iterations 944 given by `epochs`, but merely until the epoch 945 of index `epochs` is reached. 946 verbose: 'auto', 0, 1, or 2. Verbosity mode. 947 0 = silent, 1 = progress bar, 2 = one line per epoch. 948 'auto' defaults to 1 for most cases, but 2 when used with 949 `ParameterServerStrategy`. Note that the progress bar is not 950 particularly useful when logged to a file, so verbose=2 is 951 recommended when not running interactively (eg, in a production 952 environment). 953 callbacks: List of `keras.callbacks.Callback` instances. 954 List of callbacks to apply during training. 955 See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger` 956 and `tf.keras.callbacks.History` callbacks are created automatically 957 and need not be passed into `model.fit`. 958 `tf.keras.callbacks.ProgbarLogger` is created or not based on 959 `verbose` argument to `model.fit`. 960 Callbacks with batch-level calls are currently unsupported with 961 `tf.distribute.experimental.ParameterServerStrategy`, and users are 962 advised to implement epoch-level calls instead with an appropriate 963 `steps_per_epoch` value. 964 validation_split: Float between 0 and 1. 965 Fraction of the training data to be used as validation data. 966 The model will set apart this fraction of the training data, 967 will not train on it, and will evaluate 968 the loss and any model metrics 969 on this data at the end of each epoch. 970 The validation data is selected from the last samples 971 in the `x` and `y` data provided, before shuffling. This argument is 972 not supported when `x` is a dataset, generator or 973 `keras.utils.Sequence` instance. 974 `validation_split` is not yet supported with 975 `tf.distribute.experimental.ParameterServerStrategy`. 976 validation_data: Data on which to evaluate 977 the loss and any model metrics at the end of each epoch. 978 The model will not be trained on this data. Thus, note the fact 979 that the validation loss of data provided using `validation_split` 980 or `validation_data` is not affected by regularization layers like 981 noise and dropout. 982 `validation_data` will override `validation_split`. 983 `validation_data` could be: 984 - A tuple `(x_val, y_val)` of Numpy arrays or tensors. 985 - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays. 986 - A `tf.data.Dataset`. 987 - A Python generator or `keras.utils.Sequence` returning 988 `(inputs, targets)` or `(inputs, targets, sample_weights)`. 989 `validation_data` is not yet supported with 990 `tf.distribute.experimental.ParameterServerStrategy`. 991 shuffle: Boolean (whether to shuffle the training data 992 before each epoch) or str (for 'batch'). This argument is ignored 993 when `x` is a generator or an object of tf.data.Dataset. 994 'batch' is a special option for dealing 995 with the limitations of HDF5 data; it shuffles in batch-sized 996 chunks. Has no effect when `steps_per_epoch` is not `None`. 997 class_weight: Optional dictionary mapping class indices (integers) 998 to a weight (float) value, used for weighting the loss function 999 (during training only). 1000 This can be useful to tell the model to 1001 "pay more attention" to samples from 1002 an under-represented class. 1003 sample_weight: Optional Numpy array of weights for 1004 the training samples, used for weighting the loss function 1005 (during training only). You can either pass a flat (1D) 1006 Numpy array with the same length as the input samples 1007 (1:1 mapping between weights and samples), 1008 or in the case of temporal data, 1009 you can pass a 2D array with shape 1010 `(samples, sequence_length)`, 1011 to apply a different weight to every timestep of every sample. This 1012 argument is not supported when `x` is a dataset, generator, or 1013 `keras.utils.Sequence` instance, instead provide the sample_weights 1014 as the third element of `x`. 1015 initial_epoch: Integer. 1016 Epoch at which to start training 1017 (useful for resuming a previous training run). 1018 steps_per_epoch: Integer or `None`. 1019 Total number of steps (batches of samples) 1020 before declaring one epoch finished and starting the 1021 next epoch. When training with input tensors such as 1022 TensorFlow data tensors, the default `None` is equal to 1023 the number of samples in your dataset divided by 1024 the batch size, or 1 if that cannot be determined. If x is a 1025 `tf.data` dataset, and 'steps_per_epoch' 1026 is None, the epoch will run until the input dataset is exhausted. 1027 When passing an infinitely repeating dataset, you must specify the 1028 `steps_per_epoch` argument. If `steps_per_epoch=-1` the training 1029 will run indefinitely with an infinitely repeating dataset. 1030 This argument is not supported with array inputs. 1031 When using `tf.distribute.experimental.ParameterServerStrategy`: 1032 * `steps_per_epoch=None` is not supported. 1033 validation_steps: Only relevant if `validation_data` is provided and 1034 is a `tf.data` dataset. Total number of steps (batches of 1035 samples) to draw before stopping when performing validation 1036 at the end of every epoch. If 'validation_steps' is None, validation 1037 will run until the `validation_data` dataset is exhausted. In the 1038 case of an infinitely repeated dataset, it will run into an 1039 infinite loop. If 'validation_steps' is specified and only part of 1040 the dataset will be consumed, the evaluation will start from the 1041 beginning of the dataset at each epoch. This ensures that the same 1042 validation samples are used every time. 1043 validation_batch_size: Integer or `None`. 1044 Number of samples per validation batch. 1045 If unspecified, will default to `batch_size`. 1046 Do not specify the `validation_batch_size` if your data is in the 1047 form of datasets, generators, or `keras.utils.Sequence` instances 1048 (since they generate batches). 1049 validation_freq: Only relevant if validation data is provided. Integer 1050 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 1051 If an integer, specifies how many training epochs to run before a 1052 new validation run is performed, e.g. `validation_freq=2` runs 1053 validation every 2 epochs. If a Container, specifies the epochs on 1054 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 1055 validation at the end of the 1st, 2nd, and 10th epochs. 1056 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1057 input only. Maximum size for the generator queue. 1058 If unspecified, `max_queue_size` will default to 10. 1059 workers: Integer. Used for generator or `keras.utils.Sequence` input 1060 only. Maximum number of processes to spin up 1061 when using process-based threading. If unspecified, `workers` 1062 will default to 1. 1063 use_multiprocessing: Boolean. Used for generator or 1064 `keras.utils.Sequence` input only. If `True`, use process-based 1065 threading. If unspecified, `use_multiprocessing` will default to 1066 `False`. Note that because this implementation relies on 1067 multiprocessing, you should not pass non-picklable arguments to 1068 the generator as they can't be passed easily to children processes. 1069 1070 Unpacking behavior for iterator-like inputs: 1071 A common pattern is to pass a tf.data.Dataset, generator, or 1072 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact 1073 yield not only features (x) but optionally targets (y) and sample weights. 1074 Keras requires that the output of such iterator-likes be unambiguous. The 1075 iterator should return a tuple of length 1, 2, or 3, where the optional 1076 second and third elements will be used for y and sample_weight 1077 respectively. Any other type provided will be wrapped in a length one 1078 tuple, effectively treating everything as 'x'. When yielding dicts, they 1079 should still adhere to the top-level tuple structure. 1080 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate 1081 features, targets, and weights from the keys of a single dict. 1082 A notable unsupported data type is the namedtuple. The reason is that 1083 it behaves like both an ordered datatype (tuple) and a mapping 1084 datatype (dict). So given a namedtuple of the form: 1085 `namedtuple("example_tuple", ["y", "x"])` 1086 it is ambiguous whether to reverse the order of the elements when 1087 interpreting the value. Even worse is a tuple of the form: 1088 `namedtuple("other_tuple", ["x", "y", "z"])` 1089 where it is unclear if the tuple was intended to be unpacked into x, y, 1090 and sample_weight or passed through as a single element to `x`. As a 1091 result the data processing code will simply raise a ValueError if it 1092 encounters a namedtuple. (Along with instructions to remedy the issue.) 1093 1094 Returns: 1095 A `History` object. Its `History.history` attribute is 1096 a record of training loss values and metrics values 1097 at successive epochs, as well as validation loss values 1098 and validation metrics values (if applicable). 1099 1100 Raises: 1101 RuntimeError: 1. If the model was never compiled or, 1102 2. If `model.fit` is wrapped in `tf.function`. 1103 1104 ValueError: In case of mismatch between the provided input data 1105 and what the model expects or when the input data is empty. 1106 """ 1107 # Legacy graph support is contained in `training_v1.Model`. 1108 version_utils.disallow_legacy_graph('Model', 'fit') 1109 self._assert_compile_was_called() 1110 self._check_call_args('fit') 1111 _disallow_inside_tf_function('fit') 1112 1113 if verbose == 'auto': 1114 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1115 verbose = 2 # Default to epoch-level logging for PSStrategy. 1116 else: 1117 verbose = 1 # Default to batch-level logging otherwise. 1118 1119 if validation_split: 1120 # Create the validation data using the training data. Only supported for 1121 # `Tensor` and `NumPy` input. 1122 (x, y, sample_weight), validation_data = ( 1123 data_adapter.train_validation_split( 1124 (x, y, sample_weight), validation_split=validation_split)) 1125 1126 if validation_data: 1127 val_x, val_y, val_sample_weight = ( 1128 data_adapter.unpack_x_y_sample_weight(validation_data)) 1129 1130 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1131 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 1132 self.distribute_strategy) 1133 1134 with self.distribute_strategy.scope(), \ 1135 training_utils.RespectCompiledTrainableState(self): 1136 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1137 data_handler = data_adapter.get_data_handler( 1138 x=x, 1139 y=y, 1140 sample_weight=sample_weight, 1141 batch_size=batch_size, 1142 steps_per_epoch=steps_per_epoch, 1143 initial_epoch=initial_epoch, 1144 epochs=epochs, 1145 shuffle=shuffle, 1146 class_weight=class_weight, 1147 max_queue_size=max_queue_size, 1148 workers=workers, 1149 use_multiprocessing=use_multiprocessing, 1150 model=self, 1151 steps_per_execution=self._steps_per_execution) 1152 1153 # Container that configures and calls `tf.keras.Callback`s. 1154 if not isinstance(callbacks, callbacks_module.CallbackList): 1155 callbacks = callbacks_module.CallbackList( 1156 callbacks, 1157 add_history=True, 1158 add_progbar=verbose != 0, 1159 model=self, 1160 verbose=verbose, 1161 epochs=epochs, 1162 steps=data_handler.inferred_steps) 1163 1164 self.stop_training = False 1165 self.train_function = self.make_train_function() 1166 self._train_counter.assign(0) 1167 callbacks.on_train_begin() 1168 training_logs = None 1169 # Handle fault-tolerance for multi-worker. 1170 # TODO(omalleyt): Fix the ordering issues that mean this has to 1171 # happen after `callbacks.on_train_begin`. 1172 data_handler._initial_epoch = ( # pylint: disable=protected-access 1173 self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) 1174 logs = None 1175 for epoch, iterator in data_handler.enumerate_epochs(): 1176 self.reset_metrics() 1177 callbacks.on_epoch_begin(epoch) 1178 with data_handler.catch_stop_iteration(): 1179 for step in data_handler.steps(): 1180 with trace.Trace( 1181 'train', 1182 epoch_num=epoch, 1183 step_num=step, 1184 batch_size=batch_size, 1185 _r=1): 1186 callbacks.on_train_batch_begin(step) 1187 tmp_logs = self.train_function(iterator) 1188 if data_handler.should_sync: 1189 context.async_wait() 1190 logs = tmp_logs # No error, now safe to assign to logs. 1191 end_step = step + data_handler.step_increment 1192 callbacks.on_train_batch_end(end_step, logs) 1193 if self.stop_training: 1194 break 1195 1196 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1197 if logs is None: 1198 raise ValueError('Expect x to be a non-empty array or dataset.') 1199 epoch_logs = copy.copy(logs) 1200 1201 # Run validation. 1202 if validation_data and self._should_eval(epoch, validation_freq): 1203 # Create data_handler for evaluation and cache it. 1204 if getattr(self, '_eval_data_handler', None) is None: 1205 self._eval_data_handler = data_adapter.get_data_handler( 1206 x=val_x, 1207 y=val_y, 1208 sample_weight=val_sample_weight, 1209 batch_size=validation_batch_size or batch_size, 1210 steps_per_epoch=validation_steps, 1211 initial_epoch=0, 1212 epochs=1, 1213 max_queue_size=max_queue_size, 1214 workers=workers, 1215 use_multiprocessing=use_multiprocessing, 1216 model=self, 1217 steps_per_execution=self._steps_per_execution) 1218 val_logs = self.evaluate( 1219 x=val_x, 1220 y=val_y, 1221 sample_weight=val_sample_weight, 1222 batch_size=validation_batch_size or batch_size, 1223 steps=validation_steps, 1224 callbacks=callbacks, 1225 max_queue_size=max_queue_size, 1226 workers=workers, 1227 use_multiprocessing=use_multiprocessing, 1228 return_dict=True, 1229 _use_cached_eval_dataset=True) 1230 val_logs = {'val_' + name: val for name, val in val_logs.items()} 1231 epoch_logs.update(val_logs) 1232 1233 callbacks.on_epoch_end(epoch, epoch_logs) 1234 training_logs = epoch_logs 1235 if self.stop_training: 1236 break 1237 1238 # If eval data_hanlder exists, delete it after all epochs are done. 1239 if getattr(self, '_eval_data_handler', None) is not None: 1240 del self._eval_data_handler 1241 callbacks.on_train_end(logs=training_logs) 1242 return self.history 1243 1244 def test_step(self, data): 1245 """The logic for one evaluation step. 1246 1247 This method can be overridden to support custom evaluation logic. 1248 This method is called by `Model.make_test_function`. 1249 1250 This function should contain the mathematical logic for one step of 1251 evaluation. 1252 This typically includes the forward pass, loss calculation, and metrics 1253 updates. 1254 1255 Configuration details for *how* this logic is run (e.g. `tf.function` and 1256 `tf.distribute.Strategy` settings), should be left to 1257 `Model.make_test_function`, which can also be overridden. 1258 1259 Args: 1260 data: A nested structure of `Tensor`s. 1261 1262 Returns: 1263 A `dict` containing values that will be passed to 1264 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 1265 values of the `Model`'s metrics are returned. 1266 """ 1267 data = data_adapter.expand_1d(data) 1268 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 1269 1270 y_pred = self(x, training=False) 1271 # Updates stateful loss metrics. 1272 self.compiled_loss( 1273 y, y_pred, sample_weight, regularization_losses=self.losses) 1274 self.compiled_metrics.update_state(y, y_pred, sample_weight) 1275 # Collect metrics to return 1276 return_metrics = {} 1277 for metric in self.metrics: 1278 result = metric.result() 1279 if isinstance(result, dict): 1280 return_metrics.update(result) 1281 else: 1282 return_metrics[metric.name] = result 1283 return return_metrics 1284 1285 def make_test_function(self): 1286 """Creates a function that executes one step of evaluation. 1287 1288 This method can be overridden to support custom evaluation logic. 1289 This method is called by `Model.evaluate` and `Model.test_on_batch`. 1290 1291 Typically, this method directly controls `tf.function` and 1292 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1293 logic to `Model.test_step`. 1294 1295 This function is cached the first time `Model.evaluate` or 1296 `Model.test_on_batch` is called. The cache is cleared whenever 1297 `Model.compile` is called. 1298 1299 Returns: 1300 Function. The function created by this method should accept a 1301 `tf.data.Iterator`, and return a `dict` containing values that will 1302 be passed to `tf.keras.Callbacks.on_test_batch_end`. 1303 """ 1304 if self.test_function is not None: 1305 return self.test_function 1306 1307 def step_function(model, iterator): 1308 """Runs a single evaluation step.""" 1309 1310 def run_step(data): 1311 outputs = model.test_step(data) 1312 # Ensure counter is updated only if `test_step` succeeds. 1313 with ops.control_dependencies(_minimum_control_deps(outputs)): 1314 model._test_counter.assign_add(1) # pylint: disable=protected-access 1315 return outputs 1316 1317 data = next(iterator) 1318 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1319 outputs = reduce_per_replica( 1320 outputs, self.distribute_strategy, reduction='first') 1321 return outputs 1322 1323 if self._steps_per_execution.numpy().item() == 1: 1324 1325 def test_function(iterator): 1326 """Runs an evaluation execution with one step.""" 1327 return step_function(self, iterator) 1328 1329 else: 1330 1331 def test_function(iterator): 1332 """Runs an evaluation execution with multiple steps.""" 1333 for _ in math_ops.range(self._steps_per_execution): 1334 outputs = step_function(self, iterator) 1335 return outputs 1336 1337 if not self.run_eagerly: 1338 test_function = def_function.function( 1339 test_function, experimental_relax_shapes=True) 1340 1341 self.test_function = test_function 1342 1343 if self._cluster_coordinator: 1344 self.test_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 1345 test_function, args=(iterator,)) 1346 1347 return self.test_function 1348 1349 def evaluate(self, 1350 x=None, 1351 y=None, 1352 batch_size=None, 1353 verbose=1, 1354 sample_weight=None, 1355 steps=None, 1356 callbacks=None, 1357 max_queue_size=10, 1358 workers=1, 1359 use_multiprocessing=False, 1360 return_dict=False, 1361 **kwargs): 1362 """Returns the loss value & metrics values for the model in test mode. 1363 1364 Computation is done in batches (see the `batch_size` arg.) 1365 1366 Args: 1367 x: Input data. It could be: 1368 - A Numpy array (or array-like), or a list of arrays 1369 (in case the model has multiple inputs). 1370 - A TensorFlow tensor, or a list of tensors 1371 (in case the model has multiple inputs). 1372 - A dict mapping input names to the corresponding array/tensors, 1373 if the model has named inputs. 1374 - A `tf.data` dataset. Should return a tuple 1375 of either `(inputs, targets)` or 1376 `(inputs, targets, sample_weights)`. 1377 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 1378 or `(inputs, targets, sample_weights)`. 1379 A more detailed description of unpacking behavior for iterator types 1380 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1381 for iterator-like inputs` section of `Model.fit`. 1382 y: Target data. Like the input data `x`, it could be either Numpy 1383 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1384 (you cannot have Numpy inputs and tensor targets, or inversely). If 1385 `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` 1386 should not be specified (since targets will be obtained from the 1387 iterator/dataset). 1388 batch_size: Integer or `None`. Number of samples per batch of 1389 computation. If unspecified, `batch_size` will default to 32. Do not 1390 specify the `batch_size` if your data is in the form of a dataset, 1391 generators, or `keras.utils.Sequence` instances (since they generate 1392 batches). 1393 verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. 1394 sample_weight: Optional Numpy array of weights for the test samples, 1395 used for weighting the loss function. You can either pass a flat (1D) 1396 Numpy array with the same length as the input samples 1397 (1:1 mapping between weights and samples), or in the case of 1398 temporal data, you can pass a 2D array with shape `(samples, 1399 sequence_length)`, to apply a different weight to every timestep 1400 of every sample. This argument is not supported when `x` is a 1401 dataset, instead pass sample weights as the third element of `x`. 1402 steps: Integer or `None`. Total number of steps (batches of samples) 1403 before declaring the evaluation round finished. Ignored with the 1404 default value of `None`. If x is a `tf.data` dataset and `steps` is 1405 None, 'evaluate' will run until the dataset is exhausted. This 1406 argument is not supported with array inputs. 1407 callbacks: List of `keras.callbacks.Callback` instances. List of 1408 callbacks to apply during evaluation. See 1409 [callbacks](/api_docs/python/tf/keras/callbacks). 1410 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1411 input only. Maximum size for the generator queue. If unspecified, 1412 `max_queue_size` will default to 10. 1413 workers: Integer. Used for generator or `keras.utils.Sequence` input 1414 only. Maximum number of processes to spin up when using process-based 1415 threading. If unspecified, `workers` will default to 1. 1416 use_multiprocessing: Boolean. Used for generator or 1417 `keras.utils.Sequence` input only. If `True`, use process-based 1418 threading. If unspecified, `use_multiprocessing` will default to 1419 `False`. Note that because this implementation relies on 1420 multiprocessing, you should not pass non-picklable arguments to the 1421 generator as they can't be passed easily to children processes. 1422 return_dict: If `True`, loss and metric results are returned as a dict, 1423 with each key being the name of the metric. If `False`, they are 1424 returned as a list. 1425 **kwargs: Unused at this time. 1426 1427 See the discussion of `Unpacking behavior for iterator-like inputs` for 1428 `Model.fit`. 1429 1430 `Model.evaluate` is not yet supported with 1431 `tf.distribute.experimental.ParameterServerStrategy`. 1432 1433 Returns: 1434 Scalar test loss (if the model has a single output and no metrics) 1435 or list of scalars (if the model has multiple outputs 1436 and/or metrics). The attribute `model.metrics_names` will give you 1437 the display labels for the scalar outputs. 1438 1439 Raises: 1440 RuntimeError: If `model.evaluate` is wrapped in `tf.function`. 1441 ValueError: in case of invalid arguments. 1442 """ 1443 version_utils.disallow_legacy_graph('Model', 'evaluate') 1444 self._assert_compile_was_called() 1445 self._check_call_args('evaluate') 1446 _disallow_inside_tf_function('evaluate') 1447 use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False) 1448 if kwargs: 1449 raise TypeError('Invalid keyword arguments: %s' % (kwargs,)) 1450 1451 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1452 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 1453 self.distribute_strategy) 1454 1455 with self.distribute_strategy.scope(): 1456 # Use cached evaluation data only when it's called in `Model.fit` 1457 if (use_cached_eval_dataset 1458 and getattr(self, '_eval_data_handler', None) is not None): 1459 data_handler = self._eval_data_handler 1460 else: 1461 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1462 data_handler = data_adapter.get_data_handler( 1463 x=x, 1464 y=y, 1465 sample_weight=sample_weight, 1466 batch_size=batch_size, 1467 steps_per_epoch=steps, 1468 initial_epoch=0, 1469 epochs=1, 1470 max_queue_size=max_queue_size, 1471 workers=workers, 1472 use_multiprocessing=use_multiprocessing, 1473 model=self, 1474 steps_per_execution=self._steps_per_execution) 1475 1476 # Container that configures and calls `tf.keras.Callback`s. 1477 if not isinstance(callbacks, callbacks_module.CallbackList): 1478 callbacks = callbacks_module.CallbackList( 1479 callbacks, 1480 add_history=True, 1481 add_progbar=verbose != 0, 1482 model=self, 1483 verbose=verbose, 1484 epochs=1, 1485 steps=data_handler.inferred_steps) 1486 1487 logs = {} 1488 self.test_function = self.make_test_function() 1489 self._test_counter.assign(0) 1490 callbacks.on_test_begin() 1491 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1492 self.reset_metrics() 1493 with data_handler.catch_stop_iteration(): 1494 for step in data_handler.steps(): 1495 with trace.Trace('test', step_num=step, _r=1): 1496 callbacks.on_test_batch_begin(step) 1497 tmp_logs = self.test_function(iterator) 1498 if data_handler.should_sync: 1499 context.async_wait() 1500 logs = tmp_logs # No error, now safe to assign to logs. 1501 end_step = step + data_handler.step_increment 1502 callbacks.on_test_batch_end(end_step, logs) 1503 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1504 callbacks.on_test_end(logs=logs) 1505 1506 if return_dict: 1507 return logs 1508 else: 1509 return flatten_metrics_in_order(logs, self.metrics_names) 1510 1511 def predict_step(self, data): 1512 """The logic for one inference step. 1513 1514 This method can be overridden to support custom inference logic. 1515 This method is called by `Model.make_predict_function`. 1516 1517 This method should contain the mathematical logic for one step of inference. 1518 This typically includes the forward pass. 1519 1520 Configuration details for *how* this logic is run (e.g. `tf.function` and 1521 `tf.distribute.Strategy` settings), should be left to 1522 `Model.make_predict_function`, which can also be overridden. 1523 1524 Args: 1525 data: A nested structure of `Tensor`s. 1526 1527 Returns: 1528 The result of one inference step, typically the output of calling the 1529 `Model` on data. 1530 """ 1531 data = data_adapter.expand_1d(data) 1532 x, _, _ = data_adapter.unpack_x_y_sample_weight(data) 1533 return self(x, training=False) 1534 1535 def make_predict_function(self): 1536 """Creates a function that executes one step of inference. 1537 1538 This method can be overridden to support custom inference logic. 1539 This method is called by `Model.predict` and `Model.predict_on_batch`. 1540 1541 Typically, this method directly controls `tf.function` and 1542 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1543 logic to `Model.predict_step`. 1544 1545 This function is cached the first time `Model.predict` or 1546 `Model.predict_on_batch` is called. The cache is cleared whenever 1547 `Model.compile` is called. 1548 1549 Returns: 1550 Function. The function created by this method should accept a 1551 `tf.data.Iterator`, and return the outputs of the `Model`. 1552 """ 1553 if self.predict_function is not None: 1554 return self.predict_function 1555 1556 def step_function(model, iterator): 1557 """Runs a single evaluation step.""" 1558 1559 def run_step(data): 1560 outputs = model.predict_step(data) 1561 # Ensure counter is updated only if `test_step` succeeds. 1562 with ops.control_dependencies(_minimum_control_deps(outputs)): 1563 model._predict_counter.assign_add(1) # pylint: disable=protected-access 1564 return outputs 1565 1566 data = next(iterator) 1567 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1568 outputs = reduce_per_replica( 1569 outputs, self.distribute_strategy, reduction='concat') 1570 return outputs 1571 1572 if (self._steps_per_execution is None or 1573 self._steps_per_execution.numpy().item() == 1): 1574 1575 def predict_function(iterator): 1576 """Runs an evaluation execution with one step.""" 1577 return step_function(self, iterator) 1578 1579 else: 1580 1581 def predict_function(iterator): 1582 """Runs an evaluation execution with multiple steps.""" 1583 outputs = step_function(self, iterator) 1584 for _ in math_ops.range(self._steps_per_execution - 1): 1585 directives.set_loop_options( 1586 shape_invariants=[( 1587 t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape) 1588 for t in nest.flatten(outputs)]) 1589 step_outputs = step_function(self, iterator) 1590 outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs, 1591 step_outputs) 1592 return outputs 1593 1594 if not self.run_eagerly: 1595 predict_function = def_function.function( 1596 predict_function, experimental_relax_shapes=True) 1597 1598 self.predict_function = predict_function 1599 return self.predict_function 1600 1601 def predict(self, 1602 x, 1603 batch_size=None, 1604 verbose=0, 1605 steps=None, 1606 callbacks=None, 1607 max_queue_size=10, 1608 workers=1, 1609 use_multiprocessing=False): 1610 """Generates output predictions for the input samples. 1611 1612 Computation is done in batches. This method is designed for performance in 1613 large scale inputs. For small amount of inputs that fit in one batch, 1614 directly using `__call__` is recommended for faster execution, e.g., 1615 `model(x)`, or `model(x, training=False)` if you have layers such as 1616 `tf.keras.layers.BatchNormalization` that behaves differently during 1617 inference. Also, note the fact that test loss is not affected by 1618 regularization layers like noise and dropout. 1619 1620 Args: 1621 x: Input samples. It could be: 1622 - A Numpy array (or array-like), or a list of arrays 1623 (in case the model has multiple inputs). 1624 - A TensorFlow tensor, or a list of tensors 1625 (in case the model has multiple inputs). 1626 - A `tf.data` dataset. 1627 - A generator or `keras.utils.Sequence` instance. 1628 A more detailed description of unpacking behavior for iterator types 1629 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1630 for iterator-like inputs` section of `Model.fit`. 1631 batch_size: Integer or `None`. 1632 Number of samples per batch. 1633 If unspecified, `batch_size` will default to 32. 1634 Do not specify the `batch_size` if your data is in the 1635 form of dataset, generators, or `keras.utils.Sequence` instances 1636 (since they generate batches). 1637 verbose: Verbosity mode, 0 or 1. 1638 steps: Total number of steps (batches of samples) 1639 before declaring the prediction round finished. 1640 Ignored with the default value of `None`. If x is a `tf.data` 1641 dataset and `steps` is None, `predict` will 1642 run until the input dataset is exhausted. 1643 callbacks: List of `keras.callbacks.Callback` instances. 1644 List of callbacks to apply during prediction. 1645 See [callbacks](/api_docs/python/tf/keras/callbacks). 1646 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1647 input only. Maximum size for the generator queue. 1648 If unspecified, `max_queue_size` will default to 10. 1649 workers: Integer. Used for generator or `keras.utils.Sequence` input 1650 only. Maximum number of processes to spin up when using 1651 process-based threading. If unspecified, `workers` will default 1652 to 1. 1653 use_multiprocessing: Boolean. Used for generator or 1654 `keras.utils.Sequence` input only. If `True`, use process-based 1655 threading. If unspecified, `use_multiprocessing` will default to 1656 `False`. Note that because this implementation relies on 1657 multiprocessing, you should not pass non-picklable arguments to 1658 the generator as they can't be passed easily to children processes. 1659 1660 See the discussion of `Unpacking behavior for iterator-like inputs` for 1661 `Model.fit`. Note that Model.predict uses the same interpretation rules as 1662 `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all 1663 three methods. 1664 1665 Returns: 1666 Numpy array(s) of predictions. 1667 1668 Raises: 1669 RuntimeError: If `model.predict` is wrapped in `tf.function`. 1670 ValueError: In case of mismatch between the provided 1671 input data and the model's expectations, 1672 or in case a stateful model receives a number of samples 1673 that is not a multiple of the batch size. 1674 """ 1675 version_utils.disallow_legacy_graph('Model', 'predict') 1676 self._check_call_args('predict') 1677 _disallow_inside_tf_function('predict') 1678 1679 # TODO(yashkatariya): Cache model on the coordinator for faster prediction. 1680 # If running under PSS, then swap it with OneDeviceStrategy so that 1681 # execution will run on the coordinator. 1682 original_pss_strategy = None 1683 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1684 original_pss_strategy = self.distribute_strategy 1685 self._distribution_strategy = None 1686 1687 # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not 1688 # needed in `.predict()` because all the predictions happen on the 1689 # coordinator/locally. 1690 if self._cluster_coordinator: 1691 self._cluster_coordinator = None 1692 1693 outputs = None 1694 with self.distribute_strategy.scope(): 1695 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1696 dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2) 1697 if (self._in_multi_worker_mode() or _is_tpu_multi_host( 1698 self.distribute_strategy)) and isinstance(x, dataset_types): 1699 try: 1700 options = options_lib.Options() 1701 data_option = options_lib.AutoShardPolicy.DATA 1702 options.experimental_distribute.auto_shard_policy = data_option 1703 x = x.with_options(options) 1704 except ValueError: 1705 warnings.warn('Using Model.predict with ' 1706 'MultiWorkerDistributionStrategy or TPUStrategy and ' 1707 'AutoShardPolicy.FILE might lead to out-of-order result' 1708 '. Consider setting it to AutoShardPolicy.DATA.') 1709 1710 data_handler = data_adapter.get_data_handler( 1711 x=x, 1712 batch_size=batch_size, 1713 steps_per_epoch=steps, 1714 initial_epoch=0, 1715 epochs=1, 1716 max_queue_size=max_queue_size, 1717 workers=workers, 1718 use_multiprocessing=use_multiprocessing, 1719 model=self, 1720 steps_per_execution=self._steps_per_execution) 1721 1722 # Container that configures and calls `tf.keras.Callback`s. 1723 if not isinstance(callbacks, callbacks_module.CallbackList): 1724 callbacks = callbacks_module.CallbackList( 1725 callbacks, 1726 add_history=True, 1727 add_progbar=verbose != 0, 1728 model=self, 1729 verbose=verbose, 1730 epochs=1, 1731 steps=data_handler.inferred_steps) 1732 1733 self.predict_function = self.make_predict_function() 1734 self._predict_counter.assign(0) 1735 callbacks.on_predict_begin() 1736 batch_outputs = None 1737 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1738 with data_handler.catch_stop_iteration(): 1739 for step in data_handler.steps(): 1740 callbacks.on_predict_batch_begin(step) 1741 tmp_batch_outputs = self.predict_function(iterator) 1742 if data_handler.should_sync: 1743 context.async_wait() 1744 batch_outputs = tmp_batch_outputs # No error, now safe to assign. 1745 if outputs is None: 1746 outputs = nest.map_structure(lambda batch_output: [batch_output], 1747 batch_outputs) 1748 else: 1749 nest.map_structure_up_to( 1750 batch_outputs, 1751 lambda output, batch_output: output.append(batch_output), 1752 outputs, batch_outputs) 1753 end_step = step + data_handler.step_increment 1754 callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) 1755 if batch_outputs is None: 1756 raise ValueError('Expect x to be a non-empty array or dataset.') 1757 callbacks.on_predict_end() 1758 all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) 1759 1760 # If originally PSS strategy was used, then replace it back since predict 1761 # is running under `OneDeviceStrategy` after the swap and once its done 1762 # we need to replace it back to PSS again. 1763 if original_pss_strategy is not None: 1764 self._distribution_strategy = original_pss_strategy 1765 1766 return tf_utils.sync_to_numpy_or_python_type(all_outputs) 1767 1768 def reset_metrics(self): 1769 """Resets the state of all the metrics in the model. 1770 1771 Examples: 1772 1773 >>> inputs = tf.keras.layers.Input(shape=(3,)) 1774 >>> outputs = tf.keras.layers.Dense(2)(inputs) 1775 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 1776 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 1777 1778 >>> x = np.random.random((2, 3)) 1779 >>> y = np.random.randint(0, 2, (2, 2)) 1780 >>> _ = model.fit(x, y, verbose=0) 1781 >>> assert all(float(m.result()) for m in model.metrics) 1782 1783 >>> model.reset_metrics() 1784 >>> assert all(float(m.result()) == 0 for m in model.metrics) 1785 1786 """ 1787 for m in self.metrics: 1788 m.reset_state() 1789 1790 def train_on_batch(self, 1791 x, 1792 y=None, 1793 sample_weight=None, 1794 class_weight=None, 1795 reset_metrics=True, 1796 return_dict=False): 1797 """Runs a single gradient update on a single batch of data. 1798 1799 Args: 1800 x: Input data. It could be: 1801 - A Numpy array (or array-like), or a list of arrays 1802 (in case the model has multiple inputs). 1803 - A TensorFlow tensor, or a list of tensors 1804 (in case the model has multiple inputs). 1805 - A dict mapping input names to the corresponding array/tensors, 1806 if the model has named inputs. 1807 y: Target data. Like the input data `x`, it could be either Numpy 1808 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1809 (you cannot have Numpy inputs and tensor targets, or inversely). 1810 sample_weight: Optional array of the same length as x, containing 1811 weights to apply to the model's loss for each sample. In the case of 1812 temporal data, you can pass a 2D array with shape (samples, 1813 sequence_length), to apply a different weight to every timestep of 1814 every sample. 1815 class_weight: Optional dictionary mapping class indices (integers) to a 1816 weight (float) to apply to the model's loss for the samples from this 1817 class during training. This can be useful to tell the model to "pay 1818 more attention" to samples from an under-represented class. 1819 reset_metrics: If `True`, the metrics returned will be only for this 1820 batch. If `False`, the metrics will be statefully accumulated across 1821 batches. 1822 return_dict: If `True`, loss and metric results are returned as a dict, 1823 with each key being the name of the metric. If `False`, they are 1824 returned as a list. 1825 1826 Returns: 1827 Scalar training loss 1828 (if the model has a single output and no metrics) 1829 or list of scalars (if the model has multiple outputs 1830 and/or metrics). The attribute `model.metrics_names` will give you 1831 the display labels for the scalar outputs. 1832 1833 Raises: 1834 RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`. 1835 ValueError: In case of invalid user-provided arguments. 1836 """ 1837 self._assert_compile_was_called() 1838 self._check_call_args('train_on_batch') 1839 _disallow_inside_tf_function('train_on_batch') 1840 with self.distribute_strategy.scope(), \ 1841 training_utils.RespectCompiledTrainableState(self): 1842 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1843 y, sample_weight, 1844 class_weight) 1845 self.train_function = self.make_train_function() 1846 logs = self.train_function(iterator) 1847 1848 if reset_metrics: 1849 self.reset_metrics() 1850 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1851 if return_dict: 1852 return logs 1853 else: 1854 return flatten_metrics_in_order(logs, self.metrics_names) 1855 1856 def test_on_batch(self, 1857 x, 1858 y=None, 1859 sample_weight=None, 1860 reset_metrics=True, 1861 return_dict=False): 1862 """Test the model on a single batch of samples. 1863 1864 Args: 1865 x: Input data. It could be: 1866 - A Numpy array (or array-like), or a list of arrays (in case the 1867 model has multiple inputs). 1868 - A TensorFlow tensor, or a list of tensors (in case the model has 1869 multiple inputs). 1870 - A dict mapping input names to the corresponding array/tensors, if 1871 the model has named inputs. 1872 y: Target data. Like the input data `x`, it could be either Numpy 1873 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1874 (you cannot have Numpy inputs and tensor targets, or inversely). 1875 sample_weight: Optional array of the same length as x, containing 1876 weights to apply to the model's loss for each sample. In the case of 1877 temporal data, you can pass a 2D array with shape (samples, 1878 sequence_length), to apply a different weight to every timestep of 1879 every sample. 1880 reset_metrics: If `True`, the metrics returned will be only for this 1881 batch. If `False`, the metrics will be statefully accumulated across 1882 batches. 1883 return_dict: If `True`, loss and metric results are returned as a dict, 1884 with each key being the name of the metric. If `False`, they are 1885 returned as a list. 1886 1887 Returns: 1888 Scalar test loss (if the model has a single output and no metrics) 1889 or list of scalars (if the model has multiple outputs 1890 and/or metrics). The attribute `model.metrics_names` will give you 1891 the display labels for the scalar outputs. 1892 1893 Raises: 1894 RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`. 1895 ValueError: In case of invalid user-provided arguments. 1896 """ 1897 self._assert_compile_was_called() 1898 self._check_call_args('test_on_batch') 1899 _disallow_inside_tf_function('test_on_batch') 1900 with self.distribute_strategy.scope(): 1901 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1902 y, sample_weight) 1903 self.test_function = self.make_test_function() 1904 logs = self.test_function(iterator) 1905 1906 if reset_metrics: 1907 self.reset_metrics() 1908 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1909 if return_dict: 1910 return logs 1911 else: 1912 return flatten_metrics_in_order(logs, self.metrics_names) 1913 1914 def predict_on_batch(self, x): 1915 """Returns predictions for a single batch of samples. 1916 1917 Args: 1918 x: Input data. It could be: 1919 - A Numpy array (or array-like), or a list of arrays (in case the 1920 model has multiple inputs). 1921 - A TensorFlow tensor, or a list of tensors (in case the model has 1922 multiple inputs). 1923 1924 Returns: 1925 Numpy array(s) of predictions. 1926 1927 Raises: 1928 RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`. 1929 ValueError: In case of mismatch between given number of inputs and 1930 expectations of the model. 1931 """ 1932 self._check_call_args('predict_on_batch') 1933 _disallow_inside_tf_function('predict_on_batch') 1934 with self.distribute_strategy.scope(): 1935 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x) 1936 self.predict_function = self.make_predict_function() 1937 outputs = self.predict_function(iterator) 1938 return tf_utils.sync_to_numpy_or_python_type(outputs) 1939 1940 def fit_generator(self, 1941 generator, 1942 steps_per_epoch=None, 1943 epochs=1, 1944 verbose=1, 1945 callbacks=None, 1946 validation_data=None, 1947 validation_steps=None, 1948 validation_freq=1, 1949 class_weight=None, 1950 max_queue_size=10, 1951 workers=1, 1952 use_multiprocessing=False, 1953 shuffle=True, 1954 initial_epoch=0): 1955 """Fits the model on data yielded batch-by-batch by a Python generator. 1956 1957 DEPRECATED: 1958 `Model.fit` now supports generators, so there is no longer any need to use 1959 this endpoint. 1960 """ 1961 warnings.warn('`Model.fit_generator` is deprecated and ' 1962 'will be removed in a future version. ' 1963 'Please use `Model.fit`, which supports generators.') 1964 return self.fit( 1965 generator, 1966 steps_per_epoch=steps_per_epoch, 1967 epochs=epochs, 1968 verbose=verbose, 1969 callbacks=callbacks, 1970 validation_data=validation_data, 1971 validation_steps=validation_steps, 1972 validation_freq=validation_freq, 1973 class_weight=class_weight, 1974 max_queue_size=max_queue_size, 1975 workers=workers, 1976 use_multiprocessing=use_multiprocessing, 1977 shuffle=shuffle, 1978 initial_epoch=initial_epoch) 1979 1980 def evaluate_generator(self, 1981 generator, 1982 steps=None, 1983 callbacks=None, 1984 max_queue_size=10, 1985 workers=1, 1986 use_multiprocessing=False, 1987 verbose=0): 1988 """Evaluates the model on a data generator. 1989 1990 DEPRECATED: 1991 `Model.evaluate` now supports generators, so there is no longer any need 1992 to use this endpoint. 1993 """ 1994 warnings.warn('`Model.evaluate_generator` is deprecated and ' 1995 'will be removed in a future version. ' 1996 'Please use `Model.evaluate`, which supports generators.') 1997 self._check_call_args('evaluate_generator') 1998 1999 return self.evaluate( 2000 generator, 2001 steps=steps, 2002 max_queue_size=max_queue_size, 2003 workers=workers, 2004 use_multiprocessing=use_multiprocessing, 2005 verbose=verbose, 2006 callbacks=callbacks) 2007 2008 def predict_generator(self, 2009 generator, 2010 steps=None, 2011 callbacks=None, 2012 max_queue_size=10, 2013 workers=1, 2014 use_multiprocessing=False, 2015 verbose=0): 2016 """Generates predictions for the input samples from a data generator. 2017 2018 DEPRECATED: 2019 `Model.predict` now supports generators, so there is no longer any need 2020 to use this endpoint. 2021 """ 2022 warnings.warn('`Model.predict_generator` is deprecated and ' 2023 'will be removed in a future version. ' 2024 'Please use `Model.predict`, which supports generators.') 2025 return self.predict( 2026 generator, 2027 steps=steps, 2028 max_queue_size=max_queue_size, 2029 workers=workers, 2030 use_multiprocessing=use_multiprocessing, 2031 verbose=verbose, 2032 callbacks=callbacks) 2033 2034 ###################################################################### 2035 # Functions below are not training related. They are for model weights 2036 # tracking, save/load, serialization, etc. 2037 ###################################################################### 2038 2039 @property 2040 def trainable_weights(self): 2041 self._assert_weights_created() 2042 if not self._trainable: 2043 return [] 2044 trainable_variables = [] 2045 for trackable_obj in self._self_tracked_trackables: 2046 trainable_variables += trackable_obj.trainable_variables 2047 trainable_variables += self._trainable_weights 2048 return self._dedup_weights(trainable_variables) 2049 2050 @property 2051 def non_trainable_weights(self): 2052 self._assert_weights_created() 2053 non_trainable_variables = [] 2054 for trackable_obj in self._self_tracked_trackables: 2055 non_trainable_variables += trackable_obj.non_trainable_variables 2056 2057 if not self._trainable: 2058 # Return order is all trainable vars, then all non-trainable vars. 2059 trainable_variables = [] 2060 for trackable_obj in self._self_tracked_trackables: 2061 trainable_variables += trackable_obj.trainable_variables 2062 2063 non_trainable_variables = ( 2064 trainable_variables + self._trainable_weights + 2065 non_trainable_variables + self._non_trainable_weights) 2066 else: 2067 non_trainable_variables = ( 2068 non_trainable_variables + self._non_trainable_weights) 2069 2070 return self._dedup_weights(non_trainable_variables) 2071 2072 def get_weights(self): 2073 """Retrieves the weights of the model. 2074 2075 Returns: 2076 A flat list of Numpy arrays. 2077 """ 2078 with self.distribute_strategy.scope(): 2079 return super(Model, self).get_weights() 2080 2081 def save(self, 2082 filepath, 2083 overwrite=True, 2084 include_optimizer=True, 2085 save_format=None, 2086 signatures=None, 2087 options=None, 2088 save_traces=True): 2089 # pylint: disable=line-too-long 2090 """Saves the model to Tensorflow SavedModel or a single HDF5 file. 2091 2092 Please see `tf.keras.models.save_model` or the 2093 [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) 2094 for details. 2095 2096 Args: 2097 filepath: String, PathLike, path to SavedModel or H5 file to save the 2098 model. 2099 overwrite: Whether to silently overwrite any existing file at the 2100 target location, or provide the user with a manual prompt. 2101 include_optimizer: If True, save optimizer's state together. 2102 save_format: Either `'tf'` or `'h5'`, indicating whether to save the 2103 model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, 2104 and 'h5' in TF 1.X. 2105 signatures: Signatures to save with the SavedModel. Applicable to the 2106 'tf' format only. Please see the `signatures` argument in 2107 `tf.saved_model.save` for details. 2108 options: (only applies to SavedModel format) 2109 `tf.saved_model.SaveOptions` object that specifies options for 2110 saving to SavedModel. 2111 save_traces: (only applies to SavedModel format) When enabled, the 2112 SavedModel will store the function traces for each layer. This 2113 can be disabled, so that only the configs of each layer are stored. 2114 Defaults to `True`. Disabling this will decrease serialization time 2115 and reduce file size, but it requires that all custom layers/models 2116 implement a `get_config()` method. 2117 2118 Example: 2119 2120 ```python 2121 from keras.models import load_model 2122 2123 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 2124 del model # deletes the existing model 2125 2126 # returns a compiled model 2127 # identical to the previous one 2128 model = load_model('my_model.h5') 2129 ``` 2130 """ 2131 # pylint: enable=line-too-long 2132 save.save_model(self, filepath, overwrite, include_optimizer, save_format, 2133 signatures, options, save_traces) 2134 2135 def save_weights(self, 2136 filepath, 2137 overwrite=True, 2138 save_format=None, 2139 options=None): 2140 """Saves all layer weights. 2141 2142 Either saves in HDF5 or in TensorFlow format based on the `save_format` 2143 argument. 2144 2145 When saving in HDF5 format, the weight file has: 2146 - `layer_names` (attribute), a list of strings 2147 (ordered names of model layers). 2148 - For every layer, a `group` named `layer.name` 2149 - For every such layer group, a group attribute `weight_names`, 2150 a list of strings 2151 (ordered names of weights tensor of the layer). 2152 - For every weight in the layer, a dataset 2153 storing the weight value, named after the weight tensor. 2154 2155 When saving in TensorFlow format, all objects referenced by the network are 2156 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 2157 instances or `Optimizer` instances assigned to object attributes. For 2158 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 2159 outputs)`, `Layer` instances used by the network are tracked/saved 2160 automatically. For user-defined classes which inherit from `tf.keras.Model`, 2161 `Layer` instances must be assigned to object attributes, typically in the 2162 constructor. See the documentation of `tf.train.Checkpoint` and 2163 `tf.keras.Model` for details. 2164 2165 While the formats are the same, do not mix `save_weights` and 2166 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be 2167 loaded using `Model.load_weights`. Checkpoints saved using 2168 `tf.train.Checkpoint.save` should be restored using the corresponding 2169 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 2170 `save_weights` for training checkpoints. 2171 2172 The TensorFlow format matches objects and variables by starting at a root 2173 object, `self` for `save_weights`, and greedily matching attribute 2174 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this 2175 is the `Checkpoint` even if the `Checkpoint` has a model attached. This 2176 means saving a `tf.keras.Model` using `save_weights` and loading into a 2177 `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match 2178 the `Model`'s variables. See the [guide to training 2179 checkpoints](https://www.tensorflow.org/guide/checkpoint) for details 2180 on the TensorFlow format. 2181 2182 Args: 2183 filepath: String or PathLike, path to the file to save the weights to. 2184 When saving in TensorFlow format, this is the prefix used for 2185 checkpoint files (multiple files are generated). Note that the '.h5' 2186 suffix causes weights to be saved in HDF5 format. 2187 overwrite: Whether to silently overwrite any existing file at the 2188 target location, or provide the user with a manual prompt. 2189 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 2190 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 2191 `None` defaults to 'tf'. 2192 options: Optional `tf.train.CheckpointOptions` object that specifies 2193 options for saving weights. 2194 2195 Raises: 2196 ImportError: If h5py is not available when attempting to save in HDF5 2197 format. 2198 ValueError: For invalid/unknown format arguments. 2199 """ 2200 self._assert_weights_created() 2201 filepath = path_to_string(filepath) 2202 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) 2203 if save_format is None: 2204 if filepath_is_h5: 2205 save_format = 'h5' 2206 else: 2207 save_format = 'tf' 2208 else: 2209 user_format = save_format.lower().strip() 2210 if user_format in ('tensorflow', 'tf'): 2211 save_format = 'tf' 2212 elif user_format in ('hdf5', 'h5', 'keras'): 2213 save_format = 'h5' 2214 else: 2215 raise ValueError( 2216 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 2217 save_format,)) 2218 if save_format == 'tf' and filepath_is_h5: 2219 raise ValueError( 2220 ('save_weights got save_format="tf"/"tensorflow", but the ' 2221 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 2222 'when saving in TensorFlow format.') 2223 % filepath) 2224 2225 if save_format == 'h5' and h5py is None: 2226 raise ImportError( 2227 '`save_weights` requires h5py when saving in hdf5.') 2228 if save_format == 'tf': 2229 check_filepath = filepath + '.index' 2230 else: 2231 check_filepath = filepath 2232 # If file exists and should not be overwritten: 2233 if not overwrite and os.path.isfile(check_filepath): 2234 proceed = ask_to_proceed_with_overwrite(check_filepath) 2235 if not proceed: 2236 return 2237 if save_format == 'h5': 2238 with h5py.File(filepath, 'w') as f: 2239 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 2240 else: 2241 if not context.executing_eagerly(): 2242 # Call `get_session` to initialize any uninitialized variables. 2243 backend.get_session() 2244 self._checkpoint.write(filepath, options=options) 2245 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 2246 checkpoint_management.update_checkpoint_state_internal( 2247 save_dir=os.path.dirname(filepath), 2248 model_checkpoint_path=filepath, 2249 save_relative_paths=True, 2250 all_model_checkpoint_paths=[filepath]) 2251 2252 def load_weights(self, 2253 filepath, 2254 by_name=False, 2255 skip_mismatch=False, 2256 options=None): 2257 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 2258 2259 If `by_name` is False weights are loaded based on the network's 2260 topology. This means the architecture should be the same as when the weights 2261 were saved. Note that layers that don't have weights are not taken into 2262 account in the topological ordering, so adding or removing layers is fine as 2263 long as they don't have weights. 2264 2265 If `by_name` is True, weights are loaded into layers only if they share the 2266 same name. This is useful for fine-tuning or transfer-learning models where 2267 some of the layers have changed. 2268 2269 Only topological loading (`by_name=False`) is supported when loading weights 2270 from the TensorFlow format. Note that topological loading differs slightly 2271 between TensorFlow and HDF5 formats for user-defined classes inheriting from 2272 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 2273 TensorFlow format loads based on the object-local names of attributes to 2274 which layers are assigned in the `Model`'s constructor. 2275 2276 Args: 2277 filepath: String, path to the weights file to load. For weight files in 2278 TensorFlow format, this is the file prefix (the same as was passed 2279 to `save_weights`). This can also be a path to a SavedModel 2280 saved from `model.save`. 2281 by_name: Boolean, whether to load weights by name or by topological 2282 order. Only topological loading is supported for weight files in 2283 TensorFlow format. 2284 skip_mismatch: Boolean, whether to skip loading of layers where there is 2285 a mismatch in the number of weights, or a mismatch in the shape of 2286 the weight (only valid when `by_name=True`). 2287 options: Optional `tf.train.CheckpointOptions` object that specifies 2288 options for loading weights. 2289 2290 Returns: 2291 When loading a weight file in TensorFlow format, returns the same status 2292 object as `tf.train.Checkpoint.restore`. When graph building, restore 2293 ops are run automatically as soon as the network is built (on first call 2294 for user-defined classes inheriting from `Model`, immediately if it is 2295 already built). 2296 2297 When loading weights in HDF5 format, returns `None`. 2298 2299 Raises: 2300 ImportError: If h5py is not available and the weight file is in HDF5 2301 format. 2302 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 2303 `False`. 2304 """ 2305 if backend.is_tpu_strategy(self._distribution_strategy): 2306 if (self._distribution_strategy.extended.steps_per_run > 1 and 2307 (not saving_utils.is_hdf5_filepath(filepath))): 2308 raise ValueError('Load weights is not yet supported with TPUStrategy ' 2309 'with steps_per_run greater than 1.') 2310 if skip_mismatch and not by_name: 2311 raise ValueError( 2312 'When calling model.load_weights, skip_mismatch can only be set to ' 2313 'True when by_name is True.') 2314 2315 filepath, save_format = _detect_save_format(filepath) 2316 if save_format == 'tf': 2317 status = self._checkpoint.read(filepath, options) 2318 if by_name: 2319 raise NotImplementedError( 2320 'Weights may only be loaded based on topology into Models when ' 2321 'loading TensorFlow-formatted weights (got by_name=True to ' 2322 'load_weights).') 2323 if not context.executing_eagerly(): 2324 session = backend.get_session() 2325 # Restore existing variables (if any) immediately, and set up a 2326 # streaming restore for any variables created in the future. 2327 trackable_utils.streaming_restore(status=status, session=session) 2328 status.assert_nontrivial_match() 2329 else: 2330 status = None 2331 if h5py is None: 2332 raise ImportError( 2333 '`load_weights` requires h5py when loading weights from HDF5.') 2334 if not self._is_graph_network and not self.built: 2335 raise ValueError( 2336 'Unable to load weights saved in HDF5 format into a subclassed ' 2337 'Model which has not created its variables yet. Call the Model ' 2338 'first, then load the weights.') 2339 self._assert_weights_created() 2340 with h5py.File(filepath, 'r') as f: 2341 if 'layer_names' not in f.attrs and 'model_weights' in f: 2342 f = f['model_weights'] 2343 if by_name: 2344 hdf5_format.load_weights_from_hdf5_group_by_name( 2345 f, self.layers, skip_mismatch=skip_mismatch) 2346 else: 2347 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 2348 2349 # Perform any layer defined finalization of the layer state. 2350 for layer in self.layers: 2351 layer.finalize_state() 2352 return status 2353 2354 def _updated_config(self): 2355 """Util shared between different serialization methods. 2356 2357 Returns: 2358 Model config with Keras version information added. 2359 """ 2360 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 2361 2362 config = self.get_config() 2363 model_config = { 2364 'class_name': self.__class__.__name__, 2365 'config': config, 2366 'keras_version': keras_version, 2367 'backend': backend.backend() 2368 } 2369 return model_config 2370 2371 def get_config(self): 2372 raise NotImplementedError 2373 2374 @classmethod 2375 def from_config(cls, config, custom_objects=None): 2376 # `from_config` assumes `cls` is either `Functional` or a child class of 2377 # `Functional`. In the case that `cls` is meant to behave like a child class 2378 # of `Functional` but only inherits from the `Model` class, we have to call 2379 # `cls(...)` instead of `Functional.from_config`. 2380 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 2381 with generic_utils.SharedObjectLoadingScope(): 2382 input_tensors, output_tensors, created_layers = ( 2383 functional.reconstruct_from_config(config, custom_objects)) 2384 # Initialize a model belonging to `cls`, which can be user-defined or 2385 # `Functional`. 2386 model = cls(inputs=input_tensors, outputs=output_tensors, 2387 name=config.get('name')) 2388 functional.connect_ancillary_layers(model, created_layers) 2389 return model 2390 2391 def to_json(self, **kwargs): 2392 """Returns a JSON string containing the network configuration. 2393 2394 To load a network from a JSON save file, use 2395 `keras.models.model_from_json(json_string, custom_objects={})`. 2396 2397 Args: 2398 **kwargs: Additional keyword arguments 2399 to be passed to `json.dumps()`. 2400 2401 Returns: 2402 A JSON string. 2403 """ 2404 model_config = self._updated_config() 2405 return json.dumps( 2406 model_config, default=json_utils.get_json_type, **kwargs) 2407 2408 def to_yaml(self, **kwargs): 2409 """Returns a yaml string containing the network configuration. 2410 2411 Note: Since TF 2.6, this method is no longer supported and will raise a 2412 RuntimeError. 2413 2414 To load a network from a yaml save file, use 2415 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 2416 2417 `custom_objects` should be a dictionary mapping 2418 the names of custom losses / layers / etc to the corresponding 2419 functions / classes. 2420 2421 Args: 2422 **kwargs: Additional keyword arguments 2423 to be passed to `yaml.dump()`. 2424 2425 Returns: 2426 A YAML string. 2427 2428 Raises: 2429 RuntimeError: announces that the method poses a security risk 2430 """ 2431 raise RuntimeError( 2432 'Method `model.to_yaml()` has been removed due to security risk of ' 2433 'arbitrary code execution. Please use `model.to_json()` instead.' 2434 ) 2435 2436 def reset_states(self): 2437 for layer in self.layers: 2438 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 2439 layer.reset_states() 2440 2441 @property 2442 @doc_controls.do_not_generate_docs 2443 def state_updates(self): 2444 """Deprecated, do NOT use! 2445 2446 Returns the `updates` from all layers that are stateful. 2447 2448 This is useful for separating training updates and 2449 state updates, e.g. when we need to update a layer's internal state 2450 during prediction. 2451 2452 Returns: 2453 A list of update ops. 2454 """ 2455 warnings.warn('`Model.state_updates` will be removed in a future version. ' 2456 'This property should not be used in TensorFlow 2.0, ' 2457 'as `updates` are applied automatically.') 2458 state_updates = [] 2459 for layer in self.layers: 2460 if getattr(layer, 'stateful', False): 2461 if hasattr(layer, 'updates'): 2462 state_updates += layer.updates 2463 return state_updates 2464 2465 @property 2466 def weights(self): 2467 """Returns the list of all layer variables/weights. 2468 2469 Note: This will not track the weights of nested `tf.Modules` that are not 2470 themselves Keras layers. 2471 2472 Returns: 2473 A list of variables. 2474 """ 2475 return self._dedup_weights(self._undeduplicated_weights) 2476 2477 @property 2478 def _undeduplicated_weights(self): 2479 """Returns the undeduplicated list of all layer variables/weights.""" 2480 self._assert_weights_created() 2481 weights = [] 2482 for layer in self._self_tracked_trackables: 2483 weights += layer.variables 2484 weights += (self._trainable_weights + self._non_trainable_weights) 2485 return weights 2486 2487 def summary(self, line_length=None, positions=None, print_fn=None): 2488 """Prints a string summary of the network. 2489 2490 Args: 2491 line_length: Total length of printed lines 2492 (e.g. set this to adapt the display to different 2493 terminal window sizes). 2494 positions: Relative or absolute positions of log elements 2495 in each line. If not provided, 2496 defaults to `[.33, .55, .67, 1.]`. 2497 print_fn: Print function to use. Defaults to `print`. 2498 It will be called on each line of the summary. 2499 You can set it to a custom function 2500 in order to capture the string summary. 2501 2502 Raises: 2503 ValueError: if `summary()` is called before the model is built. 2504 """ 2505 if not self.built: 2506 raise ValueError('This model has not yet been built. ' 2507 'Build the model first by calling `build()` or calling ' 2508 '`fit()` with some data, or specify ' 2509 'an `input_shape` argument in the first layer(s) for ' 2510 'automatic build.') 2511 layer_utils.print_summary(self, 2512 line_length=line_length, 2513 positions=positions, 2514 print_fn=print_fn) 2515 2516 @property 2517 def layers(self): 2518 return list(self._flatten_layers(include_self=False, recursive=False)) 2519 2520 def get_layer(self, name=None, index=None): 2521 """Retrieves a layer based on either its name (unique) or index. 2522 2523 If `name` and `index` are both provided, `index` will take precedence. 2524 Indices are based on order of horizontal graph traversal (bottom-up). 2525 2526 Args: 2527 name: String, name of layer. 2528 index: Integer, index of layer. 2529 2530 Returns: 2531 A layer instance. 2532 2533 Raises: 2534 ValueError: In case of invalid layer name or index. 2535 """ 2536 # TODO(fchollet): We could build a dictionary based on layer names 2537 # since they are constant, but we have not done that yet. 2538 if index is not None and name is not None: 2539 raise ValueError('Provide only a layer name or a layer index.') 2540 2541 if index is not None: 2542 if len(self.layers) <= index: 2543 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 2544 ' but model only has ' + str(len(self.layers)) + 2545 ' layers.') 2546 else: 2547 return self.layers[index] 2548 2549 if name is not None: 2550 for layer in self.layers: 2551 if layer.name == name: 2552 return layer 2553 raise ValueError('No such layer: ' + name + '.') 2554 raise ValueError('Provide either a layer name or layer index.') 2555 2556 @trackable.no_automatic_dependency_tracking 2557 def _set_save_spec(self, inputs): 2558 if self._saved_model_inputs_spec is not None: 2559 return # Already set. 2560 2561 input_names = self.input_names 2562 if not input_names: 2563 input_names = compile_utils.create_pseudo_input_names(inputs) 2564 2565 flat_inputs = nest.flatten(inputs) 2566 specs = [] 2567 for name, tensor in zip(input_names, flat_inputs): 2568 specs.append( 2569 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)) 2570 specs = nest.pack_sequence_as(inputs, specs) 2571 2572 self._saved_model_inputs_spec = specs 2573 2574 # Store the input shapes 2575 if (self.__class__.__name__ == 'Sequential' and 2576 self._build_input_shape is None): 2577 self._build_input_shape = nest.map_structure( 2578 lambda x: None if x is None else x.shape, specs) 2579 2580 def _assert_weights_created(self): 2581 """Asserts that all the weights for the model have been created. 2582 2583 For a non-dynamic model, the weights must already be created after the 2584 layer has been called. For a dynamic model, the exact list of weights can 2585 never be known for certain since it may change at any time during execution. 2586 2587 We run this check right before accessing weights or getting the Numpy value 2588 for the current weights. Otherwise, if the layer has never been called, 2589 the user would just get an empty list, which is misleading. 2590 2591 Raises: 2592 ValueError: if the weights of the network has not yet been created. 2593 """ 2594 if self.dynamic: 2595 return 2596 2597 if ('build' in self.__class__.__dict__ and 2598 self.__class__ != Model and 2599 not self.built): 2600 # For any model that has customized build() method but hasn't 2601 # been invoked yet, this will cover both sequential and subclass model. 2602 # Also make sure to exclude Model class itself which has build() defined. 2603 raise ValueError('Weights for model %s have not yet been created. ' 2604 'Weights are created when the Model is first called on ' 2605 'inputs or `build()` is called with an `input_shape`.' % 2606 self.name) 2607 2608 def _check_call_args(self, method_name): 2609 """Check that `call` has only one positional arg.""" 2610 # Always allow first arg, regardless of arg name. 2611 fullargspec = self._call_full_argspec 2612 if fullargspec.defaults: 2613 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 2614 else: 2615 positional_args = fullargspec.args 2616 if 'training' in positional_args: 2617 positional_args.remove('training') 2618 2619 # self and first arg can be positional. 2620 if len(positional_args) > 2: 2621 extra_args = positional_args[2:] 2622 raise ValueError( 2623 'Models passed to `' + method_name + '` can only have `training` ' 2624 'and the first argument in `call` as positional arguments, ' 2625 'found: ' + str(extra_args) + '.') 2626 2627 def _validate_compile(self, optimizer, metrics, **kwargs): 2628 """Performs validation checks for the default `compile`.""" 2629 if any( 2630 isinstance(opt, optimizer_v1.Optimizer) 2631 for opt in nest.flatten(optimizer)): 2632 raise ValueError( 2633 '`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 2634 'not supported when eager execution is enabled. Use a ' 2635 '`tf.keras` Optimizer instead, or disable eager ' 2636 'execution.') 2637 2638 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 2639 kwargs.pop('experimental_run_tf_function', None) # Always `True`. 2640 if kwargs.pop('distribute', None) is not None: 2641 raise ValueError( 2642 'Distribute argument in compile is not available in TF 2.0 please ' 2643 'create the model under the distribution strategy scope.') 2644 if kwargs.pop('target_tensors', None) is not None: 2645 raise ValueError( 2646 'target_tensors argument is not supported when executing eagerly.') 2647 invalid_kwargs = set(kwargs) - {'sample_weight_mode'} 2648 if invalid_kwargs: 2649 raise TypeError('Invalid keyword argument(s) in `compile`: %s' % 2650 (invalid_kwargs,)) 2651 2652 # Model must be created and compiled with the same DistStrat. 2653 if self.built and ds_context.has_strategy(): 2654 strategy = ds_context.get_strategy() 2655 for v in self.variables: 2656 if not strategy.extended.variable_created_in_scope(v): 2657 raise ValueError( 2658 'Variable (%s) was not created in the distribution strategy ' 2659 'scope of (%s). It is most likely due to not all layers or ' 2660 'the model or optimizer being created outside the distribution ' 2661 'strategy scope. Try to make sure your code looks similar ' 2662 'to the following.\n' 2663 'with strategy.scope():\n' 2664 ' model=_create_model()\n' 2665 ' model.compile(...)' % (v, strategy)) 2666 2667 # Model metrics must be created in the same distribution strategy scope 2668 # as the model. 2669 strategy = self.distribute_strategy 2670 for metric in nest.flatten(metrics): 2671 for v in getattr(metric, 'variables', []): 2672 if not strategy.extended.variable_created_in_scope(v): 2673 raise ValueError( 2674 'Metric (%s) passed to model.compile was created inside of a ' 2675 'different distribution strategy scope than the model. All ' 2676 'metrics must be created in the same distribution strategy ' 2677 'scope as the model (in this case %s). If you pass in a string ' 2678 'identifier for a metric to compile the metric will ' 2679 'automatically be created in the correct distribution ' 2680 'strategy scope.' % (metric, strategy) 2681 ) 2682 2683 # Model metrics must be created in the same distribution strategy scope 2684 # as the model. 2685 for opt in nest.flatten(optimizer): 2686 for v in getattr(opt, '_weights', []): 2687 if not strategy.extended.variable_created_in_scope(v): 2688 raise ValueError( 2689 'Optimizer (%s) passed to model.compile was created inside of a ' 2690 'different distribution strategy scope than the model. All ' 2691 'optimizers must be created in the same distribution strategy ' 2692 'scope as the model (in this case %s). If you pass in a string ' 2693 'identifier for an optimizer to compile the optimizer will ' 2694 'automatically be created in the correct distribution ' 2695 'strategy scope.' % (opt, strategy)) 2696 2697 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): 2698 """Maybe load initial epoch from ckpt considering possible worker recovery. 2699 2700 Refer to tensorflow/python/keras/distribute/worker_training_state.py 2701 for more information. 2702 2703 Args: 2704 initial_epoch: The original initial_epoch user passes in in `fit()`. 2705 2706 Returns: 2707 If the training is recovering from previous failure under multi-worker 2708 training setting, return the epoch the training is supposed to continue 2709 at. Otherwise, return the `initial_epoch` the user passes in. 2710 """ 2711 if self._training_state is not None: 2712 return self._training_state.maybe_load_initial_epoch_from_ckpt( 2713 initial_epoch, mode=ModeKeys.TRAIN) 2714 return initial_epoch 2715 2716 def _assert_compile_was_called(self): 2717 # Checks whether `compile` has been called. If it has been called, 2718 # then the optimizer is set. This is different from whether the 2719 # model is compiled 2720 # (i.e. whether the model is built and its inputs/outputs are set). 2721 if not self._is_compiled: 2722 raise RuntimeError('You must compile your model before ' 2723 'training/testing. ' 2724 'Use `model.compile(optimizer, loss)`.') 2725 2726 def _set_inputs(self, inputs, outputs=None, training=None): 2727 """This method is for compat with Modelv1. Only inputs are needed here.""" 2728 self._set_save_spec(inputs) 2729 2730 @property 2731 def _trackable_saved_model_saver(self): 2732 return model_serialization.ModelSavedModelSaver(self) 2733 2734 def _trackable_children(self, save_type='checkpoint', **kwargs): 2735 if save_type == 'savedmodel': 2736 # SavedModel needs to ignore the execution functions. 2737 train_function = self.train_function 2738 test_function = self.test_function 2739 predict_function = self.predict_function 2740 train_tf_function = self.train_tf_function 2741 self.train_function = None 2742 self.test_function = None 2743 self.predict_function = None 2744 self.train_tf_function = None 2745 2746 children = super(Model, self)._trackable_children(save_type, **kwargs) 2747 2748 if save_type == 'savedmodel': 2749 self.train_function = train_function 2750 self.test_function = test_function 2751 self.predict_function = predict_function 2752 self.train_tf_function = train_tf_function 2753 2754 return children 2755 2756 def _should_eval(self, epoch, validation_freq): 2757 epoch = epoch + 1 # one-index the user-facing epoch. 2758 if isinstance(validation_freq, int): 2759 return epoch % validation_freq == 0 2760 elif isinstance(validation_freq, list): 2761 return epoch in validation_freq 2762 else: 2763 raise ValueError('Expected `validation_freq` to be a list or int.') 2764 2765 ###################################################################### 2766 # Functions below exist only as v1 / v2 compatibility shims. 2767 ###################################################################### 2768 2769 def _get_compile_args(self, user_metrics=True): 2770 """Used for saving or cloning a Model. 2771 2772 Args: 2773 user_metrics: Whether to return user-supplied metrics or `Metric` objects. 2774 Defaults to returning the user-supplied metrics. 2775 2776 Returns: 2777 Dictionary of arguments that were used when compiling the model. 2778 """ 2779 self._assert_compile_was_called() 2780 # pylint: disable=protected-access 2781 2782 saved_metrics = self.compiled_metrics._user_metrics 2783 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics 2784 2785 if not user_metrics: 2786 if saved_metrics is not None: 2787 saved_metrics = self.compiled_metrics._metrics 2788 if saved_weighted_metrics is not None: 2789 saved_weighted_metrics = self.compiled_metrics._weighted_metrics 2790 2791 compile_args = { 2792 'optimizer': self.optimizer, 2793 'loss': self.compiled_loss._user_losses, 2794 'metrics': saved_metrics, 2795 'weighted_metrics': saved_weighted_metrics, 2796 'loss_weights': self.compiled_loss._user_loss_weights, 2797 } 2798 # pylint: enable=protected-access 2799 return compile_args 2800 2801 def _get_callback_model(self): 2802 return self 2803 2804 def _in_multi_worker_mode(self): 2805 return self.distribute_strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2806 2807 @property 2808 def _compile_was_called(self): 2809 return self._is_compiled 2810 2811 2812def reduce_per_replica(values, strategy, reduction='first'): 2813 """Reduce PerReplica objects. 2814 2815 Args: 2816 values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are 2817 returned as-is. 2818 strategy: `tf.distribute.Strategy` object. 2819 reduction: One of 'first', 'concat'. 2820 2821 Returns: 2822 Structure of `Tensor`s. 2823 """ 2824 2825 def _reduce(v): 2826 """Reduce a single `PerReplica` object.""" 2827 if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy): 2828 return _multi_worker_concat(v, strategy) 2829 if not _is_per_replica_instance(v): 2830 return v 2831 elif reduction == 'first': 2832 return strategy.unwrap(v)[0] 2833 elif reduction == 'concat': 2834 if _is_tpu_multi_host(strategy): 2835 return _tpu_multi_host_concat(v, strategy) 2836 else: 2837 return concat(strategy.unwrap(v)) 2838 else: 2839 raise ValueError('`reduction` must be "first" or "concat".') 2840 2841 return nest.map_structure(_reduce, values) 2842 2843 2844def concat(tensors, axis=0): 2845 """Concats `tensor`s along `axis`.""" 2846 if isinstance(tensors[0], sparse_tensor.SparseTensor): 2847 return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) 2848 elif _is_scalar(tensors[0]): 2849 return array_ops.stack(tensors, axis=axis) 2850 else: 2851 return array_ops.concat(tensors, axis=axis) 2852 2853 2854def _is_tpu_multi_host(strategy): 2855 return (backend.is_tpu_strategy(strategy) and 2856 strategy.extended.num_hosts > 1) 2857 2858 2859def _tpu_multi_host_concat(v, strategy): 2860 """Correctly order TPU PerReplica objects.""" 2861 replicas = strategy.unwrap(v) 2862 # When distributed datasets are created from Tensors / NumPy, 2863 # TPUStrategy.experimental_distribute_dataset shards data in 2864 # (Replica, Host) order, and TPUStrategy.unwrap returns it in 2865 # (Host, Replica) order. 2866 # TODO(b/150317897): Figure out long-term plan here. 2867 num_replicas_per_host = strategy.extended.num_replicas_per_host 2868 ordered_replicas = [] 2869 for replica_id in range(num_replicas_per_host): 2870 ordered_replicas += replicas[replica_id::num_replicas_per_host] 2871 return concat(ordered_replicas) 2872 2873 2874def _collective_all_reduce_multi_worker(strategy): 2875 return (isinstance(strategy, 2876 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 2877 ) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2878 2879 2880# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather 2881# for all strategies 2882def _multi_worker_concat(v, strategy): 2883 """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" 2884 replicas = strategy.gather(v, axis=0) 2885 # v might not have the same shape on different replicas 2886 if _is_per_replica_instance(v): 2887 shapes = array_ops.concat([ 2888 array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) 2889 for single_value in v.values 2890 ], 2891 axis=0) 2892 all_shapes = strategy.gather(shapes, axis=0) 2893 else: 2894 # v is a tensor. This may happen when, say, we have 2x1 multi-worker. 2895 all_shapes = strategy.gather( 2896 array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) 2897 2898 replicas = array_ops.split( 2899 replicas, 2900 num_or_size_splits=all_shapes, 2901 num=strategy.num_replicas_in_sync) 2902 ordered_replicas = [] 2903 num_replicas_per_worker = len(strategy.extended.worker_devices) 2904 for replica_id in range(num_replicas_per_worker): 2905 ordered_replicas += replicas[replica_id::num_replicas_per_worker] 2906 return concat(ordered_replicas) 2907 2908 2909def _is_scalar(x): 2910 return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 2911 2912 2913def write_scalar_summaries(logs, step): 2914 for name, value in logs.items(): 2915 if _is_scalar(value): 2916 summary_ops_v2.scalar('batch_' + name, value, step=step) 2917 2918 2919def _minimum_control_deps(outputs): 2920 """Returns the minimum control dependencies to ensure step succeeded.""" 2921 if context.executing_eagerly(): 2922 return [] # Control dependencies not needed. 2923 outputs = nest.flatten(outputs, expand_composites=True) 2924 for out in outputs: 2925 # Variables can't be control dependencies. 2926 if not isinstance(out, variables.Variable): 2927 return [out] # Return first Tensor or Op from outputs. 2928 return [] # No viable Tensor or Op to use for control deps. 2929 2930 2931def _disallow_inside_tf_function(method_name): 2932 if ops.inside_function(): 2933 error_msg = ( 2934 'Detected a call to `Model.{method_name}` inside a `tf.function`. ' 2935 '`Model.{method_name} is a high-level endpoint that manages its own ' 2936 '`tf.function`. Please move the call to `Model.{method_name}` outside ' 2937 'of all enclosing `tf.function`s. Note that you can call a `Model` ' 2938 'directly on `Tensor`s inside a `tf.function` like: `model(x)`.' 2939 ).format(method_name=method_name) 2940 raise RuntimeError(error_msg) 2941 2942 2943def _detect_save_format(filepath): 2944 """Returns path to weights file and save format.""" 2945 2946 filepath = path_to_string(filepath) 2947 if saving_utils.is_hdf5_filepath(filepath): 2948 return filepath, 'h5' 2949 2950 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel 2951 # directory. It's possible for filepath to be both a prefix and directory. 2952 # Prioritize checkpoint over SavedModel. 2953 if _is_readable_tf_checkpoint(filepath): 2954 save_format = 'tf' 2955 elif sm_loader.contains_saved_model(filepath): 2956 ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY, 2957 sm_constants.VARIABLES_FILENAME) 2958 if _is_readable_tf_checkpoint(ckpt_path): 2959 filepath = ckpt_path 2960 save_format = 'tf' 2961 else: 2962 raise ValueError('Unable to load weights. filepath {} appears to be a ' 2963 'SavedModel directory, but checkpoint either doesn\'t ' 2964 'exist, or is incorrectly formatted.'.format(filepath)) 2965 else: 2966 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that 2967 # doesn't have the hdf5/keras extensions. 2968 save_format = 'h5' 2969 return filepath, save_format 2970 2971 2972def _is_readable_tf_checkpoint(filepath): 2973 try: 2974 py_checkpoint_reader.NewCheckpointReader(filepath) 2975 return True 2976 except errors_impl.DataLossError: 2977 # The checkpoint is not readable in TensorFlow format. 2978 return False 2979 2980 2981def flatten_metrics_in_order(logs, metrics_names): 2982 """Turns the `logs` dict into a list as per key order of `metrics_names`.""" 2983 results = [] 2984 for name in metrics_names: 2985 if name in logs: 2986 results.append(logs[name]) 2987 for key in sorted(logs.keys()): 2988 if key not in metrics_names: 2989 results.append(logs[key]) 2990 if len(results) == 1: 2991 return results[0] 2992 return results 2993 2994 2995def _is_per_replica_instance(obj): 2996 return (isinstance(obj, ds_values.DistributedValues) and 2997 isinstance(obj, composite_tensor.CompositeTensor)) 2998