1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/ops_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/concat_lib.h"
24 #include "tensorflow/core/kernels/split_lib.h"
25 #include "tensorflow/core/platform/status.h"
26
27 namespace tensorflow {
28 namespace concat_split_util {
29
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32
33 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
34 // Requires that all elements of 'inputs' have element type T. Writes to
35 // 'output' using 'context' for the allocation to ensure proper device
36 // placement.
37 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)38 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
39 Tensor* output) {
40 const int input_dims = inputs[0].dims();
41 const TensorShape& input_shape = inputs[0].shape();
42
43 // Note that we reduce the concat of k-dimensional tensors into a two
44 // dimensional concat. Assuming the dimensions of any input tensor are
45 // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
46 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
47 inputs_flat.reserve(inputs.size());
48 int64_t output_dim0 = 0;
49 for (size_t i = 0; i < inputs.size(); ++i) {
50 const Tensor& input = inputs[i];
51 if (input.dims() != input_dims) {
52 return errors::InvalidArgument(
53 "Ranks of all input tensors should match: shape[0] = ",
54 input_shape.DebugString(), " vs. shape[", i,
55 "] = ", input.shape().DebugString());
56 }
57 for (int j = 1; j < input_dims; ++j) {
58 if (input.dim_size(j) != input_shape.dim_size(j)) {
59 return errors::InvalidArgument(
60 "Dimensions of inputs should match: shape[0] = ",
61 input_shape.DebugString(), " vs. shape[", i,
62 "] = ", input.shape().DebugString());
63 }
64 }
65 if (input.NumElements() > 0) {
66 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
67 input.shaped<T, 2>({1, input.NumElements()})));
68 }
69 output_dim0 += input.dim_size(0);
70 }
71
72 TensorShape output_shape(input_shape);
73 output_shape.set_dim(0, output_dim0);
74 AllocatorAttributes attr;
75 attr.set_on_host(true);
76 TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
77 output_shape, output, attr));
78 if (output->NumElements() > 0) {
79 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
80 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
81 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
82 if (std::is_same<Device, GPUDevice>::value) {
83 ConcatGPU<T>(context, inputs_flat, output, &output_flat);
84 return OkStatus();
85 }
86 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87 ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
88 }
89
90 return OkStatus();
91 }
92
93 // Same as 'Concat' above, but handles Tensor dtype deduction automatically.
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)94 inline Status Concat(OpKernelContext* context,
95 const gtl::ArraySlice<Tensor> inputs, Tensor* output) {
96 const DataType type = inputs[0].dtype();
97 Status concat_status;
98 switch (type) {
99 #define CASE(type) \
100 case DataTypeToEnum<type>::value: \
101 concat_status = Concat<type>(context, inputs, output); \
102 break;
103 TF_CALL_ALL_TYPES(CASE);
104 #undef CASE
105 default:
106 concat_status = errors::InvalidArgument("Unsupported data type: ", type);
107 break;
108 }
109 return concat_status;
110 }
111
112 // The Split*() functions split 'input' with element type T into 'sizes.size()'
113 // tensors along the zeroth dimension, with the ith split having zeroth-
114 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
115 // for proper device placement.
116
117 // Handles special cases that are cheap. Sets 'done==true' iff it found an
118 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
119 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs,bool * done)120 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
121 const gtl::ArraySlice<int64_t> sizes,
122 std::vector<Tensor>* outputs, bool* done) {
123 *done = false;
124
125 int64_t total_size = 0;
126 for (const int64_t size : sizes) {
127 total_size += size;
128 }
129 if (total_size > input.shape().dim_size(0)) {
130 return errors::InvalidArgument(
131 "Sum of split sizes must not exceed dim0-size of input tensor");
132 }
133
134 // Special case 0: trivial 1-way split.
135 if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
136 outputs->push_back(input);
137 *done = true;
138 return OkStatus();
139 }
140
141 // Special case 1: input is aligned.
142 if (IsInnerDimsSizeAligned<T>(input.shape())) {
143 int64_t position = 0;
144 for (const int64_t size : sizes) {
145 outputs->emplace_back(input.Slice(position, position + size));
146 position += size;
147 }
148 *done = true;
149 return OkStatus();
150 }
151
152 return OkStatus();
153 }
154
155 // Handles the general case, on CPU.
156 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)157 Status SplitCPU(OpKernelContext* context, const Tensor& input,
158 const gtl::ArraySlice<int64_t> sizes,
159 std::vector<Tensor>* outputs) {
160 int64_t suffix_dim_size = 1;
161 for (int i = 1; i < input.shape().dims(); ++i) {
162 suffix_dim_size *= input.shape().dim_size(i);
163 }
164 auto input_reshaped =
165 input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
166
167 int64_t position = 0;
168 for (const int64_t size : sizes) {
169 TensorShape output_shape = input.shape();
170 output_shape.set_dim(0, size);
171 Tensor output;
172 AllocatorAttributes attr;
173 attr.set_on_host(true);
174 TF_RETURN_IF_ERROR(
175 context->allocate_temp(input.dtype(), output_shape, &output, attr));
176 auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
177
178 Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
179 static_cast<Eigen::DenseIndex>(position), 0};
180 Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{
181 static_cast<Eigen::DenseIndex>(size),
182 static_cast<Eigen::DenseIndex>(suffix_dim_size)};
183 functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
184 output_shaped, input_reshaped,
185 slice_indices, slice_sizes);
186
187 outputs->emplace_back(output);
188
189 position += size;
190 }
191
192 return OkStatus();
193 }
194
195 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
196 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
197
198 // Handles the general case, on GPU.
199 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> & sizes,std::vector<Tensor> * outputs)200 Status SplitGPU(OpKernelContext* context, const Tensor& input,
201 const gtl::ArraySlice<int64_t>& sizes,
202 std::vector<Tensor>* outputs) {
203 // TODO(olston, apassos): Implement this.
204 LOG(FATAL) << "Not yet implemented"; // Crash ok
205 }
206
207 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
208
209 // The outer function that dispatches to the various Split*() functions above.
210 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)211 Status Split(OpKernelContext* context, const Tensor& input,
212 const gtl::ArraySlice<int64_t> sizes,
213 std::vector<Tensor>* outputs) {
214 bool easy_cases_done;
215 TF_RETURN_IF_ERROR(
216 SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
217 if (easy_cases_done) {
218 return OkStatus();
219 }
220
221 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
222 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
223 // TODO(olston, apassos): Handle non-CPU cases.
224 // return SplitGPU<T>(context, input, sizes, outputs);
225 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
226 return SplitCPU<T>(context, input, sizes, outputs);
227 }
228
229 // Same as 'Split' above, but handles Tensor dtype automatically.
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)230 inline Status Split(OpKernelContext* context, const Tensor& input,
231 const gtl::ArraySlice<int64_t> sizes,
232 std::vector<Tensor>* outputs) {
233 const DataType type = input.dtype();
234 Status split_status;
235 switch (type) {
236 #define CASE(type) \
237 case DataTypeToEnum<type>::value: \
238 split_status = Split<type>(context, input, sizes, outputs); \
239 break;
240 TF_CALL_ALL_TYPES(CASE);
241 #undef CASE
242 default:
243 split_status = errors::InvalidArgument("Unsupported data type: ", type);
244 break;
245 }
246 return split_status;
247 }
248
249 } // namespace concat_split_util
250 } // namespace tensorflow
251
252 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
253