xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/partial_batch_padding_handler.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility object to handler partial batches for TPUStrategy."""
16# pylint: disable=protected-access
17
18import numpy as np
19
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.keras import backend
22from tensorflow.python.ops import array_ops
23from tensorflow.python.util import nest
24
25
26class PartialBatchPaddingHandler(object):
27  """A container that holds info about partial batches for `predict()`."""
28
29  def __init__(self, output_shape):
30    self.padded_batch_size = 0
31    self.padding_mask = array_ops.zeros(0)
32    self.output_shape = output_shape
33
34  def get_real_batch_size(self, dataset_batch):
35    """Returns the number of elements in a potentially partial batch."""
36    if isinstance(dataset_batch, (tuple, list)):
37      dataset_batch = dataset_batch[0]
38
39    assert nest.flatten(dataset_batch)
40
41    def _find_any_tensor(batch_features):
42      tensors = [
43          x for x in nest.flatten(batch_features) if tensor_util.is_tf_type(x)
44      ]
45      if not tensors:
46        raise ValueError('Cannot find any Tensor in features dict.')
47      return tensors[0]
48
49    return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0],
50                        dtype='int64')
51
52  def update_mask(self, padding_mask, dataset_batch):
53    """Calculate and cache the amount of padding required for a batch."""
54    original_batch_size = self.get_real_batch_size(dataset_batch)
55    missing_count = self.padded_batch_size - original_batch_size
56    mask = backend.concatenate([array_ops.ones(original_batch_size),
57                                array_ops.zeros(missing_count)], axis=0)
58    return backend.concatenate([padding_mask, mask], axis=0)
59
60  def pad_batch(self, *dataset_batch_elements):
61    """Pads out the batch dimension of a tensor to the complete batch size."""
62    def _pad(batch):
63      """Helper function to pad nested data within each batch elements."""
64      padded_dict_batch = {}
65      if isinstance(batch, dict):
66        for key, value in batch.items():
67          padded_dict_batch[key] = _pad(value)
68        return padded_dict_batch
69
70      rank = len(batch.shape)
71      assert rank > 0
72      missing_count = (self.padded_batch_size -
73                       self.get_real_batch_size(batch))
74      padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
75      return array_ops.pad(batch, padding, 'constant')
76
77    if len(dataset_batch_elements) == 1:
78      return _pad(dataset_batch_elements[0])
79
80    batch_elements = []
81    for batch_element in dataset_batch_elements:
82      batch_elements.append(_pad(batch_element))
83    return tuple(batch_elements)
84
85  def apply_mask(self, prediction_result):
86    """Removes prediction output that corresponds to padded input."""
87    padding_mask = backend.get_value(self.padding_mask)
88    assert len(padding_mask.shape) == 1
89
90    if len(self.output_shape) == 1:
91      prediction = np.take(prediction_result,
92                           np.nonzero(
93                               padding_mask[:len(prediction_result)]),
94                           axis=0)
95      if prediction.shape[0] == 1:
96        prediction = np.squeeze(prediction, axis=0)
97      return prediction
98
99    else:
100      predictions = []
101      for i in range(len(self.output_shape)):
102        prediction = prediction_result[i]
103        prediction = np.take(prediction, np.nonzero(
104            padding_mask[:len(prediction)]), axis=0)
105        predictions.append(np.squeeze(prediction))
106
107      return predictions
108