xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gather_functor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/framework/type_traits.h"
25 #include "tensorflow/core/framework/variant.h"
26 #include "tensorflow/core/platform/prefetch.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/work_sharder.h"
29 
30 namespace tensorflow {
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef Eigen::GpuDevice GPUDevice;
33 
34 namespace functor {
35 
36 // Helper method to copy using memcpy.
37 template <typename T, typename Index, typename SliceIndex,
38           SliceIndex static_slice_elems>
HandleCopies(OpKernelContext * ctx,typename TTypes<T,3>::ConstTensor params,typename TTypes<Index>::ConstFlat indices,SliceIndex slice_elems,typename TTypes<T,3>::Tensor out)39 SliceIndex HandleCopies(OpKernelContext* ctx,
40                         typename TTypes<T, 3>::ConstTensor params,
41                         typename TTypes<Index>::ConstFlat indices,
42                         SliceIndex slice_elems,
43                         typename TTypes<T, 3>::Tensor out) {
44   const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0));
45   const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
46   const Index limit = static_cast<Index>(params.dimension(1));
47   T* out_base = out.data();
48   const T* params_base = params.data();
49   if (static_slice_elems >= 0) {
50     // Give compiler static knowledge of the number of elements/bytes
51     slice_elems = static_slice_elems;
52   }
53   // Compute slice_bytes here so that static knowledge is available
54   const size_t slice_bytes = slice_elems * sizeof(T);
55   auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
56   mutex mu;
57   // Store the value of invalidate index for printing error information, it's a
58   // shared variable.
59   SliceIndex result = -1;
60   auto work = [&](int64_t start, int64_t end) {
61     SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size);
62     SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size);
63     SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size);
64     SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size);
65 
66     while ((batch_idx < batch_idx_end) ||
67            (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) {
68       SliceIndex i_next = indices_idx + 1;
69       SliceIndex b_next = batch_idx + 1;
70       const Index index = internal::SubtleMustCopy(indices(indices_idx));
71       if (!FastBoundsCheck(index, limit)) {
72         mutex_lock l(mu);
73         result = indices_idx;
74         return;
75       }
76       if ((batch_idx == batch_idx_end && i_next < indices_idx_end) ||
77           (i_next < indices_size)) {
78         port::prefetch<port::PREFETCH_HINT_T0>(
79             &params(batch_idx, indices(i_next), 0));
80         port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0));
81         b_next = batch_idx;
82       } else if (b_next <= batch_idx_end) {
83         port::prefetch<port::PREFETCH_HINT_T0>(&params(b_next, indices(0), 0));
84         port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
85         i_next = 0;
86       }
87       // Copy using memcpy if possible, otherwise an Eigen loop
88       // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
89       // ahead-of-time compilation binary size).
90       if (is_simple_type<T>::value) {
91         // Avoid auto-promotion to Index from SliceIndex by casting.
92         memcpy(
93             out_base + (batch_idx * indices_size + indices_idx) * slice_elems,
94             params_base + (batch_idx * static_cast<SliceIndex>(limit) +
95                            static_cast<SliceIndex>(index)) *
96                               slice_elems,
97             slice_bytes);
98       } else {
99         // For non-"simple" types (e.g. strings).
100         out.template chip<0>(batch_idx).template chip<0>(indices_idx) =
101             params.template chip<0>(batch_idx).template chip<0>(index);
102       }
103       indices_idx = i_next;
104       batch_idx = b_next;
105     }
106   };
107 
108   Shard(worker_threads->num_threads, worker_threads->workers,
109         batch_size * indices_size, slice_elems * sizeof(T), work);
110   return result;
111 }
112 
113 template <typename T, typename Index>
114 struct GatherFunctorCPU {
operatorGatherFunctorCPU115   int64_t operator()(OpKernelContext* ctx,
116                      typename TTypes<T, 3>::ConstTensor params,
117                      typename TTypes<Index>::ConstFlat indices,
118                      typename TTypes<T, 3>::Tensor out) {
119     const int64_t indices_size = indices.size();
120     const int64_t slice_size = out.dimension(2);
121     int64_t bad_i;
122 
123     const int64_t batch_size = params.dimension(0);
124 
125     bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
126                       params.size() > std::numeric_limits<int32>::max() ||
127                       indices_size > std::numeric_limits<int32>::max() ||
128                       batch_size * indices_size * slice_size >
129                           std::numeric_limits<int32>::max());
130 #define CALL(elems)                                                        \
131   do {                                                                     \
132     if (use_large) {                                                       \
133       bad_i = HandleCopies<T, Index, int64_t, elems>(ctx, params, indices, \
134                                                      slice_size, out);     \
135     } else {                                                               \
136       const int32 small_slice = static_cast<int32>(slice_size);            \
137       bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices,   \
138                                                    small_slice, out);      \
139     }                                                                      \
140   } while (0)
141 
142     if (slice_size == 10)
143       CALL(10);
144     else if (slice_size == 20)
145       CALL(20);
146     else
147       CALL(-1);
148 #undef CALL
149 
150     return bad_i;
151   }
152 };
153 
154 template <typename Device, typename T, typename Index>
155 struct GatherFunctor {
156   int64_t operator()(OpKernelContext* ctx,
157                      typename TTypes<T, 3>::ConstTensor params,
158                      typename TTypes<Index>::ConstFlat indices,
159                      typename TTypes<T, 3>::Tensor out);
160 };
161 
162 template <typename T, typename Index>
163 struct GatherFunctor<CPUDevice, T, Index> {
164   int64_t operator()(OpKernelContext* ctx,
165                      typename TTypes<T, 3>::ConstTensor params,
166                      typename TTypes<Index>::ConstFlat indices,
167                      typename TTypes<T, 3>::Tensor out) {
168     return GatherFunctorCPU<T, Index>()(ctx, params, indices, out);
169   }
170 };
171 
172 template <typename Index>
173 struct GatherFunctor<GPUDevice, Variant, Index> {
174   int64_t operator()(OpKernelContext* ctx,
175                      typename TTypes<Variant, 3>::ConstTensor params,
176                      typename TTypes<Index>::ConstFlat indices,
177                      typename TTypes<Variant, 3>::Tensor out) {
178     return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out);
179   }
180 };
181 
182 }  // namespace functor
183 }  // namespace tensorflow
184 
185 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
186