1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #define EIGEN_USE_GPU
22
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/kernels/gather_functor_batched.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27
28 namespace tensorflow {
29
30 typedef Eigen::GpuDevice GPUDevice;
31
32 template <typename ValueOrVec, typename Index, bool is_axis_zero,
33 bool is_batch_dims_zero>
GatherOpKernel(const ValueOrVec * __restrict__ params,const Index * __restrict__ indices,ValueOrVec * __restrict__ out,int64 outer_size,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)34 __global__ void GatherOpKernel(const ValueOrVec* __restrict__ params,
35 const Index* __restrict__ indices,
36 ValueOrVec* __restrict__ out, int64 outer_size,
37 int64 gather_dim_size, int64 indices_size,
38 int64 slice_size, int64 out_size) {
39 // params is a tensor of shape
40 // [batch_size, outer_size, gather_dim_size, slice_size].
41 GPU_1D_KERNEL_LOOP(i, out_size) {
42 Index batch_i = 0; // The batch index into params to use for i.
43 Index outer_i = 0; // The outer index into params to use for i.
44 Index indices_i = 0; // The index into indices to use for i.
45 Index slice_i = 0; // Index into the current slice in params to use for i.
46
47 const Index slices_count = i / slice_size;
48 if (is_batch_dims_zero) {
49 if (is_axis_zero) {
50 indices_i = slices_count;
51 } else {
52 outer_i = slices_count / indices_size;
53 indices_i = slices_count - outer_i * indices_size;
54 }
55 } else {
56 const Index entries_count = slices_count / indices_size;
57 if (is_axis_zero) {
58 batch_i = entries_count;
59 } else {
60 batch_i = entries_count / outer_size;
61 outer_i = entries_count - batch_i * outer_size;
62 }
63 indices_i = slices_count - entries_count * indices_size;
64 }
65 slice_i = i - slices_count * slice_size;
66
67 // Index into the gather axis to use for i.
68 Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
69
70 // Check gather_i is in [0, gather_dim_size).
71 if (!FastBoundsCheck(gather_i, gather_dim_size)) {
72 // Set indices out of range to zero
73 // TODO(fpmc): Log an error for transfer back to host.
74 out[i] = ValueOrVec(0);
75 } else {
76 // Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
77 // i'th position in out.
78 Index params_i = (
79 (batch_i * outer_size + outer_i) * gather_dim_size + gather_i
80 ) * slice_size + slice_i;
81 out[i] = params[params_i];
82 }
83 }
84 }
85
86 namespace detail {
87
88 template <bool is_axis_zero, bool is_batch_dims_zero>
89 struct LaunchGatherKernelVectorized {
90 template <int vec_size>
91 struct Impl {
92 template <typename T, typename Index>
operatorLaunchGatherKernelVectorized::Impl93 Status operator()(const GPUDevice& d, const T* params, const Index* indices,
94 T* out, int64 outer_size, int64 gather_dim_size,
95 int64 indices_size, int64 slice_size, int64 out_size) {
96 DCHECK_EQ(slice_size % vec_size, 0);
97 DCHECK_EQ(out_size % vec_size, 0);
98 DCHECK_EQ(reinterpret_cast<std::uintptr_t>(params) % vec_size, 0);
99 DCHECK_EQ(reinterpret_cast<std::uintptr_t>(out) % vec_size, 0);
100 int64 out_size_vec = out_size / vec_size;
101 int64 slice_size_vec = slice_size / vec_size;
102 using Tvec = AlignedVector<T, vec_size>;
103 const Tvec* params_vec = reinterpret_cast<const Tvec*>(params);
104 Tvec* out_vec = reinterpret_cast<Tvec*>(out);
105
106 GpuLaunchConfig config = GetGpuLaunchConfig(
107 out_size_vec, d,
108 &GatherOpKernel<Tvec, Index, is_axis_zero, is_batch_dims_zero>,
109 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
110 return GpuLaunchKernel(
111 GatherOpKernel<Tvec, Index, is_axis_zero, is_batch_dims_zero>,
112 config.block_count, config.thread_per_block, 0, d.stream(),
113 params_vec, indices, out_vec, outer_size, gather_dim_size,
114 indices_size, slice_size_vec, out_size_vec);
115 }
116 };
117 };
118
119 } // namespace detail
120
121 template <bool is_axis_zero, bool is_batch_dims_zero, typename T,
122 typename Index>
LaunchGatherKernel(const GPUDevice & d,const T * params,const Index * indices,T * out,int64 outer_size,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)123 Status LaunchGatherKernel(const GPUDevice& d, const T* params,
124 const Index* indices, T* out, int64 outer_size,
125 int64 gather_dim_size, int64 indices_size,
126 int64 slice_size, int64 out_size) {
127 // Note that the GPU memory allocator always returns aligned buffers, so the
128 // alignment of data pointers is expected to be deterministic.
129 // There will be performance cliffs when slice_size is not aligned, but there
130 // is no easy way to handle the misalignment because each row will be aligned
131 // differently.
132 return DispatchToVectorized<
133 T, detail::LaunchGatherKernelVectorized<
134 is_axis_zero, is_batch_dims_zero>::template Impl>(
135 MinAlignmentOf(params, out, slice_size), d, params, indices, out,
136 outer_size, gather_dim_size, indices_size, slice_size, out_size);
137 }
138
139 namespace functor {
140 template <typename T, typename Index>
141 struct GatherFunctorBatched<GPUDevice, T, Index> {
142 int64 operator()(OpKernelContext* ctx,
143 typename TTypes<T, 4>::ConstTensor params,
144 typename TTypes<Index>::ConstFlat indices,
145 typename TTypes<T, 4>::Tensor out) {
146 const GPUDevice& d = ctx->eigen_gpu_device();
147 const int64 out_size = out.size();
148 if (out_size == 0) {
149 // We need a check here since the CPU version does useful error checking
150 // work if there are nonempty indices but empty slices, so the kernel is
151 // executed in that case. In the GPU case we don't know how to do error
152 // checking, so we skip the loop entirely.
153 return -1;
154 }
155 const bool is_batch_dims_zero = params.dimension(0) == 1;
156 const bool is_axis_zero = params.dimension(1) == 1;
157 const int64 outer_size = params.dimension(1);
158 const int64 gather_dim_size = params.dimension(2);
159 const int64 indices_size = indices.size() / params.dimension(0);
160 const int64 slice_size = params.dimension(3);
161
162 const auto function =
163 is_axis_zero
164 ? (is_batch_dims_zero ? LaunchGatherKernel<true, true, T, Index>
165 : LaunchGatherKernel<true, false, T, Index>)
166 : (is_batch_dims_zero ? LaunchGatherKernel<false, true, T, Index>
167 : LaunchGatherKernel<false, false, T, Index>);
168 TF_CHECK_OK(function(d, params.data(), indices.data(), out.data(),
169 outer_size, gather_dim_size, indices_size, slice_size,
170 out_size));
171 // TODO(fpmc): enable indices validation on GPU.
172 // Right now checking for indices out of bound in the kernel would
173 // require copying code between GPU/CPU, and thus slow.
174 return -1;
175 }
176 };
177
178 } // namespace functor
179 } // namespace tensorflow
180
181 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
182
183 #endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
184