1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility object to handler partial batches for TPUStrategy.""" 16# pylint: disable=protected-access 17 18import numpy as np 19 20from tensorflow.python.framework import tensor_util 21from tensorflow.python.keras import backend 22from tensorflow.python.ops import array_ops 23from tensorflow.python.util import nest 24 25 26class PartialBatchPaddingHandler(object): 27 """A container that holds info about partial batches for `predict()`.""" 28 29 def __init__(self, output_shape): 30 self.padded_batch_size = 0 31 self.padding_mask = array_ops.zeros(0) 32 self.output_shape = output_shape 33 34 def get_real_batch_size(self, dataset_batch): 35 """Returns the number of elements in a potentially partial batch.""" 36 if isinstance(dataset_batch, (tuple, list)): 37 dataset_batch = dataset_batch[0] 38 39 assert nest.flatten(dataset_batch) 40 41 def _find_any_tensor(batch_features): 42 tensors = [ 43 x for x in nest.flatten(batch_features) if tensor_util.is_tf_type(x) 44 ] 45 if not tensors: 46 raise ValueError('Cannot find any Tensor in features dict.') 47 return tensors[0] 48 49 return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0], 50 dtype='int64') 51 52 def update_mask(self, padding_mask, dataset_batch): 53 """Calculate and cache the amount of padding required for a batch.""" 54 original_batch_size = self.get_real_batch_size(dataset_batch) 55 missing_count = self.padded_batch_size - original_batch_size 56 mask = backend.concatenate([array_ops.ones(original_batch_size), 57 array_ops.zeros(missing_count)], axis=0) 58 return backend.concatenate([padding_mask, mask], axis=0) 59 60 def pad_batch(self, *dataset_batch_elements): 61 """Pads out the batch dimension of a tensor to the complete batch size.""" 62 def _pad(batch): 63 """Helper function to pad nested data within each batch elements.""" 64 padded_dict_batch = {} 65 if isinstance(batch, dict): 66 for key, value in batch.items(): 67 padded_dict_batch[key] = _pad(value) 68 return padded_dict_batch 69 70 rank = len(batch.shape) 71 assert rank > 0 72 missing_count = (self.padded_batch_size - 73 self.get_real_batch_size(batch)) 74 padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) 75 return array_ops.pad(batch, padding, 'constant') 76 77 if len(dataset_batch_elements) == 1: 78 return _pad(dataset_batch_elements[0]) 79 80 batch_elements = [] 81 for batch_element in dataset_batch_elements: 82 batch_elements.append(_pad(batch_element)) 83 return tuple(batch_elements) 84 85 def apply_mask(self, prediction_result): 86 """Removes prediction output that corresponds to padded input.""" 87 padding_mask = backend.get_value(self.padding_mask) 88 assert len(padding_mask.shape) == 1 89 90 if len(self.output_shape) == 1: 91 prediction = np.take(prediction_result, 92 np.nonzero( 93 padding_mask[:len(prediction_result)]), 94 axis=0) 95 if prediction.shape[0] == 1: 96 prediction = np.squeeze(prediction, axis=0) 97 return prediction 98 99 else: 100 predictions = [] 101 for i in range(len(self.output_shape)): 102 prediction = prediction_result[i] 103 prediction = np.take(prediction, np.nonzero( 104 padding_mask[:len(prediction)]), axis=0) 105 predictions.append(np.squeeze(prediction)) 106 107 return predictions 108