xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/sharding_util_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <functional>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/kernel_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/resource_var.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 namespace {
45 
46 constexpr absl::string_view kNumSplitsAttrName = "num_splits";
47 constexpr absl::string_view kNumConcatsAttrName = "num_concats";
48 
GetAndValidateAttributesHelper(bool split,OpKernelConstruction * ctx,std::vector<int32> & num_partitions,int & num_slices,std::vector<int32> & paddings,bool & has_paddings)49 Status GetAndValidateAttributesHelper(bool split, OpKernelConstruction* ctx,
50                                       std::vector<int32>& num_partitions,
51                                       int& num_slices,
52                                       std::vector<int32>& paddings,
53                                       bool& has_paddings) {
54   absl::string_view num_partitions_attr_name =
55       split ? kNumSplitsAttrName : kNumConcatsAttrName;
56   TF_RETURN_IF_ERROR(ctx->GetAttr(num_partitions_attr_name, &num_partitions));
57 
58   int num_dims_to_split = 0;
59   for (int i = 0, e = num_partitions.size(); i < e; ++i) {
60     const auto& split = num_partitions[i];
61     if (split <= 0) {
62       return errors::InvalidArgument("'", num_partitions_attr_name,
63                                      "' at index ", i,
64                                      " must be positive, but got ", split, ".");
65     }
66     if (split > 1) {
67       ++num_dims_to_split;
68     }
69     num_slices *= split;
70   }
71 
72   int n;
73   TF_RETURN_IF_ERROR(ctx->GetAttr("N", &n));
74   if (n != num_slices) {
75     return errors::InvalidArgument(
76         "'N' must match number of slices ", num_slices, " from '",
77         num_partitions_attr_name, "', but got ", n, ".");
78   }
79 
80   TF_RETURN_IF_ERROR(ctx->GetAttr("paddings", &paddings));
81   const int expected_rank = num_partitions.size();
82   if (!paddings.empty()) {
83     if (paddings.size() != expected_rank) {
84       return errors::InvalidArgument(
85           "'paddings' length must match '", num_partitions_attr_name,
86           "' length ", expected_rank, ", but got ", paddings.size(), ".");
87     }
88 
89     for (int dim = 0; dim < expected_rank; ++dim) {
90       if (paddings[dim] < 0) {
91         return errors::InvalidArgument(
92             "'padding' must be all non-negative, but got ", paddings[dim],
93             " at index ", dim, ".");
94       }
95       if (paddings[dim] > 0) {
96         has_paddings = true;
97       }
98     }
99   } else {
100     paddings.assign(expected_rank, 0);
101   }
102 
103   return OkStatus();
104 }
105 
GetAndValidateAttributes(bool split,OpKernelConstruction * ctx,std::vector<int32> & num_partitions,int & num_slices,std::vector<int32> & paddings,bool & has_paddings)106 void GetAndValidateAttributes(bool split, OpKernelConstruction* ctx,
107                               std::vector<int32>& num_partitions,
108                               int& num_slices, std::vector<int32>& paddings,
109                               bool& has_paddings) {
110   OP_REQUIRES_OK(
111       ctx, GetAndValidateAttributesHelper(split, ctx, num_partitions,
112                                           num_slices, paddings, has_paddings));
113 }
114 
115 absl::string_view kHandle = "handle";
116 absl::string_view kTensor = "tensor";
117 
118 template <bool Handle>
CreateResourceInvalidDTypeError(const ResourceHandle & handle,DataType actual_dtype,DataType expected_dtype)119 Status CreateResourceInvalidDTypeError(const ResourceHandle& handle,
120                                        DataType actual_dtype,
121                                        DataType expected_dtype) {
122   absl::string_view resource_component = Handle ? kHandle : kTensor;
123   return errors::InvalidArgument(
124       "'T' must match 'resource' variable ", resource_component, " ('",
125       handle.name(), "') container ('", handle.container(), "') dtype ",
126       DataTypeString(actual_dtype), ", but got ",
127       DataTypeString(expected_dtype), ".");
128 }
129 
130 // Converts flatten index to start indices (subscript scaled with slice shape)
131 // for determining where to start a slice in the input tensor.
132 template <int Rank>
133 Eigen::DSizes<Eigen::DenseIndex, Rank> GetSliceIndices(
134     absl::Span<const int32> num_partitions,
135     const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_shape, const int index);
136 template <>
137 Eigen::DSizes<Eigen::DenseIndex, 1> GetSliceIndices(
138     absl::Span<const int32> num_partitions,
139     const Eigen::DSizes<Eigen::DenseIndex, 1>& slice_shape,
140     const int index) TF_ATTRIBUTE_NOINLINE;
141 template <>
142 Eigen::DSizes<Eigen::DenseIndex, 2> GetSliceIndices(
143     absl::Span<const int32> num_partitions,
144     const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_shape,
145     const int index) TF_ATTRIBUTE_NOINLINE;
146 template <>
147 Eigen::DSizes<Eigen::DenseIndex, 3> GetSliceIndices(
148     absl::Span<const int32> num_partitions,
149     const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_shape,
150     const int index) TF_ATTRIBUTE_NOINLINE;
151 template <>
152 Eigen::DSizes<Eigen::DenseIndex, 4> GetSliceIndices(
153     absl::Span<const int32> num_partitions,
154     const Eigen::DSizes<Eigen::DenseIndex, 4>& slice_shape,
155     const int index) TF_ATTRIBUTE_NOINLINE;
156 template <>
157 Eigen::DSizes<Eigen::DenseIndex, 5> GetSliceIndices(
158     absl::Span<const int32> num_partitions,
159     const Eigen::DSizes<Eigen::DenseIndex, 5>& slice_shape,
160     const int index) TF_ATTRIBUTE_NOINLINE;
161 template <>
162 Eigen::DSizes<Eigen::DenseIndex, 6> GetSliceIndices(
163     absl::Span<const int32> num_partitions,
164     const Eigen::DSizes<Eigen::DenseIndex, 6>& slice_shape,
165     const int index) TF_ATTRIBUTE_NOINLINE;
166 template <>
167 Eigen::DSizes<Eigen::DenseIndex, 7> GetSliceIndices(
168     absl::Span<const int32> num_partitions,
169     const Eigen::DSizes<Eigen::DenseIndex, 7>& slice_shape,
170     const int index) TF_ATTRIBUTE_NOINLINE;
171 template <>
172 Eigen::DSizes<Eigen::DenseIndex, 8> GetSliceIndices(
173     absl::Span<const int32> num_partitions,
174     const Eigen::DSizes<Eigen::DenseIndex, 8>& slice_shape,
175     const int index) TF_ATTRIBUTE_NOINLINE;
176 
177 template <int Rank>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,Rank> & slice_shape,const int index)178 Eigen::DSizes<Eigen::DenseIndex, Rank> GetSliceIndices(
179     absl::Span<const int32> num_partitions,
180     const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_shape,
181     const int index) {
182   return Eigen::DSizes<Eigen::DenseIndex, Rank>();
183 }
184 
185 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,1> & slice_shape,const int index)186 Eigen::DSizes<Eigen::DenseIndex, 1> GetSliceIndices(
187     absl::Span<const int32> num_partitions,
188     const Eigen::DSizes<Eigen::DenseIndex, 1>& slice_shape, const int index) {
189   Eigen::DSizes<Eigen::DenseIndex, 1> subscript;
190   subscript[0] = index * slice_shape[0];
191   return subscript;
192 }
193 
194 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,2> & slice_shape,const int index)195 Eigen::DSizes<Eigen::DenseIndex, 2> GetSliceIndices(
196     absl::Span<const int32> num_partitions,
197     const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_shape, const int index) {
198   Eigen::DSizes<Eigen::DenseIndex, 2> subscript;
199   subscript[1] = (index % num_partitions[1]) * slice_shape[1];
200   subscript[0] = (index / num_partitions[1]) * slice_shape[0];
201   return subscript;
202 }
203 
204 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,3> & slice_shape,const int index)205 Eigen::DSizes<Eigen::DenseIndex, 3> GetSliceIndices(
206     absl::Span<const int32> num_partitions,
207     const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_shape, const int index) {
208   Eigen::DSizes<Eigen::DenseIndex, 3> subscript;
209   subscript[2] = (index % num_partitions[2]) * slice_shape[2];
210   subscript[1] =
211       ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1];
212   subscript[0] =
213       (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0];
214   return subscript;
215 }
216 
217 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,4> & slice_shape,const int index)218 Eigen::DSizes<Eigen::DenseIndex, 4> GetSliceIndices(
219     absl::Span<const int32> num_partitions,
220     const Eigen::DSizes<Eigen::DenseIndex, 4>& slice_shape, const int index) {
221   Eigen::DSizes<Eigen::DenseIndex, 4> subscript;
222   subscript[3] = (index % num_partitions[3]) * slice_shape[3];
223   subscript[2] =
224       ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2];
225   subscript[1] =
226       ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) *
227       slice_shape[1];
228   subscript[0] =
229       (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) *
230       slice_shape[0];
231   return subscript;
232 }
233 
234 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,5> & slice_shape,const int index)235 Eigen::DSizes<Eigen::DenseIndex, 5> GetSliceIndices(
236     absl::Span<const int32> num_partitions,
237     const Eigen::DSizes<Eigen::DenseIndex, 5>& slice_shape, const int index) {
238   Eigen::DSizes<Eigen::DenseIndex, 5> subscript;
239   subscript[4] = (index % num_partitions[4]) * slice_shape[4];
240   subscript[3] =
241       ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3];
242   subscript[2] =
243       ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) *
244       slice_shape[2];
245   subscript[1] =
246       ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) %
247        num_partitions[1]) *
248       slice_shape[1];
249   subscript[0] = (index / (num_partitions[4] * num_partitions[3] *
250                            num_partitions[2] * num_partitions[1])) *
251                  slice_shape[0];
252   return subscript;
253 }
254 
255 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,6> & slice_shape,const int index)256 Eigen::DSizes<Eigen::DenseIndex, 6> GetSliceIndices(
257     absl::Span<const int32> num_partitions,
258     const Eigen::DSizes<Eigen::DenseIndex, 6>& slice_shape, const int index) {
259   Eigen::DSizes<Eigen::DenseIndex, 6> subscript;
260   subscript[5] = (index % num_partitions[5]) * slice_shape[5];
261   subscript[4] =
262       ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4];
263   subscript[3] =
264       ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) *
265       slice_shape[3];
266   subscript[2] =
267       ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) %
268        num_partitions[2]) *
269       slice_shape[2];
270   subscript[1] = ((index / (num_partitions[5] * num_partitions[4] *
271                             num_partitions[3] * num_partitions[2])) %
272                   num_partitions[1]) *
273                  slice_shape[1];
274   subscript[0] =
275       (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] *
276                 num_partitions[2] * num_partitions[1])) *
277       slice_shape[0];
278   return subscript;
279 }
280 
281 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,7> & slice_shape,const int index)282 Eigen::DSizes<Eigen::DenseIndex, 7> GetSliceIndices(
283     absl::Span<const int32> num_partitions,
284     const Eigen::DSizes<Eigen::DenseIndex, 7>& slice_shape, const int index) {
285   Eigen::DSizes<Eigen::DenseIndex, 7> subscript;
286   subscript[6] = (index % num_partitions[6]) * slice_shape[6];
287   subscript[5] =
288       ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5];
289   subscript[4] =
290       ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) *
291       slice_shape[4];
292   subscript[3] =
293       ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) %
294        num_partitions[3]) *
295       slice_shape[3];
296   subscript[2] = ((index / (num_partitions[6] * num_partitions[5] *
297                             num_partitions[4] * num_partitions[3])) %
298                   num_partitions[2]) *
299                  slice_shape[2];
300   subscript[1] =
301       ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] *
302                  num_partitions[3] * num_partitions[2])) %
303        num_partitions[1]) *
304       slice_shape[1];
305   subscript[0] =
306       (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] *
307                 num_partitions[3] * num_partitions[2] * num_partitions[1])) *
308       slice_shape[0];
309   return subscript;
310 }
311 
312 template <>
GetSliceIndices(absl::Span<const int32> num_partitions,const Eigen::DSizes<Eigen::DenseIndex,8> & slice_shape,const int index)313 Eigen::DSizes<Eigen::DenseIndex, 8> GetSliceIndices(
314     absl::Span<const int32> num_partitions,
315     const Eigen::DSizes<Eigen::DenseIndex, 8>& slice_shape, const int index) {
316   Eigen::DSizes<Eigen::DenseIndex, 8> subscript;
317   subscript[7] = (index % num_partitions[7]) * slice_shape[7];
318   subscript[6] =
319       ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6];
320   subscript[5] =
321       ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) *
322       slice_shape[5];
323   subscript[4] =
324       ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) %
325        num_partitions[4]) *
326       slice_shape[4];
327   subscript[3] = ((index / (num_partitions[7] * num_partitions[6] *
328                             num_partitions[5] * num_partitions[4])) %
329                   num_partitions[3]) *
330                  slice_shape[3];
331   subscript[2] =
332       ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
333                  num_partitions[4] * num_partitions[3])) %
334        num_partitions[2]) *
335       slice_shape[2];
336   subscript[1] =
337       ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
338                  num_partitions[4] * num_partitions[3] * num_partitions[2])) %
339        num_partitions[1]) *
340       slice_shape[1];
341   subscript[0] =
342       (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] *
343                 num_partitions[4] * num_partitions[3] * num_partitions[2] *
344                 num_partitions[1])) *
345       slice_shape[0];
346   return subscript;
347 }
348 
349 constexpr absl::string_view kTensorName = "'input' tensor";
350 constexpr absl::string_view kResourceName = "'resource' variable tensor";
351 
352 template <int Rank>
353 Eigen::DSizes<Eigen::DenseIndex, Rank> ShapeAsEigenDSizes(
354     const TensorShape& shape) TF_ATTRIBUTE_NOINLINE;
355 template <int Rank>
ShapeAsEigenDSizes(const TensorShape & shape)356 Eigen::DSizes<Eigen::DenseIndex, Rank> ShapeAsEigenDSizes(
357     const TensorShape& shape) {
358   return shape.AsEigenDSizes<Rank>();
359 }
360 
361 bool ValidateShapesForSlice(
362     OpKernelContext* ctx, bool resource, const Tensor* input,
363     const std::vector<int32>& num_splits,
364     const std::vector<int32>& paddings) TF_ATTRIBUTE_NOINLINE;
365 
ValidateShapesForSlice(OpKernelContext * ctx,bool resource,const Tensor * input,const std::vector<int32> & num_splits,const std::vector<int32> & paddings)366 bool ValidateShapesForSlice(OpKernelContext* ctx, bool resource,
367                             const Tensor* input,
368                             const std::vector<int32>& num_splits,
369                             const std::vector<int32>& paddings) {
370   const auto& ishape = input->shape();
371 
372   Status s;
373 
374   absl::string_view input_name = resource ? kResourceName : kTensorName;
375   const int rank = ishape.dims();
376   const auto& input_shape = ishape.dim_sizes();
377   if (rank <= 0 || rank > 8) {
378     s = errors::InvalidArgument(
379         input_name, " must have rank in range (0, 8], but got ", rank, ".");
380   } else if (rank != num_splits.size()) {
381     s = errors::InvalidArgument(
382         input_name, " rank must be the same as 'num_splits' length ",
383         num_splits.size(), ", but got rank ", rank, ".");
384   } else {
385     for (int dim = 0; dim < rank; ++dim) {
386       const auto input_shape_dim = input_shape[dim];
387       const auto paddings_dim = paddings[dim];
388       const auto num_splits_dim = num_splits[dim];
389       if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) {
390         s = errors::InvalidArgument(
391             input_name, " shape dimension ", dim, " (", input_shape_dim,
392             ") with padding ", paddings_dim,
393             " must be evenly divisible by 'num_splits' ", num_splits_dim, ".");
394         break;
395       }
396     }
397   }
398   if (!s.ok()) {
399     ctx->CtxFailure(__FILE__, __LINE__, s);
400     return false;
401   }
402   return true;
403 }
404 
405 // Shared base class to save code space
406 class XlaSplitNDShared : public OpKernel {
407  public:
XlaSplitNDShared(OpKernelConstruction * ctx)408   explicit XlaSplitNDShared(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
409       : OpKernel(ctx),
410         num_slices_(1),
411         has_paddings_(false) {
412     GetAndValidateAttributes(/*split=*/true, ctx, num_splits_, num_slices_,
413                              paddings_, has_paddings_);
414   }
415 
416  protected:
417   template <int Rank>
418   class SliceAndMaybePadState {
419    public:
420     int num_complete_pad_dims_;
421     int num_partial_pad_dims_;
422     TensorShape non_padded_slice_shape_;
423     Eigen::array<Eigen::IndexPair<int64_t>, Rank> slice_paddings_;
424     Eigen::DSizes<Eigen::DenseIndex, Rank> slice_indices_;
425     Eigen::DSizes<Eigen::DenseIndex, Rank> output_slice_shape_dsizes_;
426     Eigen::DSizes<Eigen::DenseIndex, Rank> non_padded_slice_shape_dsizes_;
427 
SliceAndMaybePadState(absl::Span<const int32> num_splits,const absl::Span<const int64_t> input_shape,const TensorShape & output_slice_shape,int slice_index)428     SliceAndMaybePadState(absl::Span<const int32> num_splits,
429                           const absl::Span<const int64_t> input_shape,
430                           const TensorShape& output_slice_shape,
431                           int slice_index) TF_ATTRIBUTE_NOINLINE {
432       output_slice_shape_dsizes_ = ShapeAsEigenDSizes<Rank>(output_slice_shape);
433       num_complete_pad_dims_ = 0;
434       num_partial_pad_dims_ = 0;
435       slice_indices_ = GetSliceIndices<Rank>(
436           num_splits, output_slice_shape_dsizes_, slice_index);
437 
438       // Calculate paddings necessary for slice instead of padding input and
439       // slicing subsequently to reduce temporary memory allocation.
440       for (int dim = 0; dim < Rank; ++dim) {
441         const int64_t dim_size = input_shape[dim];
442         const int64_t out_dim = output_slice_shape_dsizes_[dim];
443         int64_t non_padded_dim = 0;
444         if (slice_indices_[dim] >= dim_size) {
445           // Complete padding.
446           slice_indices_[dim] = dim_size;
447           non_padded_dim = 0;
448           slice_paddings_[dim] = {0, out_dim};
449           num_complete_pad_dims_++;
450         } else if (slice_indices_[dim] + out_dim > dim_size) {
451           // Partial padding.
452           non_padded_dim = dim_size - slice_indices_[dim];
453           slice_paddings_[dim] = {0, out_dim - non_padded_dim};
454           num_partial_pad_dims_++;
455         } else {
456           non_padded_dim = out_dim;
457         }
458         non_padded_slice_shape_.AddDim(non_padded_dim);
459       }
460       non_padded_slice_shape_dsizes_ =
461           ShapeAsEigenDSizes<Rank>(non_padded_slice_shape_);
462     }
463   };
464 
GetDtypeHelper(OpKernelConstruction * ctx,const char * attr_name,DataType * dtype_ptr)465   static void GetDtypeHelper(OpKernelConstruction* ctx, const char* attr_name,
466                              DataType* dtype_ptr) TF_ATTRIBUTE_NOINLINE {
467     OP_REQUIRES_OK(ctx, ctx->GetAttr(attr_name, dtype_ptr));
468   }
469 
470   std::vector<int32> num_splits_;
471   int num_slices_;
472   std::vector<int32> paddings_;
473   bool has_paddings_;
474 };
475 
476 template <typename Device, typename T>
477 class XlaSplitNDBaseOp : public XlaSplitNDShared {
478  public:
XlaSplitNDBaseOp(OpKernelConstruction * ctx)479   explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx)
480       : XlaSplitNDShared(ctx) {}
481 
482  protected:
ComputeInternal(bool resource,OpKernelContext * ctx,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const Tensor * input)483   void ComputeInternal(
484       bool resource, OpKernelContext* ctx,
485       const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
486       const Tensor* input) {
487     const int rank = input->shape().dims();
488     const auto& input_shape = input->shape().dim_sizes();
489 
490     if (!ValidateShapesForSlice(ctx, resource, input, num_splits_, paddings_)) {
491       return;
492     }
493 
494     TensorShape output_slice_shape;
495     for (int i = 0; i < rank; ++i) {
496       output_slice_shape.AddDim((input_shape[i] + paddings_[i]) /
497                                 ((num_slices_ == 1) ? 1 : num_splits_[i]));
498     }
499     if (num_slices_ == 1 && !has_paddings_) {
500       // Handle simple case first
501       OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(*input));
502     } else {
503       const Device& device = ctx->eigen_device<Device>();
504       std::vector<Tensor*> output_slices(num_slices_);
505       for (int i = 0; i < num_slices_; i++) {
506         OP_REQUIRES_OK(ctx,
507                        ctx->allocate_output(
508                            /*index=*/i, output_slice_shape, &output_slices[i]));
509       }
510 
511       if (rank == 1) {
512         SliceAndMaybePad<1>(ctx, device, input, input_shape, output_slice_shape,
513                             output_slices);
514       } else if (rank == 2) {
515         SliceAndMaybePad<2>(ctx, device, input, input_shape, output_slice_shape,
516                             output_slices);
517       } else if (rank == 3) {
518         SliceAndMaybePad<3>(ctx, device, input, input_shape, output_slice_shape,
519                             output_slices);
520       } else if (rank == 4) {
521         SliceAndMaybePad<4>(ctx, device, input, input_shape, output_slice_shape,
522                             output_slices);
523       } else if (rank == 5) {
524         SliceAndMaybePad<5>(ctx, device, input, input_shape, output_slice_shape,
525                             output_slices);
526       } else if (rank == 6) {
527         SliceAndMaybePad<6>(ctx, device, input, input_shape, output_slice_shape,
528                             output_slices);
529       } else if (rank == 7) {
530         SliceAndMaybePad<7>(ctx, device, input, input_shape, output_slice_shape,
531                             output_slices);
532       } else if (rank == 8) {
533         SliceAndMaybePad<8>(ctx, device, input, input_shape, output_slice_shape,
534                             output_slices);
535       }
536       return;
537     }
538   }
539 
540  private:
SetToConstant(Tensor * output_slice,const Device & device)541   void SetToConstant(Tensor* output_slice,
542                      const Device& device) TF_ATTRIBUTE_NOINLINE {
543     auto output_flat = output_slice->flat<T>();
544     output_flat.device(device) = output_flat.constant(T());
545   }
546 
547   template <int Rank>
AssignFromInput(Tensor * output_slice,const Device & device,const Tensor * input,const Eigen::DSizes<Eigen::DenseIndex,Rank> & slice_indices,const Eigen::DSizes<Eigen::DenseIndex,Rank> & output_slice_shape_dsizes)548   void AssignFromInput(
549       Tensor* output_slice, const Device& device, const Tensor* input,
550       const Eigen::DSizes<Eigen::DenseIndex, Rank>& slice_indices,
551       const Eigen::DSizes<Eigen::DenseIndex, Rank>& output_slice_shape_dsizes)
552       TF_ATTRIBUTE_NOINLINE {
553     output_slice->tensor<T, Rank>().device(device) =
554         input->tensor<T, Rank>().slice(slice_indices,
555                                        output_slice_shape_dsizes);
556   }
557 
558   template <int Rank>
SliceAndMaybePad(OpKernelContext * ctx,const Device & device,const Tensor * input,const absl::Span<const int64_t> input_shape,const TensorShape & output_slice_shape,const std::vector<Tensor * > & output_slices)559   void SliceAndMaybePad(
560       OpKernelContext* ctx, const Device& device, const Tensor* input,
561       const absl::Span<const int64_t> input_shape,
562       const TensorShape& output_slice_shape,
563       const std::vector<Tensor*>& output_slices) TF_ATTRIBUTE_NOINLINE {
564     const auto& input_tensor = input->tensor<T, Rank>();
565     // Slice shape with optional padding.
566     for (int i = 0; i < num_slices_; ++i) {
567       Tensor* output_slice = output_slices[i];
568       SliceAndMaybePadState<Rank> r(num_splits_, input_shape,
569                                     output_slice_shape, i);
570       if (r.num_complete_pad_dims_ == Rank ||
571           (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) {
572         // Need to init padding
573         SetToConstant(output_slice, device);
574       }
575       if (r.num_complete_pad_dims_ == Rank) {
576         // Done
577       } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) {
578         output_slice->tensor<T, Rank>()
579             .slice(Eigen::DSizes<Eigen::DenseIndex, Rank>(),
580                    r.non_padded_slice_shape_dsizes_)
581             .device(device) = input_tensor.slice(
582             r.slice_indices_, r.non_padded_slice_shape_dsizes_);
583       } else {
584         AssignFromInput<Rank>(output_slice, device, input, r.slice_indices_,
585                               r.output_slice_shape_dsizes_);
586       }
587     }
588   }
589 };
590 
591 template <typename Device, typename T>
592 class XlaSplitNDOp : public XlaSplitNDBaseOp<Device, T> {
593  public:
XlaSplitNDOp(OpKernelConstruction * ctx)594   explicit XlaSplitNDOp(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
595       : XlaSplitNDBaseOp<Device, T>(ctx) {}
596 
Compute(OpKernelContext * ctx)597   void Compute(OpKernelContext* ctx) override {
598     const Tensor& input = ctx->input(0);
599 
600     auto assign_or_copy_value_fn = [&ctx](const Tensor& input) -> Status {
601       ctx->set_output(/*index=*/0, input);
602       return OkStatus();
603     };
604 
605     this->ComputeInternal(/*resource=*/false, ctx, assign_or_copy_value_fn,
606                           &input);
607   }
608 };
609 
610 template <typename Device, typename T>
611 class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp<Device, T> {
612  public:
ReadVariableXlaSplitNDOp(OpKernelConstruction * ctx)613   explicit ReadVariableXlaSplitNDOp(OpKernelConstruction* ctx)
614       TF_ATTRIBUTE_NOINLINE : XlaSplitNDBaseOp<Device, T>(ctx) {
615     XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_);
616   }
617 
Compute(OpKernelContext * ctx)618   void Compute(OpKernelContext* ctx) override {
619     core::RefCountPtr<Var> variable;
620     const ResourceHandle& handle = HandleFromInput(ctx, 0);
621     const Status status = LookupResource(ctx, handle, &variable);
622     OP_REQUIRES(
623         ctx, status.ok(),
624         errors::InvalidArgument("'resource' variable handle ('", handle.name(),
625                                 "') container ('", handle.container(),
626                                 "') cannot be found."));
627 
628     tf_shared_lock ml(*variable->mu());
629     const Tensor* input = variable->tensor();
630     OP_REQUIRES(
631         ctx, input->dtype() == dtype_,
632         CreateResourceInvalidDTypeError<false>(handle, input->dtype(), dtype_));
633 
634     auto assign_or_copy_value_fn = [&ctx,
635                                     &variable](const Tensor& input) -> Status {
636       if (variable->copy_on_read_mode.load()) {
637         Tensor* output;
638         TF_RETURN_IF_ERROR(
639             ctx->allocate_output(/*index=*/0, input.shape(), &output));
640         output->flat<T>().device(ctx->eigen_device<Device>()) = input.flat<T>();
641       } else {
642         ctx->set_output(/*index=*/0, input);
643       }
644       return OkStatus();
645     };
646 
647     this->ComputeInternal(/*resource=*/true, ctx, assign_or_copy_value_fn,
648                           input);
649   }
650 
651  private:
652   DataType dtype_;
653 };
654 
655 #define REGISTER_XLA_SPLIT_ND(type)                                    \
656   REGISTER_KERNEL_BUILDER(                                             \
657       Name("XlaSplitND").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
658       XlaSplitNDOp<Eigen::ThreadPoolDevice, type>)
659 
660 TF_CALL_POD_TYPES(REGISTER_XLA_SPLIT_ND);
661 TF_CALL_QUANTIZED_TYPES(REGISTER_XLA_SPLIT_ND);
662 #undef REGISTER_XLA_SPLIT_ND
663 
664 #define REGISTER_READ_VARIABLE_XLA_SPLIT_ND(type) \
665   REGISTER_KERNEL_BUILDER(                        \
666       Name("ReadVariableXlaSplitND")              \
667           .Device(DEVICE_CPU)                     \
668           .TypeConstraint<type>("T"),             \
669       ReadVariableXlaSplitNDOp<Eigen::ThreadPoolDevice, type>)
670 
671 TF_CALL_POD_TYPES(REGISTER_READ_VARIABLE_XLA_SPLIT_ND);
672 TF_CALL_QUANTIZED_TYPES(REGISTER_READ_VARIABLE_XLA_SPLIT_ND);
673 #undef REGISTER_READ_VARIABLE_XLA_SPLIT_ND
674 
675 // Shared base class to save code space
676 class XlaConcatNDShared : public OpKernel {
677  public:
XlaConcatNDShared(OpKernelConstruction * ctx)678   explicit XlaConcatNDShared(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
679       : OpKernel(ctx),
680         num_slices_(1),
681         has_paddings_(false) {
682     GetAndValidateAttributes(/*split=*/false, ctx, num_concats_, num_slices_,
683                              paddings_, has_paddings_);
684   }
685 
686  protected:
GetInputsAndOutputShape(OpKernelContext * ctx,OpInputList & inputs,TensorShape & output_shape)687   Status GetInputsAndOutputShape(OpKernelContext* ctx, OpInputList& inputs,
688                                  TensorShape& output_shape) {
689     TF_RETURN_IF_ERROR(ctx->input_list("inputs", &inputs));
690     DCHECK_EQ(inputs.size(), num_slices_);
691 
692     const TensorShape& slice_shape = inputs[0].shape();
693     if (slice_shape.dims() != num_concats_.size()) {
694       return errors::InvalidArgument(
695           "'inputs' rank must be the same as 'num_concats' length ",
696           num_concats_.size(), ", but got rank ", slice_shape.dims(), ".");
697     }
698     for (int i = 1; i < num_slices_; ++i) {
699       const TensorShape& slice_shape_i = inputs[i].shape();
700       if (slice_shape != slice_shape_i) {
701         return errors::InvalidArgument(
702             "'inputs' must all have the same expected shape ", slice_shape,
703             ", but got ", slice_shape_i, " at index ", i, ".");
704       }
705     }
706 
707     for (int i = 0, e = num_concats_.size(); i < e; ++i) {
708       const int max_dim_size = slice_shape.dim_size(i) * num_concats_[i];
709       if (paddings_[i] > max_dim_size) {
710         return errors::InvalidArgument(
711             "'paddings' must not exceed expected output shape dimension ",
712             max_dim_size, " at index ", i, ", but got ", paddings_[i], ".");
713       }
714       output_shape.AddDim(max_dim_size - paddings_[i]);
715     }
716 
717     return OkStatus();
718   }
ApplyAssignOrCopyShared(OpKernelContext * ctx,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const Tensor & input)719   void ApplyAssignOrCopyShared(
720       OpKernelContext* ctx,
721       const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
722       const Tensor& input) {
723     OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(input));
724   }
725 
726   template <int Rank>
727   class MaybeUnpadAndAssignState {
728    public:
729     int num_complete_pad_dims_;
730     int num_partial_pad_dims_;
731     TensorShape non_padded_slice_shape_;
732     Eigen::DSizes<Eigen::DenseIndex, Rank> slice_shape_dsizes_;
733     Eigen::array<Eigen::IndexPair<int64_t>, Rank> slice_paddings_;
734     Eigen::DSizes<Eigen::DenseIndex, Rank> slice_indices_;
735     Eigen::DSizes<Eigen::DenseIndex, Rank> output_slice_shape_dsizes_;
736     Eigen::DSizes<Eigen::DenseIndex, Rank> non_padded_slice_shape_dsizes_;
737 
MaybeUnpadAndAssignState(absl::Span<const int32> num_concats,const Tensor & input0,Tensor * output,int slice_index)738     MaybeUnpadAndAssignState(absl::Span<const int32> num_concats,
739                              const Tensor& input0, Tensor* output,
740                              int slice_index) TF_ATTRIBUTE_NOINLINE {
741       slice_shape_dsizes_ = input0.shape().AsEigenDSizes<Rank>();
742       slice_indices_ =
743           GetSliceIndices<Rank>(num_concats, slice_shape_dsizes_, slice_index);
744       num_complete_pad_dims_ = 0;
745       num_partial_pad_dims_ = 0;
746       // Calculate paddings necessary to strip from slice.
747       for (int dim = 0; dim < Rank; ++dim) {
748         const int64_t dim_size = output->shape().dim_size(dim);
749         int64_t non_padded_dim = 0;
750         if (slice_indices_[dim] >= dim_size) {
751           // Complete padding.
752           slice_indices_[dim] = dim_size;
753           non_padded_dim = 0;
754           num_complete_pad_dims_++;
755         } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) {
756           // Partial padding.
757           non_padded_dim = dim_size - slice_indices_[dim];
758           num_partial_pad_dims_++;
759         } else {
760           non_padded_dim = slice_shape_dsizes_[dim];
761         }
762         non_padded_slice_shape_.AddDim(non_padded_dim);
763       }
764       non_padded_slice_shape_dsizes_ =
765           non_padded_slice_shape_.AsEigenDSizes<Rank>();
766     }
767   };
768 
769   std::vector<int32> num_concats_;
770   int num_slices_;
771   std::vector<int32> paddings_;
772   bool has_paddings_;
773 };
774 
775 template <typename Device, typename T>
776 class XlaConcatNDBaseOp : public XlaConcatNDShared {
777  public:
XlaConcatNDBaseOp(OpKernelConstruction * ctx)778   explicit XlaConcatNDBaseOp(OpKernelConstruction* ctx) TF_ATTRIBUTE_NOINLINE
779       : XlaConcatNDShared(ctx) {}
780 
781  protected:
ComputeInternal(bool resource,OpKernelContext * ctx,const OpInputList & inputs,const std::function<Status (const Tensor &)> & assign_or_copy_value_fn,const std::function<StatusOr<Tensor * > ()> & get_output_fn)782   void ComputeInternal(
783       bool resource, OpKernelContext* ctx, const OpInputList& inputs,
784       const std::function<Status(const Tensor&)>& assign_or_copy_value_fn,
785       const std::function<StatusOr<Tensor*>()>& get_output_fn) {
786     const int rank = inputs[0].shape().dims();
787 
788     OP_REQUIRES(ctx, rank > 0 && rank <= 8,
789                 errors::InvalidArgument(
790                     "'inputs' tensors must have rank in range (0, 8], but got ",
791                     rank, "."));
792 
793     if (num_slices_ == 1 && !has_paddings_) {
794       // Simple case
795       ApplyAssignOrCopyShared(ctx, assign_or_copy_value_fn, inputs[0]);
796       return;
797     }
798 
799     const Device& device = ctx->eigen_device<Device>();
800     auto status_or_output = get_output_fn();
801     OP_REQUIRES_OK(ctx, status_or_output.status());
802     Tensor* output = std::move(status_or_output).value();
803 
804     if (rank == 1) {
805       MaybeUnpadAndAssign<1>(ctx, device, inputs, output);
806     } else if (rank == 2) {
807       MaybeUnpadAndAssign<2>(ctx, device, inputs, output);
808     } else if (rank == 3) {
809       MaybeUnpadAndAssign<3>(ctx, device, inputs, output);
810     } else if (rank == 4) {
811       MaybeUnpadAndAssign<4>(ctx, device, inputs, output);
812     } else if (rank == 5) {
813       MaybeUnpadAndAssign<5>(ctx, device, inputs, output);
814     } else if (rank == 6) {
815       MaybeUnpadAndAssign<6>(ctx, device, inputs, output);
816     } else if (rank == 7) {
817       MaybeUnpadAndAssign<7>(ctx, device, inputs, output);
818     } else if (rank == 8) {
819       MaybeUnpadAndAssign<8>(ctx, device, inputs, output);
820     }
821   }
822 
823  private:
824   template <int Rank>
MaybeUnpadAndAssign(OpKernelContext * ctx,const Device & device,const OpInputList & inputs,Tensor * output)825   void MaybeUnpadAndAssign(OpKernelContext* ctx, const Device& device,
826                            const OpInputList& inputs,
827                            Tensor* output) TF_ATTRIBUTE_NOINLINE {
828     for (int i = 0; i < num_slices_; ++i) {
829       MaybeUnpadAndAssignState<Rank> r(num_concats_, inputs[0], output, i);
830       if (r.num_complete_pad_dims_ == Rank) {
831         continue;
832       } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) {
833         output->tensor<T, Rank>()
834             .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_)
835             .device(device) = inputs[i].tensor<T, Rank>().slice(
836             Eigen::DSizes<Eigen::DenseIndex, Rank>(),
837             r.non_padded_slice_shape_dsizes_);
838       } else {
839         output->tensor<T, Rank>()
840             .slice(r.slice_indices_, r.slice_shape_dsizes_)
841             .device(device) = inputs[i].tensor<T, Rank>();
842       }
843     }
844   }
845 };
846 
847 template <typename Device, typename T>
848 class XlaConcatNDOp : public XlaConcatNDBaseOp<Device, T> {
849  public:
XlaConcatNDOp(OpKernelConstruction * ctx)850   explicit XlaConcatNDOp(OpKernelConstruction* ctx)
851       : XlaConcatNDBaseOp<Device, T>(ctx) {}
852 
Compute(OpKernelContext * ctx)853   void Compute(OpKernelContext* ctx) override {
854     OpInputList inputs;
855     TensorShape output_shape;
856     OP_REQUIRES_OK(ctx,
857                    this->GetInputsAndOutputShape(ctx, inputs, output_shape));
858 
859     auto assign_or_copy_value_fn = [&ctx](const Tensor& input) -> Status {
860       ctx->set_output(/*index=*/0, input);
861       return OkStatus();
862     };
863 
864     auto get_output_fn = [&ctx, &output_shape]() -> StatusOr<Tensor*> {
865       Tensor* output = nullptr;
866       TF_RETURN_IF_ERROR(
867           ctx->allocate_output(/*index=*/0, output_shape, &output));
868       return output;
869     };
870     this->ComputeInternal(/*resource=*/false, ctx, inputs,
871                           assign_or_copy_value_fn, get_output_fn);
872   }
873 };
874 
875 template <typename Device, typename T>
876 class AssignVariableXlaConcatNDOp : public XlaConcatNDBaseOp<Device, T> {
877  public:
AssignVariableXlaConcatNDOp(OpKernelConstruction * ctx)878   explicit AssignVariableXlaConcatNDOp(OpKernelConstruction* ctx)
879       TF_ATTRIBUTE_NOINLINE : XlaConcatNDBaseOp<Device, T>(ctx) {
880     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
881   }
882 
Compute(OpKernelContext * ctx)883   void Compute(OpKernelContext* ctx) override {
884     OpInputList inputs;
885     TensorShape output_shape;
886     OP_REQUIRES_OK(ctx,
887                    this->GetInputsAndOutputShape(ctx, inputs, output_shape));
888 
889     core::RefCountPtr<Var> variable;
890     const ResourceHandle& handle = HandleFromInput(ctx, 0);
891     if (handle.dtypes_and_shapes().size() == 1) {
892       const DtypeAndPartialTensorShape dtype_and_shape =
893           handle.dtypes_and_shapes().front();
894       OP_REQUIRES(ctx, dtype_and_shape.dtype == dtype_,
895                   CreateResourceInvalidDTypeError<true>(
896                       handle, dtype_and_shape.dtype, dtype_));
897       OP_REQUIRES(ctx, dtype_and_shape.shape.IsCompatibleWith(output_shape),
898                   errors::InvalidArgument(
899                       "'resource' variable handle ('", handle.name(),
900                       "') container ('", handle.container(),
901                       "') shape must be compatible with expected shape ",
902                       output_shape, ", but got ", dtype_and_shape.shape, "."));
903     }
904     OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(ctx, handle, &variable,
905                                                     [this](Var** ptr) {
906                                                       *ptr = new Var(dtype_);
907                                                       return OkStatus();
908                                                     }));
909     mutex_lock ml(*variable->mu());
910 
911     OP_REQUIRES(ctx, variable->tensor()->dtype() == dtype_,
912                 CreateResourceInvalidDTypeError<false>(
913                     handle, variable->tensor()->dtype(), dtype_));
914 
915     auto assign_or_copy_value_fn = [this, &ctx, &output_shape,
916                                     &variable](const Tensor& input) -> Status {
917       if (variable->copy_on_read_mode.load()) {
918         TF_RETURN_IF_ERROR(
919             ctx->allocate_temp(dtype_, output_shape, variable->tensor()));
920         variable->tensor()->flat<T>().device(ctx->eigen_device<Device>()) =
921             input.flat<T>();
922       } else {
923         *variable->tensor() = input;
924       }
925       return OkStatus();
926     };
927 
928     auto get_output_fn = [this, &ctx, &output_shape,
929                           &variable]() -> StatusOr<Tensor*> {
930       if (variable->copy_on_read_mode.load() ||
931           !variable->tensor()->RefCountIsOne() ||
932           !variable->tensor()->shape().IsSameSize(output_shape)) {
933         TF_RETURN_IF_ERROR(
934             ctx->allocate_temp(dtype_, output_shape, variable->tensor()));
935       }
936       return variable->tensor();
937     };
938 
939     this->ComputeInternal(/*resource=*/true, ctx, inputs,
940                           assign_or_copy_value_fn, get_output_fn);
941     variable->is_initialized = true;
942   }
943 
944   DataType dtype_;
945 };
946 
947 #define REGISTER_XLA_CONCAT_ND(type)                                    \
948   REGISTER_KERNEL_BUILDER(                                              \
949       Name("XlaConcatND").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
950       XlaConcatNDOp<Eigen::ThreadPoolDevice, type>)
951 
952 TF_CALL_POD_TYPES(REGISTER_XLA_CONCAT_ND);
953 TF_CALL_QUANTIZED_TYPES(REGISTER_XLA_CONCAT_ND);
954 #undef REGISTER_XLA_CONCAT_ND
955 
956 #define REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND(type) \
957   REGISTER_KERNEL_BUILDER(                           \
958       Name("AssignVariableXlaConcatND")              \
959           .Device(DEVICE_CPU)                        \
960           .TypeConstraint<type>("T"),                \
961       AssignVariableXlaConcatNDOp<Eigen::ThreadPoolDevice, type>)
962 
963 TF_CALL_POD_TYPES(REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND);
964 TF_CALL_QUANTIZED_TYPES(REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND);
965 #undef REGISTER_ASSIGN_VARIABLE_XLA_CONCAT_ND
966 
967 }  // anonymous namespace
968 }  // namespace tensorflow
969