1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""V1 Training-related part of the Keras engine.""" 16 17import collections 18import warnings 19 20import numpy as np 21 22from tensorflow.python import tf2 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import iterator_ops 25from tensorflow.python.distribute import distribution_strategy_context 26from tensorflow.python.distribute import parameter_server_strategy 27from tensorflow.python.distribute import parameter_server_strategy_v2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework import type_spec 37from tensorflow.python.keras import backend 38from tensorflow.python.keras import losses 39from tensorflow.python.keras import metrics as metrics_module 40from tensorflow.python.keras import optimizer_v1 41from tensorflow.python.keras import optimizers 42from tensorflow.python.keras.distribute import distributed_training_utils 43from tensorflow.python.keras.distribute import distributed_training_utils_v1 44from tensorflow.python.keras.engine import base_layer 45from tensorflow.python.keras.engine import training as training_lib 46from tensorflow.python.keras.engine import training_arrays_v1 47from tensorflow.python.keras.engine import training_distributed_v1 48from tensorflow.python.keras.engine import training_eager_v1 49from tensorflow.python.keras.engine import training_generator_v1 50from tensorflow.python.keras.engine import training_utils 51from tensorflow.python.keras.engine import training_utils_v1 52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 53from tensorflow.python.keras.mixed_precision import policy 54from tensorflow.python.keras.optimizer_v2 import optimizer_v2 55from tensorflow.python.keras.saving import saving_utils 56from tensorflow.python.keras.saving.saved_model import model_serialization 57from tensorflow.python.keras.utils import data_utils 58from tensorflow.python.keras.utils import layer_utils 59from tensorflow.python.keras.utils import losses_utils 60from tensorflow.python.keras.utils import tf_inspect 61from tensorflow.python.keras.utils import tf_utils 62from tensorflow.python.keras.utils.mode_keys import ModeKeys 63from tensorflow.python.ops import array_ops 64from tensorflow.python.ops import math_ops 65from tensorflow.python.platform import tf_logging as logging 66from tensorflow.python.trackable import base as trackable 67from tensorflow.python.util import nest 68 69try: 70 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 71except ImportError: 72 issparse = None 73 74 75class Model(training_lib.Model): 76 """`Model` groups layers into an object with training and inference features. 77 78 There are two ways to instantiate a `Model`: 79 80 1 - With the "functional API", where you start from `Input`, 81 you chain layer calls to specify the model's forward pass, 82 and finally you create your model from inputs and outputs: 83 84 ```python 85 import tensorflow as tf 86 87 inputs = tf.keras.Input(shape=(3,)) 88 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 89 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 90 model = tf.keras.Model(inputs=inputs, outputs=outputs) 91 ``` 92 93 2 - By subclassing the `Model` class: in that case, you should define your 94 layers in `__init__` and you should implement the model's forward pass 95 in `call`. 96 97 ```python 98 import tensorflow as tf 99 100 class MyModel(tf.keras.Model): 101 102 def __init__(self): 103 super(MyModel, self).__init__() 104 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 105 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 106 107 def call(self, inputs): 108 x = self.dense1(inputs) 109 return self.dense2(x) 110 111 model = MyModel() 112 ``` 113 114 If you subclass `Model`, you can optionally have 115 a `training` argument (boolean) in `call`, which you can use to specify 116 a different behavior in training and inference: 117 118 ```python 119 import tensorflow as tf 120 121 class MyModel(tf.keras.Model): 122 123 def __init__(self): 124 super(MyModel, self).__init__() 125 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 126 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 127 self.dropout = tf.keras.layers.Dropout(0.5) 128 129 def call(self, inputs, training=False): 130 x = self.dense1(inputs) 131 if training: 132 x = self.dropout(x, training=training) 133 return self.dense2(x) 134 135 model = MyModel() 136 ``` 137 """ 138 139 def __init__(self, *args, **kwargs): 140 super(Model, self).__init__(*args, **kwargs) 141 # initializing _distribution_strategy here since it is possible to call 142 # predict on a model without compiling it. 143 self._distribution_strategy = None 144 self._compile_time_distribution_strategy = None 145 if (ops.executing_eagerly_outside_functions() and 146 distribution_strategy_context.has_strategy()): 147 self._set_strategy( 148 distribution_strategy_context.get_strategy()) 149 150 # This flag is used to track if the user is using the deprecated path of 151 # passing distribution strategy to compile rather than creating the model 152 # under distribution strategy scope. 153 self._compile_distribution = False 154 155 self._run_eagerly = None 156 self._experimental_run_tf_function = ( 157 ops.executing_eagerly_outside_functions()) 158 159 self._v1_compile_was_called = False 160 161 def _init_batch_counters(self): 162 pass # Batch counters should not be created in legacy graph mode. 163 164 @trackable.no_automatic_dependency_tracking 165 def _set_strategy(self, strategy): 166 self._compile_time_distribution_strategy = strategy 167 168 def get_weights(self): 169 """Retrieves the weights of the model. 170 171 Returns: 172 A flat list of Numpy arrays. 173 """ 174 strategy = (self._distribution_strategy or 175 self._compile_time_distribution_strategy) 176 if strategy: 177 with strategy.scope(): 178 return base_layer.Layer.get_weights(self) 179 return base_layer.Layer.get_weights(self) 180 181 def load_weights(self, filepath, by_name=False, skip_mismatch=False): 182 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 183 184 If `by_name` is False weights are loaded based on the network's 185 topology. This means the architecture should be the same as when the weights 186 were saved. Note that layers that don't have weights are not taken into 187 account in the topological ordering, so adding or removing layers is fine as 188 long as they don't have weights. 189 190 If `by_name` is True, weights are loaded into layers only if they share the 191 same name. This is useful for fine-tuning or transfer-learning models where 192 some of the layers have changed. 193 194 Only topological loading (`by_name=False`) is supported when loading weights 195 from the TensorFlow format. Note that topological loading differs slightly 196 between TensorFlow and HDF5 formats for user-defined classes inheriting from 197 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 198 TensorFlow format loads based on the object-local names of attributes to 199 which layers are assigned in the `Model`'s constructor. 200 201 Args: 202 filepath: String, path to the weights file to load. For weight files in 203 TensorFlow format, this is the file prefix (the same as was passed 204 to `save_weights`). 205 by_name: Boolean, whether to load weights by name or by topological 206 order. Only topological loading is supported for weight files in 207 TensorFlow format. 208 skip_mismatch: Boolean, whether to skip loading of layers where there is 209 a mismatch in the number of weights, or a mismatch in the shape of 210 the weight (only valid when `by_name=True`). 211 212 Returns: 213 When loading a weight file in TensorFlow format, returns the same status 214 object as `tf.train.Checkpoint.restore`. When graph building, restore 215 ops are run automatically as soon as the network is built (on first call 216 for user-defined classes inheriting from `Model`, immediately if it is 217 already built). 218 219 When loading weights in HDF5 format, returns `None`. 220 221 Raises: 222 ImportError: If h5py is not available and the weight file is in HDF5 223 format. 224 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 225 `False`. 226 """ 227 if backend.is_tpu_strategy(self._distribution_strategy): 228 if (self._distribution_strategy.extended.steps_per_run > 1 and 229 (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access 230 raise ValueError('Load weights is not yet supported with TPUStrategy ' 231 'with steps_per_run greater than 1.') 232 return super(Model, self).load_weights(filepath, by_name, skip_mismatch) 233 234 @trackable.no_automatic_dependency_tracking 235 def compile(self, 236 optimizer='rmsprop', 237 loss=None, 238 metrics=None, 239 loss_weights=None, 240 sample_weight_mode=None, 241 weighted_metrics=None, 242 target_tensors=None, 243 distribute=None, 244 **kwargs): 245 """Configures the model for training. 246 247 Args: 248 optimizer: String (name of optimizer) or optimizer instance. 249 See `tf.keras.optimizers`. 250 loss: String (name of objective function), objective function or 251 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 252 function is any callable with the signature 253 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple 254 outputs, you can use a different loss on each output by passing a 255 dictionary or a list of losses. The loss value that will be 256 minimized by the model will then be the sum of all individual 257 losses. 258 metrics: List of metrics to be evaluated by the model during training 259 and testing. Typically you will use `metrics=['accuracy']`. 260 To specify different metrics for different outputs of a 261 multi-output model, you could also pass a dictionary, such as 262 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 263 You can also pass a list (len = len(outputs)) of lists of metrics 264 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or 265 `metrics=['accuracy', ['accuracy', 'mse']]`. 266 loss_weights: Optional list or dictionary specifying scalar 267 coefficients (Python floats) to weight the loss contributions 268 of different model outputs. 269 The loss value that will be minimized by the model 270 will then be the *weighted sum* of all individual losses, 271 weighted by the `loss_weights` coefficients. 272 If a list, it is expected to have a 1:1 mapping 273 to the model's outputs. If a tensor, it is expected to map 274 output names (strings) to scalar coefficients. 275 sample_weight_mode: If you need to do timestep-wise 276 sample weighting (2D weights), set this to `"temporal"`. 277 `None` defaults to sample-wise weights (1D). 278 If the model has multiple outputs, you can use a different 279 `sample_weight_mode` on each output by passing a 280 dictionary or a list of modes. 281 weighted_metrics: List of metrics to be evaluated and weighted 282 by sample_weight or class_weight during training and testing. 283 target_tensors: By default, Keras will create placeholders for the 284 model's target, which will be fed with the target data during 285 training. If instead you would like to use your own 286 target tensors (in turn, Keras will not expect external 287 Numpy data for these targets at training time), you 288 can specify them via the `target_tensors` argument. It can be 289 a single tensor (for a single-output model), a list of tensors, 290 or a dict mapping output names to target tensors. 291 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the 292 model under distribution strategy scope instead of passing it to 293 compile. 294 **kwargs: Any additional arguments. 295 296 Raises: 297 ValueError: In case of invalid arguments for 298 `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 299 """ 300 self._assert_built_as_v1() 301 self._run_eagerly = kwargs.pop('run_eagerly', None) 302 self._experimental_run_tf_function = kwargs.pop( 303 'experimental_run_tf_function', True) 304 self._v1_compile_was_called = True 305 306 # Prepare Session arguments (legacy). 307 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 308 self._from_serialized = kwargs.pop('from_serialized', False) 309 allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'} 310 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs 311 if unknown_kwargs: 312 raise TypeError( 313 'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,)) 314 self._function_kwargs = kwargs 315 if self._function_kwargs: 316 self._experimental_run_tf_function = False 317 if self.run_eagerly: 318 raise ValueError( 319 'Session keyword arguments are not supported ' 320 'when `run_eagerly=True`. You passed the following ' 321 'Session arguments: %s' % (self._function_kwargs,)) 322 323 self._set_optimizer(optimizer) 324 is_any_keras_optimizer_v1 = any( 325 (isinstance(opt, optimizer_v1.Optimizer) 326 and not isinstance(opt, optimizer_v1.TFOptimizer) 327 ) for opt in nest.flatten(self.optimizer)) 328 329 if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions(): 330 raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 331 'not supported when eager execution is enabled. Use a ' 332 '`tf.keras` Optimizer instead, or disable eager ' 333 'execution.') 334 335 if ((target_tensors is not None) 336 or not ops.executing_eagerly_outside_functions()): 337 # Fallback out of things that aren't supported with v2 loops 338 self._experimental_run_tf_function = False 339 340 if distribute is not None: 341 if tf2.enabled() or self._experimental_run_tf_function: 342 raise ValueError( 343 'Distribute argument in compile is not available in TF 2.0 please ' 344 'create the model under the distribution strategy scope.') 345 logging.warning('Distribute argument in compile is deprecated please ' 346 'create the model under the distribution strategy scope.') 347 self._distribution_strategy = distribute 348 self._compile_distribution = True 349 else: 350 if distribution_strategy_context.has_strategy(): 351 # When the user builds the model in the DS scope and cross replica 352 # context we want distribution strategy to be set but when building the 353 # replica copies of the models internally we should not be compiling 354 # with distribution strategy and use the default compilation path. 355 if distribution_strategy_context.in_cross_replica_context(): 356 self._distribution_strategy = ( 357 distribution_strategy_context.get_strategy()) 358 359 if isinstance(self._distribution_strategy, 360 parameter_server_strategy.ParameterServerStrategyV1): 361 raise NotImplementedError( 362 '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` ' 363 'currently only works with the tf.Estimator API') 364 365 if isinstance(self._distribution_strategy, 366 parameter_server_strategy_v2.ParameterServerStrategyV2): 367 raise NotImplementedError( 368 '`tf.distribute.experimental.ParameterServerStrategy` is only ' 369 'supported in TF2.') 370 371 if not self._experimental_run_tf_function: 372 self._validate_compile_param_for_distribution_strategy(self.run_eagerly, 373 sample_weight_mode, 374 target_tensors, 375 weighted_metrics) 376 # We've disabled automatic dependency tracking for this method, but do want 377 # to add a checkpoint dependency on the optimizer if it's trackable. 378 if isinstance(self.optimizer, trackable.Trackable): 379 self._track_trackable( 380 self.optimizer, name='optimizer', overwrite=True) 381 self.loss = loss or {} 382 self.loss_weights = loss_weights 383 self.sample_weight_mode = sample_weight_mode 384 self._compile_metrics = metrics or [] 385 self._compile_weighted_metrics = weighted_metrics 386 if self.run_eagerly and target_tensors is not None: 387 raise ValueError( 388 'target_tensors argument is not supported when ' 389 'running a model eagerly.') 390 391 # _training_endpoints contains a list of _TrainingEndpoint object, which has 392 # all the model output/target/loss and related metadata. 393 self._training_endpoints = [] 394 395 # Used to freeze the behavior of the Model once `compile` has been called. 396 self._compiled_trainable_state = self._get_trainable_state() 397 398 # Set tf.distribute.Strategy specific parameters. 399 self._distributed_model_cache = {} 400 self._distributed_function_cache = {} 401 402 # Clear any `_eager_losses` that was added. 403 self._clear_losses() 404 405 if (not context.executing_eagerly() and 406 self._distribution_strategy is not None): 407 # Ensures a Session is created and configured correctly for Distribution 408 # Strategy. 409 backend.configure_and_create_distributed_session( 410 self._distribution_strategy) 411 # Initialize model metric attributes. 412 self._init_metric_attributes() 413 if not self.built or not self.inputs or not self.outputs: 414 # Model is not compilable because it does not know its number of inputs 415 # and outputs, nor their shapes and names. We will compile after the first 416 # time the model gets called on training data. 417 return 418 self._is_compiled = True 419 420 # Prepare list of loss functions, same size of model outputs. 421 self.loss_functions = training_utils_v1.prepare_loss_functions( 422 self.loss, self.output_names) 423 424 target_tensors = self._process_target_tensor_for_compile(target_tensors) 425 426 for o, n, l, t in zip(self.outputs, self.output_names, 427 self.loss_functions, target_tensors): 428 endpoint = _TrainingEndpoint(o, n, l) 429 endpoint.create_training_target(t, run_eagerly=self.run_eagerly) 430 self._training_endpoints.append(endpoint) 431 432 # Prepare list loss weights, same size of model outputs. 433 training_utils_v1.prepare_loss_weights(self._training_endpoints, 434 loss_weights) 435 436 # Initialization for Eager mode execution. 437 if self.run_eagerly: 438 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode) 439 return 440 441 with backend.get_graph().as_default(): 442 # Save all metric attributes per output of the model. 443 self._cache_output_metric_attributes(metrics, weighted_metrics) 444 445 # Set metric attributes on model. 446 self._set_metric_attributes() 447 448 # Invoke metric functions (unweighted) for all the outputs. 449 self._handle_metrics( 450 self.outputs, 451 targets=self._targets, 452 skip_target_masks=self._prepare_skip_target_masks(), 453 masks=self._prepare_output_masks()) 454 455 # Prepare sample weight modes. List with the same length as model outputs. 456 training_utils_v1.prepare_sample_weight_modes( 457 self._training_endpoints, sample_weight_mode) 458 459 # Creates the model loss and weighted metrics sub-graphs. 460 self._compile_weights_loss_and_weighted_metrics() 461 462 # Functions for train, test and predict will 463 # be compiled lazily when required. 464 # This saves time when the user is not using all functions. 465 self.train_function = None 466 self.test_function = None 467 self.predict_function = None 468 469 # Collected trainable weights, sorted in topological order. 470 self._collected_trainable_weights = self.trainable_weights 471 472 # Validate all variables were correctly created in distribution scope. 473 if self._distribution_strategy and not self._compile_distribution: 474 for v in self.variables: 475 strategy = self._distribution_strategy 476 if not strategy.extended.variable_created_in_scope(v): 477 raise ValueError( 478 'Variable (%s) was not created in the distribution strategy ' 479 'scope of (%s). It is most likely due to not all layers or ' 480 'the model or optimizer being created outside the distribution ' 481 'strategy scope. Try to make sure your code looks similar ' 482 'to the following.\n' 483 'with strategy.scope():\n' 484 ' model=_create_model()\n' 485 ' model.compile(...)'% (v, strategy)) 486 487 @trackable.no_automatic_dependency_tracking 488 def _init_distributed_function_cache_if_not_compiled(self): 489 if not hasattr(self, '_distributed_function_cache'): 490 self._distributed_function_cache = {} 491 492 @property 493 def metrics(self): 494 """Returns the model's metrics added using `compile`, `add_metric` APIs.""" 495 metrics = [] 496 if self._is_compiled: 497 if not hasattr(self, '_v1_compile_was_called'): 498 # See b/155687393 for more details, the model is created as a v2 499 # instance but converted to v1. Fallback to use base Model to retrieve 500 # the metrics. 501 return super(Model, self).metrics 502 metrics += self._compile_metric_functions 503 metrics.extend(self._metrics) 504 metrics.extend( 505 _get_metrics_from_layers( 506 list(self._flatten_layers(include_self=False, recursive=False)))) 507 return metrics 508 509 @property 510 def metrics_names(self): 511 """Returns the model's display labels for all outputs.""" 512 513 # This property includes all output names including `loss` and per-output 514 # losses for backward compatibility. 515 metrics_names = ['loss'] 516 if self._is_compiled: 517 if not hasattr(self, '_v1_compile_was_called'): 518 # See b/155687393 for more details, the model is created as a v2 519 # instance but converted to v1. Fallback to use base Model to retrieve 520 # the metrics name 521 return super(Model, self).metrics_names 522 523 # Add output loss metric names to the metric names list. 524 if len(self._training_endpoints) > 1: 525 metrics_names.extend([ 526 e.loss_name() 527 for e in self._training_endpoints 528 if not e.should_skip_target() 529 ]) 530 531 # Add all metric names. 532 metrics_names += [m.name for m in self.metrics] 533 return metrics_names 534 535 @property 536 def run_eagerly(self): 537 """Settable attribute indicating whether the model should run eagerly. 538 539 Running eagerly means that your model will be run step by step, 540 like Python code. Your model might run slower, but it should become easier 541 for you to debug it by stepping into individual layer calls. 542 543 By default, we will attempt to compile your model to a static graph to 544 deliver the best execution performance. 545 546 Returns: 547 Boolean, whether the model should run eagerly. 548 """ 549 if self._run_eagerly is True and not context.executing_eagerly(): 550 raise ValueError('You can only set `run_eagerly=True` if eager execution ' 551 'is enabled.') 552 if not self.dynamic: 553 if self._run_eagerly is None: 554 # Respect `tf.config.run_functions_eagerly` unless 555 # `run_eagerly` was explicitly passed to `compile`. 556 return def_function.functions_run_eagerly() 557 else: 558 return self._run_eagerly 559 else: 560 if not context.executing_eagerly(): 561 raise ValueError('Your model contains layers that can only be ' 562 'successfully run in eager execution (layers ' 563 'constructed with `dynamic=True`). ' 564 'You must enable eager execution with ' 565 '`tf.enable_eager_execution()`.') 566 if self._run_eagerly is False: 567 # TODO(fchollet): consider using py_func to enable this. 568 raise ValueError('Your model contains layers that can only be ' 569 'successfully run in eager execution (layers ' 570 'constructed with `dynamic=True`). ' 571 'You cannot set `run_eagerly=False`.') 572 return context.executing_eagerly() 573 574 @run_eagerly.setter 575 def run_eagerly(self, value): 576 self._run_eagerly = value 577 578 def _select_training_loop(self, inputs): 579 """Select training loop for fit/eval/predict based on the inputs.""" 580 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely 581 # integrated into the data adapters in the v2 loop. We can't do this yet 582 # because we currently have to fall back for unhandled data types. 583 if isinstance(inputs, (iterator_ops.Iterator, 584 iterator_ops.IteratorBase)): 585 raise ValueError('For performance reasons Keras `fit`, `evaluate` and' 586 '`predict` accept tf.data `Datasets` as input but not ' 587 'iterators that have been manually generated from ' 588 'Datasets by users. Please directly pass in the ' 589 'original `Dataset` object instead of passing in ' 590 '`iter(dataset)`.') 591 592 # Case 1: distribution strategy. 593 if self._distribution_strategy: 594 if self._in_multi_worker_mode(): 595 return training_distributed_v1.DistributionMultiWorkerTrainingLoop( 596 training_distributed_v1.DistributionSingleWorkerTrainingLoop()) 597 else: 598 return training_distributed_v1.DistributionSingleWorkerTrainingLoop() 599 600 # Case 2: generator-like. Input is Python generator, or Sequence object, 601 # or a non-distributed Dataset or iterator in eager execution. 602 if data_utils.is_generator_or_sequence(inputs): 603 return training_generator_v1.GeneratorOrSequenceTrainingLoop() 604 if training_utils_v1.is_eager_dataset_or_iterator(inputs): 605 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop() 606 607 # Case 3: Symbolic tensors or Numpy array-like. 608 # This includes Datasets and iterators in graph mode (since they 609 # generate symbolic tensors). 610 if self.run_eagerly: 611 return training_generator_v1.GeneratorLikeTrainingLoop() 612 else: 613 return training_arrays_v1.ArrayLikeTrainingLoop() 614 615 def fit(self, 616 x=None, 617 y=None, 618 batch_size=None, 619 epochs=1, 620 verbose=1, 621 callbacks=None, 622 validation_split=0., 623 validation_data=None, 624 shuffle=True, 625 class_weight=None, 626 sample_weight=None, 627 initial_epoch=0, 628 steps_per_epoch=None, 629 validation_steps=None, 630 validation_freq=1, 631 max_queue_size=10, 632 workers=1, 633 use_multiprocessing=False, 634 **kwargs): 635 """Trains the model for a fixed number of epochs (iterations on a dataset). 636 637 Args: 638 x: Input data. It could be: 639 - A Numpy array (or array-like), or a list of arrays 640 (in case the model has multiple inputs). 641 - A TensorFlow tensor, or a list of tensors 642 (in case the model has multiple inputs). 643 - A dict mapping input names to the corresponding array/tensors, 644 if the model has named inputs. 645 - A `tf.data` dataset. Should return a tuple 646 of either `(inputs, targets)` or 647 `(inputs, targets, sample_weights)`. 648 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 649 or `(inputs, targets, sample weights)`. 650 y: Target data. Like the input data `x`, 651 it could be either Numpy array(s) or TensorFlow tensor(s). 652 It should be consistent with `x` (you cannot have Numpy inputs and 653 tensor targets, or inversely). If `x` is a dataset, generator, 654 or `keras.utils.Sequence` instance, `y` should 655 not be specified (since targets will be obtained from `x`). 656 batch_size: Integer or `None`. 657 Number of samples per gradient update. 658 If unspecified, `batch_size` will default to 32. 659 Do not specify the `batch_size` if your data is in the 660 form of symbolic tensors, datasets, 661 generators, or `keras.utils.Sequence` instances (since they generate 662 batches). 663 epochs: Integer. Number of epochs to train the model. 664 An epoch is an iteration over the entire `x` and `y` 665 data provided. 666 Note that in conjunction with `initial_epoch`, 667 `epochs` is to be understood as "final epoch". 668 The model is not trained for a number of iterations 669 given by `epochs`, but merely until the epoch 670 of index `epochs` is reached. 671 verbose: 0, 1, or 2. Verbosity mode. 672 0 = silent, 1 = progress bar, 2 = one line per epoch. 673 Note that the progress bar is not particularly useful when 674 logged to a file, so verbose=2 is recommended when not running 675 interactively (eg, in a production environment). 676 callbacks: List of `keras.callbacks.Callback` instances. 677 List of callbacks to apply during training. 678 See `tf.keras.callbacks`. 679 validation_split: Float between 0 and 1. 680 Fraction of the training data to be used as validation data. 681 The model will set apart this fraction of the training data, 682 will not train on it, and will evaluate 683 the loss and any model metrics 684 on this data at the end of each epoch. 685 The validation data is selected from the last samples 686 in the `x` and `y` data provided, before shuffling. This argument is 687 not supported when `x` is a dataset, generator or 688 `keras.utils.Sequence` instance. 689 validation_data: Data on which to evaluate 690 the loss and any model metrics at the end of each epoch. 691 The model will not be trained on this data. 692 `validation_data` will override `validation_split`. 693 `validation_data` could be: 694 - tuple `(x_val, y_val)` of Numpy arrays or tensors 695 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 696 - dataset 697 For the first two cases, `batch_size` must be provided. 698 For the last case, `validation_steps` could be provided. 699 shuffle: Boolean (whether to shuffle the training data 700 before each epoch) or str (for 'batch'). 701 'batch' is a special option for dealing with the 702 limitations of HDF5 data; it shuffles in batch-sized chunks. 703 Has no effect when `steps_per_epoch` is not `None`. 704 class_weight: Optional dictionary mapping class indices (integers) 705 to a weight (float) value, used for weighting the loss function 706 (during training only). 707 This can be useful to tell the model to 708 "pay more attention" to samples from 709 an under-represented class. 710 sample_weight: Optional Numpy array of weights for 711 the training samples, used for weighting the loss function 712 (during training only). You can either pass a flat (1D) 713 Numpy array with the same length as the input samples 714 (1:1 mapping between weights and samples), 715 or in the case of temporal data, 716 you can pass a 2D array with shape 717 `(samples, sequence_length)`, 718 to apply a different weight to every timestep of every sample. 719 In this case you should make sure to specify 720 `sample_weight_mode="temporal"` in `compile()`. This argument is not 721 supported when `x` is a dataset, generator, or 722 `keras.utils.Sequence` instance, instead provide the sample_weights 723 as the third element of `x`. 724 initial_epoch: Integer. 725 Epoch at which to start training 726 (useful for resuming a previous training run). 727 steps_per_epoch: Integer or `None`. 728 Total number of steps (batches of samples) 729 before declaring one epoch finished and starting the 730 next epoch. When training with input tensors such as 731 TensorFlow data tensors, the default `None` is equal to 732 the number of samples in your dataset divided by 733 the batch size, or 1 if that cannot be determined. If x is a 734 `tf.data` dataset, and 'steps_per_epoch' 735 is None, the epoch will run until the input dataset is exhausted. 736 This argument is not supported with array inputs. 737 validation_steps: Only relevant if `validation_data` is provided and 738 is a `tf.data` dataset. Total number of steps (batches of 739 samples) to draw before stopping when performing validation 740 at the end of every epoch. If 'validation_steps' is None, validation 741 will run until the `validation_data` dataset is exhausted. In the 742 case of a infinite dataset, it will run into a infinite loop. 743 If 'validation_steps' is specified and only part of the dataset 744 will be consumed, the evaluation will start from the beginning of 745 the dataset at each epoch. This ensures that the same validation 746 samples are used every time. 747 validation_freq: Only relevant if validation data is provided. Integer 748 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 749 If an integer, specifies how many training epochs to run before a 750 new validation run is performed, e.g. `validation_freq=2` runs 751 validation every 2 epochs. If a Container, specifies the epochs on 752 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 753 validation at the end of the 1st, 2nd, and 10th epochs. 754 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 755 input only. Maximum size for the generator queue. 756 If unspecified, `max_queue_size` will default to 10. 757 workers: Integer. Used for generator or `keras.utils.Sequence` input 758 only. Maximum number of processes to spin up 759 when using process-based threading. If unspecified, `workers` 760 will default to 1. If 0, will execute the generator on the main 761 thread. 762 use_multiprocessing: Boolean. Used for generator or 763 `keras.utils.Sequence` input only. If `True`, use process-based 764 threading. If unspecified, `use_multiprocessing` will default to 765 `False`. Note that because this implementation relies on 766 multiprocessing, you should not pass non-picklable arguments to 767 the generator as they can't be passed easily to children processes. 768 **kwargs: Used for backwards compatibility. 769 770 Returns: 771 A `History` object. Its `History.history` attribute is 772 a record of training loss values and metrics values 773 at successive epochs, as well as validation loss values 774 and validation metrics values (if applicable). 775 776 Raises: 777 RuntimeError: If the model was never compiled. 778 ValueError: In case of mismatch between the provided input data 779 and what the model expects. 780 """ 781 self._assert_built_as_v1() 782 # Legacy support 783 if 'nb_epoch' in kwargs: 784 logging.warning( 785 'The `nb_epoch` argument in `fit` has been renamed `epochs`.') 786 epochs = kwargs.pop('nb_epoch') 787 if kwargs: 788 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 789 self._assert_compile_was_called() 790 self._check_call_args('fit') 791 792 func = self._select_training_loop(x) 793 return func.fit( 794 self, 795 x=x, 796 y=y, 797 batch_size=batch_size, 798 epochs=epochs, 799 verbose=verbose, 800 callbacks=callbacks, 801 validation_split=validation_split, 802 validation_data=validation_data, 803 shuffle=shuffle, 804 class_weight=class_weight, 805 sample_weight=sample_weight, 806 initial_epoch=initial_epoch, 807 steps_per_epoch=steps_per_epoch, 808 validation_steps=validation_steps, 809 validation_freq=validation_freq, 810 max_queue_size=max_queue_size, 811 workers=workers, 812 use_multiprocessing=use_multiprocessing) 813 814 def evaluate(self, 815 x=None, 816 y=None, 817 batch_size=None, 818 verbose=1, 819 sample_weight=None, 820 steps=None, 821 callbacks=None, 822 max_queue_size=10, 823 workers=1, 824 use_multiprocessing=False): 825 """Returns the loss value & metrics values for the model in test mode. 826 827 Computation is done in batches (see the `batch_size` arg.) 828 829 Args: 830 x: Input data. It could be: 831 - A Numpy array (or array-like), or a list of arrays 832 (in case the model has multiple inputs). 833 - A TensorFlow tensor, or a list of tensors 834 (in case the model has multiple inputs). 835 - A dict mapping input names to the corresponding array/tensors, 836 if the model has named inputs. 837 - A `tf.data` dataset. 838 - A generator or `keras.utils.Sequence` instance. 839 y: Target data. Like the input data `x`, 840 it could be either Numpy array(s) or TensorFlow tensor(s). 841 It should be consistent with `x` (you cannot have Numpy inputs and 842 tensor targets, or inversely). 843 If `x` is a dataset, generator or 844 `keras.utils.Sequence` instance, `y` should not be specified (since 845 targets will be obtained from the iterator/dataset). 846 batch_size: Integer or `None`. 847 Number of samples per batch of computation. 848 If unspecified, `batch_size` will default to 32. 849 Do not specify the `batch_size` if your data is in the 850 form of symbolic tensors, dataset, 851 generators, or `keras.utils.Sequence` instances (since they generate 852 batches). 853 verbose: 0 or 1. Verbosity mode. 854 0 = silent, 1 = progress bar. 855 sample_weight: Optional Numpy array of weights for 856 the test samples, used for weighting the loss function. 857 You can either pass a flat (1D) 858 Numpy array with the same length as the input samples 859 (1:1 mapping between weights and samples), 860 or in the case of temporal data, 861 you can pass a 2D array with shape 862 `(samples, sequence_length)`, 863 to apply a different weight to every timestep of every sample. 864 In this case you should make sure to specify 865 `sample_weight_mode="temporal"` in `compile()`. This argument is not 866 supported when `x` is a dataset, instead pass 867 sample weights as the third element of `x`. 868 steps: Integer or `None`. 869 Total number of steps (batches of samples) 870 before declaring the evaluation round finished. 871 Ignored with the default value of `None`. 872 If x is a `tf.data` dataset and `steps` is 873 None, 'evaluate' will run until the dataset is exhausted. 874 This argument is not supported with array inputs. 875 callbacks: List of `keras.callbacks.Callback` instances. 876 List of callbacks to apply during evaluation. 877 See [callbacks](/api_docs/python/tf/keras/callbacks). 878 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 879 input only. Maximum size for the generator queue. 880 If unspecified, `max_queue_size` will default to 10. 881 workers: Integer. Used for generator or `keras.utils.Sequence` input 882 only. Maximum number of processes to spin up when using 883 process-based threading. If unspecified, `workers` will default 884 to 1. If 0, will execute the generator on the main thread. 885 use_multiprocessing: Boolean. Used for generator or 886 `keras.utils.Sequence` input only. If `True`, use process-based 887 threading. If unspecified, `use_multiprocessing` will default to 888 `False`. Note that because this implementation relies on 889 multiprocessing, you should not pass non-picklable arguments to 890 the generator as they can't be passed easily to children processes. 891 892 Returns: 893 Scalar test loss (if the model has a single output and no metrics) 894 or list of scalars (if the model has multiple outputs 895 and/or metrics). The attribute `model.metrics_names` will give you 896 the display labels for the scalar outputs. 897 898 Raises: 899 ValueError: in case of invalid arguments. 900 """ 901 self._assert_built_as_v1() 902 self._assert_compile_was_called() 903 self._check_call_args('evaluate') 904 905 func = self._select_training_loop(x) 906 return func.evaluate( 907 self, 908 x=x, 909 y=y, 910 batch_size=batch_size, 911 verbose=verbose, 912 sample_weight=sample_weight, 913 steps=steps, 914 callbacks=callbacks, 915 max_queue_size=max_queue_size, 916 workers=workers, 917 use_multiprocessing=use_multiprocessing) 918 919 def predict(self, 920 x, 921 batch_size=None, 922 verbose=0, 923 steps=None, 924 callbacks=None, 925 max_queue_size=10, 926 workers=1, 927 use_multiprocessing=False): 928 """Generates output predictions for the input samples. 929 930 Computation is done in batches (see the `batch_size` arg.) 931 932 Args: 933 x: Input samples. It could be: 934 - A Numpy array (or array-like), or a list of arrays 935 (in case the model has multiple inputs). 936 - A TensorFlow tensor, or a list of tensors 937 (in case the model has multiple inputs). 938 - A `tf.data` dataset. 939 - A generator or `keras.utils.Sequence` instance. 940 batch_size: Integer or `None`. 941 Number of samples per batch of computation. 942 If unspecified, `batch_size` will default to 32. 943 Do not specify the `batch_size` if your data is in the 944 form of symbolic tensors, dataset, 945 generators, or `keras.utils.Sequence` instances (since they generate 946 batches). 947 verbose: Verbosity mode, 0 or 1. 948 steps: Total number of steps (batches of samples) 949 before declaring the prediction round finished. 950 Ignored with the default value of `None`. If x is a `tf.data` 951 dataset and `steps` is None, `predict` will 952 run until the input dataset is exhausted. 953 callbacks: List of `keras.callbacks.Callback` instances. 954 List of callbacks to apply during prediction. 955 See [callbacks](/api_docs/python/tf/keras/callbacks). 956 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 957 input only. Maximum size for the generator queue. 958 If unspecified, `max_queue_size` will default to 10. 959 workers: Integer. Used for generator or `keras.utils.Sequence` input 960 only. Maximum number of processes to spin up when using 961 process-based threading. If unspecified, `workers` will default 962 to 1. If 0, will execute the generator on the main thread. 963 use_multiprocessing: Boolean. Used for generator or 964 `keras.utils.Sequence` input only. If `True`, use process-based 965 threading. If unspecified, `use_multiprocessing` will default to 966 `False`. Note that because this implementation relies on 967 multiprocessing, you should not pass non-picklable arguments to 968 the generator as they can't be passed easily to children processes. 969 970 971 Returns: 972 Numpy array(s) of predictions. 973 974 Raises: 975 ValueError: In case of mismatch between the provided 976 input data and the model's expectations, 977 or in case a stateful model receives a number of samples 978 that is not a multiple of the batch size. 979 """ 980 self._assert_built_as_v1() 981 self._check_call_args('predict') 982 983 func = self._select_training_loop(x) 984 return func.predict( 985 self, 986 x=x, 987 batch_size=batch_size, 988 verbose=verbose, 989 steps=steps, 990 callbacks=callbacks, 991 max_queue_size=max_queue_size, 992 workers=workers, 993 use_multiprocessing=use_multiprocessing) 994 995 def reset_metrics(self): 996 """Resets the state of metrics.""" 997 metrics = self._get_training_eval_metrics() 998 for m in metrics: 999 m.reset_state() 1000 1001 # Reset metrics on all the distributed (cloned) models. 1002 if self._distribution_strategy: 1003 distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access 1004 1005 def train_on_batch(self, 1006 x, 1007 y=None, 1008 sample_weight=None, 1009 class_weight=None, 1010 reset_metrics=True): 1011 """Runs a single gradient update on a single batch of data. 1012 1013 Args: 1014 x: Input data. It could be: 1015 - A Numpy array (or array-like), or a list of arrays 1016 (in case the model has multiple inputs). 1017 - A TensorFlow tensor, or a list of tensors 1018 (in case the model has multiple inputs). 1019 - A dict mapping input names to the corresponding array/tensors, 1020 if the model has named inputs. 1021 - A `tf.data` dataset. 1022 y: Target data. Like the input data `x`, it could be either Numpy 1023 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1024 (you cannot have Numpy inputs and tensor targets, or inversely). If 1025 `x` is a dataset, `y` should not be specified 1026 (since targets will be obtained from the iterator). 1027 sample_weight: Optional array of the same length as x, containing 1028 weights to apply to the model's loss for each sample. In the case of 1029 temporal data, you can pass a 2D array with shape (samples, 1030 sequence_length), to apply a different weight to every timestep of 1031 every sample. In this case you should make sure to specify 1032 sample_weight_mode="temporal" in compile(). This argument is not 1033 supported when `x` is a dataset. 1034 class_weight: Optional dictionary mapping class indices (integers) to a 1035 weight (float) to apply to the model's loss for the samples from this 1036 class during training. This can be useful to tell the model to "pay 1037 more attention" to samples from an under-represented class. 1038 reset_metrics: If `True`, the metrics returned will be only for this 1039 batch. If `False`, the metrics will be statefully accumulated across 1040 batches. 1041 1042 Returns: 1043 Scalar training loss 1044 (if the model has a single output and no metrics) 1045 or list of scalars (if the model has multiple outputs 1046 and/or metrics). The attribute `model.metrics_names` will give you 1047 the display labels for the scalar outputs. 1048 1049 Raises: 1050 ValueError: In case of invalid user-provided arguments. 1051 """ 1052 self._assert_compile_was_called() 1053 self._check_call_args('train_on_batch') 1054 1055 # If at this point we are in the replica context, then it is okay to execute 1056 # the Eager code path. The expected way to get here is to call `fit` that 1057 # calls `train_on_batch` on each replica. 1058 if (self._distribution_strategy and 1059 distribution_strategy_context.in_cross_replica_context()): 1060 raise NotImplementedError('`train_on_batch` is not supported for models ' 1061 'distributed with tf.distribute.Strategy.') 1062 # Validate and standardize user data. 1063 x, y, sample_weights = self._standardize_user_data( 1064 x, y, sample_weight=sample_weight, class_weight=class_weight, 1065 extract_tensors_from_dataset=True) 1066 1067 # If `self._distribution_strategy` is True, then we are in a replica context 1068 # at this point because of the check above. `train_on_batch` is being run 1069 # for each replica by `self._distribution_strategy` and the same code path 1070 # as Eager is expected to be taken. 1071 if self.run_eagerly or self._distribution_strategy: 1072 output_dict = training_eager_v1.train_on_batch( 1073 self, 1074 x, 1075 y, 1076 sample_weights=sample_weights, 1077 output_loss_metrics=self._output_loss_metrics) 1078 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 1079 + output_dict['metrics']) 1080 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 1081 else: 1082 x = training_utils_v1.ModelInputs(x).as_list() 1083 ins = x + list(y or []) + list(sample_weights or []) 1084 1085 if not isinstance(backend.symbolic_learning_phase(), int): 1086 ins += [True] # Add learning phase value. 1087 1088 self._update_sample_weight_modes(sample_weights=sample_weights) 1089 self._make_train_function() 1090 outputs = self.train_function(ins) # pylint: disable=not-callable 1091 1092 if reset_metrics: 1093 self.reset_metrics() 1094 1095 if len(outputs) == 1: 1096 return outputs[0] 1097 return outputs 1098 1099 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): 1100 """Test the model on a single batch of samples. 1101 1102 Args: 1103 x: Input data. It could be: 1104 - A Numpy array (or array-like), or a list of arrays 1105 (in case the model has multiple inputs). 1106 - A TensorFlow tensor, or a list of tensors 1107 (in case the model has multiple inputs). 1108 - A dict mapping input names to the corresponding array/tensors, 1109 if the model has named inputs. 1110 - A `tf.data` dataset. 1111 y: Target data. Like the input data `x`, 1112 it could be either Numpy array(s) or TensorFlow tensor(s). 1113 It should be consistent with `x` (you cannot have Numpy inputs and 1114 tensor targets, or inversely). If `x` is a dataset `y` should 1115 not be specified (since targets will be obtained from the iterator). 1116 sample_weight: Optional array of the same length as x, containing 1117 weights to apply to the model's loss for each sample. 1118 In the case of temporal data, you can pass a 2D array 1119 with shape (samples, sequence_length), 1120 to apply a different weight to every timestep of every sample. 1121 In this case you should make sure to specify 1122 sample_weight_mode="temporal" in compile(). This argument is not 1123 supported when `x` is a dataset. 1124 reset_metrics: If `True`, the metrics returned will be only for this 1125 batch. If `False`, the metrics will be statefully accumulated across 1126 batches. 1127 1128 Returns: 1129 Scalar test loss (if the model has a single output and no metrics) 1130 or list of scalars (if the model has multiple outputs 1131 and/or metrics). The attribute `model.metrics_names` will give you 1132 the display labels for the scalar outputs. 1133 1134 Raises: 1135 ValueError: In case of invalid user-provided arguments. 1136 """ 1137 self._assert_compile_was_called() 1138 self._check_call_args('test_on_batch') 1139 1140 if (self._distribution_strategy and 1141 distribution_strategy_context.in_cross_replica_context()): 1142 raise NotImplementedError('`test_on_batch` is not supported for models ' 1143 'distributed with tf.distribute.Strategy.') 1144 # Validate and standardize user data. 1145 x, y, sample_weights = self._standardize_user_data( 1146 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True) 1147 1148 # If `self._distribution_strategy` is True, then we are in a replica context 1149 # at this point. 1150 if self.run_eagerly or self._distribution_strategy: 1151 output_dict = training_eager_v1.test_on_batch( 1152 self, 1153 x, 1154 y, 1155 sample_weights=sample_weights, 1156 output_loss_metrics=self._output_loss_metrics) 1157 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 1158 + output_dict['metrics']) 1159 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 1160 else: 1161 x = training_utils_v1.ModelInputs(x).as_list() 1162 inputs = x + list(y or []) + list(sample_weights or []) 1163 1164 self._update_sample_weight_modes(sample_weights=sample_weights) 1165 self._make_test_function() 1166 outputs = self.test_function(inputs) # pylint: disable=not-callable 1167 1168 if reset_metrics: 1169 self.reset_metrics() 1170 1171 if len(outputs) == 1: 1172 return outputs[0] 1173 return outputs 1174 1175 def predict_on_batch(self, x): 1176 """Returns predictions for a single batch of samples. 1177 1178 Args: 1179 x: Input data. It could be: 1180 - A Numpy array (or array-like), or a list of arrays 1181 (in case the model has multiple inputs). 1182 - A TensorFlow tensor, or a list of tensors 1183 (in case the model has multiple inputs). 1184 - A `tf.data` dataset. 1185 1186 Returns: 1187 Numpy array(s) of predictions. 1188 1189 Raises: 1190 ValueError: In case of mismatch between given number of inputs and 1191 expectations of the model. 1192 """ 1193 self._check_call_args('predict_on_batch') 1194 1195 if (self._distribution_strategy and 1196 distribution_strategy_context.in_cross_replica_context()): 1197 raise NotImplementedError( 1198 '`predict_on_batch` is not supported for models distributed with' 1199 ' tf.distribute.Strategy.') 1200 # Validate and standardize user data. 1201 inputs, _, _ = self._standardize_user_data( 1202 x, extract_tensors_from_dataset=True) 1203 # If `self._distribution_strategy` is True, then we are in a replica context 1204 # at this point. 1205 if self.run_eagerly or self._distribution_strategy: 1206 inputs = training_utils_v1.cast_if_floating_dtype(inputs) 1207 if isinstance(inputs, collections.abc.Sequence): 1208 # Unwrap lists with only one input, as we do when training on batch 1209 if len(inputs) == 1: 1210 inputs = inputs[0] 1211 1212 return self(inputs) # pylint: disable=not-callable 1213 1214 self._make_predict_function() 1215 outputs = self.predict_function(inputs) 1216 1217 if len(outputs) == 1: 1218 return outputs[0] 1219 return outputs 1220 1221 def fit_generator(self, 1222 generator, 1223 steps_per_epoch=None, 1224 epochs=1, 1225 verbose=1, 1226 callbacks=None, 1227 validation_data=None, 1228 validation_steps=None, 1229 validation_freq=1, 1230 class_weight=None, 1231 max_queue_size=10, 1232 workers=1, 1233 use_multiprocessing=False, 1234 shuffle=True, 1235 initial_epoch=0): 1236 """Fits the model on data yielded batch-by-batch by a Python generator. 1237 1238 DEPRECATED: 1239 `Model.fit` now supports generators, so there is no longer any need to use 1240 this endpoint. 1241 """ 1242 warnings.warn('`model.fit_generator` is deprecated and ' 1243 'will be removed in a future version. ' 1244 'Please use `Model.fit`, which supports generators.') 1245 return self.fit( 1246 generator, 1247 steps_per_epoch=steps_per_epoch, 1248 epochs=epochs, 1249 verbose=verbose, 1250 callbacks=callbacks, 1251 validation_data=validation_data, 1252 validation_steps=validation_steps, 1253 validation_freq=validation_freq, 1254 class_weight=class_weight, 1255 max_queue_size=max_queue_size, 1256 workers=workers, 1257 use_multiprocessing=use_multiprocessing, 1258 shuffle=shuffle, 1259 initial_epoch=initial_epoch) 1260 1261 def evaluate_generator(self, 1262 generator, 1263 steps=None, 1264 callbacks=None, 1265 max_queue_size=10, 1266 workers=1, 1267 use_multiprocessing=False, 1268 verbose=0): 1269 """Evaluates the model on a data generator. 1270 1271 DEPRECATED: 1272 `Model.evaluate` now supports generators, so there is no longer any need 1273 to use this endpoint. 1274 """ 1275 warnings.warn('`Model.evaluate_generator` is deprecated and ' 1276 'will be removed in a future version. ' 1277 'Please use `Model.evaluate`, which supports generators.') 1278 self._check_call_args('evaluate_generator') 1279 1280 return self.evaluate( 1281 generator, 1282 steps=steps, 1283 max_queue_size=max_queue_size, 1284 workers=workers, 1285 use_multiprocessing=use_multiprocessing, 1286 verbose=verbose, 1287 callbacks=callbacks) 1288 1289 def predict_generator(self, 1290 generator, 1291 steps=None, 1292 callbacks=None, 1293 max_queue_size=10, 1294 workers=1, 1295 use_multiprocessing=False, 1296 verbose=0): 1297 """Generates predictions for the input samples from a data generator. 1298 1299 DEPRECATED: 1300 `Model.predict` now supports generators, so there is no longer any need 1301 to use this endpoint. 1302 """ 1303 warnings.warn('`Model.predict_generator` is deprecated and ' 1304 'will be removed in a future version. ' 1305 'Please use `Model.predict`, which supports generators.') 1306 return self.predict( 1307 generator, 1308 steps=steps, 1309 max_queue_size=max_queue_size, 1310 workers=workers, 1311 use_multiprocessing=use_multiprocessing, 1312 verbose=verbose, 1313 callbacks=callbacks) 1314 1315 def _check_call_args(self, method_name): 1316 """Check that `call` has only one positional arg.""" 1317 # Always allow first arg, regardless of arg name. 1318 fullargspec = self._call_full_argspec 1319 if fullargspec.defaults: 1320 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 1321 else: 1322 positional_args = fullargspec.args 1323 if 'training' in positional_args: 1324 positional_args.remove('training') 1325 1326 # self and first arg can be positional. 1327 if len(positional_args) > 2: 1328 extra_args = positional_args[2:] 1329 raise ValueError( 1330 'Models passed to `' + method_name + '` can only have `training` ' 1331 'and the first argument in `call` as positional arguments, ' 1332 'found: ' + str(extra_args) + '.') 1333 1334 def _set_optimizer(self, optimizer): 1335 """Sets self.optimizer. 1336 1337 Sets self.optimizer to `optimizer`, potentially wrapping it with a 1338 LossScaleOptimizer. 1339 1340 Args: 1341 optimizer: The optimizer(s) to assign to self.optimizer. 1342 """ 1343 if isinstance(optimizer, (list, tuple)): 1344 self.optimizer = [optimizers.get(opt) for opt in optimizer] 1345 else: 1346 self.optimizer = optimizers.get(optimizer) 1347 1348 if isinstance(self._dtype_policy, policy.PolicyV1): 1349 loss_scale = self._dtype_policy.loss_scale 1350 elif self._dtype_policy.name == 'mixed_float16': 1351 loss_scale = 'dynamic' 1352 else: 1353 loss_scale = None 1354 1355 if (loss_scale is not None and 1356 not isinstance(self.optimizer, 1357 loss_scale_optimizer.LossScaleOptimizer)): 1358 if isinstance(self.optimizer, list): 1359 raise ValueError('When a dtype policy with a loss scale is used, you ' 1360 'can only pass a single optimizer. Using policy %s ' 1361 'and got optimizers: %s' % 1362 self._dtype_policy, self.optimizer) 1363 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): 1364 raise ValueError('"optimizer" must be an instance of ' 1365 'tf.keras.optimizers.Optimizer when a dype policy ' 1366 'with a loss scale used, but got: %s. Using policy: ' 1367 '%s' % 1368 (self.optimizer, self._dtype_policy)) 1369 if loss_scale == 'dynamic': 1370 self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer) 1371 else: 1372 self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1( 1373 self.optimizer, loss_scale) 1374 1375 def _prepare_validation_data(self, validation_data, batch_size, 1376 validation_steps): 1377 """Unpack and check the validation data.""" 1378 val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data( 1379 validation_data) 1380 return self._standardize_user_data( 1381 val_x, 1382 val_y, 1383 sample_weight=val_sample_weights, 1384 batch_size=batch_size, 1385 steps=validation_steps, 1386 steps_name='validation_steps') 1387 1388 def _validate_compile_param_for_distribution_strategy( 1389 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics): 1390 # Validate that arguments passed by the user to `compile` are supported by 1391 # tf.distribute.Strategy. 1392 if self._distribution_strategy: 1393 if sample_weight_mode: 1394 raise NotImplementedError('sample_weight_mode is not supported with ' 1395 'tf.distribute.Strategy.') 1396 if weighted_metrics: 1397 raise NotImplementedError('weighted_metrics is not supported with ' 1398 'tf.distribute.Strategy.') 1399 if target_tensors: 1400 raise ValueError('target_tensors is not supported with ' 1401 'tf.distribute.Strategy.') 1402 1403 if run_eagerly: 1404 raise ValueError( 1405 'We currently do not support enabling `run_eagerly` with ' 1406 'distribution strategy.') 1407 1408 if (distributed_training_utils_v1.is_distributing_by_cloning(self) and 1409 (not self.built or not self.inputs or not self.outputs)): 1410 raise ValueError( 1411 'We currently do not support distribution strategy with a ' 1412 '`Sequential` model that is created without `input_shape`/' 1413 '`input_dim` set in its first layer or a subclassed model.') 1414 1415 def _process_target_tensor_for_compile(self, target_tensors): 1416 if self.run_eagerly: 1417 # target tensor is not supported with run_eagerly. Create a list with None 1418 # as placeholder for each output. 1419 return [None for _ in self.output_names] 1420 1421 if target_tensors is not None and not (isinstance(target_tensors, list) and 1422 target_tensors == []): # pylint: disable=g-explicit-bool-comparison 1423 if isinstance(target_tensors, list): 1424 if len(target_tensors) != len(self.outputs): 1425 raise ValueError( 1426 'When passing a list as `target_tensors`, ' 1427 'it should have one entry per model output. ' 1428 'The model has %s outputs, but you passed target_tensors=%s' % 1429 (len(self.outputs), target_tensors)) 1430 elif isinstance(target_tensors, dict): 1431 unexpected_target_tensor_names = set(target_tensors.keys()).difference( 1432 self.output_names) 1433 if unexpected_target_tensor_names: 1434 raise ValueError( 1435 'Unknown entry in `target_tensors` dictionary: "{name}". ' 1436 'Only expected the following keys: {keys}'.format( 1437 name=unexpected_target_tensor_names, 1438 keys=str(self.output_names))) 1439 tmp_target_tensors = [] 1440 for name in self.output_names: 1441 tmp_target_tensors.append(target_tensors.get(name, None)) 1442 target_tensors = tmp_target_tensors 1443 elif tensor_util.is_tf_type(target_tensors): 1444 target_tensors = [target_tensors] 1445 else: 1446 raise TypeError('Expected `target_tensors` to be a list or tuple or ' 1447 'dict or a single tensor, but got:', target_tensors) 1448 else: 1449 # In case target tensor is empty or None, create a list with Nones 1450 # that has same length as self.output_names. With that, the None check of 1451 # target tensor can be skipped downstream. 1452 target_tensors = [None for _ in self.output_names] 1453 return target_tensors 1454 1455 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): 1456 # Prepare sample weight modes. List with the same length as model outputs. 1457 training_utils_v1.prepare_sample_weight_modes( 1458 self._training_endpoints, sample_weight_mode) 1459 # Prepare sample weights. 1460 self._prepare_sample_weights() 1461 # Save all metric attributes per output of the model. 1462 self._cache_output_metric_attributes(metrics, weighted_metrics) 1463 self.total_loss = None 1464 # Set metric attributes on model. 1465 self._set_metric_attributes() 1466 1467 self._collected_trainable_weights = self.trainable_weights 1468 1469 def _update_sample_weight_modes(self, sample_weights=None): 1470 """Updates sample weight modes based on training/eval inputs. 1471 1472 Sample weight placeholders will be created for all or no outputs 1473 based on whether sample_weight is provided for any output. 1474 1475 If model contains `_sample_weight_modes` we check if the input 1476 `sample_weights` corresponds to the sample weight modes. 1477 1. Set sample weight mode to be 'temporal' for output i, if `compile` 1478 sample_weight_mode was set to `temporal` and sample weight inputs 1479 are given for one or more outputs. 1480 2. Set sample weight mode to be 'samplewise' for output i, if `compile` 1481 sample_weight_mode was not set and sample weight inputs are given for 1482 one or more outputs. 1483 3. Reset sample weight mode to None for output i if sample weight mode 1484 was set but there is no sample weight input. 1485 1486 Args: 1487 sample_weights: List of sample weights of the same length as model outputs 1488 or None. 1489 """ 1490 if not self._is_compiled: 1491 return 1492 if sample_weights and any(s is not None for s in sample_weights): 1493 for endpoint in self._training_endpoints: 1494 endpoint.sample_weight_mode = ( 1495 endpoint.sample_weight_mode or 'samplewise') 1496 else: 1497 for endpoint in self._training_endpoints: 1498 endpoint.sample_weight_mode = None 1499 1500 def _recompile_weights_loss_and_weighted_metrics(self): 1501 if not self._is_compiled: 1502 return False 1503 recompile = any( 1504 e.sample_weights_mismatch() for e in self._training_endpoints) 1505 1506 if recompile: 1507 self._compile_weights_loss_and_weighted_metrics() 1508 return recompile 1509 1510 @trackable.no_automatic_dependency_tracking 1511 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None): 1512 """Compiles the model loss and weighted metric sub-graphs. 1513 1514 This may be used to set graph tensors as sample weights (instead of creating 1515 placeholders). This functionality is necessary for 1516 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1 1517 graph, and creates iterator tensors for inputs, targets, and sample weights. 1518 1519 Args: 1520 sample_weights: List of tensors to use as the sample weights. Must be the 1521 same length as the number of outputs. If left as `None`, placeholders 1522 are used instead. 1523 """ 1524 with backend.get_graph().as_default(): 1525 if sample_weights is not None: 1526 self._update_sample_weight_modes(sample_weights) 1527 self._prepare_sample_weights(sample_weights) 1528 1529 masks = self._prepare_output_masks() 1530 1531 # Compute weighted metrics. 1532 self._handle_metrics( 1533 self.outputs, 1534 targets=self._targets, 1535 skip_target_masks=self._prepare_skip_target_masks(), 1536 sample_weights=self.sample_weights, 1537 masks=masks, 1538 return_weighted_metrics=True) 1539 1540 # Compute total loss. 1541 # Used to keep track of the total loss value (stateless). 1542 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 1543 # loss_weight_2 * output_2_loss_fn(...) + 1544 # layer losses. 1545 self.total_loss = self._prepare_total_loss(masks) 1546 1547 def _prepare_skip_target_masks(self): 1548 """Boolean mask for whether the target in the output list should be skipped. 1549 1550 If the loss function corresponding to a model output is None, then this 1551 output will be skipped during total loss calculation and feed targets 1552 preparation. 1553 1554 Returns: 1555 A boolean list for whether the corresponding target in the output list 1556 should be skipped during loss calculation. 1557 """ 1558 return [l is None for l in self.loss_functions] 1559 1560 def _prepare_output_masks(self): 1561 """Returns masks corresponding to model outputs.""" 1562 return [getattr(x, '_keras_mask', None) for x in self.outputs] 1563 1564 def _prepare_total_loss(self, masks): 1565 """Computes total loss from loss functions. 1566 1567 Args: 1568 masks: List of mask values corresponding to each model output. 1569 1570 Returns: 1571 A list of loss weights of python floats. 1572 1573 Raises: 1574 TypeError: If model run_eagerly is True. 1575 """ 1576 if self.run_eagerly: 1577 raise TypeError('total loss can not be computed when compiled with ' 1578 'run_eagerly = True.') 1579 loss_list = [] 1580 with backend.name_scope('loss'): 1581 for endpoint, mask in zip(self._training_endpoints, masks): 1582 if endpoint.should_skip_target(): 1583 continue 1584 y_true = endpoint.training_target.target 1585 y_pred = endpoint.output 1586 loss_fn = endpoint.loss_fn 1587 loss_weight = endpoint.loss_weight 1588 loss_name = endpoint.loss_name() 1589 sample_weight = endpoint.sample_weight 1590 1591 with backend.name_scope(loss_name): 1592 if mask is not None: 1593 mask = math_ops.cast(mask, y_pred.dtype) 1594 # Update weights with mask. 1595 if sample_weight is None: 1596 sample_weight = mask 1597 else: 1598 # Update dimensions of weights to match with mask if possible. 1599 mask, _, sample_weight = ( 1600 losses_utils.squeeze_or_expand_dimensions( 1601 mask, sample_weight=sample_weight)) 1602 sample_weight *= mask 1603 1604 if hasattr(loss_fn, 'reduction'): 1605 per_sample_losses = loss_fn.call(y_true, y_pred) 1606 weighted_losses = losses_utils.compute_weighted_loss( 1607 per_sample_losses, 1608 sample_weight=sample_weight, 1609 reduction=losses_utils.ReductionV2.NONE) 1610 loss_reduction = loss_fn.reduction 1611 1612 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all 1613 # compile use cases. 1614 if loss_reduction == losses_utils.ReductionV2.AUTO: 1615 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 1616 1617 # Compute the stateless loss value. 1618 output_loss = losses_utils.reduce_weighted_loss( 1619 weighted_losses, reduction=loss_reduction) 1620 else: 1621 # Compute the stateless loss value for a custom loss class. 1622 # Here we assume that the class takes care of loss reduction 1623 # because if this class returns a vector value we cannot 1624 # differentiate between use case where a custom optimizer 1625 # expects a vector loss value vs unreduced per-sample loss value. 1626 output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) 1627 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 1628 1629 if len(self.outputs) > 1: 1630 # Keep track of stateful result tensor for the loss. 1631 endpoint.output_loss_metric(output_loss) 1632 1633 # Scale output loss for distribution. For custom losses we assume 1634 # reduction was mean. 1635 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: 1636 output_loss = losses_utils.scale_loss_for_distribution(output_loss) 1637 1638 loss_list.append(loss_weight * output_loss) 1639 if not loss_list and not self.losses: 1640 raise ValueError('The model cannot be compiled ' 1641 'because it has no loss to optimize.') 1642 1643 # Add regularization penalties and other layer-specific losses. 1644 custom_losses = self.get_losses_for(None) + self.get_losses_for( 1645 self.inputs) 1646 if custom_losses: 1647 total_custom_loss = math_ops.add_n( 1648 losses_utils.cast_losses_to_common_dtype(custom_losses)) 1649 loss_list.append( 1650 losses_utils.scale_loss_for_distribution(total_custom_loss)) 1651 1652 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list) 1653 if loss_list: 1654 total_loss = math_ops.add_n(loss_list) 1655 else: 1656 total_loss = 0. 1657 return total_loss 1658 1659 def _get_callback_model(self): 1660 """Returns the Callback Model for this Model.""" 1661 1662 if hasattr(self, '_replicated_model') and self._replicated_model: 1663 # When using training_distributed, we set the callback model 1664 # to an instance of the `DistributedModel` that we create in 1665 # the `compile` call. The `DistributedModel` is initialized 1666 # with the first replicated model. We need to set the callback 1667 # model to a DistributedModel to allow us to override saving 1668 # and loading weights when we checkpoint the model during training. 1669 return self._replicated_model 1670 if hasattr(self, 'callback_model') and self.callback_model: 1671 return self.callback_model 1672 return self 1673 1674 @trackable.no_automatic_dependency_tracking 1675 def _make_callback_model(self, grouped_model): 1676 first_replicated_model = self._distribution_strategy.unwrap( 1677 grouped_model)[0] 1678 # We initialize the callback model with the first replicated model. 1679 self._replicated_model = DistributedCallbackModel(first_replicated_model) 1680 self._replicated_model.set_original_model(self) 1681 1682 def _validate_or_infer_batch_size(self, batch_size, steps, x): 1683 """Validates that the `batch_size` provided is consistent with InputLayer. 1684 1685 It's possible that the user specified a static batch size in their 1686 InputLayer. If so, this method checks the provided `batch_size` and `x` 1687 arguments are consistent with this static batch size. Also, if 1688 `batch_size` is `None`, this method will attempt to infer the batch size 1689 from the static batch size of the InputLayer. Lastly, ValueError will be 1690 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we 1691 expect users to provide batched datasets. 1692 1693 Args: 1694 batch_size: The batch_size provided as an argument to 1695 fit/evaluate/predict. 1696 steps: The steps provided as an argument to fit/evaluate/predict. 1697 x: The data passed as `x` to fit/evaluate/predict. 1698 1699 Returns: 1700 The validated batch_size, auto-inferred from the first layer if not 1701 provided. 1702 """ 1703 if (isinstance(x, (dataset_ops.DatasetV1, 1704 dataset_ops.DatasetV2, 1705 data_utils.Sequence)) or 1706 tf_inspect.isgenerator(x)): 1707 if batch_size is not None: 1708 raise ValueError( 1709 'The `batch_size` argument must not be specified for the given ' 1710 'input type. Received input: {}, batch_size: {}'.format( 1711 x, batch_size)) 1712 return 1713 1714 # Avoids the override in Sequential.layers which filters Input layers. 1715 # (Which are often the very layers that we're after.) 1716 layers = self._flatten_layers(include_self=False, recursive=False) 1717 first_layer = next(layers, None) 1718 if first_layer: 1719 # The per-replica static batch size. 1720 static_batch_size = training_utils.get_static_batch_size(first_layer) 1721 if static_batch_size is not None: 1722 1723 # Determine number of times the user-supplied batch size will be split. 1724 if (self._distribution_strategy and 1725 distributed_training_utils.global_batch_size_supported( 1726 self._distribution_strategy)): 1727 num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync 1728 else: 1729 num_splits_for_ds = 1 1730 1731 # Check `batch_size` argument is consistent with InputLayer. 1732 if batch_size is not None: 1733 if batch_size % num_splits_for_ds != 0: 1734 raise ValueError('The `batch_size` argument ({}) must be divisible ' 1735 'the by number of replicas ({})'.format( 1736 batch_size, num_splits_for_ds)) 1737 per_replica_batch_size = batch_size // num_splits_for_ds 1738 1739 if per_replica_batch_size != static_batch_size: 1740 raise ValueError('The `batch_size` argument value {} is ' 1741 'incompatible with the specified batch size of ' 1742 'your Input Layer: {}'.format( 1743 per_replica_batch_size, static_batch_size)) 1744 1745 # Check Dataset/Iterator batch size is consistent with InputLayer. 1746 if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, 1747 iterator_ops.IteratorBase)): 1748 ds_batch_size = tensor_shape.Dimension( 1749 nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value 1750 if ds_batch_size is not None: 1751 if ds_batch_size % num_splits_for_ds != 0: 1752 raise ValueError( 1753 'The batch output shape of your `Dataset` {} ' 1754 'cannot be divisible by number of replicas {}'.format( 1755 ds_batch_size, num_splits_for_ds)) 1756 1757 ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds 1758 if ds_per_replica_batch_size != static_batch_size: 1759 raise ValueError('The batch output shape of your `Dataset` is ' 1760 '{}, which is incompatible with the specified ' 1761 'batch size of your Input Layer: {}'.format( 1762 ds_per_replica_batch_size, 1763 static_batch_size)) 1764 1765 # Set inferred batch size from the InputLayer. 1766 if steps is None: 1767 batch_size = static_batch_size * num_splits_for_ds 1768 1769 if batch_size is None and steps is None: 1770 # Backwards compatibility 1771 batch_size = 32 1772 return batch_size 1773 1774 def _prepare_sample_weights(self, sample_weights=None): 1775 """Sets sample weight attribute on the model.""" 1776 # List with the same length as model outputs. 1777 if sample_weights is not None: 1778 if len(sample_weights) != len(self._training_endpoints): 1779 raise ValueError('Provided sample weights must have same length as the ' 1780 'number of outputs. Expected: {}, got: {}.'.format( 1781 len(self._training_endpoints), 1782 len(sample_weights))) 1783 else: 1784 sample_weights = [None] * len(self._training_endpoints) 1785 for endpoint, weight in zip(self._training_endpoints, sample_weights): 1786 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode) 1787 1788 def _cache_output_metric_attributes(self, metrics, weighted_metrics): 1789 """Caches metric name and function attributes for every model output.""" 1790 output_shapes = [] 1791 for output in self.outputs: 1792 if output is None or output.shape.rank is None: 1793 output_shapes.append(None) 1794 else: 1795 output_shapes.append(output.shape.as_list()) 1796 self._per_output_metrics = training_utils_v1.collect_per_output_metric_info( 1797 metrics, self.output_names, output_shapes, self.loss_functions, 1798 from_serialized=self._from_serialized) 1799 self._per_output_weighted_metrics = ( 1800 training_utils_v1.collect_per_output_metric_info( 1801 weighted_metrics, 1802 self.output_names, 1803 output_shapes, 1804 self.loss_functions, 1805 from_serialized=self._from_serialized, 1806 is_weighted=True)) 1807 1808 def _add_unique_metric_name(self, metric_name, metric_fn, output_index): 1809 """Makes the metric name unique. 1810 1811 If there are multiple outputs for which the metrics are calculated, the 1812 metric names have to be made unique by appending an integer. 1813 1814 Args: 1815 metric_name: Metric name that corresponds to the metric specified by the 1816 user. For example: 'acc'. 1817 metric_fn: The Metric object. 1818 output_index: The index of the model output for which the metric name is 1819 being added. 1820 1821 Returns: 1822 string, name of the model's unique metric name 1823 """ 1824 # For multi-output models, prepend the output names to the metric name. 1825 if len(self.output_names) > 1: 1826 # If we're loading from an already-serialized model, we've already 1827 # prepended the output name, and we don't want to do it again. 1828 # 1829 # Alternatively, we may be receiving a stateless metric (e.g. the string 1830 # "accuracy") rather than a `Metric` object, in which case we want to 1831 # prepend the output name even if we are loading a serialized model. 1832 if not getattr(metric_fn, '_from_serialized', False): 1833 metric_name = '%s_%s' % (self.output_names[output_index], metric_name) 1834 1835 j = 1 1836 base_metric_name = metric_name 1837 while metric_name in self.metrics_names: 1838 metric_name = '%s_%d' % (base_metric_name, j) 1839 j += 1 1840 1841 return metric_name 1842 1843 def _init_metric_attributes(self): 1844 """Initialized model metric attributes.""" 1845 # List of stateful metric functions. Used for resetting metric state during 1846 # training/eval. 1847 self._compile_metric_functions = [] 1848 1849 def _set_per_output_metric_attributes(self, metrics_dict, output_index): 1850 """Sets the metric attributes on the model for the given output. 1851 1852 Args: 1853 metrics_dict: A dict with metric names as keys and metric fns as values. 1854 output_index: The index of the model output for which the metric 1855 attributes are added. 1856 1857 Returns: 1858 Metrics dict updated with unique metric names as keys. 1859 """ 1860 updated_metrics_dict = collections.OrderedDict() 1861 for metric_name, metric_fn in metrics_dict.items(): 1862 metric_name = self._add_unique_metric_name( 1863 metric_name, metric_fn, output_index) 1864 1865 # Update the name on the metric class to be the unique generated name. 1866 metric_fn._name = metric_name # pylint: disable=protected-access 1867 updated_metrics_dict[metric_name] = metric_fn 1868 # Keep track of metric name and function. 1869 self._compile_metric_functions.append(metric_fn) 1870 return updated_metrics_dict 1871 1872 def _set_metric_attributes(self): 1873 """Sets the metric attributes on the model for all the model outputs.""" 1874 updated_per_output_metrics = [] 1875 updated_per_output_weighted_metrics = [] 1876 for i, endpoint in enumerate(self._training_endpoints): 1877 if endpoint.should_skip_target(): 1878 updated_per_output_metrics.append(self._per_output_metrics[i]) 1879 updated_per_output_weighted_metrics.append( 1880 self._per_output_weighted_metrics[i]) 1881 continue 1882 updated_per_output_metrics.append( 1883 self._set_per_output_metric_attributes(self._per_output_metrics[i], 1884 i)) 1885 updated_per_output_weighted_metrics.append( 1886 self._set_per_output_metric_attributes( 1887 self._per_output_weighted_metrics[i], i)) 1888 1889 # Create a metric wrapper for each output loss. This computes mean of an 1890 # output loss across mini-batches (irrespective of how we reduce within a 1891 # batch). 1892 if len(self._training_endpoints) > 1: 1893 for endpoint in self._training_endpoints: 1894 if not endpoint.should_skip_target(): 1895 endpoint.output_loss_metric = metrics_module.Mean( 1896 name=endpoint.loss_name()) 1897 1898 self._per_output_metrics = updated_per_output_metrics 1899 self._per_output_weighted_metrics = updated_per_output_weighted_metrics 1900 1901 def _handle_per_output_metrics(self, 1902 metrics_dict, 1903 y_true, 1904 y_pred, 1905 mask, 1906 weights=None): 1907 """Calls metric functions for a single output. 1908 1909 Args: 1910 metrics_dict: A dict with metric names as keys and metric fns as values. 1911 y_true: Target output. 1912 y_pred: Predicted output. 1913 mask: Computed mask value for the current output. 1914 weights: Weights to be applied on the current output. 1915 1916 Returns: 1917 A list of metric result tensors. 1918 """ 1919 metric_results = [] 1920 for metric_name, metric_fn in metrics_dict.items(): 1921 with backend.name_scope(metric_name): 1922 metric_result = training_utils_v1.call_metric_function( 1923 metric_fn, y_true, y_pred, weights=weights, mask=mask) 1924 metric_results.append(metric_result) 1925 return metric_results 1926 1927 def _handle_metrics(self, 1928 outputs, 1929 targets=None, 1930 skip_target_masks=None, 1931 sample_weights=None, 1932 masks=None, 1933 return_weighted_metrics=False, 1934 return_weighted_and_unweighted_metrics=False): 1935 """Handles calling metric functions. 1936 1937 Args: 1938 outputs: List of outputs (predictions). 1939 targets: List of targets. 1940 skip_target_masks: Optional. List of boolean for whether the corresponding 1941 target should be ignored or not. 1942 sample_weights: Optional list of sample weight arrays. 1943 masks: List of computed output mask values. 1944 return_weighted_metrics: Flag that indicates whether weighted metrics 1945 should be computed instead of unweighted metrics. This flag is ignored 1946 when `return_weighted_and_unweighted_metrics` is enabled. 1947 return_weighted_and_unweighted_metrics: Flag that is used to indicate 1948 whether both weighted and unweighted metrics should be computed. When 1949 this is not enabled, we use `return_weighted_metrics` param to indicate 1950 whether weighted or unweighted metrics should be returned. 1951 1952 Returns: 1953 A list of metric result tensors. 1954 """ 1955 # TODO(scottzhu): Update this to use the new training_endpoints. Currently 1956 # the eager and graph logic is bit different. 1957 skip_target_masks = skip_target_masks or [False] * len(outputs) 1958 metric_results = [] 1959 with backend.name_scope('metrics'): 1960 # Invoke all metrics added using `compile`. 1961 for i in range(len(outputs)): 1962 if skip_target_masks[i]: 1963 continue 1964 output = outputs[i] if outputs else None 1965 target = targets[i] if targets else None 1966 output_mask = masks[i] if masks else None 1967 1968 if (return_weighted_and_unweighted_metrics or 1969 not return_weighted_metrics): 1970 metric_results.extend( 1971 self._handle_per_output_metrics(self._per_output_metrics[i], 1972 target, output, output_mask)) 1973 if return_weighted_and_unweighted_metrics or return_weighted_metrics: 1974 metric_results.extend( 1975 self._handle_per_output_metrics( 1976 self._per_output_weighted_metrics[i], 1977 target, 1978 output, 1979 output_mask, 1980 weights=sample_weights[i] if sample_weights else None)) 1981 return metric_results 1982 1983 def _check_trainable_weights_consistency(self): 1984 """Check trainable weights count consistency. 1985 1986 This will raise a warning if `trainable_weights` and 1987 `_collected_trainable_weights` are inconsistent (i.e. have different 1988 number of parameters). 1989 Inconsistency will typically arise when one modifies `model.trainable` 1990 without calling `model.compile` again. 1991 """ 1992 if not hasattr(self, '_collected_trainable_weights'): 1993 return 1994 1995 if len(self.trainable_weights) != len(self._collected_trainable_weights): 1996 logging.log_first_n( 1997 logging.WARN, 'Discrepancy between trainable weights and collected' 1998 ' trainable weights, did you set `model.trainable`' 1999 ' without calling `model.compile` after ?', 1) 2000 2001 def _make_train_function(self): 2002 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 2003 self._check_trainable_weights_consistency() 2004 if isinstance(self.optimizer, list): 2005 raise ValueError('The `optimizer` in `compile` should be a single ' 2006 'optimizer.') 2007 # If we have re-compiled the loss/weighted metric sub-graphs then create 2008 # train function even if one exists already. This is because 2009 # `_feed_sample_weights` list has been updated on re-compile. 2010 if getattr(self, 'train_function', None) is None or has_recompiled: 2011 # Restore the compiled trainable state. 2012 current_trainable_state = self._get_trainable_state() 2013 self._set_trainable_state(self._compiled_trainable_state) 2014 2015 inputs = (self._feed_inputs + 2016 self._feed_targets + 2017 self._feed_sample_weights) 2018 if not isinstance(backend.symbolic_learning_phase(), int): 2019 inputs += [backend.symbolic_learning_phase()] 2020 2021 with backend.get_graph().as_default(): 2022 with backend.name_scope('training'): 2023 # Training updates 2024 updates = self.optimizer.get_updates( 2025 params=self._collected_trainable_weights, loss=self.total_loss) 2026 # Unconditional updates 2027 updates += self.get_updates_for(None) 2028 # Conditional updates relevant to this model 2029 updates += self.get_updates_for(self.inputs) 2030 2031 metrics = self._get_training_eval_metrics() 2032 metrics_tensors = [ 2033 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 2034 ] 2035 2036 with backend.name_scope('training'): 2037 # Gets loss and metrics. Updates weights at each call. 2038 fn = backend.function( 2039 inputs, [self.total_loss] + metrics_tensors, 2040 updates=updates, 2041 name='train_function', 2042 **self._function_kwargs) 2043 setattr(self, 'train_function', fn) 2044 2045 # Restore the current trainable state 2046 self._set_trainable_state(current_trainable_state) 2047 2048 def _make_test_function(self): 2049 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 2050 # If we have re-compiled the loss/weighted metric sub-graphs then create 2051 # test function even if one exists already. This is because 2052 # `_feed_sample_weights` list has been updated on re-compile. 2053 if getattr(self, 'test_function', None) is None or has_recompiled: 2054 inputs = (self._feed_inputs + 2055 self._feed_targets + 2056 self._feed_sample_weights) 2057 2058 with backend.get_graph().as_default(): 2059 metrics = self._get_training_eval_metrics() 2060 metrics_tensors = [ 2061 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 2062 ] 2063 2064 with backend.name_scope('evaluation'): 2065 updates = self.state_updates 2066 # Return loss and metrics, no gradient updates. 2067 # Does update the network states. 2068 fn = backend.function( 2069 inputs, [self.total_loss] + metrics_tensors, 2070 updates=updates, 2071 name='test_function', 2072 **self._function_kwargs) 2073 setattr(self, 'test_function', fn) 2074 2075 def _make_predict_function(self): 2076 if not hasattr(self, 'predict_function'): 2077 self.predict_function = None 2078 if self.predict_function is None: 2079 inputs = self._feed_inputs 2080 # Gets network outputs. Does not update weights. 2081 # Does update the network states. 2082 kwargs = getattr(self, '_function_kwargs', {}) 2083 with backend.name_scope(ModeKeys.PREDICT): 2084 self.predict_function = backend.function( 2085 inputs, 2086 self.outputs, 2087 updates=self.state_updates, 2088 name='predict_function', 2089 **kwargs) 2090 2091 def _make_execution_function(self, mode): 2092 if mode == ModeKeys.TRAIN: 2093 self._make_train_function() 2094 return self.train_function 2095 if mode == ModeKeys.TEST: 2096 self._make_test_function() 2097 return self.test_function 2098 if mode == ModeKeys.PREDICT: 2099 self._make_predict_function() 2100 return self.predict_function 2101 2102 def _distribution_standardize_user_data(self, 2103 x, 2104 y=None, 2105 sample_weight=None, 2106 class_weight=None, 2107 batch_size=None, 2108 validation_split=0, 2109 shuffle=False, 2110 epochs=1, 2111 allow_partial_batch=False): 2112 """Runs validation checks on input and target data passed by the user. 2113 2114 This is called when using tf.distribute.Strategy to train, evaluate or serve 2115 the model. 2116 2117 Args: 2118 x: Input data. A numpy array or `tf.data` dataset. 2119 y: Target data. A numpy array or None if x is a `tf.data` dataset. 2120 sample_weight: An optional sample-weight array passed by the user to 2121 weight the importance of each sample in `x`. 2122 class_weight: An optional class-weight array by the user to 2123 weight the importance of samples in `x` based on the class they belong 2124 to, as conveyed by `y`. 2125 batch_size: Integer batch size. If provided, it is used to run additional 2126 validation checks on stateful models. 2127 validation_split: Float between 0 and 1. 2128 Fraction of the training data to be used as validation data. 2129 shuffle: Boolean whether to shuffle the training data before each epoch. 2130 epochs: Integer epochs. If > 1, repeat the numpy training data epochs 2131 times when converting to training dataset. 2132 allow_partial_batch: Boolean whether to enforce that all batches have the 2133 same size. 2134 2135 Returns: 2136 Dataset instance. 2137 2138 Raises: 2139 ValueError: In case of invalid user-provided data. 2140 RuntimeError: If the model was never compiled. 2141 """ 2142 if class_weight: 2143 raise NotImplementedError('`class_weight` is currently not supported ' 2144 'when using tf.distribute.Strategy.') 2145 2146 if (sample_weight is not None and sample_weight.all() and 2147 backend.is_tpu_strategy(self._distribution_strategy)): 2148 raise NotImplementedError('`sample_weight` is currently not supported ' 2149 'when using TPUStrategy.') 2150 2151 # Validates `steps` and `shuffle` arguments right at the beginning 2152 # since we use it to construct the dataset object. 2153 # TODO(anjalisridhar): Remove this check once we refactor the 2154 # _standardize_user_data code path. This check is already present elsewhere 2155 # in the codebase. 2156 if isinstance(x, dataset_ops.DatasetV2): 2157 if shuffle: 2158 training_utils_v1.verify_dataset_shuffled(x) 2159 2160 strategy = self._distribution_strategy 2161 with strategy.scope(): 2162 # We should be sure to call get_session() inside the strategy.scope() 2163 # so the strategy can affect the session options. 2164 if ops.executing_eagerly_outside_functions(): 2165 session = None 2166 else: 2167 session = backend.get_session() 2168 2169 first_x_value = nest.flatten(x)[0] 2170 if isinstance(first_x_value, np.ndarray): 2171 x = training_utils.list_to_tuple(x) 2172 if y is not None: 2173 y = training_utils.list_to_tuple(y) 2174 if sample_weight is not None: 2175 sample_weight = training_utils.list_to_tuple(sample_weight) 2176 in_tuple = (x, y, sample_weight) 2177 else: 2178 in_tuple = (x, y) 2179 else: 2180 in_tuple = x 2181 2182 ds = strategy.extended.experimental_make_numpy_dataset(in_tuple, 2183 session=session) 2184 if shuffle: 2185 # We want a buffer size that is larger than the batch size provided by 2186 # the user and provides sufficient randomness. Note that larger 2187 # numbers introduce more memory usage based on the size of each 2188 # sample. 2189 ds = ds.shuffle(max(1024, batch_size * 8)) 2190 if epochs > 1: 2191 ds = ds.repeat(epochs) 2192 2193 # We need to use the drop_remainder argument to get a known static 2194 # input shape which is required for TPUs. 2195 drop_remainder = (not allow_partial_batch and 2196 strategy.extended.experimental_require_static_shapes) 2197 2198 # TODO(b/131720208): We still drop remainder here if number of examples 2199 # is divisible by batch size, as sometimes dynamic padder will time out 2200 # with keras.metrics.CategoricalAccuracy() metric. 2201 if backend.is_tpu_strategy(strategy) and not drop_remainder: 2202 dataset_size = first_x_value.shape[0] 2203 if dataset_size % batch_size == 0: 2204 drop_remainder = True 2205 2206 x = ds.batch(batch_size, drop_remainder=drop_remainder) 2207 else: 2208 assert isinstance(x, dataset_ops.DatasetV2) 2209 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2210 validation_split) 2211 return x 2212 2213 def _standardize_user_data(self, 2214 x, 2215 y=None, 2216 sample_weight=None, 2217 class_weight=None, 2218 batch_size=None, 2219 check_steps=False, 2220 steps_name='steps', 2221 steps=None, 2222 validation_split=0, 2223 shuffle=False, 2224 extract_tensors_from_dataset=False): 2225 """Runs validation checks on input and target data passed by the user. 2226 2227 Also standardizes the data to lists of arrays, in order. 2228 2229 Also builds and compiles the model on the fly if it is a subclassed model 2230 that has never been called before (and thus has no inputs/outputs). 2231 2232 This is a purely internal method, subject to refactoring at any time. 2233 2234 Args: 2235 x: Input data. It could be: 2236 - A Numpy array (or array-like), or a list of arrays 2237 (in case the model has multiple inputs). 2238 - A TensorFlow tensor, or a list of tensors 2239 (in case the model has multiple inputs). 2240 - A dict mapping input names to the corresponding array/tensors, 2241 if the model has named inputs. 2242 - A `tf.data` dataset. 2243 y: Target data. Like the input data `x`, 2244 it could be either Numpy array(s) or TensorFlow tensor(s). 2245 It should be consistent with `x` (you cannot have Numpy inputs and 2246 tensor targets, or inversely). If `x` is a dataset, `y` should not be 2247 specified (since targets will be obtained from the iterator). 2248 sample_weight: An optional sample-weight array passed by the user to 2249 weight the importance of each sample in `x`. 2250 class_weight: An optional class-weight array by the user to 2251 weight the importance of samples in `x` based on the class they belong 2252 to, as conveyed by `y`. If both `sample_weight` and `class_weight` are 2253 provided, the weights are multiplied. 2254 batch_size: Integer batch size. If provided, it is used to run additional 2255 validation checks on stateful models. 2256 check_steps: boolean, True if we want to check for validity of `steps` and 2257 False, otherwise. For example, when we are standardizing one batch of 2258 data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` 2259 value is not required and we should not check for its validity in these 2260 cases. 2261 steps_name: The public API's parameter name for `steps`. 2262 steps: Integer or `None`. Total number of steps (batches of samples) to 2263 execute. 2264 validation_split: Float between 0 and 1. 2265 Fraction of the training data to be used as validation data. 2266 shuffle: Boolean whether to shuffle the training data before each epoch. 2267 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, 2268 this indicates whether to extract actual tensors from the dataset or 2269 instead output the dataset instance itself. 2270 Set to True when calling from `train_on_batch`/etc. 2271 2272 Returns: 2273 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict 2274 or not), target arrays, sample-weight arrays. 2275 If the model's input and targets are symbolic, these lists are empty 2276 (since the model takes no user-provided data, instead the data comes 2277 from the symbolic inputs/targets). 2278 2279 Raises: 2280 ValueError: In case of invalid user-provided data. 2281 RuntimeError: If the model was never compiled. 2282 """ 2283 if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2284 # Graph mode dataset. We'll pass the dataset as-is (unless 2285 # `extract_tensors_from_dataset` is True, in which case we extract 2286 # the tensors from the dataset and we output them. 2287 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2288 validation_split) 2289 if shuffle: 2290 training_utils_v1.verify_dataset_shuffled(x) 2291 2292 is_dataset = True 2293 if extract_tensors_from_dataset: 2294 # We do this for `train_on_batch`/etc. 2295 x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x) 2296 elif isinstance(x, iterator_ops.Iterator): 2297 # Graph mode iterator. We extract the symbolic tensors. 2298 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2299 validation_split) 2300 iterator = x 2301 x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator) 2302 is_dataset = True 2303 else: 2304 is_dataset = False 2305 2306 # Validates `steps` argument based on x's type. 2307 if check_steps: 2308 training_utils_v1.check_steps_argument(x, steps, steps_name) 2309 2310 # First, we build the model on the fly if necessary. 2311 if not self.inputs: 2312 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y) 2313 is_build_called = True 2314 else: 2315 all_inputs = [] 2316 # Whether this is a subclassed model that expects dictionary inputs 2317 # rather than list inputs (e.g. FeatureColumn-based models). 2318 dict_inputs = isinstance(self.inputs, dict) 2319 is_build_called = False 2320 y_input = y 2321 2322 # Second, we compile the model on the fly if necessary, mostly for subclass 2323 # models. 2324 is_compile_called = False 2325 if not self._is_compiled and self.optimizer: 2326 self._compile_from_inputs(all_inputs, y_input, x, y) 2327 is_compile_called = True 2328 2329 # In graph mode, if we had just set inputs and targets as symbolic tensors 2330 # by invoking build and compile on the model respectively, we do not have to 2331 # feed anything to the model. Model already has input and target data as 2332 # part of the graph. 2333 # Note: in this case, `any` and `all` are equivalent since we disallow 2334 # mixed symbolic/value inputs. 2335 2336 # self.run_eagerly is not free to compute, so we want to reuse the value. 2337 run_eagerly = self.run_eagerly 2338 2339 if (not run_eagerly and is_build_called and is_compile_called and 2340 not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)): 2341 return [], [], None 2342 2343 return self._standardize_tensors( 2344 x, y, sample_weight, 2345 run_eagerly=run_eagerly, 2346 dict_inputs=dict_inputs, 2347 is_dataset=is_dataset, 2348 class_weight=class_weight, 2349 batch_size=batch_size) 2350 2351 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, 2352 is_dataset, class_weight=None, batch_size=None): 2353 if run_eagerly: 2354 # In eager mode, do not do shape validation 2355 # since the network has no input nodes (placeholders) to be fed. 2356 feed_input_names = self.input_names 2357 feed_input_shapes = None 2358 elif not self._is_graph_network: 2359 # Case: symbolic-mode subclassed network. Do not do shape validation. 2360 feed_input_names = self._feed_input_names 2361 feed_input_shapes = None 2362 else: 2363 # Case: symbolic-mode graph network. 2364 # In this case, we run extensive shape validation checks. 2365 feed_input_names = self._feed_input_names 2366 feed_input_shapes = self._feed_input_shapes 2367 2368 # Standardize the inputs. 2369 if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2370 # TODO(fchollet): run static checks with dataset output shape(s). 2371 x = training_utils_v1.standardize_input_data( 2372 x, 2373 feed_input_names, 2374 feed_input_shapes, 2375 check_batch_axis=False, # Don't enforce the batch size. 2376 exception_prefix='input') 2377 2378 # Get typespecs for the input data and sanitize it if necessary. 2379 # TODO(momernick): This should be capable of doing full input validation 2380 # at all times - validate that this is so and refactor the standardization 2381 # code. 2382 if isinstance(x, dataset_ops.DatasetV2): 2383 x_shapes = dataset_ops.get_structure(x) 2384 if isinstance(x_shapes, tuple): 2385 # If the output of a Dataset is a tuple, we assume it's either of the 2386 # form (x_data, y_data) or (x_data, y_data, sample_weights). In either 2387 # case, we only care about x_data here. 2388 x_shapes = x_shapes[0] 2389 else: 2390 flat_inputs = nest.flatten(x, expand_composites=False) 2391 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 2392 converted_x = [] 2393 for (a, b) in zip(flat_inputs, flat_expected_inputs): 2394 converted_x.append(_convert_scipy_sparse_tensor(a, b)) 2395 x = nest.pack_sequence_as(x, converted_x, expand_composites=False) 2396 2397 def _type_spec_from_value(value): 2398 """Grab type_spec without converting array-likes to tensors.""" 2399 if tf_utils.is_extension_type(value): 2400 return value._type_spec # pylint: disable=protected-access 2401 # Get a TensorSpec for array-like data without 2402 # converting the data to a Tensor 2403 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 2404 return tensor_spec.TensorSpec(value.shape, value.dtype) 2405 else: 2406 return type_spec.type_spec_from_value(value) 2407 2408 x_shapes = nest.map_structure(_type_spec_from_value, x) 2409 2410 flat_inputs = nest.flatten(x_shapes, expand_composites=False) 2411 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 2412 for (a, b) in zip(flat_inputs, flat_expected_inputs): 2413 nest.assert_same_structure(a, b, expand_composites=True) 2414 2415 if y is not None: 2416 # Prepare self._sample_weight_modes. List with the same length as 2417 # model outputs. 2418 training_utils_v1.prepare_sample_weight_modes(self._training_endpoints, 2419 self.sample_weight_mode) 2420 feed_output_names = self._feed_output_names 2421 feed_sample_weight_modes = self._sample_weight_modes 2422 if not self._is_graph_network: 2423 feed_output_shapes = None 2424 else: 2425 feed_output_shapes = self._feed_output_shapes 2426 2427 # Standardize the outputs. 2428 y = training_utils_v1.standardize_input_data( 2429 y, 2430 feed_output_names, 2431 # Don't enforce target shapes to match output shapes. 2432 # Precise checks will be run in `check_loss_and_target_compatibility`. 2433 shapes=None, 2434 check_batch_axis=False, # Don't enforce the batch size. 2435 exception_prefix='target') 2436 2437 # Generate sample-wise weight values given the `sample_weight` and 2438 # `class_weight` arguments. 2439 sample_weights = training_utils_v1.standardize_sample_weights( 2440 sample_weight, feed_output_names) 2441 class_weights = training_utils_v1.standardize_class_weights( 2442 class_weight, feed_output_names) 2443 2444 sample_weights = [ 2445 training_utils_v1.standardize_weights(ref, sw, cw, mode) 2446 for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, 2447 feed_sample_weight_modes) 2448 ] 2449 # Check that all arrays have the same length. 2450 if not self._distribution_strategy: 2451 training_utils_v1.check_array_lengths(x, y, sample_weights) 2452 if self._is_graph_network and not run_eagerly: 2453 # Additional checks to avoid users mistakenly using improper loss fns. 2454 training_utils_v1.check_loss_and_target_compatibility( 2455 y, self._feed_loss_fns, feed_output_shapes) 2456 2457 sample_weights, _, _ = training_utils.handle_partial_sample_weights( 2458 y, sample_weights, feed_sample_weight_modes, check_all_flat=True) 2459 else: 2460 y = [] 2461 sample_weights = None 2462 2463 if self.stateful and batch_size and not is_dataset: 2464 # Check that for stateful networks, number of samples is a multiple 2465 # of the static batch size. 2466 if x[0].shape[0] % batch_size != 0: 2467 raise ValueError('In a stateful network, ' 2468 'you should only pass inputs with ' 2469 'a number of samples that can be ' 2470 'divided by the batch size. Found: ' + 2471 str(x[0].shape[0]) + ' samples') 2472 2473 # If dictionary inputs were provided, we return a dictionary as well. 2474 if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1, 2475 dataset_ops.DatasetV2)): 2476 x = dict(zip(feed_input_names, x)) 2477 return x, y, sample_weights 2478 2479 def _build_model_with_inputs(self, inputs, targets): 2480 """Build the model (set model inputs/outputs), mainly for subclass model.""" 2481 processed_inputs = [] 2482 is_dict_inputs = False 2483 orig_inputs = inputs 2484 # We need to use `inputs` to set the model inputs. 2485 # If input data is a dataset iterator in graph mode or if it is an eager 2486 # iterator and only one batch of samples is required, we fetch the data 2487 # tensors from the iterator and then standardize them. 2488 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2489 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset( 2490 inputs) 2491 # We type-check that `inputs` and `targets` are either single arrays 2492 # or lists of arrays, and extract a flat list of inputs from the passed 2493 # structure. 2494 training_utils_v1.validate_input_types(inputs, orig_inputs) 2495 2496 if isinstance(inputs, (list, tuple)): 2497 processed_inputs += list(inputs) 2498 elif isinstance(inputs, dict): 2499 is_dict_inputs = True 2500 keys = sorted(inputs.keys()) 2501 processed_inputs = [inputs[k] for k in keys] 2502 else: 2503 processed_inputs.append(inputs) 2504 # Now that we have a flat set of inputs, we make sure that none of them 2505 # are CompositeTensors or CompositeTensorValues of any type (or scipy 2506 # sparse arrays, which we treat as SparseTensor values). We cannot safely 2507 # infer input data from an arbitrary composite tensor, so we don't try - 2508 # users should explicitly add composite tensor inputs to their subclassed 2509 # models. 2510 for input_tensor in processed_inputs: 2511 if training_utils_v1.is_composite_or_composite_value(input_tensor): 2512 # TODO(b/132691975): Document subclass-model CT input handling. 2513 raise ValueError( 2514 'All SparseTensor and RaggedTensor inputs must be explicitly ' 2515 'declared using a keras.Input() with sparse=True or ragged=True. ' 2516 'We found an undeclared input %s. For Sequential models, please ' 2517 'add a keras.Input() as your first Layer. For subclassed models, ' 2518 'please call self._set_inputs() on your input set, which you can ' 2519 'create using keras.Input() for each input to your model.' % 2520 (input_tensor,)) 2521 # Build the model using the retrieved inputs (value or symbolic). 2522 # If values are generated from a dataset, then in symbolic-mode 2523 # placeholders will be created to match the value shapes. 2524 if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 2525 iterator_ops.Iterator)): 2526 if not self.inputs: 2527 # For subclassed models, a robust input spec is not available so we 2528 # must cast to the model dtype. 2529 inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype) 2530 2531 def create_tensor_spec(t): 2532 return tensor_spec.TensorSpec(t.shape, t.dtype) 2533 2534 cast_inputs = nest.map_structure(create_tensor_spec, inputs) 2535 elif training_utils_v1.has_tensors(inputs): 2536 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs) 2537 else: 2538 cast_inputs = inputs 2539 self._set_inputs(cast_inputs) 2540 return processed_inputs, targets, is_dict_inputs 2541 2542 def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target): 2543 if target is not None: 2544 # We need to use `y` to set the model targets. 2545 if training_utils_v1.has_tensors(target): 2546 target = training_utils_v1.cast_if_floating_dtype_and_mismatch( 2547 target, self.outputs) 2548 training_utils_v1.validate_input_types( 2549 target, orig_target, allow_dict=False, field_name='target') 2550 if isinstance(target, (list, tuple)): 2551 all_inputs += list(target) 2552 else: 2553 all_inputs.append(target) 2554 # Type check that all inputs are *either* value *or* symbolic. 2555 # TODO(fchollet): this check could be removed in Eager mode? 2556 if any(tensor_util.is_tf_type(v) for v in all_inputs): 2557 if not all(tensor_util.is_tf_type(v) for v in all_inputs): 2558 raise ValueError('Do not pass inputs that mix Numpy arrays and ' 2559 'TensorFlow tensors. ' 2560 'You passed: x=' + str(orig_inputs) + 2561 '; y=' + str(orig_target)) 2562 is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1, 2563 dataset_ops.DatasetV2, 2564 iterator_ops.Iterator)) 2565 if is_dataset or context.executing_eagerly(): 2566 target_tensors = None 2567 else: 2568 # Handle target tensors if any passed. 2569 if target is not None: 2570 if not isinstance(target, (list, tuple)): 2571 target = [target] 2572 target_tensors = [v for v in target if _is_symbolic_tensor(v)] 2573 else: 2574 target_tensors = None 2575 2576 self.compile( 2577 optimizer=self.optimizer, 2578 loss=self.loss, 2579 metrics=self._compile_metrics, 2580 weighted_metrics=self._compile_weighted_metrics, 2581 loss_weights=self.loss_weights, 2582 target_tensors=target_tensors, 2583 sample_weight_mode=self.sample_weight_mode, 2584 run_eagerly=self.run_eagerly, 2585 experimental_run_tf_function=self._experimental_run_tf_function) 2586 2587 # TODO(omalleyt): Consider changing to a more descriptive function name. 2588 def _set_inputs(self, inputs, outputs=None, training=None): 2589 """Set model's input and output specs based on the input data received. 2590 2591 This is to be used for Model subclasses, which do not know at instantiation 2592 time what their inputs look like. 2593 2594 Args: 2595 inputs: Single array, or list of arrays. The arrays could be placeholders, 2596 Numpy arrays, data tensors, or TensorSpecs. 2597 - if placeholders: the model is built on top of these placeholders, 2598 and we expect Numpy data to be fed for them when calling `fit`/etc. 2599 - if Numpy data or TensorShapes: we create placeholders matching the 2600 TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be 2601 fed for these placeholders when calling `fit`/etc. 2602 - if data tensors: the model is built on top of these tensors. 2603 We do not expect any Numpy data to be provided when calling `fit`/etc. 2604 outputs: None, a data tensor, or a list of tensors. If None, the 2605 outputs will be determined by invoking `self.call()`, otherwise the 2606 provided value will be used. 2607 training: Boolean or None. Only relevant in symbolic mode. Specifies 2608 whether to build the model's graph in inference mode (False), training 2609 mode (True), or using the Keras learning phase (None). 2610 Raises: 2611 ValueError: If dict inputs are passed to a Sequential Model where the 2612 first layer isn't FeatureLayer. 2613 """ 2614 self._set_save_spec(inputs) 2615 inputs = self._set_input_attrs(inputs) 2616 2617 if outputs is None: 2618 kwargs = {} 2619 if self._expects_training_arg: 2620 # In V2 mode, feeding `training=None` is not allowed because any value 2621 # explicitly passed by the user is respected, even `None`.` 2622 if training is None and not ops.executing_eagerly_outside_functions(): 2623 training = backend.learning_phase() 2624 if training is not None: 2625 kwargs['training'] = training 2626 try: 2627 outputs = self(inputs, **kwargs) 2628 except NotImplementedError: 2629 # This Model or a submodel is dynamic and hasn't overridden 2630 # `compute_output_shape`. 2631 outputs = None 2632 2633 self._set_output_attrs(outputs) 2634 2635 @trackable.no_automatic_dependency_tracking 2636 def _set_input_attrs(self, inputs): 2637 """Sets attributes related to the inputs of the Model.""" 2638 if self.inputs: 2639 raise ValueError('Model inputs are already set.') 2640 2641 if self.__class__.__name__ == 'Sequential' and not self.built: 2642 if tensor_util.is_tf_type(inputs): 2643 input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) 2644 elif isinstance(inputs, tensor_shape.TensorShape): 2645 input_shape = (None,) + tuple(inputs.as_list()[1:]) 2646 elif isinstance(inputs, dict): 2647 # We assert that the first layer is a FeatureLayer. 2648 if not training_utils_v1.is_feature_layer(self.layers[0]): 2649 raise ValueError('Passing a dictionary input to a Sequential Model ' 2650 'which doesn\'t have FeatureLayer as the first layer' 2651 ' is an error.') 2652 input_shape = (None,) 2653 else: 2654 input_shape = (None,) + tuple(inputs.shape[1:]) 2655 self._build_input_shape = input_shape 2656 2657 # Cast inputs to the compute dtype. This is primarily used 2658 # when saving to determine the correct dtype in the input signature. 2659 inputs = self._maybe_cast_inputs(inputs) 2660 2661 # On-the-fly setting of symbolic model inputs (either by using the tensor 2662 # provided, or by creating a placeholder if Numpy data was provided). 2663 model_inputs = training_utils_v1.ModelInputs(inputs) 2664 inputs = model_inputs.get_symbolic_inputs() 2665 self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) 2666 self.input_names = model_inputs.get_input_names() 2667 2668 self._feed_inputs = [] 2669 self._feed_input_names = [] 2670 self._feed_input_shapes = [] 2671 2672 for k, v in model_inputs.as_dict(): 2673 if backend.is_placeholder(v): 2674 self._feed_input_names.append(k) 2675 self._feed_inputs.append(v) 2676 self._feed_input_shapes.append(backend.int_shape(v)) 2677 2678 return inputs 2679 2680 @trackable.no_automatic_dependency_tracking 2681 def _set_output_attrs(self, outputs): 2682 """Sets attributes related to the outputs of the Model.""" 2683 # NOTE(taylorrobie): This convention cannot be changed without updating the 2684 # data adapter since it assumes nest.flatten ordering. 2685 outputs = nest.flatten(outputs) 2686 self.outputs = outputs 2687 self.output_names = training_utils_v1.generic_output_names(outputs) 2688 # TODO(scottzhu): Should we cleanup the self._training_endpoints here? 2689 self.built = True 2690 2691 @property 2692 def _targets(self): 2693 """The output target tensors for the model.""" 2694 return [ 2695 e.training_target.target 2696 for e in self._training_endpoints 2697 if e.has_training_target() 2698 ] 2699 2700 @property 2701 def _feed_targets(self): 2702 return [ 2703 e.training_target.target 2704 for e in self._training_endpoints 2705 if e.has_feedable_training_target() 2706 ] 2707 2708 @property 2709 def _feed_output_names(self): 2710 return [ 2711 e.output_name 2712 for e in self._training_endpoints 2713 if e.has_feedable_training_target() 2714 ] 2715 2716 @property 2717 def _feed_output_shapes(self): 2718 return [ 2719 e.feed_output_shape 2720 for e in self._training_endpoints 2721 if e.has_feedable_training_target() 2722 ] 2723 2724 @property 2725 def _feed_loss_fns(self): 2726 return [ 2727 e.loss_fn 2728 for e in self._training_endpoints 2729 if e.has_feedable_training_target() 2730 ] 2731 2732 @property 2733 def _loss_weights_list(self): 2734 return [e.loss_weight for e in self._training_endpoints] 2735 2736 @property 2737 def _output_loss_metrics(self): 2738 if hasattr(self, '_training_endpoints'): 2739 return [ 2740 e.output_loss_metric 2741 for e in self._training_endpoints 2742 if e.output_loss_metric is not None 2743 ] 2744 return None 2745 2746 @property 2747 def sample_weights(self): 2748 return [e.sample_weight for e in self._training_endpoints] 2749 2750 @property 2751 def _sample_weight_modes(self): 2752 return [e.sample_weight_mode for e in self._training_endpoints] 2753 2754 @property 2755 def _feed_sample_weights(self): 2756 return [e.sample_weight for e in self._training_endpoints 2757 if e.sample_weight is not None] 2758 2759 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): 2760 """Maybe load initial epoch from ckpt considering possible worker recovery. 2761 2762 Refer to tensorflow/python/keras/distribute/worker_training_state.py 2763 for more information. 2764 2765 Args: 2766 initial_epoch: The original initial_epoch user passes in in `fit()`. 2767 mode: The mode for running `model.fit()`. 2768 2769 Returns: 2770 If the training is recovering from previous failure under multi-worker 2771 training setting, return the epoch the training is supposed to continue 2772 at. Otherwise, return the `initial_epoch` the user passes in. 2773 """ 2774 if self._training_state is not None: 2775 return self._training_state.maybe_load_initial_epoch_from_ckpt( 2776 initial_epoch, mode) 2777 return initial_epoch 2778 2779 def _get_training_eval_metrics(self): 2780 """Returns all the metrics that are to be reported. 2781 2782 This includes the output loss metrics, compile metrics/weighted metrics, 2783 add_metric metrics. 2784 """ 2785 metrics = [] 2786 metrics.extend(getattr(self, '_output_loss_metrics', None) or []) 2787 metrics.extend(getattr(self, 'metrics', None) or []) 2788 return metrics 2789 2790 def _assert_compile_was_called(self): 2791 # Checks whether `compile` has been called. If it has been called, 2792 # then the optimizer is set. This is different from whether the 2793 # model is compiled 2794 # (i.e. whether the model is built and its inputs/outputs are set). 2795 if not self._compile_was_called: 2796 raise RuntimeError('You must compile your model before ' 2797 'training/testing. ' 2798 'Use `model.compile(optimizer, loss)`.') 2799 2800 def _in_multi_worker_mode(self): 2801 """Method to infer if this `Model` is working in multi-worker settings. 2802 2803 Multi-worker training refers to the setup where the training is 2804 distributed across multiple workers, as opposed to the case where 2805 only a local process performs the training. This function is 2806 used to infer for example whether or not a distribute coordinator 2807 should be run, and thus TensorFlow servers should be started for 2808 communication with other servers in the cluster, or whether or not 2809 saving/restoring checkpoints is relevant for preemption fault tolerance. 2810 2811 Experimental. Signature and implementation are subject to change. 2812 2813 Returns: 2814 Whether this model indicates it's working in multi-worker settings. 2815 """ 2816 strategy = self._distribution_strategy 2817 2818 # Otherwise, use the strategy whose scope this is in. 2819 if not strategy and distribution_strategy_context.has_strategy(): 2820 strategy = distribution_strategy_context.get_strategy() 2821 return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2822 2823 @property 2824 def _trackable_saved_model_saver(self): 2825 return model_serialization.ModelSavedModelSaver(self) 2826 2827 def _get_compile_args(self, user_metrics=True): 2828 del user_metrics 2829 self._assert_compile_was_called() 2830 kwargs = { 2831 'loss': self.loss, 2832 'metrics': self._compile_metrics, 2833 'loss_weights': self.loss_weights, 2834 'sample_weight_mode': self.sample_weight_mode, 2835 'weighted_metrics': self._compile_weighted_metrics, 2836 } 2837 return kwargs 2838 2839 @property 2840 def _compile_was_called(self): 2841 return self._v1_compile_was_called 2842 2843 2844class DistributedCallbackModel(Model): 2845 """Model that is used for callbacks with tf.distribute.Strategy.""" 2846 2847 def __init__(self, model): 2848 super(DistributedCallbackModel, self).__init__() 2849 self.optimizer = model.optimizer 2850 2851 def set_original_model(self, orig_model): 2852 self._original_model = orig_model 2853 2854 def save_weights(self, filepath, overwrite=True, save_format=None): 2855 self._replicated_model.save_weights(filepath, overwrite=overwrite, 2856 save_format=save_format) 2857 2858 def save(self, filepath, overwrite=True, include_optimizer=True): 2859 # save weights from the distributed model to the original model 2860 distributed_model_weights = self.get_weights() 2861 self._original_model.set_weights(distributed_model_weights) 2862 # TODO(anjalisridhar): Do we need to save the original model here? 2863 # Saving the first replicated model works as well. 2864 self._original_model.save(filepath, overwrite=True, include_optimizer=False) 2865 2866 def load_weights(self, filepath, by_name=False): 2867 self._original_model.load_weights(filepath, by_name=False) 2868 # Copy the weights from the original model to each of the replicated models. 2869 orig_model_weights = self._original_model.get_weights() 2870 distributed_training_utils_v1.set_weights( 2871 self._original_model._distribution_strategy, self, # pylint: disable=protected-access 2872 orig_model_weights) 2873 2874 def __getattr__(self, item): 2875 # Allowed attributes of the model that can be accessed by the user 2876 # during a callback. 2877 if item not in ('_setattr_tracking', '_layers'): 2878 logging.warning('You are accessing attribute ' + item + ' of the ' 2879 'DistributedCallbackModel that may not have been set ' 2880 'correctly.') 2881 return super(DistributedCallbackModel, self).__getattr__(item) 2882 2883 2884class _TrainingEndpoint(object): 2885 """A container for the training output/target and related entities. 2886 2887 In the case of model with multiple outputs, there is a one-to-one mapping 2888 between model output (y_pred), model target (y_true), loss, metrics etc. 2889 By unifying these entities into one class, different entity can access 2890 information between each other, rather than currently access different list of 2891 attributes of the model. 2892 """ 2893 2894 def __init__(self, 2895 output, 2896 output_name, 2897 loss_fn, 2898 loss_weight=None, 2899 training_target=None, 2900 output_loss_metric=None, 2901 sample_weight=None, 2902 sample_weight_mode=None): 2903 """Initialize the _TrainingEndpoint. 2904 2905 Note that the output and output_name should be stable as long as the model 2906 structure doesn't change. The training_target suppose to be mutable since 2907 the information is provided via `compile()` 2908 2909 Args: 2910 output: the output tensor of the model. 2911 output_name: the unique name of the output tensor. 2912 loss_fn: the loss function for the output tensor. 2913 loss_weight: float, the weights for the loss. 2914 training_target: the _TrainingTarget for the model. 2915 output_loss_metric: the metric object for the loss function. 2916 sample_weight: the weights for how a sample is weighted during metric and 2917 loss calculation. Could be None. 2918 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for 2919 how the sample_weight is populated. 2920 """ 2921 self._output = output 2922 self._output_name = output_name 2923 self._loss_fn = loss_fn 2924 self._loss_weight = loss_weight 2925 self._training_target = training_target 2926 self._output_loss_metric = output_loss_metric 2927 self._sample_weight = sample_weight 2928 self._sample_weight_mode = sample_weight_mode 2929 2930 @property 2931 def output(self): 2932 return self._output 2933 2934 @property 2935 def output_name(self): 2936 return self._output_name 2937 2938 @property 2939 def shape(self): 2940 return backend.int_shape(self.output) 2941 2942 @property 2943 def loss_fn(self): 2944 return self._loss_fn 2945 2946 @property 2947 def loss_weight(self): 2948 return self._loss_weight 2949 2950 @loss_weight.setter 2951 def loss_weight(self, value): 2952 self._loss_weight = value 2953 2954 @property 2955 def training_target(self): 2956 return self._training_target 2957 2958 @training_target.setter 2959 def training_target(self, value): 2960 self._training_target = value 2961 2962 def create_training_target(self, target, run_eagerly=False): 2963 """Create training_target instance and update the self.training_target. 2964 2965 Note that the input target should just be a tensor or None, and 2966 corresponding training target will be created based on the output and 2967 loss_fn. 2968 2969 Args: 2970 target: the target tensor for the current output. Could be None. 2971 run_eagerly: boolean, whether the model is in run_eagerly mode. 2972 2973 Raises: 2974 ValueError if the training_target field for the current instance has 2975 already been populated. 2976 """ 2977 if self.has_training_target(): 2978 raise ValueError('The training_target field for the _TrainingEndpoint ' 2979 'instance has already been populated') 2980 if run_eagerly: 2981 # When run_eagerly, the target tensor is ignored, and the None placeholder 2982 # is created instead. 2983 self.training_target = _TrainingTarget( 2984 None, feedable=True, skip_target_weights=False) 2985 return 2986 2987 if self.should_skip_target(): 2988 self.training_target = _TrainingTarget(None) 2989 else: 2990 if target is not None and not backend.is_placeholder(target): 2991 feedable = False 2992 skip_target_weights = True 2993 else: 2994 feedable = True 2995 skip_target_weights = False 2996 2997 if target is None: 2998 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( 2999 self.loss_fn, backend.dtype(self.output)) 3000 3001 target = backend.placeholder( 3002 ndim=len(self.shape), 3003 name=self.output_name + '_target', 3004 sparse=backend.is_sparse(self.output), 3005 dtype=target_dtype) 3006 3007 self.training_target = _TrainingTarget( 3008 target, 3009 feedable=feedable, 3010 skip_target_weights=skip_target_weights) 3011 3012 @property 3013 def output_loss_metric(self): 3014 return self._output_loss_metric 3015 3016 @output_loss_metric.setter 3017 def output_loss_metric(self, value): 3018 self._output_loss_metric = value 3019 3020 @property 3021 def sample_weight(self): 3022 return self._sample_weight 3023 3024 @sample_weight.setter 3025 def sample_weight(self, value): 3026 self._sample_weight = value 3027 3028 @property 3029 def sample_weight_mode(self): 3030 return self._sample_weight_mode 3031 3032 @sample_weight_mode.setter 3033 def sample_weight_mode(self, value): 3034 self._sample_weight_mode = value 3035 3036 def should_skip_target(self): 3037 return self._loss_fn is None 3038 3039 def should_skip_target_weights(self): 3040 return (self.should_skip_target() or self.training_target is None or 3041 self.training_target.skip_target_weights) 3042 3043 def has_training_target(self): 3044 return self.training_target is not None 3045 3046 def has_feedable_training_target(self): 3047 return (not self.should_skip_target() and 3048 self.training_target is not None and self.training_target.feedable) 3049 3050 def loss_name(self): 3051 if self._loss_fn is not None: 3052 return self._output_name + '_loss' 3053 return None 3054 3055 @property 3056 def feed_output_shape(self): 3057 """The output shape for the feedable target.""" 3058 if not self.has_feedable_training_target(): 3059 return None 3060 3061 if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and 3062 self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( 3063 isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)): 3064 if backend.image_data_format() == 'channels_first': 3065 return (self.shape[0], 1) + self.shape[2:] 3066 else: 3067 return self.shape[:-1] + (1,) 3068 elif (not isinstance(self.loss_fn, losses.Loss) or 3069 (isinstance(self.loss_fn, losses.LossFunctionWrapper) and 3070 (getattr(losses, self.loss_fn.fn.__name__, None) is None))): 3071 # If the given loss is not an instance of the `Loss` class (custom 3072 # class) or if the loss function that is wrapped is not in the 3073 # `losses` module, then it is a user-defined loss and we make no 3074 # assumptions about it. 3075 return None 3076 else: 3077 return self.shape 3078 3079 def sample_weights_mismatch(self): 3080 """Check if the sample weight and the mode match or not.""" 3081 # If there is a mismatch between sample weight mode and the placeholders 3082 # created, then recompile the sub-graphs that depend on sample weights. 3083 return ( 3084 (self.sample_weight_mode is not None and self.sample_weight is None) or 3085 (self.sample_weight_mode is None and self.sample_weight is not None)) 3086 3087 def populate_sample_weight(self, sample_weight, sample_weight_mode): 3088 """Populate the sample weight and based on the sample weight mode.""" 3089 if (sample_weight is None and 3090 (self.should_skip_target_weights() or sample_weight_mode is None or 3091 context.executing_eagerly())): 3092 self._sample_weight = None 3093 return 3094 3095 assert sample_weight_mode in ['temporal', 'samplewise'] 3096 if sample_weight_mode == 'temporal': 3097 default_value = [[1.]] 3098 shape = [None, None] 3099 else: 3100 # sample_weight_mode == 'samplewise' 3101 default_value = [1.] 3102 shape = [None] 3103 3104 if sample_weight is not None: 3105 if not sample_weight.shape.is_compatible_with(shape): 3106 raise ValueError('Received sample weight with shape {}. Expected shape ' 3107 '{}.'.format(sample_weight.shape, shape)) 3108 self._sample_weight = sample_weight 3109 else: 3110 self._sample_weight = array_ops.placeholder_with_default( 3111 constant_op.constant(default_value, dtype=backend.floatx()), 3112 shape=shape, 3113 name=self.output_name + '_sample_weights') 3114 3115 3116class _TrainingTarget(object): 3117 """Container for a target tensor (y_true) and its metadata (shape, loss...). 3118 3119 Args: 3120 target: A target tensor for the model. It may be `None` if the 3121 output is excluded from loss computation. It is still kept as None 3122 since each output of the model should have a corresponding target. If 3123 the target is None, the rest of the attributes will be None as well. 3124 feedable: Boolean, whether the target is feedable (requires data to be 3125 passed in `fit` or `train_on_batch`), or not (model compiled with 3126 `target_tensors` argument). 3127 skip_target_weights: Boolean, whether the target should be skipped during 3128 weights calculation. 3129 """ 3130 3131 def __init__(self, target, feedable=False, skip_target_weights=True): 3132 self._target = target 3133 self._feedable = feedable 3134 self._skip_target_weights = skip_target_weights 3135 3136 @property 3137 def target(self): 3138 return self._target 3139 3140 @property 3141 def feedable(self): 3142 return self._feedable 3143 3144 @property 3145 def skip_target_weights(self): 3146 return self._skip_target_weights 3147 3148 3149def _is_symbolic_tensor(x): 3150 return tensor_util.is_tf_type(x) 3151 3152 3153def _convert_scipy_sparse_tensor(value, expected_input): 3154 """Handle scipy sparse tensor conversions. 3155 3156 This method takes a value 'value' and returns the proper conversion. If 3157 value is a scipy sparse tensor and the expected input is a dense tensor, 3158 we densify 'value'. If value is a scipy sparse tensor and the expected input 3159 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is 3160 not a scipy sparse tensor, or scipy is not imported, we pass it through 3161 unchanged. 3162 3163 Args: 3164 value: An object that may be a scipy sparse tensor 3165 expected_input: The expected input placeholder. 3166 3167 Returns: 3168 The possibly-converted 'value'. 3169 """ 3170 if issparse is not None and issparse(value): 3171 if backend.is_sparse(expected_input): 3172 sparse_coo = value.tocoo() 3173 row, col = sparse_coo.row, sparse_coo.col 3174 data, shape = sparse_coo.data, sparse_coo.shape 3175 indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), 3176 1) 3177 return sparse_tensor.SparseTensor(indices, data, shape) 3178 else: 3179 if ops.executing_eagerly_outside_functions(): 3180 # In TF2 we do not silently densify sparse matrices. 3181 raise ValueError('A SciPy sparse matrix was passed to a model ' 3182 'that expects dense inputs. Please densify your ' 3183 'inputs first, such as by calling `x.toarray().') 3184 return value.toarray() 3185 else: 3186 return value 3187 3188 3189def _get_metrics_from_layers(layers): 3190 """Returns list of metrics from the given layers. 3191 3192 This will not include the `compile` metrics of a model layer. 3193 3194 Args: 3195 layers: List of layers. 3196 3197 Returns: 3198 List of metrics. 3199 """ 3200 metrics = [] 3201 layers = layer_utils.filter_empty_layer_containers(layers) 3202 for layer in layers: 3203 if isinstance(layer, Model): 3204 # We cannot call 'metrics' on the model because we do not want to 3205 # include the metrics that were added in compile API of a nested model. 3206 metrics.extend(layer._metrics) # pylint: disable=protected-access 3207 metrics.extend(_get_metrics_from_layers(layer.layers)) 3208 else: 3209 metrics.extend(layer.metrics) 3210 return metrics 3211 3212 3213def _non_none_constant_value(v): 3214 constant_value = tensor_util.constant_value(v) 3215 return constant_value if constant_value is not None else v 3216