xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/compile_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1#  Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilites for `Model.compile`."""
16
17import copy
18
19from tensorflow.python.distribute import distribution_strategy_context as ds_context
20from tensorflow.python.keras import losses as losses_mod
21from tensorflow.python.keras import metrics as metrics_mod
22from tensorflow.python.keras.utils import generic_utils
23from tensorflow.python.keras.utils import losses_utils
24from tensorflow.python.keras.utils import tf_utils
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.util import nest
28
29
30class Container(object):
31  """Base Container class."""
32
33  def __init__(self, output_names=None):
34    self._output_names = output_names
35
36  def build(self, y_pred):
37    if self._output_names is None:
38      # In Subclass API, output names like 'output_1' are used for
39      # `Metric` names.
40      self._output_names = create_pseudo_output_names(y_pred)
41
42  def _conform_to_outputs(self, outputs, struct):
43    """Convenience method to conform `struct` to `outputs` structure.
44
45    Mappings performed:
46
47    (1) Map a dict to a list of outputs, using the output names.
48    (2) Fill missing keys in a dict w/ `None`s.
49    (3) Map a single item to all outputs.
50
51    Args:
52      outputs: Model predictions.
53      struct: Arbitrary nested structure (e.g. of labels, sample_weights,
54        losses, or metrics).
55
56    Returns:
57      Mapping of `struct` to `outputs` structure.
58    """
59    struct = map_to_output_names(outputs, self._output_names, struct)
60    struct = map_missing_dict_keys(outputs, struct)
61    # Allow passing one object that applies to all outputs.
62    if not nest.is_nested(struct) and nest.is_nested(outputs):
63      struct = nest.map_structure(lambda _: struct, outputs)
64    return struct
65
66  def _maybe_broadcast_to_outputs(self, outputs, objects):
67    """Determines if losses / metrics should be applied to all outputs.
68
69    NOTE: This method should only be called for Metrics / Losses, not for
70    y_true / sample_weight.
71
72    Args:
73      outputs: Model predictions.
74      objects: Arbitrary nested structure (e.g. of losses or metrics)
75
76    Returns:
77      Arbitrary nested structure of objects, maybe copied to each output.
78
79    Applies a Loss / Metric to all outputs.
80    """
81    if not self._should_broadcast(objects):
82      return objects
83
84    # When there is more than one Model output, this is needed to keep
85    # each Metric / Loss separate. When there is only one Model output,
86    # the user-supplied object should be used.
87    should_copy_objects = len(nest.flatten(outputs)) > 1
88
89    def _broadcast_fn():
90      if should_copy_objects:
91        return nest.map_structure(self._copy_object, objects)
92      return objects
93
94    return nest.map_structure(lambda _: _broadcast_fn(), outputs)
95
96  def _should_broadcast(self, objects):
97    raise NotImplementedError
98
99  def _copy_object(self, obj):
100    raise NotImplementedError
101
102
103class LossesContainer(Container):
104  """A container class for losses passed to `Model.compile`."""
105
106  def __init__(self, losses, loss_weights=None, output_names=None):
107    super(LossesContainer, self).__init__(output_names=output_names)
108
109    # Keep user-supplied values untouched for recompiling and serialization.
110    self._user_losses = losses
111    self._user_loss_weights = loss_weights
112
113    self._losses = losses
114    self._loss_weights = loss_weights
115    self._per_output_metrics = None  # Per-output losses become metrics.
116    self._loss_metric = metrics_mod.Mean(name='loss')  # Total loss.
117    self._built = False
118
119  @property
120  def metrics(self):
121    """Per-output loss metrics."""
122    if not self._built:
123      return []
124    per_output_metrics = [
125        metric_obj for metric_obj in nest.flatten(self._per_output_metrics)
126        if metric_obj is not None
127    ]
128    return [self._loss_metric] + per_output_metrics
129
130  def build(self, y_pred):
131    """One-time setup of loss objects."""
132    super(LossesContainer, self).build(y_pred)
133
134    self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
135    self._losses = self._conform_to_outputs(y_pred, self._losses)
136    self._losses = nest.map_structure(self._get_loss_object, self._losses)
137    self._losses = nest.flatten(self._losses)
138
139    self._loss_weights = self._maybe_broadcast_to_outputs(
140        y_pred, self._loss_weights)
141    self._loss_weights = self._conform_to_outputs(y_pred, self._loss_weights)
142    self._loss_weights = nest.flatten(self._loss_weights)
143
144    self._create_metrics()
145    self._built = True
146
147  @property
148  def built(self):
149    return self._built
150
151  def _create_metrics(self):
152    """Creates per-output loss metrics, but only for multi-output Models."""
153    if len(self._output_names) == 1:
154      self._per_output_metrics = [None]
155    else:
156      self._per_output_metrics = []
157      for loss_obj, output_name in zip(self._losses, self._output_names):
158        if loss_obj is None:
159          self._per_output_metrics.append(None)
160        else:
161          self._per_output_metrics.append(
162              metrics_mod.Mean(output_name + '_loss'))
163
164  def __call__(self,
165               y_true,
166               y_pred,
167               sample_weight=None,
168               regularization_losses=None):
169    """Computes the overall loss.
170
171    Args:
172      y_true: An arbitrary structure of Tensors representing the ground truth.
173      y_pred: An arbitrary structure of Tensors representing a Model's outputs.
174      sample_weight: An arbitrary structure of Tensors representing the
175        per-sample loss weights. If one Tensor is passed, it is used for all
176        losses. If multiple Tensors are passed, the structure should match
177        `y_pred`.
178      regularization_losses: Additional losses to be added to the total loss.
179
180    Returns:
181      Tuple of `(total_loss, per_output_loss_list)`
182    """
183    y_true = self._conform_to_outputs(y_pred, y_true)
184    sample_weight = self._conform_to_outputs(y_pred, sample_weight)
185
186    if not self._built:
187      self.build(y_pred)
188
189    y_pred = nest.flatten(y_pred)
190    y_true = nest.flatten(y_true)
191    sample_weight = nest.flatten(sample_weight)
192
193    loss_values = []  # Used for gradient calculation.
194    loss_metric_values = []  # Used for loss metric calculation.
195    batch_dim = None
196    zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
197                self._per_output_metrics)
198    for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
199      if y_t is None or loss_obj is None:  # Ok to have no loss for an output.
200        continue
201
202      y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
203      sw = apply_mask(y_p, sw, get_mask(y_p))
204      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
205
206      loss_metric_value = loss_value
207      # Correct for the `Mean` loss metrics counting each replica as a batch.
208      if loss_obj.reduction == losses_utils.ReductionV2.SUM:
209        loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
210
211      if batch_dim is None:
212        if tf_utils.is_ragged(y_t):
213          batch_dim = y_t.nrows()
214        else:
215          batch_dim = array_ops.shape(y_t)[0]
216
217      if metric_obj is not None:
218        metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)
219
220      if loss_weight is not None:
221        loss_value *= loss_weight
222        loss_metric_value *= loss_weight
223
224      if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
225          loss_obj.reduction == losses_utils.ReductionV2.AUTO):
226        loss_value = losses_utils.scale_loss_for_distribution(loss_value)
227
228      loss_values.append(loss_value)
229      loss_metric_values.append(loss_metric_value)
230
231    if regularization_losses:
232      regularization_losses = losses_utils.cast_losses_to_common_dtype(
233          regularization_losses)
234      reg_loss = math_ops.add_n(regularization_losses)
235      loss_metric_values.append(reg_loss)
236      loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
237
238    if loss_values:
239      loss_metric_values = losses_utils.cast_losses_to_common_dtype(
240          loss_metric_values)
241      total_loss_metric_value = math_ops.add_n(loss_metric_values)
242      self._loss_metric.update_state(
243          total_loss_metric_value, sample_weight=batch_dim)
244
245      loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
246      total_loss = math_ops.add_n(loss_values)
247      return total_loss
248    else:
249      # Ok for a model to have no compiled loss.
250      return array_ops.zeros(shape=())
251
252  def reset_state(self):
253    """Resets the state of loss metrics."""
254    if not self._built:
255      return
256    metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
257    for metric_obj in metrics:
258      if metric_obj is not None:
259        metric_obj.reset_state()
260
261  def _get_loss_object(self, loss):
262    """Returns a `Loss` object.
263
264    Converts the user-supplied loss to a `Loss` object. Also allows
265    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
266
267    Args:
268      loss: A string, function, or `Loss` object.
269
270    Returns:
271      A `Loss` object.
272    """
273    if loss is None:
274      return None  # Ok to have no loss for an output.
275
276    loss = losses_mod.get(loss)
277    if not isinstance(loss, losses_mod.Loss):
278      loss_name = get_custom_object_name(loss)
279      if loss_name is None:
280        raise ValueError('Loss should be a callable, found: {}'.format(loss))
281      loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
282    loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
283    return loss
284
285  def _should_broadcast(self, obj):
286    return not nest.is_nested(obj)
287
288  def _copy_object(self, obj):
289    return obj  # Losses don't need to be copied.
290
291
292class MetricsContainer(Container):
293  """A container class for metrics passed to `Model.compile`."""
294
295  def __init__(self, metrics=None, weighted_metrics=None, output_names=None,
296               from_serialized=False):
297    """Initializes a container for metrics.
298
299    Arguments:
300      metrics: see the `metrics` argument from `tf.keras.Model.compile`.
301      weighted_metrics: see the `weighted_metrics` argument from
302        `tf.keras.Model.compile`.
303      output_names: A list of strings of names of outputs for the model.
304      from_serialized: Whether the model being compiled is from a serialized
305        model.  Used to avoid redundantly applying pre-processing renaming
306        steps.
307    """
308    super(MetricsContainer, self).__init__(output_names=output_names)
309
310    # Keep user-supplied values untouched for recompiling and serialization.
311    self._user_metrics = metrics
312    self._user_weighted_metrics = weighted_metrics
313
314    self._metrics = metrics
315    self._weighted_metrics = weighted_metrics
316    self._built = False
317
318    self._from_serialized = from_serialized
319
320  @property
321  def metrics(self):
322    """All metrics in this container."""
323    if not self._built:
324      return []
325    return self._metrics_in_order
326
327  @property
328  def unweighted_metrics(self):
329    """Metrics in this container that should not be passed `sample_weight`."""
330    if not self._built:
331      return None
332    return nest.flatten(self._metrics)
333
334  @property
335  def weighted_metrics(self):
336    """Metrics in this container that should be passed `sample_weight`."""
337    if not self._built:
338      return None
339    return nest.flatten(self._weighted_metrics)
340
341  def build(self, y_pred, y_true):
342    """One-time setup of metric objects."""
343    super(MetricsContainer, self).build(y_pred)
344
345    self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
346    self._metrics = self._conform_to_outputs(y_pred, self._metrics)
347
348    self._weighted_metrics = self._maybe_broadcast_to_outputs(
349        y_pred, self._weighted_metrics)
350    self._weighted_metrics = self._conform_to_outputs(y_pred,
351                                                      self._weighted_metrics)
352
353    # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
354    y_pred = nest.list_to_tuple(y_pred)
355    y_true = nest.list_to_tuple(y_true)
356    self._metrics = nest.list_to_tuple(self._metrics)
357    self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)
358
359    # Convert to `Metric` objects, potentially disambiguating based on output
360    # properties.
361    self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects,
362                                             self._metrics, y_true, y_pred)
363    self._weighted_metrics = nest.map_structure_up_to(y_pred,
364                                                      self._get_metric_objects,
365                                                      self._weighted_metrics,
366                                                      y_true, y_pred)
367
368    self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False)
369    self._weighted_metrics = nest.flatten_up_to(
370        y_pred, self._weighted_metrics, check_types=False)
371
372    # Assumes metrics, weighted_metrics have been flattened up to outputs.
373    #
374    # If we are loading a model that has been already serialized, we do not
375    # want to re-apply any pre-processing metric renaming steps.
376    if not self._from_serialized:
377      self._set_metric_names()
378    self._create_ordered_metrics()
379    self._built = True
380
381  @property
382  def built(self):
383    return self._built
384
385  def _set_metric_names(self):
386    """Sets unique metric names."""
387    # For multi-output models, prepend the output name to the metric name.
388    # For weighted metrics, prepend "weighted_" if the name would be non-unique.
389    # pylint: disable=protected-access
390    metric_names = set()
391    is_multi_output = len(self._output_names) > 1
392    zip_args = (self._output_names, self._metrics, self._weighted_metrics)
393    for output_name, output_metrics, weighted_output_metrics in zip(*zip_args):
394      for m in output_metrics:
395        if m is None:
396          continue
397        if is_multi_output:
398          m._name = output_name + '_' + m._name
399        if m._name in metric_names:
400          raise ValueError('Found two metrics with the same name: {}'.format(
401              m._name))
402        metric_names.add(m._name)
403
404      for wm in weighted_output_metrics:
405        if wm is None:
406          continue
407        if is_multi_output:
408          if output_name + '_' + wm._name in metric_names:
409            wm._name = output_name + '_weighted_' + wm._name
410          else:
411            wm._name = output_name + '_' + wm._name
412        elif wm._name in metric_names:
413          wm._name = 'weighted_' + wm._name
414
415        if wm._name in metric_names:
416          raise ValueError('Found two metrics with the same name: {}'.format(
417              wm._name))
418        metric_names.add(wm._name)
419    # pylint: enable=protected-access
420
421  def _create_ordered_metrics(self):
422    """Cache the flat order needed when returning metrics, for backwards compat."""
423    self._metrics_in_order = []
424    for output_metrics, output_weighted_metrics in zip(self._metrics,
425                                                       self._weighted_metrics):
426      for m in nest.flatten(output_metrics):
427        if m is not None:
428          self._metrics_in_order.append(m)
429      for wm in nest.flatten(output_weighted_metrics):
430        if wm is not None:
431          self._metrics_in_order.append(wm)
432
433  def update_state(self, y_true, y_pred, sample_weight=None):
434    """Updates the state of per-output metrics."""
435    y_true = self._conform_to_outputs(y_pred, y_true)
436    sample_weight = self._conform_to_outputs(y_pred, sample_weight)
437
438    if not self._built:
439      self.build(y_pred, y_true)
440
441    y_pred = nest.flatten(y_pred)
442    y_true = nest.flatten(y_true) if y_true is not None else []
443    sample_weight = nest.flatten(sample_weight)
444
445    zip_args = (y_true, y_pred, sample_weight, self._metrics,
446                self._weighted_metrics)
447    for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
448      # Ok to have no metrics for an output.
449      if (y_t is None or (all(m is None for m in metric_objs) and
450                          all(wm is None for wm in weighted_metric_objs))):
451        continue
452
453      y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
454      mask = get_mask(y_p)
455      sw = apply_mask(y_p, sw, mask)
456
457      for metric_obj in metric_objs:
458        if metric_obj is None:
459          continue
460        metric_obj.update_state(y_t, y_p, sample_weight=mask)
461
462      for weighted_metric_obj in weighted_metric_objs:
463        if weighted_metric_obj is None:
464          continue
465        weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
466
467  def reset_state(self):
468    """Resets the state of all `Metric`s in this container."""
469    if self._built:
470      metrics = self._metrics_in_order
471    else:
472      # If the user supplied `Metric` objects directly, we should
473      # reset those. This could also contain `str`s or `function`s
474      # though.
475      metrics = nest.flatten(self._user_metrics) + nest.flatten(
476          self._user_weighted_metrics)
477
478    for metric_obj in metrics:
479      if isinstance(metric_obj, metrics_mod.Metric):
480        metric_obj.reset_state()
481
482  def _get_metric_objects(self, metrics, y_t, y_p):
483    """Convert user-supplied metrics to `Metric` objects."""
484    metrics = nest.flatten(metrics)
485    return [self._get_metric_object(m, y_t, y_p) for m in metrics]
486
487  def _get_metric_object(self, metric, y_t, y_p):
488    """Converts user-supplied metric to a `Metric` object.
489
490    Args:
491      metric: A string, function, or `Metric` object.
492      y_t: Sample of label.
493      y_p: Sample of output.
494
495    Returns:
496      A `Metric` object.
497    """
498    if metric is None:
499      return None  # Ok to have no metric for an output.
500
501    # Convenience feature for selecting b/t binary, categorical,
502    # and sparse categorical.
503    if str(metric).lower() not in ['accuracy', 'acc', 'crossentropy', 'ce']:
504      metric_obj = metrics_mod.get(metric)
505    else:
506      y_t_rank = len(y_t.shape.as_list())
507      y_p_rank = len(y_p.shape.as_list())
508      y_t_last_dim = y_t.shape.as_list()[-1]
509      y_p_last_dim = y_p.shape.as_list()[-1]
510
511      is_binary = y_p_last_dim == 1
512      is_sparse_categorical = (
513          y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1)
514
515      if str(metric).lower() in ['accuracy', 'acc']:
516        if is_binary:
517          metric_obj = metrics_mod.binary_accuracy
518        elif is_sparse_categorical:
519          metric_obj = metrics_mod.sparse_categorical_accuracy
520        else:
521          metric_obj = metrics_mod.categorical_accuracy
522      else:
523        if is_binary:
524          metric_obj = metrics_mod.binary_crossentropy
525        elif is_sparse_categorical:
526          metric_obj = metrics_mod.sparse_categorical_crossentropy
527        else:
528          metric_obj = metrics_mod.categorical_crossentropy
529
530    if isinstance(metric_obj, losses_mod.Loss):
531      metric_obj._allow_sum_over_batch_size = True  # pylint: disable=protected-access
532
533    if not isinstance(metric_obj, metrics_mod.Metric):
534      if isinstance(metric, str):
535        metric_name = metric
536      else:
537        metric_name = get_custom_object_name(metric)
538        if metric_name is None:
539          raise ValueError(
540              'Metric should be a callable, found: {}'.format(metric))
541
542      metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
543
544    return metric_obj
545
546  def _should_broadcast(self, obj):
547    # e.g. 'mse'.
548    if not nest.is_nested(obj):
549      return True
550    # e.g. ['mse'] or ['mse', 'mae'].
551    return (isinstance(obj, (list, tuple)) and
552            not any(nest.is_nested(o) for o in obj))
553
554  def _copy_object(self, obj):
555    if isinstance(obj, metrics_mod.Metric):
556      return obj.__class__.from_config(obj.get_config())
557    return obj  # Can be a function or `None`.
558
559
560def create_pseudo_output_names(outputs):
561  """Create pseudo output names for a subclassed Model."""
562  return _create_pseudo_names(outputs, prefix='output_')
563
564
565def create_pseudo_input_names(inputs):
566  """Create pseudo input names for a subclassed Model."""
567  return _create_pseudo_names(inputs, prefix='input_')
568
569
570def _create_pseudo_names(tensors, prefix):
571  """Creates pseudo {input | output} names for subclassed Models.
572
573  Warning: this function should only be used to define default
574  names for `Metics` and `SavedModel`. No other use cases should
575  rely on a `Model`'s input or output names.
576
577  Example with dict:
578
579  `{'a': [x1, x2], 'b': x3}` becomes:
580  `['a_1', 'a_2', 'b']`
581
582  Example with list:
583
584  `[x, y]` becomes:
585  `['output_1', 'output_2']`
586
587  Args:
588    tensors: `Model`'s outputs or inputs.
589    prefix: 'output_' for outputs, 'input_' for inputs.
590
591  Returns:
592    Flattened list of pseudo names.
593  """
594
595  def one_index(ele):
596    # Start with "output_1" instead of "output_0".
597    if isinstance(ele, int):
598      return ele + 1
599    return ele
600
601  flat_paths = list(nest.yield_flat_paths(tensors))
602  flat_paths = nest.map_structure(one_index, flat_paths)
603  names = []
604  for path in flat_paths:
605    if not path:
606      name = prefix + '1'  # Single output.
607    else:
608      name = '_'.join(str(p) for p in path)
609      if isinstance(path[0], int):
610        name = prefix + name
611    names.append(name)
612  return names
613
614
615def map_to_output_names(y_pred, output_names, struct):
616  """Maps a dict to a list using `output_names` as keys.
617
618  This is a convenience feature only. When a `Model`'s outputs
619  are a list, you can specify per-output losses and metrics as
620  a dict, where the keys are the output names. If you specify
621  per-output losses and metrics via the same structure as the
622  `Model`'s outputs (recommended), no mapping is performed.
623
624  For the Functional API, the output names are the names of the
625  last layer of each output. For the Subclass API, the output names
626  are determined by `create_pseudo_output_names` (For example:
627  `['output_1', 'output_2']` for a list of outputs).
628
629  This mapping preserves backwards compatibility for `compile` and
630  `fit`.
631
632  Args:
633    y_pred: Sample outputs of the Model, to determine if this convenience
634      feature should be applied (`struct` is returned unmodified if `y_pred`
635      isn't a flat list).
636    output_names: List. The names of the outputs of the Model.
637    struct: The structure to map.
638
639  Returns:
640    `struct` mapped to a list in same order as `output_names`.
641  """
642  single_output = not nest.is_nested(y_pred)
643  outputs_are_flat_list = (not single_output and
644                           isinstance(y_pred, (list, tuple)) and
645                           not any(nest.is_nested(y_p) for y_p in y_pred))
646
647  if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
648    output_names = output_names or create_pseudo_output_names(y_pred)
649    struct = copy.copy(struct)
650    new_struct = [struct.pop(name, None) for name in output_names]
651    if struct:
652      raise ValueError('Found unexpected keys that do not correspond '
653                       'to any Model output: {}. Expected: {}'.format(
654                           struct.keys(), output_names))
655    if len(new_struct) == 1:
656      return new_struct[0]
657    return new_struct
658  else:
659    return struct
660
661
662def map_missing_dict_keys(y_pred, struct):
663  """Replaces missing dict keys in `struct` with `None` placeholders."""
664  if not isinstance(y_pred, dict) or not isinstance(struct, dict):
665    return struct
666  for k in y_pred.keys():
667    if k not in struct:
668      struct[k] = None
669  return struct
670
671
672def match_dtype_and_rank(y_t, y_p, sw):
673  """Match dtype and rank of predictions."""
674  if y_t.shape.rank == 1 and y_p.shape.rank == 2:
675    y_t = array_ops.expand_dims_v2(y_t, axis=-1)
676  if sw is not None:
677    if sw.shape.rank == 1 and y_p.shape.rank == 2:
678      sw = array_ops.expand_dims_v2(sw, axis=-1)
679
680  # Dtype.
681  # This is required mainly for custom loss functions which do not take care
682  # casting dtypes.
683  if ((y_t.dtype.is_floating and y_p.dtype.is_floating) or
684      (y_t.dtype.is_integer and y_p.dtype.is_integer)):
685    y_t = math_ops.cast(y_t, y_p.dtype)
686
687  if sw is not None:
688    sw = math_ops.cast(sw, y_p.dtype)
689  return y_t, y_p, sw
690
691
692def get_mask(y_p):
693  """Returns Keras mask from tensor."""
694  return getattr(y_p, '_keras_mask', None)
695
696
697def apply_mask(y_p, sw, mask):
698  """Applies any mask on predictions to sample weights."""
699  if mask is not None:
700    mask = math_ops.cast(mask, y_p.dtype)
701    if sw is not None:
702      mask, _, sw = (
703          losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw))
704      sw *= mask
705    else:
706      sw = mask
707  return sw
708
709
710def get_custom_object_name(obj):
711  """Returns the name to use for a custom loss or metric callable.
712
713  Args:
714    obj: Custom loss of metric callable
715
716  Returns:
717    Name to use, or `None` if the object was not recognized.
718  """
719  if hasattr(obj, 'name'):  # Accept `Loss` instance as `Metric`.
720    return obj.name
721  elif hasattr(obj, '__name__'):  # Function.
722    return obj.__name__
723  elif hasattr(obj, '__class__'):  # Class instance.
724    return generic_utils.to_snake_case(obj.__class__.__name__)
725  else:  # Unrecognized object.
726    return None
727