1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilites for `Model.compile`.""" 16 17import copy 18 19from tensorflow.python.distribute import distribution_strategy_context as ds_context 20from tensorflow.python.keras import losses as losses_mod 21from tensorflow.python.keras import metrics as metrics_mod 22from tensorflow.python.keras.utils import generic_utils 23from tensorflow.python.keras.utils import losses_utils 24from tensorflow.python.keras.utils import tf_utils 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.util import nest 28 29 30class Container(object): 31 """Base Container class.""" 32 33 def __init__(self, output_names=None): 34 self._output_names = output_names 35 36 def build(self, y_pred): 37 if self._output_names is None: 38 # In Subclass API, output names like 'output_1' are used for 39 # `Metric` names. 40 self._output_names = create_pseudo_output_names(y_pred) 41 42 def _conform_to_outputs(self, outputs, struct): 43 """Convenience method to conform `struct` to `outputs` structure. 44 45 Mappings performed: 46 47 (1) Map a dict to a list of outputs, using the output names. 48 (2) Fill missing keys in a dict w/ `None`s. 49 (3) Map a single item to all outputs. 50 51 Args: 52 outputs: Model predictions. 53 struct: Arbitrary nested structure (e.g. of labels, sample_weights, 54 losses, or metrics). 55 56 Returns: 57 Mapping of `struct` to `outputs` structure. 58 """ 59 struct = map_to_output_names(outputs, self._output_names, struct) 60 struct = map_missing_dict_keys(outputs, struct) 61 # Allow passing one object that applies to all outputs. 62 if not nest.is_nested(struct) and nest.is_nested(outputs): 63 struct = nest.map_structure(lambda _: struct, outputs) 64 return struct 65 66 def _maybe_broadcast_to_outputs(self, outputs, objects): 67 """Determines if losses / metrics should be applied to all outputs. 68 69 NOTE: This method should only be called for Metrics / Losses, not for 70 y_true / sample_weight. 71 72 Args: 73 outputs: Model predictions. 74 objects: Arbitrary nested structure (e.g. of losses or metrics) 75 76 Returns: 77 Arbitrary nested structure of objects, maybe copied to each output. 78 79 Applies a Loss / Metric to all outputs. 80 """ 81 if not self._should_broadcast(objects): 82 return objects 83 84 # When there is more than one Model output, this is needed to keep 85 # each Metric / Loss separate. When there is only one Model output, 86 # the user-supplied object should be used. 87 should_copy_objects = len(nest.flatten(outputs)) > 1 88 89 def _broadcast_fn(): 90 if should_copy_objects: 91 return nest.map_structure(self._copy_object, objects) 92 return objects 93 94 return nest.map_structure(lambda _: _broadcast_fn(), outputs) 95 96 def _should_broadcast(self, objects): 97 raise NotImplementedError 98 99 def _copy_object(self, obj): 100 raise NotImplementedError 101 102 103class LossesContainer(Container): 104 """A container class for losses passed to `Model.compile`.""" 105 106 def __init__(self, losses, loss_weights=None, output_names=None): 107 super(LossesContainer, self).__init__(output_names=output_names) 108 109 # Keep user-supplied values untouched for recompiling and serialization. 110 self._user_losses = losses 111 self._user_loss_weights = loss_weights 112 113 self._losses = losses 114 self._loss_weights = loss_weights 115 self._per_output_metrics = None # Per-output losses become metrics. 116 self._loss_metric = metrics_mod.Mean(name='loss') # Total loss. 117 self._built = False 118 119 @property 120 def metrics(self): 121 """Per-output loss metrics.""" 122 if not self._built: 123 return [] 124 per_output_metrics = [ 125 metric_obj for metric_obj in nest.flatten(self._per_output_metrics) 126 if metric_obj is not None 127 ] 128 return [self._loss_metric] + per_output_metrics 129 130 def build(self, y_pred): 131 """One-time setup of loss objects.""" 132 super(LossesContainer, self).build(y_pred) 133 134 self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses) 135 self._losses = self._conform_to_outputs(y_pred, self._losses) 136 self._losses = nest.map_structure(self._get_loss_object, self._losses) 137 self._losses = nest.flatten(self._losses) 138 139 self._loss_weights = self._maybe_broadcast_to_outputs( 140 y_pred, self._loss_weights) 141 self._loss_weights = self._conform_to_outputs(y_pred, self._loss_weights) 142 self._loss_weights = nest.flatten(self._loss_weights) 143 144 self._create_metrics() 145 self._built = True 146 147 @property 148 def built(self): 149 return self._built 150 151 def _create_metrics(self): 152 """Creates per-output loss metrics, but only for multi-output Models.""" 153 if len(self._output_names) == 1: 154 self._per_output_metrics = [None] 155 else: 156 self._per_output_metrics = [] 157 for loss_obj, output_name in zip(self._losses, self._output_names): 158 if loss_obj is None: 159 self._per_output_metrics.append(None) 160 else: 161 self._per_output_metrics.append( 162 metrics_mod.Mean(output_name + '_loss')) 163 164 def __call__(self, 165 y_true, 166 y_pred, 167 sample_weight=None, 168 regularization_losses=None): 169 """Computes the overall loss. 170 171 Args: 172 y_true: An arbitrary structure of Tensors representing the ground truth. 173 y_pred: An arbitrary structure of Tensors representing a Model's outputs. 174 sample_weight: An arbitrary structure of Tensors representing the 175 per-sample loss weights. If one Tensor is passed, it is used for all 176 losses. If multiple Tensors are passed, the structure should match 177 `y_pred`. 178 regularization_losses: Additional losses to be added to the total loss. 179 180 Returns: 181 Tuple of `(total_loss, per_output_loss_list)` 182 """ 183 y_true = self._conform_to_outputs(y_pred, y_true) 184 sample_weight = self._conform_to_outputs(y_pred, sample_weight) 185 186 if not self._built: 187 self.build(y_pred) 188 189 y_pred = nest.flatten(y_pred) 190 y_true = nest.flatten(y_true) 191 sample_weight = nest.flatten(sample_weight) 192 193 loss_values = [] # Used for gradient calculation. 194 loss_metric_values = [] # Used for loss metric calculation. 195 batch_dim = None 196 zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights, 197 self._per_output_metrics) 198 for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args): 199 if y_t is None or loss_obj is None: # Ok to have no loss for an output. 200 continue 201 202 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) 203 sw = apply_mask(y_p, sw, get_mask(y_p)) 204 loss_value = loss_obj(y_t, y_p, sample_weight=sw) 205 206 loss_metric_value = loss_value 207 # Correct for the `Mean` loss metrics counting each replica as a batch. 208 if loss_obj.reduction == losses_utils.ReductionV2.SUM: 209 loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync 210 211 if batch_dim is None: 212 if tf_utils.is_ragged(y_t): 213 batch_dim = y_t.nrows() 214 else: 215 batch_dim = array_ops.shape(y_t)[0] 216 217 if metric_obj is not None: 218 metric_obj.update_state(loss_metric_value, sample_weight=batch_dim) 219 220 if loss_weight is not None: 221 loss_value *= loss_weight 222 loss_metric_value *= loss_weight 223 224 if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or 225 loss_obj.reduction == losses_utils.ReductionV2.AUTO): 226 loss_value = losses_utils.scale_loss_for_distribution(loss_value) 227 228 loss_values.append(loss_value) 229 loss_metric_values.append(loss_metric_value) 230 231 if regularization_losses: 232 regularization_losses = losses_utils.cast_losses_to_common_dtype( 233 regularization_losses) 234 reg_loss = math_ops.add_n(regularization_losses) 235 loss_metric_values.append(reg_loss) 236 loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss)) 237 238 if loss_values: 239 loss_metric_values = losses_utils.cast_losses_to_common_dtype( 240 loss_metric_values) 241 total_loss_metric_value = math_ops.add_n(loss_metric_values) 242 self._loss_metric.update_state( 243 total_loss_metric_value, sample_weight=batch_dim) 244 245 loss_values = losses_utils.cast_losses_to_common_dtype(loss_values) 246 total_loss = math_ops.add_n(loss_values) 247 return total_loss 248 else: 249 # Ok for a model to have no compiled loss. 250 return array_ops.zeros(shape=()) 251 252 def reset_state(self): 253 """Resets the state of loss metrics.""" 254 if not self._built: 255 return 256 metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics) 257 for metric_obj in metrics: 258 if metric_obj is not None: 259 metric_obj.reset_state() 260 261 def _get_loss_object(self, loss): 262 """Returns a `Loss` object. 263 264 Converts the user-supplied loss to a `Loss` object. Also allows 265 `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. 266 267 Args: 268 loss: A string, function, or `Loss` object. 269 270 Returns: 271 A `Loss` object. 272 """ 273 if loss is None: 274 return None # Ok to have no loss for an output. 275 276 loss = losses_mod.get(loss) 277 if not isinstance(loss, losses_mod.Loss): 278 loss_name = get_custom_object_name(loss) 279 if loss_name is None: 280 raise ValueError('Loss should be a callable, found: {}'.format(loss)) 281 loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) 282 loss._allow_sum_over_batch_size = True # pylint: disable=protected-access 283 return loss 284 285 def _should_broadcast(self, obj): 286 return not nest.is_nested(obj) 287 288 def _copy_object(self, obj): 289 return obj # Losses don't need to be copied. 290 291 292class MetricsContainer(Container): 293 """A container class for metrics passed to `Model.compile`.""" 294 295 def __init__(self, metrics=None, weighted_metrics=None, output_names=None, 296 from_serialized=False): 297 """Initializes a container for metrics. 298 299 Arguments: 300 metrics: see the `metrics` argument from `tf.keras.Model.compile`. 301 weighted_metrics: see the `weighted_metrics` argument from 302 `tf.keras.Model.compile`. 303 output_names: A list of strings of names of outputs for the model. 304 from_serialized: Whether the model being compiled is from a serialized 305 model. Used to avoid redundantly applying pre-processing renaming 306 steps. 307 """ 308 super(MetricsContainer, self).__init__(output_names=output_names) 309 310 # Keep user-supplied values untouched for recompiling and serialization. 311 self._user_metrics = metrics 312 self._user_weighted_metrics = weighted_metrics 313 314 self._metrics = metrics 315 self._weighted_metrics = weighted_metrics 316 self._built = False 317 318 self._from_serialized = from_serialized 319 320 @property 321 def metrics(self): 322 """All metrics in this container.""" 323 if not self._built: 324 return [] 325 return self._metrics_in_order 326 327 @property 328 def unweighted_metrics(self): 329 """Metrics in this container that should not be passed `sample_weight`.""" 330 if not self._built: 331 return None 332 return nest.flatten(self._metrics) 333 334 @property 335 def weighted_metrics(self): 336 """Metrics in this container that should be passed `sample_weight`.""" 337 if not self._built: 338 return None 339 return nest.flatten(self._weighted_metrics) 340 341 def build(self, y_pred, y_true): 342 """One-time setup of metric objects.""" 343 super(MetricsContainer, self).build(y_pred) 344 345 self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) 346 self._metrics = self._conform_to_outputs(y_pred, self._metrics) 347 348 self._weighted_metrics = self._maybe_broadcast_to_outputs( 349 y_pred, self._weighted_metrics) 350 self._weighted_metrics = self._conform_to_outputs(y_pred, 351 self._weighted_metrics) 352 353 # Standardize on tuple since `tf.data` turns lists into `Tensor`s. 354 y_pred = nest.list_to_tuple(y_pred) 355 y_true = nest.list_to_tuple(y_true) 356 self._metrics = nest.list_to_tuple(self._metrics) 357 self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics) 358 359 # Convert to `Metric` objects, potentially disambiguating based on output 360 # properties. 361 self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, 362 self._metrics, y_true, y_pred) 363 self._weighted_metrics = nest.map_structure_up_to(y_pred, 364 self._get_metric_objects, 365 self._weighted_metrics, 366 y_true, y_pred) 367 368 self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) 369 self._weighted_metrics = nest.flatten_up_to( 370 y_pred, self._weighted_metrics, check_types=False) 371 372 # Assumes metrics, weighted_metrics have been flattened up to outputs. 373 # 374 # If we are loading a model that has been already serialized, we do not 375 # want to re-apply any pre-processing metric renaming steps. 376 if not self._from_serialized: 377 self._set_metric_names() 378 self._create_ordered_metrics() 379 self._built = True 380 381 @property 382 def built(self): 383 return self._built 384 385 def _set_metric_names(self): 386 """Sets unique metric names.""" 387 # For multi-output models, prepend the output name to the metric name. 388 # For weighted metrics, prepend "weighted_" if the name would be non-unique. 389 # pylint: disable=protected-access 390 metric_names = set() 391 is_multi_output = len(self._output_names) > 1 392 zip_args = (self._output_names, self._metrics, self._weighted_metrics) 393 for output_name, output_metrics, weighted_output_metrics in zip(*zip_args): 394 for m in output_metrics: 395 if m is None: 396 continue 397 if is_multi_output: 398 m._name = output_name + '_' + m._name 399 if m._name in metric_names: 400 raise ValueError('Found two metrics with the same name: {}'.format( 401 m._name)) 402 metric_names.add(m._name) 403 404 for wm in weighted_output_metrics: 405 if wm is None: 406 continue 407 if is_multi_output: 408 if output_name + '_' + wm._name in metric_names: 409 wm._name = output_name + '_weighted_' + wm._name 410 else: 411 wm._name = output_name + '_' + wm._name 412 elif wm._name in metric_names: 413 wm._name = 'weighted_' + wm._name 414 415 if wm._name in metric_names: 416 raise ValueError('Found two metrics with the same name: {}'.format( 417 wm._name)) 418 metric_names.add(wm._name) 419 # pylint: enable=protected-access 420 421 def _create_ordered_metrics(self): 422 """Cache the flat order needed when returning metrics, for backwards compat.""" 423 self._metrics_in_order = [] 424 for output_metrics, output_weighted_metrics in zip(self._metrics, 425 self._weighted_metrics): 426 for m in nest.flatten(output_metrics): 427 if m is not None: 428 self._metrics_in_order.append(m) 429 for wm in nest.flatten(output_weighted_metrics): 430 if wm is not None: 431 self._metrics_in_order.append(wm) 432 433 def update_state(self, y_true, y_pred, sample_weight=None): 434 """Updates the state of per-output metrics.""" 435 y_true = self._conform_to_outputs(y_pred, y_true) 436 sample_weight = self._conform_to_outputs(y_pred, sample_weight) 437 438 if not self._built: 439 self.build(y_pred, y_true) 440 441 y_pred = nest.flatten(y_pred) 442 y_true = nest.flatten(y_true) if y_true is not None else [] 443 sample_weight = nest.flatten(sample_weight) 444 445 zip_args = (y_true, y_pred, sample_weight, self._metrics, 446 self._weighted_metrics) 447 for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args): 448 # Ok to have no metrics for an output. 449 if (y_t is None or (all(m is None for m in metric_objs) and 450 all(wm is None for wm in weighted_metric_objs))): 451 continue 452 453 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) 454 mask = get_mask(y_p) 455 sw = apply_mask(y_p, sw, mask) 456 457 for metric_obj in metric_objs: 458 if metric_obj is None: 459 continue 460 metric_obj.update_state(y_t, y_p, sample_weight=mask) 461 462 for weighted_metric_obj in weighted_metric_objs: 463 if weighted_metric_obj is None: 464 continue 465 weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw) 466 467 def reset_state(self): 468 """Resets the state of all `Metric`s in this container.""" 469 if self._built: 470 metrics = self._metrics_in_order 471 else: 472 # If the user supplied `Metric` objects directly, we should 473 # reset those. This could also contain `str`s or `function`s 474 # though. 475 metrics = nest.flatten(self._user_metrics) + nest.flatten( 476 self._user_weighted_metrics) 477 478 for metric_obj in metrics: 479 if isinstance(metric_obj, metrics_mod.Metric): 480 metric_obj.reset_state() 481 482 def _get_metric_objects(self, metrics, y_t, y_p): 483 """Convert user-supplied metrics to `Metric` objects.""" 484 metrics = nest.flatten(metrics) 485 return [self._get_metric_object(m, y_t, y_p) for m in metrics] 486 487 def _get_metric_object(self, metric, y_t, y_p): 488 """Converts user-supplied metric to a `Metric` object. 489 490 Args: 491 metric: A string, function, or `Metric` object. 492 y_t: Sample of label. 493 y_p: Sample of output. 494 495 Returns: 496 A `Metric` object. 497 """ 498 if metric is None: 499 return None # Ok to have no metric for an output. 500 501 # Convenience feature for selecting b/t binary, categorical, 502 # and sparse categorical. 503 if str(metric).lower() not in ['accuracy', 'acc', 'crossentropy', 'ce']: 504 metric_obj = metrics_mod.get(metric) 505 else: 506 y_t_rank = len(y_t.shape.as_list()) 507 y_p_rank = len(y_p.shape.as_list()) 508 y_t_last_dim = y_t.shape.as_list()[-1] 509 y_p_last_dim = y_p.shape.as_list()[-1] 510 511 is_binary = y_p_last_dim == 1 512 is_sparse_categorical = ( 513 y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1) 514 515 if str(metric).lower() in ['accuracy', 'acc']: 516 if is_binary: 517 metric_obj = metrics_mod.binary_accuracy 518 elif is_sparse_categorical: 519 metric_obj = metrics_mod.sparse_categorical_accuracy 520 else: 521 metric_obj = metrics_mod.categorical_accuracy 522 else: 523 if is_binary: 524 metric_obj = metrics_mod.binary_crossentropy 525 elif is_sparse_categorical: 526 metric_obj = metrics_mod.sparse_categorical_crossentropy 527 else: 528 metric_obj = metrics_mod.categorical_crossentropy 529 530 if isinstance(metric_obj, losses_mod.Loss): 531 metric_obj._allow_sum_over_batch_size = True # pylint: disable=protected-access 532 533 if not isinstance(metric_obj, metrics_mod.Metric): 534 if isinstance(metric, str): 535 metric_name = metric 536 else: 537 metric_name = get_custom_object_name(metric) 538 if metric_name is None: 539 raise ValueError( 540 'Metric should be a callable, found: {}'.format(metric)) 541 542 metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name) 543 544 return metric_obj 545 546 def _should_broadcast(self, obj): 547 # e.g. 'mse'. 548 if not nest.is_nested(obj): 549 return True 550 # e.g. ['mse'] or ['mse', 'mae']. 551 return (isinstance(obj, (list, tuple)) and 552 not any(nest.is_nested(o) for o in obj)) 553 554 def _copy_object(self, obj): 555 if isinstance(obj, metrics_mod.Metric): 556 return obj.__class__.from_config(obj.get_config()) 557 return obj # Can be a function or `None`. 558 559 560def create_pseudo_output_names(outputs): 561 """Create pseudo output names for a subclassed Model.""" 562 return _create_pseudo_names(outputs, prefix='output_') 563 564 565def create_pseudo_input_names(inputs): 566 """Create pseudo input names for a subclassed Model.""" 567 return _create_pseudo_names(inputs, prefix='input_') 568 569 570def _create_pseudo_names(tensors, prefix): 571 """Creates pseudo {input | output} names for subclassed Models. 572 573 Warning: this function should only be used to define default 574 names for `Metics` and `SavedModel`. No other use cases should 575 rely on a `Model`'s input or output names. 576 577 Example with dict: 578 579 `{'a': [x1, x2], 'b': x3}` becomes: 580 `['a_1', 'a_2', 'b']` 581 582 Example with list: 583 584 `[x, y]` becomes: 585 `['output_1', 'output_2']` 586 587 Args: 588 tensors: `Model`'s outputs or inputs. 589 prefix: 'output_' for outputs, 'input_' for inputs. 590 591 Returns: 592 Flattened list of pseudo names. 593 """ 594 595 def one_index(ele): 596 # Start with "output_1" instead of "output_0". 597 if isinstance(ele, int): 598 return ele + 1 599 return ele 600 601 flat_paths = list(nest.yield_flat_paths(tensors)) 602 flat_paths = nest.map_structure(one_index, flat_paths) 603 names = [] 604 for path in flat_paths: 605 if not path: 606 name = prefix + '1' # Single output. 607 else: 608 name = '_'.join(str(p) for p in path) 609 if isinstance(path[0], int): 610 name = prefix + name 611 names.append(name) 612 return names 613 614 615def map_to_output_names(y_pred, output_names, struct): 616 """Maps a dict to a list using `output_names` as keys. 617 618 This is a convenience feature only. When a `Model`'s outputs 619 are a list, you can specify per-output losses and metrics as 620 a dict, where the keys are the output names. If you specify 621 per-output losses and metrics via the same structure as the 622 `Model`'s outputs (recommended), no mapping is performed. 623 624 For the Functional API, the output names are the names of the 625 last layer of each output. For the Subclass API, the output names 626 are determined by `create_pseudo_output_names` (For example: 627 `['output_1', 'output_2']` for a list of outputs). 628 629 This mapping preserves backwards compatibility for `compile` and 630 `fit`. 631 632 Args: 633 y_pred: Sample outputs of the Model, to determine if this convenience 634 feature should be applied (`struct` is returned unmodified if `y_pred` 635 isn't a flat list). 636 output_names: List. The names of the outputs of the Model. 637 struct: The structure to map. 638 639 Returns: 640 `struct` mapped to a list in same order as `output_names`. 641 """ 642 single_output = not nest.is_nested(y_pred) 643 outputs_are_flat_list = (not single_output and 644 isinstance(y_pred, (list, tuple)) and 645 not any(nest.is_nested(y_p) for y_p in y_pred)) 646 647 if (single_output or outputs_are_flat_list) and isinstance(struct, dict): 648 output_names = output_names or create_pseudo_output_names(y_pred) 649 struct = copy.copy(struct) 650 new_struct = [struct.pop(name, None) for name in output_names] 651 if struct: 652 raise ValueError('Found unexpected keys that do not correspond ' 653 'to any Model output: {}. Expected: {}'.format( 654 struct.keys(), output_names)) 655 if len(new_struct) == 1: 656 return new_struct[0] 657 return new_struct 658 else: 659 return struct 660 661 662def map_missing_dict_keys(y_pred, struct): 663 """Replaces missing dict keys in `struct` with `None` placeholders.""" 664 if not isinstance(y_pred, dict) or not isinstance(struct, dict): 665 return struct 666 for k in y_pred.keys(): 667 if k not in struct: 668 struct[k] = None 669 return struct 670 671 672def match_dtype_and_rank(y_t, y_p, sw): 673 """Match dtype and rank of predictions.""" 674 if y_t.shape.rank == 1 and y_p.shape.rank == 2: 675 y_t = array_ops.expand_dims_v2(y_t, axis=-1) 676 if sw is not None: 677 if sw.shape.rank == 1 and y_p.shape.rank == 2: 678 sw = array_ops.expand_dims_v2(sw, axis=-1) 679 680 # Dtype. 681 # This is required mainly for custom loss functions which do not take care 682 # casting dtypes. 683 if ((y_t.dtype.is_floating and y_p.dtype.is_floating) or 684 (y_t.dtype.is_integer and y_p.dtype.is_integer)): 685 y_t = math_ops.cast(y_t, y_p.dtype) 686 687 if sw is not None: 688 sw = math_ops.cast(sw, y_p.dtype) 689 return y_t, y_p, sw 690 691 692def get_mask(y_p): 693 """Returns Keras mask from tensor.""" 694 return getattr(y_p, '_keras_mask', None) 695 696 697def apply_mask(y_p, sw, mask): 698 """Applies any mask on predictions to sample weights.""" 699 if mask is not None: 700 mask = math_ops.cast(mask, y_p.dtype) 701 if sw is not None: 702 mask, _, sw = ( 703 losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw)) 704 sw *= mask 705 else: 706 sw = mask 707 return sw 708 709 710def get_custom_object_name(obj): 711 """Returns the name to use for a custom loss or metric callable. 712 713 Args: 714 obj: Custom loss of metric callable 715 716 Returns: 717 Name to use, or `None` if the object was not recognized. 718 """ 719 if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`. 720 return obj.name 721 elif hasattr(obj, '__name__'): # Function. 722 return obj.__name__ 723 elif hasattr(obj, '__class__'): # Class instance. 724 return generic_utils.to_snake_case(obj.__class__.__name__) 725 else: # Unrecognized object. 726 return None 727