xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/concat_split_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/ops_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/concat_lib.h"
24 #include "tensorflow/core/kernels/split_lib.h"
25 #include "tensorflow/core/platform/status.h"
26 
27 namespace tensorflow {
28 namespace concat_split_util {
29 
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32 
33 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
34 // Requires that all elements of 'inputs' have element type T. Writes to
35 // 'output' using 'context' for the allocation to ensure proper device
36 // placement.
37 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)38 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
39               Tensor* output) {
40   const int input_dims = inputs[0].dims();
41   const TensorShape& input_shape = inputs[0].shape();
42 
43   // Note that we reduce the concat of k-dimensional tensors into a two
44   // dimensional concat. Assuming the dimensions of any input tensor are
45   // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
46   std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
47   inputs_flat.reserve(inputs.size());
48   int64_t output_dim0 = 0;
49   for (size_t i = 0; i < inputs.size(); ++i) {
50     const Tensor& input = inputs[i];
51     if (input.dims() != input_dims) {
52       return errors::InvalidArgument(
53           "Ranks of all input tensors should match: shape[0] = ",
54           input_shape.DebugString(), " vs. shape[", i,
55           "] = ", input.shape().DebugString());
56     }
57     for (int j = 1; j < input_dims; ++j) {
58       if (input.dim_size(j) != input_shape.dim_size(j)) {
59         return errors::InvalidArgument(
60             "Dimensions of inputs should match: shape[0] = ",
61             input_shape.DebugString(), " vs. shape[", i,
62             "] = ", input.shape().DebugString());
63       }
64     }
65     if (input.NumElements() > 0) {
66       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
67           input.shaped<T, 2>({1, input.NumElements()})));
68     }
69     output_dim0 += input.dim_size(0);
70   }
71 
72   TensorShape output_shape(input_shape);
73   output_shape.set_dim(0, output_dim0);
74   AllocatorAttributes attr;
75   attr.set_on_host(true);
76   TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
77                                             output_shape, output, attr));
78   if (output->NumElements() > 0) {
79     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
80 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
81     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
82     if (std::is_same<Device, GPUDevice>::value) {
83       ConcatGPU<T>(context, inputs_flat, output, &output_flat);
84       return OkStatus();
85     }
86 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87     ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
88   }
89 
90   return OkStatus();
91 }
92 
93 // Same as 'Concat' above, but handles Tensor dtype deduction automatically.
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)94 inline Status Concat(OpKernelContext* context,
95                      const gtl::ArraySlice<Tensor> inputs, Tensor* output) {
96   const DataType type = inputs[0].dtype();
97   Status concat_status;
98   switch (type) {
99 #define CASE(type)                                         \
100   case DataTypeToEnum<type>::value:                        \
101     concat_status = Concat<type>(context, inputs, output); \
102     break;
103     TF_CALL_ALL_TYPES(CASE);
104 #undef CASE
105     default:
106       concat_status = errors::InvalidArgument("Unsupported data type: ", type);
107       break;
108   }
109   return concat_status;
110 }
111 
112 // The Split*() functions split 'input' with element type T into 'sizes.size()'
113 // tensors along the zeroth dimension, with the ith split having zeroth-
114 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
115 // for proper device placement.
116 
117 // Handles special cases that are cheap. Sets 'done==true' iff it found an
118 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
119 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs,bool * done)120 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
121                       const gtl::ArraySlice<int64_t> sizes,
122                       std::vector<Tensor>* outputs, bool* done) {
123   *done = false;
124 
125   int64_t total_size = 0;
126   for (const int64_t size : sizes) {
127     total_size += size;
128   }
129   if (total_size > input.shape().dim_size(0)) {
130     return errors::InvalidArgument(
131         "Sum of split sizes must not exceed dim0-size of input tensor");
132   }
133 
134   // Special case 0: trivial 1-way split.
135   if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
136     outputs->push_back(input);
137     *done = true;
138     return OkStatus();
139   }
140 
141   // Special case 1: input is aligned.
142   if (IsInnerDimsSizeAligned<T>(input.shape())) {
143     int64_t position = 0;
144     for (const int64_t size : sizes) {
145       outputs->emplace_back(input.Slice(position, position + size));
146       position += size;
147     }
148     *done = true;
149     return OkStatus();
150   }
151 
152   return OkStatus();
153 }
154 
155 // Handles the general case, on CPU.
156 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)157 Status SplitCPU(OpKernelContext* context, const Tensor& input,
158                 const gtl::ArraySlice<int64_t> sizes,
159                 std::vector<Tensor>* outputs) {
160   int64_t suffix_dim_size = 1;
161   for (int i = 1; i < input.shape().dims(); ++i) {
162     suffix_dim_size *= input.shape().dim_size(i);
163   }
164   auto input_reshaped =
165       input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
166 
167   int64_t position = 0;
168   for (const int64_t size : sizes) {
169     TensorShape output_shape = input.shape();
170     output_shape.set_dim(0, size);
171     Tensor output;
172     AllocatorAttributes attr;
173     attr.set_on_host(true);
174     TF_RETURN_IF_ERROR(
175         context->allocate_temp(input.dtype(), output_shape, &output, attr));
176     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
177 
178     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
179         static_cast<Eigen::DenseIndex>(position), 0};
180     Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{
181         static_cast<Eigen::DenseIndex>(size),
182         static_cast<Eigen::DenseIndex>(suffix_dim_size)};
183     functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
184                                       output_shaped, input_reshaped,
185                                       slice_indices, slice_sizes);
186 
187     outputs->emplace_back(output);
188 
189     position += size;
190   }
191 
192   return OkStatus();
193 }
194 
195 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
196     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
197 
198 // Handles the general case, on GPU.
199 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> & sizes,std::vector<Tensor> * outputs)200 Status SplitGPU(OpKernelContext* context, const Tensor& input,
201                 const gtl::ArraySlice<int64_t>& sizes,
202                 std::vector<Tensor>* outputs) {
203   // TODO(olston, apassos): Implement this.
204   LOG(FATAL) << "Not yet implemented";  // Crash ok
205 }
206 
207 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
208 
209 // The outer function that dispatches to the various Split*() functions above.
210 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)211 Status Split(OpKernelContext* context, const Tensor& input,
212              const gtl::ArraySlice<int64_t> sizes,
213              std::vector<Tensor>* outputs) {
214   bool easy_cases_done;
215   TF_RETURN_IF_ERROR(
216       SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
217   if (easy_cases_done) {
218     return OkStatus();
219   }
220 
221 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
222     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
223 // TODO(olston, apassos): Handle non-CPU cases.
224 // return SplitGPU<T>(context, input, sizes, outputs);
225 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
226   return SplitCPU<T>(context, input, sizes, outputs);
227 }
228 
229 // Same as 'Split' above, but handles Tensor dtype automatically.
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64_t> sizes,std::vector<Tensor> * outputs)230 inline Status Split(OpKernelContext* context, const Tensor& input,
231                     const gtl::ArraySlice<int64_t> sizes,
232                     std::vector<Tensor>* outputs) {
233   const DataType type = input.dtype();
234   Status split_status;
235   switch (type) {
236 #define CASE(type)                                              \
237   case DataTypeToEnum<type>::value:                             \
238     split_status = Split<type>(context, input, sizes, outputs); \
239     break;
240     TF_CALL_ALL_TYPES(CASE);
241 #undef CASE
242     default:
243       split_status = errors::InvalidArgument("Unsupported data type: ", type);
244       break;
245   }
246   return split_status;
247 }
248 
249 }  // namespace concat_split_util
250 }  // namespace tensorflow
251 
252 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
253