xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/batch_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/batch_util.h"
17 
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23 
24 namespace tensorflow {
25 namespace batch_util {
26 
27 namespace {
28 
ValidateInput(const Tensor & parent,const Tensor & element,int64_t index)29 Status ValidateInput(const Tensor& parent, const Tensor& element,
30                      int64_t index) {
31   DCHECK_NE(parent.dim_size(0), 0);
32   DCHECK_GE(index, 0);
33   if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
34     TensorShape chip_shape = parent.shape();
35     chip_shape.RemoveDim(0);
36     return errors::Internal(
37         "ValidateInput Cannot perform copy: number of elements does not match. "
38         " Shapes are: [element]: ",
39         element.shape().DebugString(),
40         ", [parent slice]: ", chip_shape.DebugString());
41   }
42   return OkStatus();
43 }
44 
45 template <typename T>
HandleElementToSlice(const Tensor &,T * src,T * dest,int64_t num_values)46 Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest,
47                             int64_t num_values) {
48   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
49   memcpy(dest, src, num_values * sizeof(T));
50   return OkStatus();
51 }
52 
53 template <>
HandleElementToSlice(const Tensor & element,tstring * src,tstring * dest,int64_t num_values)54 Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src,
55                                      tstring* dest, int64_t num_values) {
56   if (element.RefCountIsOne()) {
57     for (int64_t i = 0; i < num_values; ++i) {
58       *dest++ = std::move(*src++);
59     }
60   } else {
61     std::copy_n(src, num_values, dest);
62   }
63   return OkStatus();
64 }
65 
66 template <>
HandleElementToSlice(const Tensor & element,Variant * src,Variant * dest,int64_t num_values)67 Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src,
68                                      Variant* dest, int64_t num_values) {
69   if (element.RefCountIsOne()) {
70     for (int64_t i = 0; i < num_values; ++i) {
71       *dest++ = std::move(*src++);
72     }
73   } else {
74     std::copy_n(src, num_values, dest);
75   }
76   return OkStatus();
77 }
78 
79 template <>
HandleElementToSlice(const Tensor &,ResourceHandle * src,ResourceHandle * dest,int64_t num_values)80 Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */,
81                                             ResourceHandle* src,
82                                             ResourceHandle* dest,
83                                             int64_t num_values) {
84   std::copy_n(src, num_values, dest);
85   return OkStatus();
86 }
87 
88 template <>
HandleElementToSlice(const Tensor &,Eigen::half * src,Eigen::half * dest,int64_t num_values)89 Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */,
90                                          Eigen::half* src, Eigen::half* dest,
91                                          int64_t num_values) {
92   std::copy_n(src, num_values, dest);
93   return OkStatus();
94 }
95 
96 template <typename T>
HandleSliceToElement(const T * src,T * dest,int64_t num_values)97 void HandleSliceToElement(const T* src, T* dest, int64_t num_values) {
98   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
99   memcpy(dest, src, num_values * sizeof(T));
100 }
101 
102 template <>
HandleSliceToElement(const tstring * src,tstring * dest,int64_t num_values)103 void HandleSliceToElement<tstring>(const tstring* src, tstring* dest,
104                                    int64_t num_values) {
105   std::copy_n(src, num_values, dest);
106 }
107 
108 template <>
HandleSliceToElement(const Variant * src,Variant * dest,int64_t num_values)109 void HandleSliceToElement<Variant>(const Variant* src, Variant* dest,
110                                    int64_t num_values) {
111   std::copy_n(src, num_values, dest);
112 }
113 
114 template <>
HandleSliceToElement(const ResourceHandle * src,ResourceHandle * dest,int64_t num_values)115 void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src,
116                                           ResourceHandle* dest,
117                                           int64_t num_values) {
118   std::copy_n(src, num_values, dest);
119 }
120 
121 template <>
HandleSliceToElement(const Eigen::half * src,Eigen::half * dest,int64_t num_values)122 void HandleSliceToElement<Eigen::half>(const Eigen::half* src,
123                                        Eigen::half* dest, int64_t num_values) {
124   std::copy_n(src, num_values, dest);
125 }
126 
127 template <typename T>
HandleSliceToElement(Tensor * parent,T * src,T * dest,int64_t num_values)128 void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64_t num_values) {
129   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
130   memcpy(dest, src, num_values * sizeof(T));
131 }
132 
133 template <>
HandleSliceToElement(Tensor * parent,tstring * src,tstring * dest,int64_t num_values)134 void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest,
135                                    int64_t num_values) {
136   if (parent->RefCountIsOne()) {
137     for (int64_t i = 0; i < num_values; ++i) {
138       dest[i] = std::move(src[i]);
139     }
140   } else {
141     std::copy_n(src, num_values, dest);
142   }
143 }
144 
145 template <>
HandleSliceToElement(Tensor * parent,Variant * src,Variant * dest,int64_t num_values)146 void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest,
147                                    int64_t num_values) {
148   if (parent->RefCountIsOne()) {
149     for (int64_t i = 0; i < num_values; ++i) {
150       dest[i] = std::move(src[i]);
151     }
152   } else {
153     std::copy_n(src, num_values, dest);
154   }
155 }
156 
157 template <>
HandleSliceToElement(Tensor * parent,ResourceHandle * src,ResourceHandle * dest,int64_t num_values)158 void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src,
159                                           ResourceHandle* dest,
160                                           int64_t num_values) {
161   std::copy_n(src, num_values, dest);
162 }
163 
164 template <>
HandleSliceToElement(Tensor * parent,Eigen::half * src,Eigen::half * dest,int64_t num_values)165 void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src,
166                                        Eigen::half* dest, int64_t num_values) {
167   std::copy_n(src, num_values, dest);
168 }
169 
170 }  // namespace
171 
172 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64_t index)173 Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) {
174   TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
175   const int64_t num_values = element.NumElements();
176 #define HANDLE_TYPE(T)                                              \
177   case DataTypeToEnum<T>::value: {                                  \
178     T* src = element.base<T>();                                     \
179     T* dest = parent->base<T>() + (num_values * index);             \
180     return HandleElementToSlice<T>(element, src, dest, num_values); \
181   }
182 
183   switch (element.dtype()) {
184     TF_CALL_ALL_TYPES(HANDLE_TYPE);
185     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
186 #undef HANDLE_TYPE
187     default:
188       return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
189                                    element.dtype());
190   }
191 }
192 
193 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64_t index)194 Status CopySliceToElement(const Tensor& parent, Tensor* element,
195                           int64_t index) {
196   TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
197   const int64_t num_values = element->NumElements();
198 
199 #define HANDLE_TYPE(T)                                      \
200   case DataTypeToEnum<T>::value: {                          \
201     const T* src = parent.base<T>() + (num_values * index); \
202     T* dest = element->base<T>();                           \
203     HandleSliceToElement<T>(src, dest, num_values);         \
204     return OkStatus();                                      \
205   }
206 
207   switch (parent.dtype()) {
208     TF_CALL_ALL_TYPES(HANDLE_TYPE);
209     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
210 #undef HANDLE_TYPE
211     default:
212       return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
213                                    element->dtype());
214   }
215 }
216 
CopyContiguousSlices(const Tensor & src,int64_t src_offset,int64_t dst_offset,int64_t num_slices,Tensor * dst)217 Status CopyContiguousSlices(const Tensor& src, int64_t src_offset,
218                             int64_t dst_offset, int64_t num_slices,
219                             Tensor* dst) {
220   if (src.dtype() != dst->dtype()) {
221     return errors::FailedPrecondition(
222         "CopyContiguousSlices cannot perform copy: src and dst have different "
223         "dtypes. Source dtype: ",
224         src.dtype(), " dstination dtype: ", dst->dtype(), ".");
225   }
226   if (src.dims() < 1) {
227     return errors::FailedPrecondition(
228         "CopyContiguousSlices cannot perform copy: src has to be a tensor with "
229         "rank >= 1. Source shape: ",
230         src.shape().DebugString());
231   }
232 
233   if (dst->dims() < 1) {
234     return errors::FailedPrecondition(
235         "CopyContiguousSlices cannot perform copy: dst has to be a tensor "
236         "with rank >= 1. Dest shape: ",
237         dst->shape().DebugString());
238   }
239 
240   const int64_t src_dim0 = src.dim_size(0);
241   const int64_t dst_dim0 = dst->dim_size(0);
242   int64_t src_chip_size = 1;
243   int64_t dst_chip_size = 1;
244   for (int i = 1; i < src.dims(); ++i) {
245     src_chip_size *= src.dim_size(i);
246   }
247   for (int i = 1; i < dst->dims(); ++i) {
248     dst_chip_size *= dst->dim_size(i);
249   }
250 
251   if (src_chip_size != dst_chip_size) {
252     return errors::FailedPrecondition(
253         "CopyContiguousSlices cannot perform copy: source and dst shapes are"
254         "not compatible. Source shape: ",
255         src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
256   }
257 
258   if (src_chip_size == 0 && dst_chip_size == 0) {
259     return OkStatus();
260   }
261 
262   if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
263       dst_offset + num_slices > dst_dim0) {
264     return errors::FailedPrecondition(
265         "CopyContiguousSlices cannot perform copy: index out of range. "
266         "src_offset: ",
267         src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
268         ", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
269   }
270 
271 #define HANDLE_TYPE(T)                                                 \
272   case DataTypeToEnum<T>::value: {                                     \
273     const T* src_p = src.base<T>() + (src_chip_size * src_offset);     \
274     T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset);          \
275     HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
276     return OkStatus();                                                 \
277   }
278 
279   switch (src.dtype()) {
280     TF_CALL_ALL_TYPES(HANDLE_TYPE);
281     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
282 #undef HANDLE_TYPE
283     default:
284       return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
285                                    src.dtype());
286   }
287 }
288 
289 // Copies the index^th slice of parent (in the 0th dimension) into element.
290 //
291 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
292 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64_t index)293 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) {
294   TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
295   const int64_t num_values = element->NumElements();
296 
297 #define HANDLE_TYPE(T)                                      \
298   case DataTypeToEnum<T>::value: {                          \
299     T* src = parent->base<T>() + (num_values * index);      \
300     T* dest = element->base<T>();                           \
301     HandleSliceToElement<T>(parent, src, dest, num_values); \
302     return OkStatus();                                      \
303   }
304 
305   switch (parent->dtype()) {
306     TF_CALL_ALL_TYPES(HANDLE_TYPE);
307     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
308 #undef HANDLE_TYPE
309     default:
310       return errors::Unimplemented(
311           "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
312   }
313 }
314 
315 // The following five functions are copied from padding_fifo_queue.cc.
316 // TODO(mrry): Reconcile these functions with the similar methods in the
317 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)318 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
319   DCHECK_NE(parent->dim_size(0), 0);
320   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
321     TensorShape chip_shape = parent->shape();
322     chip_shape.RemoveDim(0);
323     return errors::Internal(
324         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
325         "element is greater than number of elements in parent slice.  ",
326         "Shapes are: [element]: ", element.shape().DebugString(),
327         ", [parent slice]: ", chip_shape.DebugString());
328   }
329   return OkStatus();
330 }
331 
332 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)333 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
334                                   int index) {
335   TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
336   if (element.NumElements() == 0) {
337     return OkStatus();
338   }
339   auto element_t = element.tensor<T, NDIMS>();
340   auto parent_t = parent->tensor<T, NDIMS + 1>();
341   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
342   slice_indices[0] = index;
343   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
344   slice_size[0] = 1;
345   for (size_t i = 1; i < slice_size.size(); ++i) {
346     slice_size[i] = element_t.dimension(i - 1);
347   }
348   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
349   return OkStatus();
350 }
351 
352 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)353 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
354                                           int index) {
355 #define HANDLE_TYPE(T)                                                   \
356   case DataTypeToEnum<T>::value: {                                       \
357     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
358   }
359 
360   switch (element.dtype()) {
361     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
362 #undef HANDLE_TYPE
363     default:
364       return errors::Unimplemented(
365           "HandleElementToLargerSliceWithRank Unhandled data type: ",
366           element.dtype());
367   }
368 }
369 
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)370 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
371                                 int index) {
372   if (parent->dims() != element.dims() + 1) {
373     return errors::Internal(
374         "Mismatched ranks.  Element's rank is: ", element.dims(),
375         " but element is meant to be a slice in output Tensor having rank: ",
376         parent->dims(), " (should be: ", element.dims() + 1, ")");
377   }
378 
379 #define HANDLE_DIMS(NDIMS)                                                  \
380   case NDIMS: {                                                             \
381     TF_RETURN_IF_ERROR(                                                     \
382         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
383     return OkStatus();                                                      \
384   }
385 
386   switch (element.dims()) {
387     HANDLE_DIMS(0);
388     HANDLE_DIMS(1);
389     HANDLE_DIMS(2);
390     HANDLE_DIMS(3);
391     HANDLE_DIMS(4);
392     HANDLE_DIMS(5);
393 #undef HANDLE_DIMS
394     default:
395       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
396                                    element.dims());
397   }
398 }
399 
SetElementZero(Tensor * element,const Tensor & padding)400 Status SetElementZero(Tensor* element, const Tensor& padding) {
401 #define HANDLE_TYPE(T)                                     \
402   if (element->dtype() == DataTypeToEnum<T>::value) {      \
403     element->flat<T>().setConstant(padding.scalar<T>()()); \
404     return OkStatus();                                     \
405   }
406   TF_CALL_DATASET_TYPES(HANDLE_TYPE);
407 #undef HANDLE_TYPE
408   return errors::Unimplemented("SetElementZero Unhandled data type: ",
409                                element->dtype());
410 }
411 
412 }  // namespace batch_util
413 }  // namespace tensorflow
414