1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/batch_util.h"
17
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23
24 namespace tensorflow {
25 namespace batch_util {
26
27 namespace {
28
ValidateInput(const Tensor & parent,const Tensor & element,int64_t index)29 Status ValidateInput(const Tensor& parent, const Tensor& element,
30 int64_t index) {
31 DCHECK_NE(parent.dim_size(0), 0);
32 DCHECK_GE(index, 0);
33 if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
34 TensorShape chip_shape = parent.shape();
35 chip_shape.RemoveDim(0);
36 return errors::Internal(
37 "ValidateInput Cannot perform copy: number of elements does not match. "
38 " Shapes are: [element]: ",
39 element.shape().DebugString(),
40 ", [parent slice]: ", chip_shape.DebugString());
41 }
42 return OkStatus();
43 }
44
45 template <typename T>
HandleElementToSlice(const Tensor &,T * src,T * dest,int64_t num_values)46 Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest,
47 int64_t num_values) {
48 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
49 memcpy(dest, src, num_values * sizeof(T));
50 return OkStatus();
51 }
52
53 template <>
HandleElementToSlice(const Tensor & element,tstring * src,tstring * dest,int64_t num_values)54 Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src,
55 tstring* dest, int64_t num_values) {
56 if (element.RefCountIsOne()) {
57 for (int64_t i = 0; i < num_values; ++i) {
58 *dest++ = std::move(*src++);
59 }
60 } else {
61 std::copy_n(src, num_values, dest);
62 }
63 return OkStatus();
64 }
65
66 template <>
HandleElementToSlice(const Tensor & element,Variant * src,Variant * dest,int64_t num_values)67 Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src,
68 Variant* dest, int64_t num_values) {
69 if (element.RefCountIsOne()) {
70 for (int64_t i = 0; i < num_values; ++i) {
71 *dest++ = std::move(*src++);
72 }
73 } else {
74 std::copy_n(src, num_values, dest);
75 }
76 return OkStatus();
77 }
78
79 template <>
HandleElementToSlice(const Tensor &,ResourceHandle * src,ResourceHandle * dest,int64_t num_values)80 Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */,
81 ResourceHandle* src,
82 ResourceHandle* dest,
83 int64_t num_values) {
84 std::copy_n(src, num_values, dest);
85 return OkStatus();
86 }
87
88 template <>
HandleElementToSlice(const Tensor &,Eigen::half * src,Eigen::half * dest,int64_t num_values)89 Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */,
90 Eigen::half* src, Eigen::half* dest,
91 int64_t num_values) {
92 std::copy_n(src, num_values, dest);
93 return OkStatus();
94 }
95
96 template <typename T>
HandleSliceToElement(const T * src,T * dest,int64_t num_values)97 void HandleSliceToElement(const T* src, T* dest, int64_t num_values) {
98 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
99 memcpy(dest, src, num_values * sizeof(T));
100 }
101
102 template <>
HandleSliceToElement(const tstring * src,tstring * dest,int64_t num_values)103 void HandleSliceToElement<tstring>(const tstring* src, tstring* dest,
104 int64_t num_values) {
105 std::copy_n(src, num_values, dest);
106 }
107
108 template <>
HandleSliceToElement(const Variant * src,Variant * dest,int64_t num_values)109 void HandleSliceToElement<Variant>(const Variant* src, Variant* dest,
110 int64_t num_values) {
111 std::copy_n(src, num_values, dest);
112 }
113
114 template <>
HandleSliceToElement(const ResourceHandle * src,ResourceHandle * dest,int64_t num_values)115 void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src,
116 ResourceHandle* dest,
117 int64_t num_values) {
118 std::copy_n(src, num_values, dest);
119 }
120
121 template <>
HandleSliceToElement(const Eigen::half * src,Eigen::half * dest,int64_t num_values)122 void HandleSliceToElement<Eigen::half>(const Eigen::half* src,
123 Eigen::half* dest, int64_t num_values) {
124 std::copy_n(src, num_values, dest);
125 }
126
127 template <typename T>
HandleSliceToElement(Tensor * parent,T * src,T * dest,int64_t num_values)128 void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64_t num_values) {
129 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
130 memcpy(dest, src, num_values * sizeof(T));
131 }
132
133 template <>
HandleSliceToElement(Tensor * parent,tstring * src,tstring * dest,int64_t num_values)134 void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest,
135 int64_t num_values) {
136 if (parent->RefCountIsOne()) {
137 for (int64_t i = 0; i < num_values; ++i) {
138 dest[i] = std::move(src[i]);
139 }
140 } else {
141 std::copy_n(src, num_values, dest);
142 }
143 }
144
145 template <>
HandleSliceToElement(Tensor * parent,Variant * src,Variant * dest,int64_t num_values)146 void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest,
147 int64_t num_values) {
148 if (parent->RefCountIsOne()) {
149 for (int64_t i = 0; i < num_values; ++i) {
150 dest[i] = std::move(src[i]);
151 }
152 } else {
153 std::copy_n(src, num_values, dest);
154 }
155 }
156
157 template <>
HandleSliceToElement(Tensor * parent,ResourceHandle * src,ResourceHandle * dest,int64_t num_values)158 void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src,
159 ResourceHandle* dest,
160 int64_t num_values) {
161 std::copy_n(src, num_values, dest);
162 }
163
164 template <>
HandleSliceToElement(Tensor * parent,Eigen::half * src,Eigen::half * dest,int64_t num_values)165 void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src,
166 Eigen::half* dest, int64_t num_values) {
167 std::copy_n(src, num_values, dest);
168 }
169
170 } // namespace
171
172 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64_t index)173 Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) {
174 TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
175 const int64_t num_values = element.NumElements();
176 #define HANDLE_TYPE(T) \
177 case DataTypeToEnum<T>::value: { \
178 T* src = element.base<T>(); \
179 T* dest = parent->base<T>() + (num_values * index); \
180 return HandleElementToSlice<T>(element, src, dest, num_values); \
181 }
182
183 switch (element.dtype()) {
184 TF_CALL_ALL_TYPES(HANDLE_TYPE);
185 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
186 #undef HANDLE_TYPE
187 default:
188 return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
189 element.dtype());
190 }
191 }
192
193 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64_t index)194 Status CopySliceToElement(const Tensor& parent, Tensor* element,
195 int64_t index) {
196 TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
197 const int64_t num_values = element->NumElements();
198
199 #define HANDLE_TYPE(T) \
200 case DataTypeToEnum<T>::value: { \
201 const T* src = parent.base<T>() + (num_values * index); \
202 T* dest = element->base<T>(); \
203 HandleSliceToElement<T>(src, dest, num_values); \
204 return OkStatus(); \
205 }
206
207 switch (parent.dtype()) {
208 TF_CALL_ALL_TYPES(HANDLE_TYPE);
209 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
210 #undef HANDLE_TYPE
211 default:
212 return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
213 element->dtype());
214 }
215 }
216
CopyContiguousSlices(const Tensor & src,int64_t src_offset,int64_t dst_offset,int64_t num_slices,Tensor * dst)217 Status CopyContiguousSlices(const Tensor& src, int64_t src_offset,
218 int64_t dst_offset, int64_t num_slices,
219 Tensor* dst) {
220 if (src.dtype() != dst->dtype()) {
221 return errors::FailedPrecondition(
222 "CopyContiguousSlices cannot perform copy: src and dst have different "
223 "dtypes. Source dtype: ",
224 src.dtype(), " dstination dtype: ", dst->dtype(), ".");
225 }
226 if (src.dims() < 1) {
227 return errors::FailedPrecondition(
228 "CopyContiguousSlices cannot perform copy: src has to be a tensor with "
229 "rank >= 1. Source shape: ",
230 src.shape().DebugString());
231 }
232
233 if (dst->dims() < 1) {
234 return errors::FailedPrecondition(
235 "CopyContiguousSlices cannot perform copy: dst has to be a tensor "
236 "with rank >= 1. Dest shape: ",
237 dst->shape().DebugString());
238 }
239
240 const int64_t src_dim0 = src.dim_size(0);
241 const int64_t dst_dim0 = dst->dim_size(0);
242 int64_t src_chip_size = 1;
243 int64_t dst_chip_size = 1;
244 for (int i = 1; i < src.dims(); ++i) {
245 src_chip_size *= src.dim_size(i);
246 }
247 for (int i = 1; i < dst->dims(); ++i) {
248 dst_chip_size *= dst->dim_size(i);
249 }
250
251 if (src_chip_size != dst_chip_size) {
252 return errors::FailedPrecondition(
253 "CopyContiguousSlices cannot perform copy: source and dst shapes are"
254 "not compatible. Source shape: ",
255 src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
256 }
257
258 if (src_chip_size == 0 && dst_chip_size == 0) {
259 return OkStatus();
260 }
261
262 if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
263 dst_offset + num_slices > dst_dim0) {
264 return errors::FailedPrecondition(
265 "CopyContiguousSlices cannot perform copy: index out of range. "
266 "src_offset: ",
267 src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
268 ", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
269 }
270
271 #define HANDLE_TYPE(T) \
272 case DataTypeToEnum<T>::value: { \
273 const T* src_p = src.base<T>() + (src_chip_size * src_offset); \
274 T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset); \
275 HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
276 return OkStatus(); \
277 }
278
279 switch (src.dtype()) {
280 TF_CALL_ALL_TYPES(HANDLE_TYPE);
281 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
282 #undef HANDLE_TYPE
283 default:
284 return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
285 src.dtype());
286 }
287 }
288
289 // Copies the index^th slice of parent (in the 0th dimension) into element.
290 //
291 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
292 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64_t index)293 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) {
294 TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
295 const int64_t num_values = element->NumElements();
296
297 #define HANDLE_TYPE(T) \
298 case DataTypeToEnum<T>::value: { \
299 T* src = parent->base<T>() + (num_values * index); \
300 T* dest = element->base<T>(); \
301 HandleSliceToElement<T>(parent, src, dest, num_values); \
302 return OkStatus(); \
303 }
304
305 switch (parent->dtype()) {
306 TF_CALL_ALL_TYPES(HANDLE_TYPE);
307 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
308 #undef HANDLE_TYPE
309 default:
310 return errors::Unimplemented(
311 "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
312 }
313 }
314
315 // The following five functions are copied from padding_fifo_queue.cc.
316 // TODO(mrry): Reconcile these functions with the similar methods in the
317 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)318 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
319 DCHECK_NE(parent->dim_size(0), 0);
320 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
321 TensorShape chip_shape = parent->shape();
322 chip_shape.RemoveDim(0);
323 return errors::Internal(
324 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
325 "element is greater than number of elements in parent slice. ",
326 "Shapes are: [element]: ", element.shape().DebugString(),
327 ", [parent slice]: ", chip_shape.DebugString());
328 }
329 return OkStatus();
330 }
331
332 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)333 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
334 int index) {
335 TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
336 if (element.NumElements() == 0) {
337 return OkStatus();
338 }
339 auto element_t = element.tensor<T, NDIMS>();
340 auto parent_t = parent->tensor<T, NDIMS + 1>();
341 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
342 slice_indices[0] = index;
343 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
344 slice_size[0] = 1;
345 for (size_t i = 1; i < slice_size.size(); ++i) {
346 slice_size[i] = element_t.dimension(i - 1);
347 }
348 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
349 return OkStatus();
350 }
351
352 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)353 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
354 int index) {
355 #define HANDLE_TYPE(T) \
356 case DataTypeToEnum<T>::value: { \
357 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
358 }
359
360 switch (element.dtype()) {
361 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
362 #undef HANDLE_TYPE
363 default:
364 return errors::Unimplemented(
365 "HandleElementToLargerSliceWithRank Unhandled data type: ",
366 element.dtype());
367 }
368 }
369
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)370 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
371 int index) {
372 if (parent->dims() != element.dims() + 1) {
373 return errors::Internal(
374 "Mismatched ranks. Element's rank is: ", element.dims(),
375 " but element is meant to be a slice in output Tensor having rank: ",
376 parent->dims(), " (should be: ", element.dims() + 1, ")");
377 }
378
379 #define HANDLE_DIMS(NDIMS) \
380 case NDIMS: { \
381 TF_RETURN_IF_ERROR( \
382 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
383 return OkStatus(); \
384 }
385
386 switch (element.dims()) {
387 HANDLE_DIMS(0);
388 HANDLE_DIMS(1);
389 HANDLE_DIMS(2);
390 HANDLE_DIMS(3);
391 HANDLE_DIMS(4);
392 HANDLE_DIMS(5);
393 #undef HANDLE_DIMS
394 default:
395 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
396 element.dims());
397 }
398 }
399
SetElementZero(Tensor * element,const Tensor & padding)400 Status SetElementZero(Tensor* element, const Tensor& padding) {
401 #define HANDLE_TYPE(T) \
402 if (element->dtype() == DataTypeToEnum<T>::value) { \
403 element->flat<T>().setConstant(padding.scalar<T>()()); \
404 return OkStatus(); \
405 }
406 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
407 #undef HANDLE_TYPE
408 return errors::Unimplemented("SetElementZero Unhandled data type: ",
409 element->dtype());
410 }
411
412 } // namespace batch_util
413 } // namespace tensorflow
414