xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/kernels/gather_functor_batched.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 
28 namespace tensorflow {
29 
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 template <typename ValueOrVec, typename Index, bool is_axis_zero,
33           bool is_batch_dims_zero>
GatherOpKernel(const ValueOrVec * __restrict__ params,const Index * __restrict__ indices,ValueOrVec * __restrict__ out,int64 outer_size,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)34 __global__ void GatherOpKernel(const ValueOrVec* __restrict__ params,
35                                const Index* __restrict__ indices,
36                                ValueOrVec* __restrict__ out, int64 outer_size,
37                                int64 gather_dim_size, int64 indices_size,
38                                int64 slice_size, int64 out_size) {
39   // params is a tensor of shape
40   // [batch_size, outer_size, gather_dim_size, slice_size].
41   GPU_1D_KERNEL_LOOP(i, out_size) {
42     Index batch_i = 0;  // The batch index into params to use for i.
43     Index outer_i = 0;  // The outer index into params to use for i.
44     Index indices_i = 0;  // The index into indices to use for i.
45     Index slice_i = 0;  // Index into the current slice in params to use for i.
46 
47     const Index slices_count = i / slice_size;
48     if (is_batch_dims_zero) {
49       if (is_axis_zero) {
50         indices_i = slices_count;
51       } else {
52         outer_i = slices_count / indices_size;
53         indices_i = slices_count - outer_i * indices_size;
54       }
55     } else {
56       const Index entries_count = slices_count / indices_size;
57       if (is_axis_zero) {
58         batch_i = entries_count;
59       } else {
60         batch_i = entries_count / outer_size;
61         outer_i = entries_count - batch_i * outer_size;
62       }
63       indices_i = slices_count - entries_count * indices_size;
64     }
65     slice_i = i - slices_count * slice_size;
66 
67     // Index into the gather axis to use for i.
68     Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
69 
70     // Check gather_i is in [0, gather_dim_size).
71     if (!FastBoundsCheck(gather_i, gather_dim_size)) {
72       // Set indices out of range to zero
73       // TODO(fpmc): Log an error for transfer back to host.
74       out[i] = ValueOrVec(0);
75     } else {
76       // Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
77       // i'th position in out.
78       Index params_i = (
79           (batch_i * outer_size + outer_i) * gather_dim_size + gather_i
80       ) * slice_size + slice_i;
81       out[i] = params[params_i];
82     }
83   }
84 }
85 
86 namespace detail {
87 
88 template <bool is_axis_zero, bool is_batch_dims_zero>
89 struct LaunchGatherKernelVectorized {
90   template <int vec_size>
91   struct Impl {
92     template <typename T, typename Index>
operatorLaunchGatherKernelVectorized::Impl93     Status operator()(const GPUDevice& d, const T* params, const Index* indices,
94                       T* out, int64 outer_size, int64 gather_dim_size,
95                       int64 indices_size, int64 slice_size, int64 out_size) {
96       DCHECK_EQ(slice_size % vec_size, 0);
97       DCHECK_EQ(out_size % vec_size, 0);
98       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(params) % vec_size, 0);
99       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(out) % vec_size, 0);
100       int64 out_size_vec = out_size / vec_size;
101       int64 slice_size_vec = slice_size / vec_size;
102       using Tvec = AlignedVector<T, vec_size>;
103       const Tvec* params_vec = reinterpret_cast<const Tvec*>(params);
104       Tvec* out_vec = reinterpret_cast<Tvec*>(out);
105 
106       GpuLaunchConfig config = GetGpuLaunchConfig(
107           out_size_vec, d,
108           &GatherOpKernel<Tvec, Index, is_axis_zero, is_batch_dims_zero>,
109           /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
110       return GpuLaunchKernel(
111           GatherOpKernel<Tvec, Index, is_axis_zero, is_batch_dims_zero>,
112           config.block_count, config.thread_per_block, 0, d.stream(),
113           params_vec, indices, out_vec, outer_size, gather_dim_size,
114           indices_size, slice_size_vec, out_size_vec);
115     }
116   };
117 };
118 
119 }  // namespace detail
120 
121 template <bool is_axis_zero, bool is_batch_dims_zero, typename T,
122           typename Index>
LaunchGatherKernel(const GPUDevice & d,const T * params,const Index * indices,T * out,int64 outer_size,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)123 Status LaunchGatherKernel(const GPUDevice& d, const T* params,
124                           const Index* indices, T* out, int64 outer_size,
125                           int64 gather_dim_size, int64 indices_size,
126                           int64 slice_size, int64 out_size) {
127   // Note that the GPU memory allocator always returns aligned buffers, so the
128   // alignment of data pointers is expected to be deterministic.
129   // There will be performance cliffs when slice_size is not aligned, but there
130   // is no easy way to handle the misalignment because each row will be aligned
131   // differently.
132   return DispatchToVectorized<
133       T, detail::LaunchGatherKernelVectorized<
134              is_axis_zero, is_batch_dims_zero>::template Impl>(
135       MinAlignmentOf(params, out, slice_size), d, params, indices, out,
136       outer_size, gather_dim_size, indices_size, slice_size, out_size);
137 }
138 
139 namespace functor {
140 template <typename T, typename Index>
141 struct GatherFunctorBatched<GPUDevice, T, Index> {
142   int64 operator()(OpKernelContext* ctx,
143                    typename TTypes<T, 4>::ConstTensor params,
144                    typename TTypes<Index>::ConstFlat indices,
145                    typename TTypes<T, 4>::Tensor out) {
146     const GPUDevice& d = ctx->eigen_gpu_device();
147     const int64 out_size = out.size();
148     if (out_size == 0) {
149       // We need a check here since the CPU version does useful error checking
150       // work if there are nonempty indices but empty slices, so the kernel is
151       // executed in that case.  In the GPU case we don't know how to do error
152       // checking, so we skip the loop entirely.
153       return -1;
154     }
155     const bool is_batch_dims_zero = params.dimension(0) == 1;
156     const bool is_axis_zero = params.dimension(1) == 1;
157     const int64 outer_size = params.dimension(1);
158     const int64 gather_dim_size = params.dimension(2);
159     const int64 indices_size = indices.size() / params.dimension(0);
160     const int64 slice_size = params.dimension(3);
161 
162     const auto function =
163         is_axis_zero
164             ? (is_batch_dims_zero ? LaunchGatherKernel<true, true, T, Index>
165                                   : LaunchGatherKernel<true, false, T, Index>)
166             : (is_batch_dims_zero ? LaunchGatherKernel<false, true, T, Index>
167                                   : LaunchGatherKernel<false, false, T, Index>);
168     TF_CHECK_OK(function(d, params.data(), indices.data(), out.data(),
169                          outer_size, gather_dim_size, indices_size, slice_size,
170                          out_size));
171     // TODO(fpmc): enable indices validation on GPU.
172     // Right now checking for indices out of bound in the kernel would
173     // require copying code between GPU/CPU, and thus slow.
174     return -1;
175   }
176 };
177 
178 }  // namespace functor
179 }  // namespace tensorflow
180 
181 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
182 
183 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
184