1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include <functional>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/kernel_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/resource_var.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44 namespace {
45
46 constexpr absl::string_view kNumSplitsAttrName = "num_splits";
47 constexpr absl::string_view kNumConcatsAttrName = "num_concats";
48
GetAndValidateAttributesHelper(bool split,OpKernelConstruction * ctx,std::vector<int32> & num_partitions,int & num_slices,std::vector<int32> & paddings,bool & has_paddings)49 Status GetAndValidateAttributesHelper(bool split, OpKernelConstruction* ctx,
50 std::vector<int32>& num_partitions,
51 int& num_slices,
52 std::vector<int32>& paddings,
53 bool& has_paddings) {
54 absl::string_view num_partitions_attr_name =
55 split ? kNumSplitsAttrName : kNumConcatsAttrName;
56 TF_RETURN_IF_ERROR(ctx->GetAttr(num_partitions_attr_name, &num_partitions));
57
58 int num_dims_to_split = 0;
59 for (int i = 0, e = num_partitions.size(); i < e; ++i) {
60 const auto& split = num_partitions[i];
61 if (split <= 0) {
62 return errors::InvalidArgument("'", num_partitions_attr_name,
63 "' at index ", i,
64 " must be positive, but got ", split, ".");
65 }
66 if (split > 1) {
67 ++num_dims_to_split;
68 }
69 num_slices *= split;
70 }
71
72 int n;
73 TF_RETURN_IF_ERROR(ctx->GetAttr("N", &n));
74 if (n != num_slices) {
75 return errors::InvalidArgument(
76 "'N' must match number of slices ", num_slices, " from '",
77 num_partitions_attr_name, "', but got ", n, ".");
78 }
79
80 TF_RETURN_IF_ERROR(ctx->GetAttr("paddings", &paddings));
81 const int expected_rank = num_partitions.size();
82 if (!paddings.empty()) {
83 if (paddings.size() != expected_rank) {
84 return errors::InvalidArgument(
85 "'paddings' length must match '", num_partitions_attr_name,
86 "' length ", expected_rank, ", but got ", paddings.size(), ".");
87 }
88
89 for (int dim = 0; dim < expected_rank; ++dim) {
90 if (paddings[dim] < 0) {
91 return errors::InvalidArgument(
92 "'padding' must be all non-negative, but got ", paddings[dim],
93 " at index ", dim, ".");
94 }
95 if (paddings[dim] > 0) {
96 has_paddings = true;
97 }
98 }
99 } else {
100 paddings.assign(expected_rank, 0);
101 }
102
103 return OkStatus();
104 }
105
GetAndValidateAttributes(bool split,OpKernelConstruction * ctx,std::vector<int32> & num_partitions,int & num_slices,std::vector<int32> & paddings,bool & has_paddings)106 void GetAndValidateAttributes(bool split, OpKernelConstruction* ctx,
107 std::vector<int32>& num_partitions,
108 int& num_slices, std::vector<int32>& paddings,
109 bool& has_paddings) {
110 OP_REQUIRES_OK(
111 ctx, GetAndValidateAttributesHelper(split, ctx, num_partitions,
112 num_slices, paddings, has_paddings));
113 }
114
115 absl::string_view kHandle = "handle";
116 absl::string_view kTensor = "tensor";
117
118 template <bool Handle>
CreateResourceInvalidDTypeError(const ResourceHandle & handle,DataType actual_dtype,DataType expected_dtype)119 Status CreateResourceInvalidDTypeError(const ResourceHandle& handle,
120 DataType actual_dtype,
121 DataType expected_dtype) {
122 absl::string_view resource_component = Handle ? kHandle : kTensor;
123 return errors::InvalidArgument(
124 "'T' must match 'resource' variable ", resource_component, " ('",
125 handle.name(), "') container ('", handle.container(), "') dtype ",
126 DataTypeString(actual_dtype), ", but got ",
127 DataTypeString(expected_dtype), ".");
128 }
129
130 // Converts flatten index to start indices (subscript scaled with slice shape)
131 // for determining where to start a slice in the input tensor.
132 template <int Rank>
133 Eigen::DSizes<Eigen::DenseIndex, Rank> GetSliceIndices(
134 absl::Span<const int32> num_partitions,
135 const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_shape, const int index);
136 template <>
137 Eigen::DSizes<Eigen::DenseIndex, 1> GetSliceIndices(
138 absl::Span<const int32> num_partitions,
139 const Eigen::DSizes<Eigen::DenseIndex, 1>& slice_shape,
140 const int index) TF_ATTRIBUTE_NOINLINE;
141 template <>
142 Eigen::DSizes<Eigen::DenseIndex, 2> GetSliceIndices(
143 absl::Span<const int32> num_partitions,
144 const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_shape,
145 const int index) TF_ATTRIBUTE_NOINLINE;
146 template <>
147 Eigen::DSizes<Eigen::DenseIndex, 3> GetSliceIndices(
148 absl::Span<const int32> num_partitions,
149 const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_shape,
150 const int index) TF_ATTRIBUTE_NOINLINE;
151 template <>
152 Eigen::DSizes<Eigen::DenseIndex, 4> GetSliceIndices(
153 absl::Span<const int32> num_partitions,
154 const Eigen::DSizes<Eigen::DenseIndex, 4>& slice_shape,
155 const int index) TF_ATTRIBUTE_NOINLINE;
156 template <>
157 Eigen::DSizes<Eigen::DenseIndex, 5> GetSliceIndices(
158 absl::Span<const int32> num_partitions,
159 const Eigen::DSizes<Eigen::DenseIndex, 5>& slice_shape,
160 const int index) TF_ATTRIBUTE_NOINLINE;
161 template <>
162 Eigen::DSizes<Eigen::DenseIndex, 6> GetSliceIndices(
163 absl::Span<const int32> num_partitions,
164 const Eigen::DSizes<Eigen::DenseIndex, 6>& slice_shape,
165 const int index) TF_ATTRIBUTE_NOINLINE;
166 template <>
167 Eigen::DSizes<Eigen::DenseIndex, 7> GetSliceIndices(
168 absl::Span<const int32> num_partitions,
169 const Eigen::DSizes<Eigen::DenseIndex, 7>& slice_shape,
170 const int index) TF_ATTRIBUTE_NOINLINE;
171 template <>
172 Eigen::DSizes<Eigen::DenseIndex, 8> GetSliceIndices(
173 absl::Span<const int32> num_partitions,
174 const Eigen::DSizes<Eigen::DenseIndex, 8>& slice_shape,
175 const int index) TF_ATTRIBUTE_NOINLINE;
176
177 template <int Rank>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,Rank> & slice_shape,const int index)178 Eigen::DSizes<Eigen::DenseIndex, Rank> GetSliceIndices(
179 absl::Span<const int32> num_partitions,
180 const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_shape,
181 const int index) {
182 return Eigen::DSizes<Eigen::DenseIndex, Rank>();
183 }
184
185 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,1> & slice_shape,const int index)186 Eigen::DSizes<Eigen::DenseIndex, 1> GetSliceIndices(
187 absl::Span<const int32> num_partitions,
188 const Eigen::DSizes<Eigen::DenseIndex, 1>& slice_shape, const int index) {
189 Eigen::DSizes<Eigen::DenseIndex, 1> subscript;
190 subscript[0] = index * slice_shape[0];
191 return subscript;
192 }
193
194 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,2> & slice_shape,const int index)195 Eigen::DSizes<Eigen::DenseIndex, 2> GetSliceIndices(
196 absl::Span<const int32> num_partitions,
197 const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_shape, const int index) {
198 Eigen::DSizes<Eigen::DenseIndex, 2> subscript;
199 subscript[1] = (index % num_partitions[1]) * slice_shape[1];
200 subscript[0] = (index / num_partitions[1]) * slice_shape[0];
201 return subscript;
202 }
203
204 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,3> & slice_shape,const int index)205 Eigen::DSizes<Eigen::DenseIndex, 3> GetSliceIndices(
206 absl::Span<const int32> num_partitions,
207 const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_shape, const int index) {
208 Eigen::DSizes<Eigen::DenseIndex, 3> subscript;
209 subscript[2] = (index % num_partitions[2]) * slice_shape[2];
210 subscript[1] =
211 ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1];
212 subscript[0] =
213 (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0];
214 return subscript;
215 }
216
217 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,4> & slice_shape,const int index)218 Eigen::DSizes<Eigen::DenseIndex, 4> GetSliceIndices(
219 absl::Span<const int32> num_partitions,
220 const Eigen::DSizes<Eigen::DenseIndex, 4>& slice_shape, const int index) {
221 Eigen::DSizes<Eigen::DenseIndex, 4> subscript;
222 subscript[3] = (index % num_partitions[3]) * slice_shape[3];
223 subscript[2] =
224 ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2];
225 subscript[1] =
226 ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) *
227 slice_shape[1];
228 subscript[0] =
229 (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) *
230 slice_shape[0];
231 return subscript;
232 }
233
234 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,5> & slice_shape,const int index)235 Eigen::DSizes<Eigen::DenseIndex, 5> GetSliceIndices(
236 absl::Span<const int32> num_partitions,
237 const Eigen::DSizes<Eigen::DenseIndex, 5>& slice_shape, const int index) {
238 Eigen::DSizes<Eigen::DenseIndex, 5> subscript;
239 subscript[4] = (index % num_partitions[4]) * slice_shape[4];
240 subscript[3] =
241 ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3];
242 subscript[2] =
243 ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) *
244 slice_shape[2];
245 subscript[1] =
246 ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) %
247 num_partitions[1]) *
248 slice_shape[1];
249 subscript[0] = (index / (num_partitions[4] * num_partitions[3] *
250 num_partitions[2] * num_partitions[1])) *
251 slice_shape[0];
252 return subscript;
253 }
254
255 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,6> & slice_shape,const int index)256 Eigen::DSizes<Eigen::DenseIndex, 6> GetSliceIndices(
257 absl::Span<const int32> num_partitions,
258 const Eigen::DSizes<Eigen::DenseIndex, 6>& slice_shape, const int index) {
259 Eigen::DSizes<Eigen::DenseIndex, 6> subscript;
260 subscript[5] = (index % num_partitions[5]) * slice_shape[5];
261 subscript[4] =
262 ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4];
263 subscript[3] =
264 ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) *
265 slice_shape[3];
266 subscript[2] =
267 ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) %
268 num_partitions[2]) *
269 slice_shape[2];
270 subscript[1] = ((index / (num_partitions[5] * num_partitions[4] *
271 num_partitions[3] * num_partitions[2])) %
272 num_partitions[1]) *
273 slice_shape[1];
274 subscript[0] =
275 (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] *
276 num_partitions[2] * num_partitions[1])) *
277 slice_shape[0];
278 return subscript;
279 }
280
281 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,7> & slice_shape,const int index)282 Eigen::DSizes<Eigen::DenseIndex, 7> GetSliceIndices(
283 absl::Span<const int32> num_partitions,
284 const Eigen::DSizes<Eigen::DenseIndex, 7>& slice_shape, const int index) {
285 Eigen::DSizes<Eigen::DenseIndex, 7> subscript;
286 subscript[6] = (index % num_partitions[6]) * slice_shape[6];
287 subscript[5] =
288 ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5];
289 subscript[4] =
290 ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) *
291 slice_shape[4];
292 subscript[3] =
293 ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) %
294 num_partitions[3]) *
295 slice_shape[3];
296 subscript[2] = ((index / (num_partitions[6] * num_partitions[5] *
297 num_partitions[4] * num_partitions[3])) %
298 num_partitions[2]) *
299 slice_shape[2];
300 subscript[1] =
301 ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] *
302 num_partitions[3] * num_partitions[2])) %
303 num_partitions[1]) *
304 slice_shape[1];
305 subscript[0] =
306 (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] *
307 num_partitions[3] * num_partitions[2] * num_partitions[1])) *
308 slice_shape[0];
309 return subscript;
310 }
311
312 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,8> & slice_shape,const int index)313 Eigen::DSizes<Eigen::DenseIndex, 8> GetSliceIndices(
314 absl::Span<const int32> num_partitions,
315 const Eigen::DSizes<Eigen::DenseIndex, 8>& slice_shape, const int index) {
316 Eigen::DSizes<Eigen::DenseIndex, 8> subscript;
317 subscript[7] = (index % num_partitions[7]) * slice_shape[7];
318 subscript[6] =
319 ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6];
320 subscript[5] =
321 ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) *
322 slice_shape[5];
323 subscript[4] =
324 ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) %
325 num_partitions[4]) *
326 slice_shape[4];
327 subscript[3] = ((index / (num_partitions[7] * num_partitions[6] *
328 num_partitions[5] * num_partitions[4])) %
329 num_partitions[3]) *
330 slice_shape[3];
331 subscript[2] =
332 ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
333 num_partitions[4] * num_partitions[3])) %
334 num_partitions[2]) *
335 slice_shape[2];
336 subscript[1] =
337 ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
338 num_partitions[4] * num_partitions[3] * num_partitions[2])) %
339 num_partitions[1]) *
340 slice_shape[1];
341 subscript[0] =
342 (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
343 num_partitions[4] * num_partitions[3] * num_partitions[2] *
344 num_partitions[1])) *
345 slice_shape[0];
346 return subscript;
347 }
348
349 constexpr absl::string_view kTensorName = "'input' tensor";
350 constexpr absl::string_view kResourceName = "'resource' variable tensor";
351
352 template <int Rank>
353 Eigen::DSizes<Eigen::DenseIndex, Rank> ShapeAsEigenDSizes(
354 const TensorShape& shape) TF_ATTRIBUTE_NOINLINE;
355 template <int Rank>
ShapeAsEigenDSizes(const TensorShape & shape)356 Eigen::DSizes<Eigen::DenseIndex, Rank> ShapeAsEigenDSizes(
357 const TensorShape& shape) {
358 return shape.AsEigenDSizes<Rank>();
359 }
360
361 bool ValidateShapesForSlice(
362 OpKernelContext* ctx, bool resource, const Tensor* input,
363 const std::vector<int32>& num_splits,
364 const std::vector<int32>& paddings) TF_ATTRIBUTE_NOINLINE;
365
ValidateShapesForSlice(OpKernelContext * ctx,bool resource,const Tensor * input,const std::vector<int32> & num_splits,const std::vector<int32> & paddings)366 bool ValidateShapesForSlice(OpKernelContext* ctx, bool resource,
367 const Tensor* input,
368 const std::vector<int32>& num_splits,
369 const std::vector<int32>& paddings) {
370 const auto& ishape = input->shape();
371
372 Status s;
373
374 absl::string_view input_name = resource ? kResourceName : kTensorName;
375 const int rank = ishape.dims();
376 const auto& input_shape = ishape.dim_sizes();
377 if (rank <= 0 || rank > 8) {
378 s = errors::InvalidArgument(
379 input_name, " must have rank in range (0, 8], but got ", rank, ".");
380 } else if (rank != num_splits.size()) {
381 s = errors::InvalidArgument(
382 input_name, " rank must be the same as 'num_splits' length ",
383 num_splits.size(), ", but got rank ", rank, ".");
384 } else {
385 for (int dim = 0; dim < rank; ++dim) {
386 const auto input_shape_dim = input_shape[dim];
387 const auto paddings_dim = paddings[dim];
388 const auto num_splits_dim = num_splits[dim];
389 if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) {
390 s = errors::InvalidArgument(
391 input_name, " shape dimension ", dim, " (", input_shape_dim,
392 ") with padding ", paddings_dim,
393 " must be evenly divisible by 'num_splits' ", num_splits_dim, ".");
394 break;
395 }
396 }
397 }
398 if (!s.ok()) {
399 ctx->CtxFailure(__FILE__, __LINE__, s);
400 return false;
401 }
402 return true;
403 }
404
405 // Shared base class to save code space
406 class XlaSplitNDShared : public OpKernel {
407 public:
XlaSplitNDShared(OpKernelConstruction * ctx)408 explicit XlaSplitNDShared(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
409 : OpKernel(ctx),
410 num_slices_(1),
411 has_paddings_(false) {
412 GetAndValidateAttributes(/*split=*/true, ctx, num_splits_, num_slices_,
413 paddings_, has_paddings_);
414 }
415
416 protected:
417 template <int Rank>
418 class SliceAndMaybePadState {
419 public:
420 int num_complete_pad_dims_;
421 int num_partial_pad_dims_;
422 TensorShape non_padded_slice_shape_;
423 Eigen::array<Eigen::IndexPair<int64_t>, Rank> slice_paddings_;
424 Eigen::DSizes<Eigen::DenseIndex, Rank> slice_indices_;
425 Eigen::DSizes<Eigen::DenseIndex, Rank> output_slice_shape_dsizes_;
426 Eigen::DSizes<Eigen::DenseIndex, Rank> non_padded_slice_shape_dsizes_;
427
SliceAndMaybePadState(absl::Span<const int32> num_splits,const absl::Span<const int64_t> input_shape,const TensorShape & output_slice_shape,int slice_index)428 SliceAndMaybePadState(absl::Span<const int32> num_splits,
429 const absl::Span<const int64_t> input_shape,
430 const TensorShape& output_slice_shape,
431 int slice_index) TF_ATTRIBUTE_NOINLINE {
432 output_slice_shape_dsizes_ = ShapeAsEigenDSizes<Rank>(output_slice_shape);
433 num_complete_pad_dims_ = 0;
434 num_partial_pad_dims_ = 0;
435 slice_indices_ = GetSliceIndices<Rank>(
436 num_splits, output_slice_shape_dsizes_, slice_index);
437
438 // Calculate paddings necessary for slice instead of padding input and
439 // slicing subsequently to reduce temporary memory allocation.
440 for (int dim = 0; dim < Rank; ++dim) {
441 const int64_t dim_size = input_shape[dim];
442 const int64_t out_dim = output_slice_shape_dsizes_[dim];
443 int64_t non_padded_dim = 0;
444 if (slice_indices_[dim] >= dim_size) {
445 // Complete padding.
446 slice_indices_[dim] = dim_size;
447 non_padded_dim = 0;
448 slice_paddings_[dim] = {0, out_dim};
449 num_complete_pad_dims_++;
450 } else if (slice_indices_[dim] + out_dim > dim_size) {
451 // Partial padding.
452 non_padded_dim = dim_size - slice_indices_[dim];
453 slice_paddings_[dim] = {0, out_dim - non_padded_dim};
454 num_partial_pad_dims_++;
455 } else {
456 non_padded_dim = out_dim;
457 }
458 non_padded_slice_shape_.AddDim(non_padded_dim);
459 }
460 non_padded_slice_shape_dsizes_ =
461 ShapeAsEigenDSizes<Rank>(non_padded_slice_shape_);
462 }
463 };
464
GetDtypeHelper(OpKernelConstruction * ctx,const char * attr_name,DataType * dtype_ptr)465 static void GetDtypeHelper(OpKernelConstruction* ctx, const char* attr_name,
466 DataType* dtype_ptr) TF_ATTRIBUTE_NOINLINE {
467 OP_REQUIRES_OK(ctx, ctx->GetAttr(attr_name, dtype_ptr));
468 }
469
470 std::vector<int32> num_splits_;
471 int num_slices_;
472 std::vector<int32> paddings_;
473 bool has_paddings_;
474 };
475
476 template <typename Device, typename T>
477 class XlaSplitNDBaseOp : public XlaSplitNDShared {
478 public:
XlaSplitNDBaseOp(OpKernelConstruction * ctx)479 explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx)
480 : XlaSplitNDShared(ctx) {}
481
482 protected:
ComputeInternal(bool resource,OpKernelContext * ctx,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const Tensor * input)483 void ComputeInternal(
484 bool resource, OpKernelContext* ctx,
485 const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
486 const Tensor* input) {
487 const int rank = input->shape().dims();
488 const auto& input_shape = input->shape().dim_sizes();
489
490 if (!ValidateShapesForSlice(ctx, resource, input, num_splits_, paddings_)) {
491 return;
492 }
493
494 TensorShape output_slice_shape;
495 for (int i = 0; i < rank; ++i) {
496 output_slice_shape.AddDim((input_shape[i] + paddings_[i]) /
497 ((num_slices_ == 1) ? 1 : num_splits_[i]));
498 }
499 if (num_slices_ == 1 && !has_paddings_) {
500 // Handle simple case first
501 OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(*input));
502 } else {
503 const Device& device = ctx->eigen_device<Device>();
504 std::vector<Tensor*> output_slices(num_slices_);
505 for (int i = 0; i < num_slices_; i++) {
506 OP_REQUIRES_OK(ctx,
507 ctx->allocate_output(
508 /*index=*/i, output_slice_shape, &output_slices[i]));
509 }
510
511 if (rank == 1) {
512 SliceAndMaybePad<1>(ctx, device, input, input_shape, output_slice_shape,
513 output_slices);
514 } else if (rank == 2) {
515 SliceAndMaybePad<2>(ctx, device, input, input_shape, output_slice_shape,
516 output_slices);
517 } else if (rank == 3) {
518 SliceAndMaybePad<3>(ctx, device, input, input_shape, output_slice_shape,
519 output_slices);
520 } else if (rank == 4) {
521 SliceAndMaybePad<4>(ctx, device, input, input_shape, output_slice_shape,
522 output_slices);
523 } else if (rank == 5) {
524 SliceAndMaybePad<5>(ctx, device, input, input_shape, output_slice_shape,
525 output_slices);
526 } else if (rank == 6) {
527 SliceAndMaybePad<6>(ctx, device, input, input_shape, output_slice_shape,
528 output_slices);
529 } else if (rank == 7) {
530 SliceAndMaybePad<7>(ctx, device, input, input_shape, output_slice_shape,
531 output_slices);
532 } else if (rank == 8) {
533 SliceAndMaybePad<8>(ctx, device, input, input_shape, output_slice_shape,
534 output_slices);
535 }
536 return;
537 }
538 }
539
540 private:
SetToConstant(Tensor * output_slice,const Device & device)541 void SetToConstant(Tensor* output_slice,
542 const Device& device) TF_ATTRIBUTE_NOINLINE {
543 auto output_flat = output_slice->flat<T>();
544 output_flat.device(device) = output_flat.constant(T());
545 }
546
547 template <int Rank>
AssignFromInput(Tensor * output_slice,const Device & device,const Tensor * input,const Eigen::DSizes<Eigen::DenseIndex,Rank> & slice_indices,const Eigen::DSizes<Eigen::DenseIndex,Rank> & output_slice_shape_dsizes)548 void AssignFromInput(
549 Tensor* output_slice, const Device& device, const Tensor* input,
550 const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_indices,
551 const Eigen::DSizes<Eigen::DenseIndex, Rank>& output_slice_shape_dsizes)
552 TF_ATTRIBUTE_NOINLINE {
553 output_slice->tensor<T, Rank>().device(device) =
554 input->tensor<T, Rank>().slice(slice_indices,
555 output_slice_shape_dsizes);
556 }
557
558 template <int Rank>
SliceAndMaybePad(OpKernelContext * ctx,const Device & device,const Tensor * input,const absl::Span<const int64_t> input_shape,const TensorShape & output_slice_shape,const std::vector<Tensor * > & output_slices)559 void SliceAndMaybePad(
560 OpKernelContext* ctx, const Device& device, const Tensor* input,
561 const absl::Span<const int64_t> input_shape,
562 const TensorShape& output_slice_shape,
563 const std::vector<Tensor*>& output_slices) TF_ATTRIBUTE_NOINLINE {
564 const auto& input_tensor = input->tensor<T, Rank>();
565 // Slice shape with optional padding.
566 for (int i = 0; i < num_slices_; ++i) {
567 Tensor* output_slice = output_slices[i];
568 SliceAndMaybePadState<Rank> r(num_splits_, input_shape,
569 output_slice_shape, i);
570 if (r.num_complete_pad_dims_ == Rank ||
571 (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) {
572 // Need to init padding
573 SetToConstant(output_slice, device);
574 }
575 if (r.num_complete_pad_dims_ == Rank) {
576 // Done
577 } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) {
578 output_slice->tensor<T, Rank>()
579 .slice(Eigen::DSizes<Eigen::DenseIndex, Rank>(),
580 r.non_padded_slice_shape_dsizes_)
581 .device(device) = input_tensor.slice(
582 r.slice_indices_, r.non_padded_slice_shape_dsizes_);
583 } else {
584 AssignFromInput<Rank>(output_slice, device, input, r.slice_indices_,
585 r.output_slice_shape_dsizes_);
586 }
587 }
588 }
589 };
590
591 template <typename Device, typename T>
592 class XlaSplitNDOp : public XlaSplitNDBaseOp<Device, T> {
593 public:
XlaSplitNDOp(OpKernelConstruction * ctx)594 explicit XlaSplitNDOp(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
595 : XlaSplitNDBaseOp<Device, T>(ctx) {}
596
Compute(OpKernelContext * ctx)597 void Compute(OpKernelContext* ctx) override {
598 const Tensor& input = ctx->input(0);
599
600 auto assign_or_copy_value_fn = [&ctx](const Tensor& input) -> Status {
601 ctx->set_output(/*index=*/0, input);
602 return OkStatus();
603 };
604
605 this->ComputeInternal(/*resource=*/false, ctx, assign_or_copy_value_fn,
606 &input);
607 }
608 };
609
610 template <typename Device, typename T>
611 class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp<Device, T> {
612 public:
ReadVariableXlaSplitNDOp(OpKernelConstruction * ctx)613 explicit ReadVariableXlaSplitNDOp(OpKernelConstruction* ctx)
614 TF_ATTRIBUTE_NOINLINE : XlaSplitNDBaseOp<Device, T>(ctx) {
615 XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_);
616 }
617
Compute(OpKernelContext * ctx)618 void Compute(OpKernelContext* ctx) override {
619 core::RefCountPtr<Var> variable;
620 const ResourceHandle& handle = HandleFromInput(ctx, 0);
621 const Status status = LookupResource(ctx, handle, &variable);
622 OP_REQUIRES(
623 ctx, status.ok(),
624 errors::InvalidArgument("'resource' variable handle ('", handle.name(),
625 "') container ('", handle.container(),
626 "') cannot be found."));
627
628 tf_shared_lock ml(*variable->mu());
629 const Tensor* input = variable->tensor();
630 OP_REQUIRES(
631 ctx, input->dtype() == dtype_,
632 CreateResourceInvalidDTypeError<false>(handle, input->dtype(), dtype_));
633
634 auto assign_or_copy_value_fn = [&ctx,
635 &variable](const Tensor& input) -> Status {
636 if (variable->copy_on_read_mode.load()) {
637 Tensor* output;
638 TF_RETURN_IF_ERROR(
639 ctx->allocate_output(/*index=*/0, input.shape(), &output));
640 output->flat<T>().device(ctx->eigen_device<Device>()) = input.flat<T>();
641 } else {
642 ctx->set_output(/*index=*/0, input);
643 }
644 return OkStatus();
645 };
646
647 this->ComputeInternal(/*resource=*/true, ctx, assign_or_copy_value_fn,
648 input);
649 }
650
651 private:
652 DataType dtype_;
653 };
654
655 #define REGISTER_XLA_SPLIT_ND(type) \
656 REGISTER_KERNEL_BUILDER( \
657 Name("XlaSplitND").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
658 XlaSplitNDOp<Eigen::ThreadPoolDevice, type>)
659
660 TF_CALL_POD_TYPES(REGISTER_XLA_SPLIT_ND);
661 TF_CALL_QUANTIZED_TYPES(REGISTER_XLA_SPLIT_ND);
662 #undef REGISTER_XLA_SPLIT_ND
663
664 #define REGISTER_READ_VARIABLE_XLA_SPLIT_ND(type) \
665 REGISTER_KERNEL_BUILDER( \
666 Name("ReadVariableXlaSplitND") \
667 .Device(DEVICE_CPU) \
668 .TypeConstraint<type>("T"), \
669 ReadVariableXlaSplitNDOp<Eigen::ThreadPoolDevice, type>)
670
671 TF_CALL_POD_TYPES(REGISTER_READ_VARIABLE_XLA_SPLIT_ND);
672 TF_CALL_QUANTIZED_TYPES(REGISTER_READ_VARIABLE_XLA_SPLIT_ND);
673 #undef REGISTER_READ_VARIABLE_XLA_SPLIT_ND
674
675 // Shared base class to save code space
676 class XlaConcatNDShared : public OpKernel {
677 public:
XlaConcatNDShared(OpKernelConstruction * ctx)678 explicit XlaConcatNDShared(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
679 : OpKernel(ctx),
680 num_slices_(1),
681 has_paddings_(false) {
682 GetAndValidateAttributes(/*split=*/false, ctx, num_concats_, num_slices_,
683 paddings_, has_paddings_);
684 }
685
686 protected:
GetInputsAndOutputShape(OpKernelContext * ctx,OpInputList & inputs,TensorShape & output_shape)687 Status GetInputsAndOutputShape(OpKernelContext* ctx, OpInputList& inputs,
688 TensorShape& output_shape) {
689 TF_RETURN_IF_ERROR(ctx->input_list("inputs", &inputs));
690 DCHECK_EQ(inputs.size(), num_slices_);
691
692 const TensorShape& slice_shape = inputs[0].shape();
693 if (slice_shape.dims() != num_concats_.size()) {
694 return errors::InvalidArgument(
695 "'inputs' rank must be the same as 'num_concats' length ",
696 num_concats_.size(), ", but got rank ", slice_shape.dims(), ".");
697 }
698 for (int i = 1; i < num_slices_; ++i) {
699 const TensorShape& slice_shape_i = inputs[i].shape();
700 if (slice_shape != slice_shape_i) {
701 return errors::InvalidArgument(
702 "'inputs' must all have the same expected shape ", slice_shape,
703 ", but got ", slice_shape_i, " at index ", i, ".");
704 }
705 }
706
707 for (int i = 0, e = num_concats_.size(); i < e; ++i) {
708 const int max_dim_size = slice_shape.dim_size(i) * num_concats_[i];
709 if (paddings_[i] > max_dim_size) {
710 return errors::InvalidArgument(
711 "'paddings' must not exceed expected output shape dimension ",
712 max_dim_size, " at index ", i, ", but got ", paddings_[i], ".");
713 }
714 output_shape.AddDim(max_dim_size - paddings_[i]);
715 }
716
717 return OkStatus();
718 }
ApplyAssignOrCopyShared(OpKernelContext * ctx,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const Tensor & input)719 void ApplyAssignOrCopyShared(
720 OpKernelContext* ctx,
721 const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
722 const Tensor& input) {
723 OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(input));
724 }
725
726 template <int Rank>
727 class MaybeUnpadAndAssignState {
728 public:
729 int num_complete_pad_dims_;
730 int num_partial_pad_dims_;
731 TensorShape non_padded_slice_shape_;
732 Eigen::DSizes<Eigen::DenseIndex, Rank> slice_shape_dsizes_;
733 Eigen::array<Eigen::IndexPair<int64_t>, Rank> slice_paddings_;
734 Eigen::DSizes<Eigen::DenseIndex, Rank> slice_indices_;
735 Eigen::DSizes<Eigen::DenseIndex, Rank> output_slice_shape_dsizes_;
736 Eigen::DSizes<Eigen::DenseIndex, Rank> non_padded_slice_shape_dsizes_;
737
MaybeUnpadAndAssignState(absl::Span<const int32> num_concats,const Tensor & input0,Tensor * output,int slice_index)738 MaybeUnpadAndAssignState(absl::Span<const int32> num_concats,
739 const Tensor& input0, Tensor* output,
740 int slice_index) TF_ATTRIBUTE_NOINLINE {
741 slice_shape_dsizes_ = input0.shape().AsEigenDSizes<Rank>();
742 slice_indices_ =
743 GetSliceIndices<Rank>(num_concats, slice_shape_dsizes_, slice_index);
744 num_complete_pad_dims_ = 0;
745 num_partial_pad_dims_ = 0;
746 // Calculate paddings necessary to strip from slice.
747 for (int dim = 0; dim < Rank; ++dim) {
748 const int64_t dim_size = output->shape().dim_size(dim);
749 int64_t non_padded_dim = 0;
750 if (slice_indices_[dim] >= dim_size) {
751 // Complete padding.
752 slice_indices_[dim] = dim_size;
753 non_padded_dim = 0;
754 num_complete_pad_dims_++;
755 } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) {
756 // Partial padding.
757 non_padded_dim = dim_size - slice_indices_[dim];
758 num_partial_pad_dims_++;
759 } else {
760 non_padded_dim = slice_shape_dsizes_[dim];
761 }
762 non_padded_slice_shape_.AddDim(non_padded_dim);
763 }
764 non_padded_slice_shape_dsizes_ =
765 non_padded_slice_shape_.AsEigenDSizes<Rank>();
766 }
767 };
768
769 std::vector<int32> num_concats_;
770 int num_slices_;
771 std::vector<int32> paddings_;
772 bool has_paddings_;
773 };
774
775 template <typename Device, typename T>
776 class XlaConcatNDBaseOp : public XlaConcatNDShared {
777 public:
XlaConcatNDBaseOp(OpKernelConstruction * ctx)778 explicit XlaConcatNDBaseOp(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
779 : XlaConcatNDShared(ctx) {}
780
781 protected:
ComputeInternal(bool resource,OpKernelContext * ctx,const OpInputList & inputs,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const std::function<StatusOr<Tensor * > ()> & get_output_fn)782 void ComputeInternal(
783 bool resource, OpKernelContext* ctx, const OpInputList& inputs,
784 const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
785 const std::function<StatusOr<Tensor*>()>& get_output_fn) {
786 const int rank = inputs[0].shape().dims();
787
788 OP_REQUIRES(ctx, rank > 0 && rank <= 8,
789 errors::InvalidArgument(
790 "'inputs' tensors must have rank in range (0, 8], but got ",
791 rank, "."));
792
793 if (num_slices_ == 1 && !has_paddings_) {
794 // Simple case
795 ApplyAssignOrCopyShared(ctx, assign_or_copy_value_fn, inputs[0]);
796 return;
797 }
798
799 const Device& device = ctx->eigen_device<Device>();
800 auto status_or_output = get_output_fn();
801 OP_REQUIRES_OK(ctx, status_or_output.status());
802 Tensor* output = std::move(status_or_output).value();
803
804 if (rank == 1) {
805 MaybeUnpadAndAssign<1>(ctx, device, inputs, output);
806 } else if (rank == 2) {
807 MaybeUnpadAndAssign<2>(ctx, device, inputs, output);
808 } else if (rank == 3) {
809 MaybeUnpadAndAssign<3>(ctx, device, inputs, output);
810 } else if (rank == 4) {
811 MaybeUnpadAndAssign<4>(ctx, device, inputs, output);
812 } else if (rank == 5) {
813 MaybeUnpadAndAssign<5>(ctx, device, inputs, output);
814 } else if (rank == 6) {
815 MaybeUnpadAndAssign<6>(ctx, device, inputs, output);
816 } else if (rank == 7) {
817 MaybeUnpadAndAssign<7>(ctx, device, inputs, output);
818 } else if (rank == 8) {
819 MaybeUnpadAndAssign<8>(ctx, device, inputs, output);
820 }
821 }
822
823 private:
824 template <int Rank>
MaybeUnpadAndAssign(OpKernelContext * ctx,const Device & device,const OpInputList & inputs,Tensor * output)825 void MaybeUnpadAndAssign(OpKernelContext* ctx, const Device& device,
826 const OpInputList& inputs,
827 Tensor* output) TF_ATTRIBUTE_NOINLINE {
828 for (int i = 0; i < num_slices_; ++i) {
829 MaybeUnpadAndAssignState<Rank> r(num_concats_, inputs[0], output, i);
830 if (r.num_complete_pad_dims_ == Rank) {
831 continue;
832 } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) {
833 output->tensor<T, Rank>()
834 .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_)
835 .device(device) = inputs[i].tensor<T, Rank>().slice(
836 Eigen::DSizes<Eigen::DenseIndex, Rank>(),
837 r.non_padded_slice_shape_dsizes_);
838 } else {
839 output->tensor<T, Rank>()
840 .slice(r.slice_indices_, r.slice_shape_dsizes_)
841 .device(device) = inputs[i].tensor<T, Rank>();
842 }
843 }
844 }
845 };
846
847 template <typename Device, typename T>
848 class XlaConcatNDOp : public XlaConcatNDBaseOp<Device, T> {
849 public:
XlaConcatNDOp(OpKernelConstruction * ctx)850 explicit XlaConcatNDOp(OpKernelConstruction* ctx)
851 : XlaConcatNDBaseOp<Device, T>(ctx) {}
852
Compute(OpKernelContext * ctx)853 void Compute(OpKernelContext* ctx) override {
854 OpInputList inputs;
855 TensorShape output_shape;
856 OP_REQUIRES_OK(ctx,
857 this->GetInputsAndOutputShape(ctx, inputs, output_shape));
858
859 auto assign_or_copy_value_fn = [&ctx](const Tensor& input) -> Status {
860 ctx->set_output(/*index=*/0, input);
861 return OkStatus();
862 };
863
864 auto get_output_fn = [&ctx, &output_shape]() -> StatusOr<Tensor*> {
865 Tensor* output = nullptr;
866 TF_RETURN_IF_ERROR(
867 ctx->allocate_output(/*index=*/0, output_shape, &output));
868 return output;
869 };
870 this->ComputeInternal(/*resource=*/false, ctx, inputs,
871 assign_or_copy_value_fn, get_output_fn);
872 }
873 };
874
875 template <typename Device, typename T>
876 class AssignVariableXlaConcatNDOp : public XlaConcatNDBaseOp<Device, T> {
877 public:
AssignVariableXlaConcatNDOp(OpKernelConstruction * ctx)878 explicit AssignVariableXlaConcatNDOp(OpKernelConstruction* ctx)
879 TF_ATTRIBUTE_NOINLINE : XlaConcatNDBaseOp<Device, T>(ctx) {
880 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
881 }
882
Compute(OpKernelContext * ctx)883 void Compute(OpKernelContext* ctx) override {
884 OpInputList inputs;
885 TensorShape output_shape;
886 OP_REQUIRES_OK(ctx,
887 this->GetInputsAndOutputShape(ctx, inputs, output_shape));
888
889 core::RefCountPtr<Var> variable;
890 const ResourceHandle& handle = HandleFromInput(ctx, 0);
891 if (handle.dtypes_and_shapes().size() == 1) {
892 const DtypeAndPartialTensorShape dtype_and_shape =
893 handle.dtypes_and_shapes().front();
894 OP_REQUIRES(ctx, dtype_and_shape.dtype == dtype_,
895 CreateResourceInvalidDTypeError<true>(
896 handle, dtype_and_shape.dtype, dtype_));
897 OP_REQUIRES(ctx, dtype_and_shape.shape.IsCompatibleWith(output_shape),
898 errors::InvalidArgument(
899 "'resource' variable handle ('", handle.name(),
900 "') container ('", handle.container(),
901 "') shape must be compatible with expected shape ",
902 output_shape, ", but got ", dtype_and_shape.shape, "."));
903 }
904 OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(ctx, handle, &variable,
905 [this](Var** ptr) {
906 *ptr = new Var(dtype_);
907 return OkStatus();
908 }));
909 mutex_lock ml(*variable->mu());
910
911 OP_REQUIRES(ctx, variable->tensor()->dtype() == dtype_,
912 CreateResourceInvalidDTypeError<false>(
913 handle, variable->tensor()->dtype(), dtype_));
914
915 auto assign_or_copy_value_fn = [this, &ctx, &output_shape,
916 &variable](const Tensor& input) -> Status {
917 if (variable->copy_on_read_mode.load()) {
918 TF_RETURN_IF_ERROR(
919 ctx->allocate_temp(dtype_, output_shape, variable->tensor()));
920 variable->tensor()->flat<T>().device(ctx->eigen_device<Device>()) =
921 input.flat<T>();
922 } else {
923 *variable->tensor() = input;
924 }
925 return OkStatus();
926 };
927
928 auto get_output_fn = [this, &ctx, &output_shape,
929 &variable]() -> StatusOr<Tensor*> {
930 if (variable->copy_on_read_mode.load() ||
931 !variable->tensor()->RefCountIsOne() ||
932 !variable->tensor()->shape().IsSameSize(output_shape)) {
933 TF_RETURN_IF_ERROR(
934 ctx->allocate_temp(dtype_, output_shape, variable->tensor()));
935 }
936 return variable->tensor();
937 };
938
939 this->ComputeInternal(/*resource=*/true, ctx, inputs,
940 assign_or_copy_value_fn, get_output_fn);
941 variable->is_initialized = true;
942 }
943
944 DataType dtype_;
945 };
946
947 #define REGISTER_XLA_CONCAT_ND(type) \
948 REGISTER_KERNEL_BUILDER( \
949 Name("XlaConcatND").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
950 XlaConcatNDOp<Eigen::ThreadPoolDevice, type>)
951
952 TF_CALL_POD_TYPES(REGISTER_XLA_CONCAT_ND);
953 TF_CALL_QUANTIZED_TYPES(REGISTER_XLA_CONCAT_ND);
954 #undef REGISTER_XLA_CONCAT_ND
955
956 #define REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND(type) \
957 REGISTER_KERNEL_BUILDER( \
958 Name("AssignVariableXlaConcatND") \
959 .Device(DEVICE_CPU) \
960 .TypeConstraint<type>("T"), \
961 AssignVariableXlaConcatNDOp<Eigen::ThreadPoolDevice, type>)
962
963 TF_CALL_POD_TYPES(REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND);
964 TF_CALL_QUANTIZED_TYPES(REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND);
965 #undef REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND
966
967 } // anonymous namespace
968 } // namespace tensorflow
969