xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gather_functor_batched.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/framework/type_traits.h"
25 #include "tensorflow/core/framework/variant.h"
26 #include "tensorflow/core/platform/prefetch.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/work_sharder.h"
29 
30 namespace tensorflow {
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef Eigen::GpuDevice GPUDevice;
33 
34 namespace functor {
35 
36 // Helper method to copy using memcpy.
37 template <typename T, typename Index, typename SliceIndex,
38           SliceIndex static_slice_elems>
HandleCopiesBatched(OpKernelContext * ctx,typename TTypes<T,4>::ConstTensor params,typename TTypes<Index>::ConstFlat indices,SliceIndex slice_elems,typename TTypes<T,4>::Tensor out)39 SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
40                                typename TTypes<T, 4>::ConstTensor params,
41                                typename TTypes<Index>::ConstFlat indices,
42                                SliceIndex slice_elems,
43                                typename TTypes<T, 4>::Tensor out) {
44   const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
45   const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
46   const SliceIndex indices_size =
47       static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
48 
49   const Index limit = static_cast<Index>(params.dimension(2));
50   if (static_slice_elems >= 0) {
51     // Give compiler static knowledge of the number of elements/bytes
52     slice_elems = static_slice_elems;
53   }
54   // Compute slice_bytes here so that static knowledge is available
55   const size_t slice_bytes = slice_elems * sizeof(T);
56   auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
57   mutex mu;
58   // Store the value of invalidate index for printing error information, it's a
59   // shared variable.
60   SliceIndex result = -1;
61   auto work = [&](int64_t start, int64_t end) {
62     const int64_t r_start = start % (outer_size * indices_size);
63     SliceIndex batch_idx = static_cast<SliceIndex>(
64         start / (outer_size * indices_size));
65     SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
66     SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
67 
68     SliceIndex batch_offset = batch_idx * indices_size;
69     for (; start < end; ++start) {
70       SliceIndex i_next = indices_idx + 1;
71       SliceIndex o_next = outer_idx;
72       SliceIndex b_next = batch_idx;
73       SliceIndex b_offset_next = batch_offset;
74 
75       if (i_next >= indices_size) {
76         i_next = 0;
77         if (++o_next >= outer_size) {
78           o_next = 0;
79           ++b_next;
80           b_offset_next += indices_size;
81         }
82       }
83       if (start + 1 < end) {
84         port::prefetch<port::PREFETCH_HINT_T0>(
85             &params(b_next, o_next, indices(b_offset_next + i_next), 0));
86         port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
87       }
88       const Index index = internal::SubtleMustCopy(
89           indices(batch_offset + indices_idx));
90       if (!FastBoundsCheck(index, limit)) {
91         mutex_lock l(mu);
92         result = batch_offset + indices_idx;
93         return;
94       }
95 
96       // Copy using memcpy if possible, otherwise an Eigen loop
97       // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
98       // ahead-of-time compilation binary size).
99       if (is_simple_type<T>::value) {
100         // Avoid auto-promotion to Index from SliceIndex by casting.
101         memcpy(
102             &out(batch_idx, outer_idx, indices_idx, 0),
103             &params(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
104             slice_bytes);
105       } else {
106         // For non-"simple" types (e.g. strings).
107         out.template chip<0>(batch_idx)
108             .template chip<0>(outer_idx)
109             .template chip<0>(indices_idx) =
110             params.template chip<0>(batch_idx)
111                 .template chip<0>(outer_idx)
112                 .template chip<0>(static_cast<SliceIndex>(index));
113       }
114 
115       indices_idx = i_next;
116       outer_idx = o_next;
117       batch_idx = b_next;
118       batch_offset = b_offset_next;
119     }
120   };
121 
122   Shard(worker_threads->num_threads, worker_threads->workers,
123         batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
124   return result;
125 }
126 
127 template <typename T, typename Index>
128 struct GatherFunctorBatchedCPU {
operatorGatherFunctorBatchedCPU129   int64_t operator()(OpKernelContext* ctx,
130                      typename TTypes<T, 4>::ConstTensor params,
131                      typename TTypes<Index>::ConstFlat indices,
132                      typename TTypes<T, 4>::Tensor out) {
133     const int64_t indices_size = indices.size();  // Includes the batch_size.
134     const int64_t slice_size = out.dimension(3);
135     int64_t bad_i;
136 
137     const int64_t batch_size = params.dimension(0);
138     const int64_t outer_size = params.dimension(1);
139 
140     bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
141                       params.size() > std::numeric_limits<int32>::max() ||
142                       indices_size > std::numeric_limits<int32>::max() ||
143                       batch_size * outer_size * indices_size * slice_size >
144                           std::numeric_limits<int32>::max());
145 #define CALL(elems)                                             \
146   do {                                                          \
147     if (use_large) {                                            \
148       bad_i = HandleCopiesBatched<T, Index, int64_t, elems>(    \
149           ctx, params, indices, slice_size, out);               \
150     } else {                                                    \
151       const int32 small_slice = static_cast<int32>(slice_size); \
152       bad_i = HandleCopiesBatched<T, Index, int32, elems>(      \
153           ctx, params, indices, small_slice, out);              \
154     }                                                           \
155   } while (0)
156 
157     // TODO(rmlarsen): Investigate whether these specializations are still
158     // needed and, if yes, whether the slice sizes are appropriate.
159     if (slice_size == 10)
160       CALL(10);
161     else if (slice_size == 20)
162       CALL(20);
163     else
164       CALL(-1);
165 #undef CALL
166 
167     return bad_i;
168   }
169 };
170 
171 template <typename Device, typename T, typename Index>
172 struct GatherFunctorBatched {
173   int64_t operator()(OpKernelContext* ctx,
174                      typename TTypes<T, 4>::ConstTensor params,
175                      typename TTypes<Index>::ConstFlat indices,
176                      typename TTypes<T, 4>::Tensor out);
177 };
178 
179 template <typename T, typename Index>
180 struct GatherFunctorBatched<CPUDevice, T, Index> {
181   int64_t operator()(OpKernelContext* ctx,
182                      typename TTypes<T, 4>::ConstTensor params,
183                      typename TTypes<Index>::ConstFlat indices,
184                      typename TTypes<T, 4>::Tensor out) {
185     return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
186   }
187 };
188 
189 template <typename Index>
190 struct GatherFunctorBatched<GPUDevice, Variant, Index> {
191   int64_t operator()(OpKernelContext* ctx,
192                      typename TTypes<Variant, 4>::ConstTensor params,
193                      typename TTypes<Index>::ConstFlat indices,
194                      typename TTypes<Variant, 4>::Tensor out) {
195     return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
196   }
197 };
198 
199 }  // namespace functor
200 }  // namespace tensorflow
201 
202 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
203