1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related utilities.""" 16 17import abc 18import atexit 19import collections 20import functools 21import multiprocessing.pool 22import threading 23import time 24 25import numpy as np 26 27from tensorflow.core.framework import graph_pb2 28from tensorflow.python import tf2 29from tensorflow.python.data.experimental.ops import cardinality 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.data.ops import iterator_ops 32from tensorflow.python.data.ops import options as options_lib 33from tensorflow.python.eager import context 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import smart_cond 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.framework import tensor_util 42from tensorflow.python.keras import backend 43from tensorflow.python.keras import callbacks as cbks 44from tensorflow.python.keras import losses 45from tensorflow.python.keras import metrics as metrics_module 46from tensorflow.python.keras.utils import data_utils 47from tensorflow.python.keras.utils import generic_utils 48from tensorflow.python.keras.utils import losses_utils 49from tensorflow.python.keras.utils import tf_inspect 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import gen_array_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import sparse_ops 54from tensorflow.python.ops.ragged import ragged_tensor 55from tensorflow.python.ops.ragged import ragged_tensor_value 56from tensorflow.python.platform import tf_logging as logging 57from tensorflow.python.util import nest 58 59 60def is_composite_or_composite_value(tensor): 61 """Returns true if 'tensor' is a CompositeTensor or a CT Value object.""" 62 # TODO(b/125094323): This should be isinstance(CompositeTensor) or 63 # isinstance(CompositeTensorValue) once we support that. 64 return isinstance( 65 tensor, 66 (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue, 67 ragged_tensor_value.RaggedTensorValue)) 68 69 70class Aggregator(object, metaclass=abc.ABCMeta): 71 """Abstract base class used to aggregate batch-level outputs of a loop. 72 73 Attributes: 74 use_steps: Whether the loop is using `step` or `batch_size`. 75 num_samples: Total number of samples: `batch_size * num_batches`. 76 steps: Total number of steps. 77 batch_size: Batch size. It is used for validation checks between inputs and 78 outputs. 79 results: What to return at the end of the aggregation loop. 80 """ 81 82 def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None): 83 self.use_steps = use_steps 84 self.num_samples = num_samples 85 self.steps = steps 86 self.batch_size = batch_size 87 self.results = [] 88 89 @abc.abstractmethod 90 def create(self, batch_outs): 91 """Creates the initial results from the first batch outputs. 92 93 Args: 94 batch_outs: A list of batch-level outputs. 95 """ 96 raise NotImplementedError('Must be implemented in subclasses.') 97 98 @abc.abstractmethod 99 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 100 """Aggregates batch-level results into total results. 101 102 Args: 103 batch_outs: A list of batch-level outputs. 104 batch_start: The start index of this batch. Always `None` if `use_steps` 105 is `True`. 106 batch_end: The end index of this batch. Always `None` if `use_steps` is 107 `True`. 108 """ 109 raise NotImplementedError('Must be implemented in subclasses.') 110 111 @abc.abstractmethod 112 def finalize(self): 113 """Prepares the total results to be returned.""" 114 raise NotImplementedError('Must be implemented in subclasses.') 115 116 117class MetricsAggregator(Aggregator): 118 """Aggregator that calculates loss and metrics info. 119 120 Attributes: 121 use_steps: Whether the loop is using `step` or `batch_size`. 122 num_samples: Total number of samples: `batch_size*num_batches`. 123 steps: Total number of steps, ie number of times to iterate over a dataset 124 to cover all samples. 125 """ 126 127 def __init__(self, use_steps, num_samples=None, steps=None): 128 super(MetricsAggregator, self).__init__( 129 use_steps=use_steps, 130 num_samples=num_samples, 131 steps=steps, 132 batch_size=None) 133 134 def create(self, batch_outs): 135 self.results = [0.] * len(batch_outs) 136 137 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 138 # Loss. 139 if self.use_steps: 140 self.results[0] += batch_outs[0] 141 else: 142 self.results[0] += batch_outs[0] * (batch_end - batch_start) 143 # Metrics (always stateful, just grab current values.) 144 self.results[1:] = batch_outs[1:] 145 146 def finalize(self): 147 if not self.results: 148 raise ValueError('Empty training data.') 149 self.results[0] /= (self.num_samples or self.steps) 150 151 152def _append_sparse_tensor_value(target, to_append): 153 """Append sparse tensor value objects.""" 154 # Make sure the sparse tensors are of the same size (except for the 0th dim). 155 if len(target.dense_shape) != len(to_append.dense_shape): 156 raise RuntimeError( 157 'Unable to concatenate %s and %s. The inner dense shapes do not ' 158 'have the same number of dimensions (%s vs %s)' % 159 (target, to_append, target.dense_shape, to_append.dense_shape)) 160 161 if target.dense_shape[1:] != to_append.dense_shape[1:]: 162 raise RuntimeError( 163 'Unable to concatenate %s and %s. The inner dense shapes do not ' 164 'match inner dimensions (%s vs %s)' % 165 (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:])) 166 167 # Add the to_append indices to target, updating the 0th value, and keeping 168 # track of the maximum so we know the final dense_shape of this tensor. 169 base_dim0_value = target.dense_shape[0] 170 max_dim0_value = target.dense_shape[0] 171 new_indices = target.indices 172 for index in to_append.indices: 173 # Here, we iterate through the sparse indices of the tensor to append. For 174 # each index, we update its zeroth value (the batch index) by adding the 175 # number of batch items in the tensor we are appending to (so an index 176 # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim 177 # size 3 would become [3, 0, 1].) 178 index[0] += base_dim0_value 179 max_dim0_value = max(max_dim0_value, index[0]) 180 new_indices = np.append(new_indices, [index], axis=0) 181 182 # Extend the values array to contain all of the appended values. These will 183 # be in the same order as the indices added above. 184 new_values = np.concatenate((target.values, to_append.values), axis=0) 185 186 # Create a new dense shape by replacing the value for the 0th dimension 187 # with the new max dim0 value. 188 new_dense_shape = list(target.dense_shape) 189 new_dense_shape[0] = max_dim0_value + 1 190 new_dense_shape = tuple(new_dense_shape) 191 192 return sparse_tensor.SparseTensorValue( 193 indices=new_indices, values=new_values, dense_shape=new_dense_shape) 194 195 196def _append_ragged_tensor_value(target, to_append): 197 """Append ragged tensor value objects.""" 198 # Make sure the ragged tensors are of the same size (save for the 0th dim). 199 if len(target.shape) != len(to_append.shape): 200 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 201 202 if target.shape[1:] != to_append.shape[1:]: 203 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 204 205 adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1] 206 new_row_splits = np.append(target.row_splits, adjusted_row_splits) 207 if isinstance(target.values, ragged_tensor_value.RaggedTensorValue): 208 new_values = _append_ragged_tensor_value(target.values, to_append.values) 209 else: 210 new_values = np.concatenate((target.values, to_append.values), axis=0) 211 212 return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits) 213 214 215def _append_composite_tensor(target, to_append): 216 """Helper function to append composite tensors to each other in the 0 axis. 217 218 In order to support batching within a fit/evaluate/predict call, we need 219 to be able to aggregate within a CompositeTensor. Unfortunately, the CT 220 API currently does not make this easy - especially in V1 mode, where we're 221 working with CompositeTensor Value objects that have no connection with the 222 CompositeTensors that created them. 223 224 Args: 225 target: CompositeTensor or CompositeTensor value object that will be 226 appended to. 227 to_append: CompositeTensor or CompositeTensor value object to append to. 228 'target'. 229 230 Returns: 231 A CompositeTensor or CompositeTensor value object. 232 233 Raises: 234 RuntimeError: if concatenation is not possible. 235 """ 236 if type(target) is not type(to_append): 237 raise RuntimeError('Unable to concatenate %s and %s' % 238 (type(target), type(to_append))) 239 240 # Perform type-specific concatenation. 241 # TODO(b/125094323): This should be replaced by a simple call to 242 # target.append() that should work on all of the below classes. 243 244 # If we're seeing a CompositeTensor here, we know it's because we're in 245 # Eager mode (or else we'd have evaluated the CT to a CT Value object 246 # already). Therefore, it's safe to call concat() on it without evaluating 247 # the result any further. If not - that is, if we're seeing a 248 # SparseTensorValue or a RaggedTensorValue - we need to hand-update it 249 # since we're outside of the graph anyways. 250 if isinstance(target, sparse_tensor.SparseTensor): 251 # We need to invoke the sparse version of concatenate here - tf.concat 252 # won't work. 253 return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0) 254 elif isinstance(target, ragged_tensor.RaggedTensor): 255 return array_ops.concat([target, to_append], axis=0) 256 elif isinstance(target, sparse_tensor.SparseTensorValue): 257 return _append_sparse_tensor_value(target, to_append) 258 elif isinstance(target, ragged_tensor_value.RaggedTensorValue): 259 return _append_ragged_tensor_value(target, to_append) 260 else: 261 raise RuntimeError('Attempted to concatenate unsupported object %s.' % 262 type(target)) 263 264 265class ConcatAggregator(Aggregator): 266 """Combine tensor-likes which cannot be merged on the fly. 267 268 This class expects to aggregate a single tensor-like rather than a nested 269 structure of tensor-likes. 270 """ 271 272 def __init__(self, batch_size): 273 self.composite = None 274 super(ConcatAggregator, self).__init__( 275 use_steps=True, num_samples=None, steps=None, batch_size=batch_size) 276 277 def create(self, batch_element): 278 self.composite = is_composite_or_composite_value(batch_element) 279 280 def aggregate(self, batch_element, batch_start=None, batch_end=None): 281 282 # TODO(psv): Add num_samples check here to detect when output batch 283 # #samples is < batch size and != input batch #samples. 284 if self.batch_size and self.batch_size < batch_element.shape[0]: 285 raise ValueError( 286 'Mismatch between expected batch size and model output batch size. ' 287 'Output shape = {}, expected output shape = shape {}'.format( 288 batch_element.shape, 289 (self.batch_size,) + batch_element.shape[1:])) 290 self.results.append(batch_element) 291 292 def finalize(self): 293 # Special case of single batch inference which skips a copy. 294 if len(self.results) == 1: 295 self.results = self.results[0] 296 297 elif self.composite: 298 # TODO(taylorrobie): efficiently concatenate. 299 results = self.results[0] 300 for r in self.results[1:]: 301 results = _append_composite_tensor(results, r) 302 self.results = results 303 304 else: 305 self.results = np.concatenate(self.results, axis=0) 306 307 308_COPY_THREADS = 4 309_COPY_POOL = None 310 311 312def get_copy_pool(): 313 """Shared threadpool for copying arrays. 314 315 Pool instantiation takes ~ 2ms, so a singleton pool is used rather than 316 creating a pool per SliceAggregator. 317 318 Returns: 319 The global copy threadpool. 320 """ 321 global _COPY_POOL 322 if _COPY_POOL is None: 323 _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS) 324 atexit.register(_COPY_POOL.close) 325 return _COPY_POOL 326 327 328class SliceAggregator(Aggregator): 329 """Combine arrays where the final size is known. 330 331 This class expects to aggregate a single tensor-like rather than a nested 332 structure of tensor-likes. 333 334 NumPy copies are an operation that threads handle quite well because all of 335 the heavy lifting is in c and does not need the GIL. Moreover, we can perform 336 lock-free writes to the same buffer in multiple threads because the nature of 337 result aggregation guarantees that either the indices are disjoint or the 338 aggregator will throw an exception in finalize. Moreover, because aggregation 339 is performed on the slowest varying dimension, assignments for a given batch 340 will write to contiguous blocks of memory, further minimizing contention. 341 342 There is, however, some scheduling and context switching overhead which will 343 offset the gains from pipelining the slice assignment. Below a given threshold 344 it is faster to simply assign in the main thread rather than enqueue the 345 assignment in a side thread. The exact threshold will vary from system to 346 system, but the time is not very sensitive to the exact transition so a value 347 of 2 ** 14 was chosen which should be reasonable on most systems. 348 """ 349 350 _BINARY_SIZE_THRESHOLD = 2 ** 14 351 _MAX_COPY_SECONDS = 300 352 353 def __init__(self, num_samples, batch_size): 354 self._async_copies = [] 355 self._pool = get_copy_pool() 356 self._errors = [] 357 super(SliceAggregator, self).__init__( 358 use_steps=False, 359 num_samples=num_samples, 360 steps=None, 361 batch_size=batch_size) 362 363 def create(self, batch_element): 364 # This step does not need to be pipelined because NumPy empty array 365 # initialization is effectively instantaneous. 366 shape = (self.num_samples,) + batch_element.shape[1:] 367 dtype = batch_element.dtype 368 369 self.results = np.empty(shape=shape, dtype=dtype) 370 371 def aggregate(self, batch_element, batch_start, batch_end): 372 # Fail early. 373 if self._errors: 374 raise self._errors[0] 375 376 # In the special case of single batch inference, no copy is needed. 377 if batch_end - batch_start == self.num_samples: 378 if self.num_samples != batch_element.shape[0]: 379 raise ValueError( 380 'Mismatch between expected batch size and model output batch size. ' 381 'Output shape = {}, expected output shape = shape {}'.format( 382 batch_element.shape, self.results.shape)) 383 384 self.results = batch_element 385 return 386 387 # This is an approximate threshold, so we don't need to consider the number 388 # of bytes per element. 389 num_elements = np.prod(batch_element.shape) 390 if num_elements < self._BINARY_SIZE_THRESHOLD: 391 self.results[batch_start:batch_end] = batch_element 392 else: 393 is_finished = threading.Event() 394 self._pool.apply_async( 395 self._slice_assign, 396 args=(batch_element, batch_start, batch_end, is_finished)) 397 self._async_copies.append(is_finished) 398 399 def _slice_assign(self, batch_element, batch_start, batch_end, is_finished): 400 """Legacy utility method to slice input arrays.""" 401 try: 402 self.results[batch_start:batch_end] = batch_element 403 404 except Exception as e: # pylint: disable=broad-except 405 # `_slice_assign` should only be called in threads and exceptions raised 406 # in threads do not carry over to the main thread. So instead we perform a 407 # a broad catch in the thread and then store the exception to be re-raised 408 # in the main thread. 409 self._errors.append(e) 410 411 finally: 412 is_finished.set() 413 414 def finalize(self): 415 start_time = time.time() 416 for is_finished in self._async_copies: 417 timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)]) 418 if not is_finished.wait(timeout): 419 raise ValueError('Timed out waiting for copy to complete.') 420 421 if self._errors: 422 raise self._errors[0] 423 424 425class OutputsAggregator(Aggregator): 426 """Aggregator that concatenates outputs.""" 427 428 _structure = None 429 430 def create(self, batch_outs): 431 # SparseTensorValue is a named tuple which nest will flatten, so we need 432 # to guard it to properly handle the structure. 433 self._structure = nest.get_traverse_shallow_structure( 434 lambda x: not is_composite_or_composite_value(x), batch_outs) 435 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 436 437 for batch_element in batch_outs: 438 if is_composite_or_composite_value(batch_element): 439 # If the output is not a ndarray, it will be either a composite tensor 440 # or a composite tensor's Value object. In either case, we can't 441 # allocate an array to hold the object - we'll handle it later. 442 self.results.append(ConcatAggregator(self.batch_size)) 443 elif isinstance(batch_element, np.ndarray): 444 self.results.append( 445 (ConcatAggregator(self.batch_size) if self.use_steps else 446 SliceAggregator(self.num_samples, self.batch_size))) 447 else: 448 # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue. 449 # Fail fast rather than trying to concatenate it. 450 raise RuntimeError('Attempted to aggregate unsupported object {}.' 451 .format(batch_element)) 452 453 self.results[-1].create(batch_element) 454 455 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 456 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 457 for batch_element, result in zip(batch_outs, self.results): 458 result.aggregate(batch_element, batch_start, batch_end) 459 460 def finalize(self): 461 for result in self.results: 462 result.finalize() 463 self.results = [i.results for i in self.results] 464 self.results = nest.pack_sequence_as(self._structure, self.results) 465 466 467def get_progbar(model, count_mode, include_metrics=True): 468 """Get Progbar.""" 469 if include_metrics: 470 stateful_metric_names = getattr(model, 'metrics_names', None) 471 if stateful_metric_names: 472 stateful_metric_names = stateful_metric_names[1:] # Exclude `loss` 473 else: 474 stateful_metric_names = None 475 return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) 476 477 478def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): 479 """Determine the number of samples provided for training and evaluation. 480 481 The number of samples is not defined when running with `steps`, 482 in which case the number of samples is set to `None`. 483 484 Args: 485 ins: List of tensors to be fed to the Keras function. 486 batch_size: Integer batch size or `None` if not defined. 487 steps: Total number of steps (batches of samples) before declaring 488 `_predict_loop` finished. Ignored with the default value of `None`. 489 steps_name: The public API's parameter name for `steps`. 490 491 Raises: 492 ValueError: when `steps` is `None` and the attribute `ins.shape` 493 does not exist. Also raises ValueError when `steps` is not `None` 494 and `batch_size` is not `None` because they are mutually 495 exclusive. 496 497 Returns: 498 When steps is `None`, returns the number of samples to be 499 processed based on the size of the first dimension of the 500 first input numpy array. When steps is not `None` and 501 `batch_size` is `None`, returns `None`. 502 """ 503 if steps is not None and batch_size is not None: 504 raise ValueError('If ' + steps_name + 505 ' is set, the `batch_size` must be None.') 506 if check_steps_argument(ins, steps, steps_name): 507 return None 508 509 if hasattr(ins[0], 'shape'): 510 return int(ins[0].shape[0]) 511 return None # Edge case where ins == [static_learning_phase] 512 513 514def standardize_single_array(x, expected_shape=None): 515 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" 516 if x is None: 517 return None 518 519 if is_composite_or_composite_value(x): 520 return x 521 522 if isinstance(x, int): 523 raise ValueError( 524 'Expected an array data type but received an integer: {}'.format(x)) 525 526 if (x.shape is not None and len(x.shape) == 1 and 527 (expected_shape is None or len(expected_shape) != 1)): 528 if tensor_util.is_tf_type(x): 529 x = array_ops.expand_dims(x, axis=1) 530 else: 531 x = np.expand_dims(x, 1) 532 return x 533 534 535def get_composite_shape(tensor): 536 """Returns the shape of the passed composite tensor.""" 537 if isinstance(tensor, sparse_tensor.SparseTensorValue): 538 # SparseTensorValues use a 'dense_shape' attribute 539 return tensor.dense_shape 540 else: 541 return tensor.shape 542 543 544def standardize_input_data(data, 545 names, 546 shapes=None, 547 check_batch_axis=True, 548 exception_prefix=''): 549 """Normalizes inputs and targets provided by users. 550 551 Users may pass data as a list of arrays, dictionary of arrays, 552 or as a single array. We normalize this to an ordered list of 553 arrays (same order as `names`), while checking that the provided 554 arrays have shapes that match the network's expectations. 555 556 Args: 557 data: User-provided input data (polymorphic). 558 names: List of expected array names. 559 shapes: Optional list of expected array shapes. 560 check_batch_axis: Boolean; whether to check that the batch axis of the 561 arrays matches the expected value found in `shapes`. 562 exception_prefix: String prefix used for exception formatting. 563 564 Returns: 565 List of standardized input arrays (one array per model input). 566 567 Raises: 568 ValueError: in case of improperly formatted user-provided data. 569 """ 570 try: 571 data_len = len(data) 572 except TypeError: 573 # For instance if data is `None` or a symbolic Tensor. 574 data_len = None 575 576 if not names: 577 if data_len and not isinstance(data, dict): 578 raise ValueError( 579 'Error when checking model ' + exception_prefix + ': ' 580 'expected no data, but got:', data) 581 return [] 582 if data is None: 583 return [None for _ in range(len(names))] 584 585 if isinstance(data, dict): 586 try: 587 data = [ 588 data[x].values 589 if data[x].__class__.__name__ == 'DataFrame' else data[x] 590 for x in names 591 ] 592 except KeyError as e: 593 raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 594 'for each key in: ' + str(names)) 595 elif isinstance(data, (list, tuple)): 596 if isinstance(data[0], (list, tuple)): 597 data = [np.asarray(d) for d in data] 598 elif len(names) == 1 and isinstance(data[0], (float, int)): 599 data = [np.asarray(data)] 600 else: 601 data = [ 602 x.values if x.__class__.__name__ == 'DataFrame' else x for x in data 603 ] 604 else: 605 data = data.values if data.__class__.__name__ == 'DataFrame' else data 606 data = [data] 607 608 if shapes is not None: 609 data = [ 610 standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) 611 ] 612 else: 613 data = [standardize_single_array(x) for x in data] 614 615 if len(data) != len(names): 616 if data and hasattr(data[0], 'shape'): 617 raise ValueError('Error when checking model ' + exception_prefix + 618 ': the list of Numpy arrays that you are passing to ' 619 'your model is not the size the model expected. ' 620 'Expected to see ' + str(len(names)) + ' array(s), ' + 621 'for inputs ' + str(names) + ' but instead got the ' 622 'following list of ' + str(len(data)) + ' arrays: ' + 623 str(data)[:200] + '...') 624 elif len(names) > 1: 625 raise ValueError('Error when checking model ' + exception_prefix + 626 ': you are passing a list as input to your model, ' 627 'but the model expects a list of ' + str(len(names)) + 628 ' Numpy arrays instead. The list you passed was: ' + 629 str(data)[:200]) 630 elif len(data) == 1 and not hasattr(data[0], 'shape'): 631 raise TypeError('Error when checking model ' + exception_prefix + 632 ': data should be a Numpy array, or list/dict of ' 633 'Numpy arrays. Found: ' + str(data)[:200] + '...') 634 elif len(names) == 1: 635 data = [np.asarray(data)] 636 637 # Check shapes compatibility. 638 if shapes: 639 for i in range(len(names)): 640 if shapes[i] is not None: 641 if tensor_util.is_tf_type(data[i]): 642 tensorshape = data[i].shape 643 if not tensorshape: 644 continue 645 data_shape = tuple(tensorshape.as_list()) 646 elif is_composite_or_composite_value(data[i]): 647 tensorshape = get_composite_shape(data[i]) 648 data_shape = tuple(tensorshape.as_list()) 649 else: 650 data_shape = data[i].shape 651 652 shape = shapes[i] 653 if len(data_shape) != len(shape): 654 raise ValueError('Error when checking ' + exception_prefix + 655 ': expected ' + names[i] + ' to have ' + 656 str(len(shape)) + ' dimensions, but got array ' 657 'with shape ' + str(data_shape)) 658 if not check_batch_axis: 659 data_shape = data_shape[1:] 660 shape = shape[1:] 661 for dim, ref_dim in zip(data_shape, shape): 662 if ref_dim != dim and ref_dim is not None and dim is not None: 663 raise ValueError('Error when checking ' + exception_prefix + 664 ': expected ' + names[i] + ' to have shape ' + 665 str(shape) + ' but got array with shape ' + 666 str(data_shape)) 667 return data 668 669 670def standardize_sample_or_class_weights(x_weight, output_names, weight_type): 671 """Maps `sample_weight` or `class_weight` to model outputs. 672 673 Args: 674 x_weight: User-provided `sample_weight` or `class_weight` argument. 675 output_names: List of output names (strings) in the model. 676 weight_type: A string used purely for exception printing. 677 678 Returns: 679 A list of `sample_weight` or `class_weight` where there are exactly 680 one element per model output. 681 682 Raises: 683 ValueError: In case of invalid user-provided argument. 684 """ 685 if x_weight is None or (isinstance(x_weight, (list, tuple)) and 686 len(x_weight) == 0): # pylint: disable=g-explicit-length-test 687 return [None for _ in output_names] 688 if len(output_names) == 1: 689 if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: 690 return x_weight 691 if isinstance(x_weight, dict) and output_names[0] in x_weight: 692 return [x_weight[output_names[0]]] 693 else: 694 return [x_weight] 695 if isinstance(x_weight, (list, tuple)): 696 if len(x_weight) != len(output_names): 697 raise ValueError('Provided `' + weight_type + '` was a list of ' + 698 str(len(x_weight)) + ' elements, but the model has ' + 699 str(len(output_names)) + ' outputs. ' 700 'You should provide one `' + weight_type + '`' 701 'array per model output.') 702 return x_weight 703 if isinstance(x_weight, collections.abc.Mapping): 704 generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) 705 x_weights = [] 706 for name in output_names: 707 x_weights.append(x_weight.get(name)) 708 return x_weights 709 else: 710 raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' 711 'should be either a list or a dict. ' 712 'Provided `' + weight_type + '` type not understood: ' + 713 str(x_weight)) 714 715 716def standardize_class_weights(class_weight, output_names): 717 return standardize_sample_or_class_weights(class_weight, output_names, 718 'class_weight') 719 720 721def standardize_sample_weights(sample_weight, output_names): 722 return standardize_sample_or_class_weights(sample_weight, output_names, 723 'sample_weight') 724 725 726def check_array_lengths(inputs, targets, weights=None): 727 """Does user input validation for numpy arrays. 728 729 Args: 730 inputs: list of Numpy arrays of inputs. 731 targets: list of Numpy arrays of targets. 732 weights: list of Numpy arrays of sample weights. 733 734 Raises: 735 ValueError: in case of incorrectly formatted data. 736 """ 737 738 def is_tensor_or_composite_tensor(x): 739 return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x) 740 741 def set_of_lengths(x): 742 # Returns a set with the variation between 743 # different shapes, with None => 0 744 if x is None: 745 return {} 746 else: 747 return set([ 748 y.shape[0] 749 for y in x 750 if y is not None and not is_tensor_or_composite_tensor(y) 751 ]) 752 753 set_x = set_of_lengths(inputs) 754 set_y = set_of_lengths(targets) 755 set_w = set_of_lengths(weights) 756 if len(set_x) > 1: 757 raise ValueError('All input arrays (x) should have ' 758 'the same number of samples. Got array shapes: ' + 759 str([x.shape for x in inputs])) 760 if len(set_y) > 1: 761 raise ValueError('All target arrays (y) should have ' 762 'the same number of samples. Got array shapes: ' + 763 str([y.shape for y in targets])) 764 if set_x and set_y and list(set_x)[0] != list(set_y)[0]: 765 raise ValueError('Input arrays should have ' 766 'the same number of samples as target arrays. ' 767 'Found ' + str(list(set_x)[0]) + ' input samples ' 768 'and ' + str(list(set_y)[0]) + ' target samples.') 769 if len(set_w) > 1: 770 raise ValueError('All sample_weight arrays should have ' 771 'the same number of samples. Got array shapes: ' + 772 str([w.shape for w in weights])) 773 if set_y and set_w and list(set_y)[0] != list(set_w)[0]: 774 raise ValueError('Sample_weight arrays should have ' 775 'the same number of samples as target arrays. Got ' + 776 str(list(set_y)[0]) + ' input samples and ' + 777 str(list(set_w)[0]) + ' target samples.') 778 779 780def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 781 """Does validation on the compatibility of targets and loss functions. 782 783 This helps prevent users from using loss functions incorrectly. This check 784 is purely for UX purposes. 785 786 Args: 787 targets: list of Numpy arrays of targets. 788 loss_fns: list of loss functions. 789 output_shapes: list of shapes of model outputs. 790 791 Raises: 792 ValueError: if a loss function or target array 793 is incompatible with an output. 794 """ 795 key_loss_fns = { 796 losses.mean_squared_error, losses.binary_crossentropy, 797 losses.categorical_crossentropy 798 } 799 key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, 800 losses.CategoricalCrossentropy) 801 for y, loss, shape in zip(targets, loss_fns, output_shapes): 802 if y is None or loss is None or tensor_util.is_tf_type(y): 803 continue 804 if losses.is_categorical_crossentropy(loss): 805 if y.shape[-1] == 1: 806 raise ValueError('You are passing a target array of shape ' + 807 str(y.shape) + 808 ' while using as loss `categorical_crossentropy`. ' 809 '`categorical_crossentropy` expects ' 810 'targets to be binary matrices (1s and 0s) ' 811 'of shape (samples, classes). ' 812 'If your targets are integer classes, ' 813 'you can convert them to the expected format via:\n' 814 '```\n' 815 'from keras.utils import to_categorical\n' 816 'y_binary = to_categorical(y_int)\n' 817 '```\n' 818 '\n' 819 'Alternatively, you can use the loss function ' 820 '`sparse_categorical_crossentropy` instead, ' 821 'which does expect integer targets.') 822 823 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) 824 if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and 825 (loss.fn in key_loss_fns))): 826 for target_dim, out_dim in zip(y.shape[1:], shape[1:]): 827 if out_dim is not None and target_dim != out_dim: 828 loss_name = loss.name 829 if loss_name is None: 830 loss_type = loss.fn if is_loss_wrapper else type(loss) 831 loss_name = loss_type.__name__ 832 raise ValueError('A target array with shape ' + str(y.shape) + 833 ' was passed for an output of shape ' + str(shape) + 834 ' while using as loss `' + loss_name + '`. ' 835 'This loss expects targets to have the same shape ' 836 'as the output.') 837 838 839def collect_per_output_metric_info(metrics, 840 output_names, 841 output_shapes, 842 loss_fns, 843 from_serialized=False, 844 is_weighted=False): 845 """Maps metric names and functions to model outputs. 846 847 Args: 848 metrics: a list or a list of lists or a dict of metric functions. 849 output_names: a list of the names (strings) of model outputs. 850 output_shapes: a list of the shapes (strings) of model outputs. 851 loss_fns: a list of the loss functions corresponding to the model outputs. 852 from_serialized: whether the model the metrics are being sourced from is 853 being initialized from a serialized format. 854 is_weighted: Boolean indicating whether the given metrics are weighted. 855 856 Returns: 857 A list (one entry per model output) of dicts. 858 For instance, if the model has 2 outputs, and for the first output 859 we want to compute "binary_accuracy" and "binary_crossentropy", 860 and just "binary_accuracy" for the second output, 861 the list would look like: `[{ 862 'acc': binary_accuracy(), 863 'ce': binary_crossentropy(), 864 }, { 865 'acc': binary_accuracy(), 866 }]` 867 868 Raises: 869 TypeError: if an incorrect type is passed for the `metrics` argument. 870 """ 871 if not metrics: 872 return [{} for _ in output_names] 873 874 if isinstance(metrics, list): 875 any_sub_list = any(isinstance(m, list) for m in metrics) 876 if any_sub_list: 877 if len(metrics) != len(output_names): 878 raise ValueError('When passing a list of lists as `metrics`, ' 879 'it should have one entry per model output. ' 880 'The model has ' + str(len(output_names)) + 881 ' outputs, but you passed metrics=' + str(metrics)) 882 # User has provided a list of len = len(outputs). 883 nested_metrics = [generic_utils.to_list(m) for m in metrics] 884 else: 885 # If it is a single list we then apply all metrics to all outputs. 886 if len(output_names) > 1: 887 nested_metrics = [] 888 for _ in output_names: 889 nested_metrics.append( 890 [metrics_module.clone_metric(m) for m in metrics]) 891 else: 892 nested_metrics = [metrics] 893 elif isinstance(metrics, collections.abc.Mapping): 894 generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) 895 nested_metrics = [] 896 for name in output_names: 897 output_metrics = generic_utils.to_list(metrics.get(name, [])) 898 nested_metrics.append(output_metrics) 899 else: 900 raise TypeError('Type of `metrics` argument not understood. ' 901 'Expected a list or dictionary, found: ' + str(metrics)) 902 903 per_output_metrics = [] 904 for i, metrics in enumerate(nested_metrics): 905 metrics_dict = collections.OrderedDict() 906 for metric in metrics: 907 metric_name = get_metric_name(metric, is_weighted) 908 metric_fn = get_metric_function( 909 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) 910 metric_fn._from_serialized = from_serialized # pylint: disable=protected-access 911 912 # If the metric function is not stateful, we create a stateful version. 913 if not isinstance(metric_fn, metrics_module.Metric): 914 metric_fn = metrics_module.MeanMetricWrapper( 915 metric_fn, name=metric_name) 916 # If the metric is being revived from something stateless, such as a 917 # string (e.g. "accuracy"), we may need to later reapply transformations 918 # such as renaming. 919 metric_fn._from_serialized = False # pylint: disable=protected-access 920 metrics_dict[metric_name] = metric_fn 921 per_output_metrics.append(metrics_dict) 922 923 return per_output_metrics 924 925 926def batch_shuffle(index_array, batch_size): 927 """Shuffles an array in a batch-wise fashion. 928 929 Useful for shuffling HDF5 arrays 930 (where one cannot access arbitrary indices). 931 932 Args: 933 index_array: array of indices to be shuffled. 934 batch_size: integer. 935 936 Returns: 937 The `index_array` array, shuffled in a batch-wise fashion. 938 """ 939 batch_count = int(len(index_array) / batch_size) 940 # to reshape we need to be cleanly divisible by batch size 941 # we stash extra items and reappend them after shuffling 942 last_batch = index_array[batch_count * batch_size:] 943 index_array = index_array[:batch_count * batch_size] 944 index_array = index_array.reshape((batch_count, batch_size)) 945 np.random.shuffle(index_array) 946 index_array = index_array.flatten() 947 return np.append(index_array, last_batch) 948 949 950def standardize_weights(y, 951 sample_weight=None, 952 class_weight=None, 953 sample_weight_mode=None): 954 """Performs sample weight validation and standardization. 955 956 Everything gets normalized to a single sample-wise (or timestep-wise) 957 weight array. If both `sample_weight` and `class_weight` are provided, 958 the weights are multiplied. 959 960 Args: 961 y: Numpy array or Tensor of model targets to be weighted. 962 sample_weight: User-provided `sample_weight` argument. 963 class_weight: User-provided `class_weight` argument. 964 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated 965 that we expect 2D weight data that will be applied to the last 2 966 dimensions of the targets (i.e. we are weighting timesteps, not 967 samples). 968 969 Returns: 970 A numpy array of target weights, one entry per sample to weight. 971 972 Raises: 973 ValueError: In case of invalid user-provided arguments. 974 """ 975 # Iterator may return sample_weight as 1-tuple 976 if isinstance(sample_weight, tuple): 977 sample_weight = sample_weight[0] 978 if sample_weight_mode is not None and sample_weight_mode != 'samplewise': 979 if sample_weight_mode != 'temporal': 980 raise ValueError('"sample_weight_mode ' 981 'should be None or "temporal". ' 982 'Found: ' + str(sample_weight_mode)) 983 if len(y.shape) < 3: 984 raise ValueError('Found a sample_weight array for ' 985 'an input with shape ' + str(y.shape) + '. ' 986 'Timestep-wise sample weighting (use of ' 987 'sample_weight_mode="temporal") is restricted to ' 988 'outputs that are at least 3D, i.e. that have ' 989 'a time dimension.') 990 if sample_weight is not None and len(sample_weight.shape) != 2: 991 raise ValueError('Found a sample_weight array with shape ' + 992 str(sample_weight.shape) + '. ' 993 'In order to use timestep-wise sample weighting, ' 994 'you should pass a 2D sample_weight array.') 995 else: 996 if sample_weight is not None and len(sample_weight.shape) != 1: 997 raise ValueError( 998 'Found a sample_weight array with shape {}. In order to ' 999 'use timestep-wise sample weights, you should specify ' 1000 'sample_weight_mode="temporal" in compile(); founssd "{}" ' 1001 'instead. If you just mean to use sample-wise weights, ' 1002 'make sure your sample_weight array is 1D.'.format( 1003 sample_weight.shape, sample_weight_mode)) 1004 1005 if sample_weight is not None: 1006 if len(sample_weight.shape) > len(y.shape): 1007 raise ValueError('Found a sample_weight with shape' + 1008 str(sample_weight.shape) + '.' 1009 'Expected sample_weight with rank ' 1010 'less than or equal to ' + str(len(y.shape))) 1011 1012 if (not tensor_util.is_tf_type(sample_weight) and 1013 y.shape[:sample_weight.ndim] != sample_weight.shape): 1014 raise ValueError('Found a sample_weight array with shape ' + 1015 str(sample_weight.shape) + ' for an input with shape ' + 1016 str(y.shape) + '. ' 1017 'sample_weight cannot be broadcast.') 1018 1019 # Class weights applied per-sample. 1020 class_sample_weight = None 1021 if isinstance(class_weight, dict): 1022 if len(y.shape) > 2: 1023 raise ValueError('`class_weight` not supported for ' 1024 '3+ dimensional targets.') 1025 1026 if tensor_util.is_tf_type(y): 1027 # Few classes are expected, so densifying is reasonable. 1028 keys = np.array(sorted(class_weight.keys())) 1029 values = np.array([class_weight[i] for i in keys]) 1030 weight_vector = np.zeros(np.max(keys) + 1) 1031 weight_vector[:] = np.nan 1032 weight_vector[keys] = values 1033 1034 y_classes = smart_cond.smart_cond( 1035 len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1, 1036 lambda: backend.argmax(y, axis=1), 1037 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) 1038 class_sample_weight = array_ops.gather(weight_vector, y_classes) 1039 gen_array_ops.check_numerics( 1040 class_sample_weight, 1041 'Invalid classes or class weights detected. NaN values indicate that ' 1042 'an appropriate class weight could not be determined.') 1043 class_sample_weight = math_ops.cast(class_sample_weight, backend.floatx()) 1044 if sample_weight is not None: 1045 sample_weight = math_ops.cast( 1046 ops.convert_to_tensor_v2_with_dispatch(sample_weight), 1047 backend.floatx()) 1048 else: 1049 y_classes = y 1050 if len(y.shape) == 2: 1051 if y.shape[1] > 1: 1052 y_classes = np.argmax(y, axis=1) 1053 elif y.shape[1] == 1: 1054 y_classes = np.reshape(y, y.shape[0]) 1055 1056 class_sample_weight = np.asarray( 1057 [class_weight[cls] for cls in y_classes if cls in class_weight]) 1058 1059 if len(class_sample_weight) != len(y_classes): 1060 # subtract the sets to pick all missing classes 1061 existing_classes = set(y_classes) 1062 existing_class_weight = set(class_weight.keys()) 1063 raise ValueError( 1064 '`class_weight` must contain all classes in the data.' 1065 ' The classes %s exist in the data but not in ' 1066 '`class_weight`.' % (existing_classes - existing_class_weight)) 1067 1068 if class_sample_weight is not None and sample_weight is not None: 1069 # Multiply weights if both are provided. 1070 return class_sample_weight * sample_weight 1071 if sample_weight is not None: 1072 return sample_weight 1073 if class_sample_weight is not None: 1074 return class_sample_weight 1075 return None 1076 1077 1078def has_symbolic_tensors(ls): 1079 if context.executing_eagerly(): 1080 return False 1081 return has_tensors(ls) 1082 1083 1084def has_tensors(ls): 1085 """Returns true if `ls` contains tensors.""" 1086 # Note: at some point in time ragged tensors didn't count as tensors, so this 1087 # returned false for ragged tensors. Making this return true fails some tests 1088 # which would then require a steps_per_epoch argument. 1089 if isinstance(ls, (list, tuple)): 1090 return any( 1091 tensor_util.is_tf_type(v) and 1092 not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) 1093 if isinstance(ls, dict): 1094 return any( 1095 tensor_util.is_tf_type(v) and 1096 not isinstance(v, ragged_tensor.RaggedTensor) 1097 for _, v in ls.items()) 1098 return tensor_util.is_tf_type(ls) and not isinstance( 1099 ls, ragged_tensor.RaggedTensor) 1100 1101 1102def get_metric_name(metric, weighted=False): 1103 """Returns the name corresponding to the given metric input. 1104 1105 Args: 1106 metric: Metric function name or reference. 1107 weighted: Boolean indicating if the given metric is weighted. 1108 1109 Returns: 1110 The metric name. 1111 """ 1112 if tf2.enabled(): 1113 # We keep the string that the user has set in compile as the metric name. 1114 if isinstance(metric, str): 1115 return metric 1116 1117 metric = metrics_module.get(metric) 1118 return metric.name if hasattr(metric, 'name') else metric.__name__ 1119 else: 1120 metric_name_prefix = 'weighted_' if weighted else '' 1121 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): 1122 if metric in ('accuracy', 'acc'): 1123 suffix = 'acc' 1124 elif metric in ('crossentropy', 'ce'): 1125 suffix = 'ce' 1126 else: 1127 metric_fn = metrics_module.get(metric) 1128 # Get metric name as string 1129 if hasattr(metric_fn, 'name'): 1130 suffix = metric_fn.name 1131 else: 1132 suffix = metric_fn.__name__ 1133 metric_name = metric_name_prefix + suffix 1134 return metric_name 1135 1136 1137def get_metric_function(metric, output_shape=None, loss_fn=None): 1138 """Returns the metric function corresponding to the given metric input. 1139 1140 Args: 1141 metric: Metric function name or reference. 1142 output_shape: The shape of the output that this metric will be calculated 1143 for. 1144 loss_fn: The loss function used. 1145 1146 Returns: 1147 The metric function. 1148 """ 1149 if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: 1150 return metrics_module.get(metric) 1151 1152 is_sparse_categorical_crossentropy = ( 1153 isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or 1154 (isinstance(loss_fn, losses.LossFunctionWrapper) and 1155 loss_fn.fn == losses.sparse_categorical_crossentropy)) 1156 1157 is_binary_crossentropy = ( 1158 isinstance(loss_fn, losses.BinaryCrossentropy) or 1159 (isinstance(loss_fn, losses.LossFunctionWrapper) and 1160 loss_fn.fn == losses.binary_crossentropy)) 1161 1162 if metric in ['accuracy', 'acc']: 1163 if output_shape[-1] == 1 or is_binary_crossentropy: 1164 return metrics_module.binary_accuracy 1165 elif is_sparse_categorical_crossentropy: 1166 return metrics_module.sparse_categorical_accuracy 1167 # If the output_shape[-1] is not 1, then we know output is `categorical`. 1168 # We assume it is sparse categorical only if loss is explicitly given 1169 # as sparse categorical crossentropy loss. 1170 return metrics_module.categorical_accuracy 1171 else: 1172 if output_shape[-1] == 1 or is_binary_crossentropy: 1173 return metrics_module.binary_crossentropy 1174 elif is_sparse_categorical_crossentropy: 1175 return metrics_module.sparse_categorical_crossentropy 1176 return metrics_module.categorical_crossentropy 1177 1178 1179def call_metric_function(metric_fn, 1180 y_true, 1181 y_pred=None, 1182 weights=None, 1183 mask=None): 1184 """Invokes metric function and returns the metric result tensor.""" 1185 if mask is not None: 1186 mask = math_ops.cast(mask, y_pred.dtype) 1187 if weights is None: 1188 # Use mask as sample weight. 1189 weights = mask 1190 else: 1191 # Update dimensions of weights to match with mask. 1192 weights = math_ops.cast(weights, dtype=y_pred.dtype) 1193 mask, _, weights = losses_utils.squeeze_or_expand_dimensions( 1194 mask, sample_weight=weights) 1195 weights *= mask 1196 1197 if y_pred is not None: 1198 return metric_fn(y_true, y_pred, sample_weight=weights) 1199 # `Mean` metric only takes a single value. 1200 return metric_fn(y_true, sample_weight=weights) 1201 1202 1203def get_loss_function(loss): 1204 """Returns the loss corresponding to the loss input in `compile` API.""" 1205 if loss is None or isinstance(loss, losses.Loss): 1206 return loss 1207 1208 if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss): 1209 # It is not safe to assume that the loss takes no constructor arguments. 1210 raise ValueError( 1211 'Received uninstantiated Loss class: {}\nPlease call loss ""classes ' 1212 'before passing them to Model.compile.'.format(loss)) 1213 1214 # Deserialize loss configuration, if needed. 1215 if isinstance(loss, collections.abc.Mapping): 1216 loss = losses.get(loss) 1217 1218 # Custom callable class. 1219 if callable(loss) and not hasattr(loss, '__name__'): 1220 return loss 1221 1222 # Wrap loss function with signature `(y_true, y_pred, **kwargs)` 1223 # in `LossFunctionWrapper` class. 1224 loss_fn = losses.get(loss) 1225 1226 # For losses which are given as strings/functions in the compile API, 1227 # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` 1228 # (both in distribution strategy context and otherwise). 1229 return losses.LossFunctionWrapper( 1230 loss_fn, 1231 name=loss_fn.__name__, 1232 reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) 1233 1234 1235def validate_dataset_input(x, y, sample_weight, validation_split=None): 1236 """Validates user input arguments when a dataset iterator is passed. 1237 1238 Args: 1239 x: Input data. A `tf.data` dataset or iterator. 1240 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). 1241 Expected to be `None` when `x` is a dataset iterator. 1242 sample_weight: An optional sample-weight array passed by the user to weight 1243 the importance of each sample in `x`. Expected to be `None` when `x` is a 1244 dataset iterator 1245 validation_split: Float between 0 and 1. Fraction of the training data to be 1246 used as validation data. Expected to be `None` when `x` is a dataset 1247 iterator. 1248 1249 Raises: 1250 ValueError: if argument `y` or `sample_weight` or `validation_split` are 1251 provided by user. 1252 """ 1253 if y is not None: 1254 raise ValueError('You passed a dataset or dataset iterator (%s) as ' 1255 'input `x` to your model. In that case, you should ' 1256 'not specify a target (`y`) argument, since the dataset ' 1257 'or dataset iterator generates both input data and ' 1258 'target data. ' 1259 'Received: %s' % (x, y)) 1260 if sample_weight is not None: 1261 raise ValueError('`sample_weight` argument is not supported when input ' 1262 '`x` is a dataset or a dataset iterator. Instead, you' 1263 'can provide sample_weight as the third element of your' 1264 'dataset, i.e. (inputs, targets, sample_weight). ' 1265 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) 1266 if validation_split is not None and validation_split != 0.0: 1267 raise ValueError( 1268 '`validation_split` argument is not supported when ' 1269 'input `x` is a dataset or a dataset iterator. ' 1270 'Received: x=%s, validation_split=%f' % (x, validation_split)) 1271 1272 1273def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'): 1274 """Helper function to validate either inputs or targets.""" 1275 if isinstance(inp, (list, tuple)): 1276 if not all(isinstance(v, np.ndarray) or 1277 tensor_util.is_tf_type(v) for v in inp): 1278 raise ValueError( 1279 'Please provide as model inputs either a single array or a list of ' 1280 'arrays. You passed: {}={}'.format(field_name, str(orig_inp))) 1281 elif isinstance(inp, dict): 1282 if not allow_dict: 1283 raise ValueError( 1284 'You cannot pass a dictionary as model {}.'.format(field_name)) 1285 elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp): 1286 raise ValueError( 1287 'Please provide as model inputs either a single array or a list of ' 1288 'arrays. You passed: {}={}'.format(field_name, orig_inp)) 1289 1290 1291def check_generator_arguments(y=None, sample_weight=None, 1292 validation_split=None): 1293 """Validates arguments passed when using a generator.""" 1294 if y is not None: 1295 raise ValueError('`y` argument is not supported when data is' 1296 'a generator or Sequence instance. Instead pass targets' 1297 ' as the second element of the generator.') 1298 if sample_weight is not None: 1299 raise ValueError('`sample_weight` argument is not supported when data is' 1300 'a generator or Sequence instance. Instead pass sample' 1301 ' weights as the third element of the generator.') 1302 if validation_split: 1303 raise ValueError('If your data is in the form of a Python generator, ' 1304 'you cannot use `validation_split`.') 1305 1306 1307def check_steps_argument(input_data, steps, steps_name): 1308 """Validates `steps` argument based on input data's type. 1309 1310 The cases when `steps` value must be provided are when 1311 1. input data passed is an iterator. 1312 2. model was built on top of symbolic tensors, input data is not 1313 required and is `None`. 1314 3. input data passed is a symbolic tensor. 1315 1316 Args: 1317 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or 1318 tf.data.Dataset iterator or `None`. 1319 steps: Integer or `None`. Total number of steps (batches of samples) to 1320 execute. 1321 steps_name: The public API's parameter name for `steps`. 1322 1323 Returns: 1324 boolean, True if `steps` argument is required, else False. 1325 1326 Raises: 1327 ValueError: if `steps` argument is required for given input data type 1328 but not provided. 1329 """ 1330 is_x_iterator = isinstance( 1331 input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase)) 1332 if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or 1333 (isinstance(input_data, list) and not input_data)): 1334 if steps is None: 1335 input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' 1336 raise ValueError('When using {input_type} as input to a model, you should' 1337 ' specify the `{steps_name}` argument.'.format( 1338 input_type=input_type_str, steps_name=steps_name)) 1339 return True 1340 1341 if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 1342 return True 1343 1344 if steps is not None: 1345 list_types = (np.ndarray, list, tuple) 1346 if (isinstance(input_data, list_types) or 1347 (isinstance(input_data, dict) and 1348 any(isinstance(v, list_types) for v in input_data.values()))): 1349 logging.warning('When passing input data as arrays, do not specify ' 1350 '`steps_per_epoch`/`steps` argument. ' 1351 'Please use `batch_size` instead.') 1352 return False 1353 1354 1355def cast_single_tensor(x, dtype=None): 1356 if isinstance(x, np.ndarray): 1357 x = ops.convert_to_tensor_v2_with_dispatch(x) 1358 dtype = dtype or backend.floatx() 1359 if x.dtype.is_floating: 1360 return math_ops.cast(x, dtype=dtype) 1361 return x 1362 1363 1364def cast_if_floating_dtype_and_mismatch(targets, outputs): 1365 """Returns target data tensors using correct datatype. 1366 1367 Checks that each target and output pair are the same datatype. If not, casts 1368 the target to the output's datatype. 1369 1370 Args: 1371 targets: tensor or list of targets. 1372 outputs: tensor or list of outputs. 1373 1374 Returns: 1375 Targets in appropriate datatype. 1376 """ 1377 if tensor_util.is_tf_type(targets): 1378 # There is one target, so output[0] should be the only output. 1379 return cast_single_tensor(targets, dtype=outputs[0].dtype) 1380 new_targets = [] 1381 for target, out in zip(targets, outputs): 1382 if isinstance(target, np.ndarray): 1383 target = ops.convert_to_tensor_v2_with_dispatch(target) 1384 if target.dtype != out.dtype: 1385 new_targets.append(cast_single_tensor(target, dtype=out.dtype)) 1386 else: 1387 new_targets.append(target) 1388 return new_targets 1389 1390 1391def cast_if_floating_dtype(x, dtype=None): 1392 """Casts the given data tensors to the default floating point type. 1393 1394 Casts only if the input is already a floating point type. 1395 Args: 1396 x: tensor or list/tuple of tensors. 1397 dtype: The dtype to which Tensors should be cast. 1398 1399 Returns: 1400 Converted input. 1401 """ 1402 return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype), 1403 x) 1404 1405 1406def cast_to_model_input_dtypes(x, model): 1407 """Casts the given data tensors to the dtypes of the model inputs. 1408 1409 Args: 1410 x: tensor or list/tuple of tensors. 1411 model: The model. 1412 1413 Returns: 1414 Converted input. Each tensor is casted to the corresponding input in 1415 `model.inputs`. 1416 """ 1417 input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs) 1418 return nest.map_structure(math_ops.cast, x, input_dtypes) 1419 1420 1421def prepare_sample_weight_modes(training_endpoints, sample_weight_mode): 1422 """Prepares sample weight modes for the model. 1423 1424 Args: 1425 training_endpoints: List of model _TrainingEndpoints. 1426 sample_weight_mode: sample weight mode user input passed from compile API. 1427 1428 Raises: 1429 ValueError: In case of invalid `sample_weight_mode` input. 1430 """ 1431 1432 if isinstance(sample_weight_mode, collections.abc.Mapping): 1433 generic_utils.check_for_unexpected_keys( 1434 'sample_weight_mode', sample_weight_mode, 1435 [e.output_name for e in training_endpoints]) 1436 1437 for end_point in training_endpoints: 1438 if not end_point.should_skip_target_weights(): 1439 if end_point.output_name not in sample_weight_mode: 1440 raise ValueError('Output ' + end_point.output_name + 1441 'missing from `_sample_weight_modes` dictionary') 1442 else: 1443 end_point.sample_weight_mode = sample_weight_mode.get( 1444 end_point.output_name) 1445 elif isinstance(sample_weight_mode, (list, tuple)): 1446 if len(sample_weight_mode) != len(training_endpoints): 1447 raise ValueError('When passing a list as sample_weight_mode, ' 1448 'it should have one entry per model output. ' 1449 'The model has ' + str(len(training_endpoints)) + 1450 ' outputs, but you passed ' + 1451 str(len(sample_weight_mode)) + '_sample_weight_modes.') 1452 for mode, endpoint in zip(sample_weight_mode, training_endpoints): 1453 if not endpoint.should_skip_target_weights(): 1454 endpoint.sample_weight_mode = mode 1455 else: 1456 for endpoint in training_endpoints: 1457 if not endpoint.should_skip_target_weights(): 1458 endpoint.sample_weight_mode = sample_weight_mode 1459 1460 1461def prepare_loss_functions(loss, output_names): 1462 """Converts loss to a list of loss functions. 1463 1464 Args: 1465 loss: String (name of objective function), objective function or 1466 `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple 1467 outputs, you can use a different loss on each output by passing a 1468 dictionary or a list of losses. The loss value that will be minimized by 1469 the model will then be the sum of all individual losses. 1470 output_names: List of model output names. 1471 1472 Returns: 1473 A list of loss objective functions. 1474 1475 Raises: 1476 ValueError: If loss is a dict with keys not in model output names, 1477 or if loss is a list with len not equal to model outputs. 1478 """ 1479 if isinstance(loss, collections.abc.Mapping): 1480 generic_utils.check_for_unexpected_keys('loss', loss, output_names) 1481 loss_functions = [] 1482 for name in output_names: 1483 if name not in loss: 1484 logging.warning( 1485 'Output {0} missing from loss dictionary. We assume ' 1486 'this was done on purpose. The fit and evaluate APIs will not be ' 1487 'expecting any data to be passed to {0}.'.format(name)) 1488 loss_functions.append(get_loss_function(loss.get(name, None))) 1489 elif isinstance(loss, str): 1490 loss_functions = [get_loss_function(loss) for _ in output_names] 1491 elif isinstance(loss, collections.abc.Sequence): 1492 if len(loss) != len(output_names): 1493 raise ValueError('When passing a list as loss, it should have one entry ' 1494 'per model outputs. The model has {} outputs, but you ' 1495 'passed loss={}'.format(len(output_names), loss)) 1496 loss_functions = nest.map_structure(get_loss_function, loss) 1497 else: 1498 loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] 1499 1500 return loss_functions 1501 1502 1503def prepare_loss_weights(training_endpoints, loss_weights=None): 1504 """Converts loss weights to a list of loss weights. 1505 1506 The result loss weights will be populated on the training endpoint. 1507 1508 Args: 1509 training_endpoints: List of model training endpoints. 1510 loss_weights: Optional list or dictionary specifying scalar coefficients 1511 (Python floats) to weight the loss contributions of different model 1512 outputs. The loss value that will be minimized by the model will then be 1513 the *weighted sum* of all individual losses, weighted by the 1514 `loss_weights` coefficients. If a list, it is expected to have a 1:1 1515 mapping to the model's outputs. If a dict, it is expected to map 1516 output names (strings) to scalar coefficients. 1517 1518 Raises: 1519 ValueError: If loss weight is a dict with key not in model output names, 1520 or if loss is a list with len not equal to model outputs. 1521 """ 1522 if loss_weights is None: 1523 for e in training_endpoints: 1524 e.loss_weight = 1. 1525 elif isinstance(loss_weights, collections.abc.Mapping): 1526 generic_utils.check_for_unexpected_keys( 1527 'loss_weights', loss_weights, 1528 [e.output_name for e in training_endpoints]) 1529 for e in training_endpoints: 1530 e.loss_weight = loss_weights.get(e.output_name, 1.) 1531 elif isinstance(loss_weights, list): 1532 if len(loss_weights) != len(training_endpoints): 1533 raise ValueError('When passing a list as loss_weights, ' 1534 'it should have one entry per model output. ' 1535 'The model has ' + str(len(training_endpoints)) + 1536 ' outputs, but you passed loss_weights=' + 1537 str(loss_weights)) 1538 for w, e in zip(loss_weights, training_endpoints): 1539 e.loss_weight = w 1540 else: 1541 raise TypeError('Could not interpret loss_weights argument: ' + 1542 str(loss_weights) + ' - expected a list of dicts.') 1543 1544 1545# TODO(rohanj): This is a hack to get around not depending on feature_column and 1546# create a cyclical dependency. Figure out a cleaner solution 1547def is_feature_layer(layer): 1548 """Returns whether `layer` is a FeatureLayer or not.""" 1549 return getattr(layer, '_is_feature_layer', False) 1550 1551 1552def is_eager_dataset_or_iterator(data): 1553 return context.executing_eagerly() and isinstance( 1554 data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1555 iterator_ops.IteratorBase)) 1556 1557 1558# pylint: disable=protected-access 1559def get_dataset_graph_def(dataset): 1560 if context.executing_eagerly(): 1561 graph_def_str = dataset._as_serialized_graph().numpy() 1562 else: 1563 graph_def_str = backend.get_value(dataset._as_serialized_graph()) 1564 return graph_pb2.GraphDef().FromString(graph_def_str) 1565 1566 1567def verify_dataset_shuffled(x): 1568 """Verifies that the dataset is shuffled. 1569 1570 Args: 1571 x: Dataset passed as an input to the model. 1572 1573 Returns: 1574 boolean, whether the input dataset is shuffled or not. 1575 """ 1576 assert isinstance(x, dataset_ops.DatasetV2) 1577 graph_def = get_dataset_graph_def(x) 1578 for node in graph_def.node: 1579 if node.op.startswith('ShuffleDataset'): 1580 return True 1581 # Also check graph_def.library.function for ds.interleave or ds.flat_map 1582 for function in graph_def.library.function: 1583 for node in function.node_def: 1584 if node.op.startswith('ShuffleDataset'): 1585 return True 1586 logging.warning('Expected a shuffled dataset but input dataset `x` is ' 1587 'not shuffled. Please invoke `shuffle()` on input dataset.') 1588 return False 1589 1590 1591def is_dataset_or_iterator(data): 1592 return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1593 iterator_ops.Iterator, iterator_ops.IteratorBase)) 1594 1595 1596def get_iterator(dataset): 1597 """Create and initialize an iterator from a dataset.""" 1598 if context.executing_eagerly(): 1599 iterator = dataset_ops.make_one_shot_iterator(dataset) 1600 else: 1601 iterator = dataset_ops.make_initializable_iterator(dataset) 1602 initialize_iterator(iterator) 1603 return iterator 1604 1605 1606def initialize_iterator(iterator): 1607 if not context.executing_eagerly(): 1608 init_op = iterator.initializer 1609 backend.get_session((init_op,)).run(init_op) 1610 1611 1612def extract_tensors_from_dataset(dataset): 1613 """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. 1614 1615 Args: 1616 dataset: Dataset instance. 1617 1618 Returns: 1619 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1620 """ 1621 iterator = get_iterator(dataset) 1622 inputs, targets, sample_weight = unpack_iterator_input(iterator) 1623 return inputs, targets, sample_weight 1624 1625 1626def unpack_iterator_input(iterator): 1627 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. 1628 1629 Args: 1630 iterator: Instance of a dataset iterator. 1631 1632 Returns: 1633 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1634 """ 1635 try: 1636 next_element = iterator.get_next() 1637 except errors.OutOfRangeError: 1638 raise RuntimeError('Your dataset iterator ran out of data; ' 1639 'Make sure that your dataset can generate ' 1640 'required number of samples.') 1641 1642 if isinstance(next_element, (list, tuple)): 1643 if len(next_element) not in [2, 3]: 1644 raise ValueError( 1645 'Please provide model inputs as a list or tuple of 2 or 3 ' 1646 'elements: (input, target) or (input, target, sample_weights) ' 1647 'Received %s' % next_element) 1648 if len(next_element) == 2: 1649 x, y = next_element 1650 weights = None 1651 else: 1652 x, y, weights = next_element 1653 else: 1654 x = next_element 1655 y = None 1656 weights = None 1657 return x, y, weights 1658 1659 1660def infer_steps_for_dataset(model, 1661 dataset, 1662 steps, 1663 epochs=1, 1664 steps_name='steps'): 1665 """Infers steps_per_epoch needed to loop through a dataset. 1666 1667 Args: 1668 model: Keras model instance. 1669 dataset: Input data of type tf.data.Dataset. 1670 steps: Number of steps to draw from the dataset (may be None if unknown). 1671 epochs: Number of times to iterate over the dataset. 1672 steps_name: The string name of the steps argument, either `steps`, 1673 `validation_steps`, or `steps_per_epoch`. Only used for error message 1674 formatting. 1675 1676 Returns: 1677 Integer or `None`. Inferred number of steps to loop through the dataset. 1678 `None` is returned if 1) the size of the dataset is unknown and `steps` was 1679 not specified, or 2) this is multi-worker training and auto sharding is 1680 enabled. 1681 1682 Raises: 1683 ValueError: In case of invalid argument values. 1684 """ 1685 assert isinstance(dataset, dataset_ops.DatasetV2) 1686 if (model._in_multi_worker_mode() and 1687 (dataset.options().experimental_distribute.auto_shard_policy != 1688 options_lib.AutoShardPolicy.OFF)): 1689 # If the dataset would be auto-sharded, we should not infer a local 1690 # steps_per_epoch due to the possible inbalanced sharding between workers. 1691 return None 1692 1693 size = backend.get_value(cardinality.cardinality(dataset)) 1694 if size == cardinality.INFINITE and steps is None: 1695 raise ValueError('When passing an infinitely repeating dataset, you ' 1696 'must specify the `%s` argument.' % (steps_name,)) 1697 if size >= 0: 1698 if steps is not None and steps * epochs > size: 1699 if epochs > 1: 1700 raise ValueError('The dataset you passed contains %s batches, but you ' 1701 'passed `epochs=%s` and `%s=%s`, which is a total of ' 1702 '%s steps. We cannot draw that many steps from this ' 1703 'dataset. We suggest to set `%s=%s`.' % 1704 (size, epochs, steps_name, steps, steps * epochs, 1705 steps_name, size // epochs)) 1706 else: 1707 raise ValueError('The dataset you passed contains %s batches, but you ' 1708 'passed `%s=%s`. We cannot draw that many steps from ' 1709 'this dataset. We suggest to set `%s=%s`.' % 1710 (size, steps_name, steps, steps_name, size)) 1711 if steps is None: 1712 if size >= 0: 1713 return size 1714 return None 1715 return steps 1716 1717 1718class ModelInputs(object): 1719 """Encapsulates model inputs. 1720 1721 Allows for transforming model inputs while keeping the same structure. 1722 """ 1723 1724 def __init__(self, inputs): 1725 self._inputs = inputs 1726 self._is_dict = isinstance(self._inputs, dict) 1727 self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) 1728 1729 self._flattened_inputs = [] 1730 self._input_names = [] 1731 1732 if self._is_dict: 1733 for k in sorted(self._inputs.keys()): 1734 self._flattened_inputs.append(self._inputs[k]) 1735 self._input_names.append(k) 1736 else: 1737 self._flattened_inputs = nest.flatten(self._inputs) 1738 self._input_names = [ 1739 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) 1740 ] 1741 1742 def get_input_names(self): 1743 """Returns keys to name inputs by. 1744 1745 In case inputs provided were a list, tuple or single entry, we make up a 1746 key 'input_%d'. For dictionary case, we return a sorted list of keys. 1747 """ 1748 return self._input_names 1749 1750 def get_symbolic_inputs(self, return_single_as_list=False): 1751 """Returns inputs to be set as self.inputs for a model.""" 1752 # TODO(karmel): There is a side-effect here where what you get 1753 # with as_list and as_dict depends on whether you have called this 1754 # method first, since it modifies in place. 1755 for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)): 1756 if isinstance(v, (list, float, int)): 1757 v = np.asarray(v) 1758 if v.ndim == 1: 1759 v = np.expand_dims(v, 1) 1760 1761 if isinstance(v, np.ndarray): 1762 # We fix the placeholder shape except the batch size. 1763 # This is suboptimal, but it is the best we can do with the info 1764 # we have. The user should call `model._set_inputs(placeholders)` 1765 # to specify custom placeholders if the need arises. 1766 shape = (None,) + tuple(v.shape[1:]) 1767 if shape == (None,): 1768 shape = (None, 1) 1769 dtype = dtypes.as_dtype(v.dtype) 1770 if dtype.is_floating: 1771 dtype = backend.floatx() 1772 v = backend.placeholder(shape=shape, name=k, dtype=dtype) 1773 elif isinstance(v, tensor_spec.TensorSpec): 1774 shape = (None,) + tuple(v.shape.as_list()[1:]) 1775 if shape == (None,): 1776 shape = (None, 1) 1777 v = backend.placeholder(shape=shape, name=k, dtype=v.dtype) 1778 1779 self._flattened_inputs[i] = v 1780 1781 if self._is_dict: 1782 return dict(zip(self._input_names, self._flattened_inputs)) 1783 if self._is_single_input and not return_single_as_list: 1784 return self._flattened_inputs[0] 1785 return self._flattened_inputs 1786 1787 def as_dict(self): 1788 """An iterable over a dictionary version of inputs.""" 1789 for k, v in zip(self._input_names, self._flattened_inputs): 1790 yield k, v 1791 1792 def as_list(self): 1793 """Returning the inputs as a list.""" 1794 return self._flattened_inputs 1795 1796 1797# Allow use of methods not exposed to the user. 1798# pylint: disable=protected-access 1799 1800 1801# pylint: enable=protected-access 1802 1803 1804def generic_output_names(outputs_list): 1805 return ['output_%d' % (i + 1) for i in range(len(outputs_list))] 1806 1807 1808def should_run_validation(validation_freq, epoch): 1809 """Checks if validation should be run this epoch. 1810 1811 Args: 1812 validation_freq: Integer or list. If an integer, specifies how many training 1813 epochs to run before a new validation run is performed. If a list, 1814 specifies the epochs on which to run validation. 1815 epoch: Integer, the number of the training epoch just completed. 1816 1817 Returns: 1818 Bool, True if validation should be run. 1819 1820 Raises: 1821 ValueError: if `validation_freq` is an Integer and less than 1, or if 1822 it is neither an Integer nor a Sequence. 1823 """ 1824 # `epoch` is 0-indexed internally but 1-indexed in the public API. 1825 one_indexed_epoch = epoch + 1 1826 1827 if isinstance(validation_freq, int): 1828 if validation_freq < 1: 1829 raise ValueError('`validation_freq` can not be less than 1.') 1830 return one_indexed_epoch % validation_freq == 0 1831 1832 if not isinstance(validation_freq, collections.abc.Container): 1833 raise ValueError('`validation_freq` must be an Integer or ' 1834 '`collections.abc.Container` (e.g. list, tuple, etc.)') 1835 return one_indexed_epoch in validation_freq 1836 1837 1838def split_training_and_validation_data(x, y, sample_weights, validation_split): 1839 """Split input data into train/eval section based on validation_split.""" 1840 if has_symbolic_tensors(x): 1841 raise ValueError('If your data is in the form of symbolic tensors, ' 1842 'you cannot use `validation_split`.') 1843 if hasattr(x[0], 'shape'): 1844 split_at = int(x[0].shape[0] * (1. - validation_split)) 1845 else: 1846 split_at = int(len(x[0]) * (1. - validation_split)) 1847 x, val_x = (generic_utils.slice_arrays(x, 0, split_at), 1848 generic_utils.slice_arrays(x, split_at)) 1849 y, val_y = (generic_utils.slice_arrays(y, 0, split_at), 1850 generic_utils.slice_arrays(y, split_at)) 1851 if sample_weights: 1852 sample_weights, val_sample_weights = ( 1853 generic_utils.slice_arrays(sample_weights, 0, split_at), 1854 generic_utils.slice_arrays(sample_weights, split_at), 1855 ) 1856 else: 1857 val_sample_weights = None 1858 return x, y, sample_weights, val_x, val_y, val_sample_weights 1859 1860 1861def unpack_validation_data(validation_data, raise_if_ambiguous=True): 1862 """Unpack validation data based input type. 1863 1864 The validation data is not touched if its dataset or dataset iterator. 1865 For other type of input (Numpy or tensor), it will be unpacked into tuple of 1866 3 which is x, y and sample weights. 1867 1868 Args: 1869 validation_data: dataset, dataset iterator, or numpy, tensor tuple. 1870 raise_if_ambiguous: boolean on whether to fail if validation_data cannot be 1871 parsed. Otherwise simply return validation_data, None, None and defer the 1872 decision to the caller. 1873 1874 Returns: 1875 tuple of 3, (x, y, sample_weights) for numpy and tensor input. 1876 """ 1877 if (isinstance(validation_data, (iterator_ops.Iterator, 1878 iterator_ops.IteratorBase, 1879 dataset_ops.DatasetV2, 1880 data_utils.Sequence)) 1881 or not hasattr(validation_data, '__len__')): 1882 val_x = validation_data 1883 val_y = None 1884 val_sample_weight = None 1885 elif len(validation_data) == 2: 1886 try: 1887 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence 1888 val_sample_weight = None 1889 except ValueError: 1890 val_x, val_y, val_sample_weight = validation_data, None, None 1891 elif len(validation_data) == 3: 1892 try: 1893 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence 1894 except ValueError: 1895 val_x, val_y, val_sample_weight = validation_data, None, None 1896 else: 1897 if raise_if_ambiguous: 1898 raise ValueError( 1899 'When passing a `validation_data` argument, ' 1900 'it must contain either 2 items (x_val, y_val), ' 1901 'or 3 items (x_val, y_val, val_sample_weights), ' 1902 'or alternatively it could be a dataset or a ' 1903 'dataset or a dataset iterator. ' 1904 'However we received `validation_data=%s`' % validation_data) 1905 val_x, val_y, val_sample_weight = validation_data, None, None 1906 return val_x, val_y, val_sample_weight 1907 1908 1909class TrainingLoop(object): 1910 """TrainingLoop is a wrapper class around the training logic. 1911 1912 This class is trying to encapsulate the different logic of fit/eval/predict 1913 with regard to different data input and model condition. 1914 1915 Note that TrainingLoop is stateless, which means it doesn't contain any 1916 internal field and can be reused with different model and inputs. 1917 """ 1918 1919 def fit(self, 1920 model, 1921 x=None, 1922 y=None, 1923 batch_size=None, 1924 epochs=1, 1925 verbose=1, 1926 callbacks=None, 1927 validation_split=0., 1928 validation_data=None, 1929 shuffle=True, 1930 class_weight=None, 1931 sample_weight=None, 1932 initial_epoch=0, 1933 steps_per_epoch=None, 1934 validation_steps=None, 1935 validation_freq=1, 1936 **kwargs): 1937 """Train the model with the inputs and targets.""" 1938 raise NotImplementedError() 1939 1940 def evaluate(self, 1941 model, 1942 x=None, 1943 y=None, 1944 batch_size=None, 1945 verbose=1, 1946 sample_weight=None, 1947 steps=None, 1948 callbacks=None, 1949 **kwargs): 1950 """Returns the loss value & metrics values for the model in test mode.""" 1951 raise NotImplementedError() 1952 1953 def predict(self, 1954 model, 1955 x, 1956 batch_size=None, 1957 verbose=0, 1958 steps=None, 1959 callbacks=None, 1960 **kwargs): 1961 raise NotImplementedError() 1962