1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
16 #define TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19
20 #define EIGEN_USE_GPU
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/gpu_prim.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 #include "tensorflow/stream_executor/stream.h"
28
29 namespace tensorflow {
30
31 namespace detail {
32
33 template <typename T>
RangeInitKernel(const T start,const T delta,const T size,T * out)34 __global__ void RangeInitKernel(const T start, const T delta, const T size,
35 T* out) {
36 GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
37 }
38
39 // Initialize out with range start, start + delta, start + 2 * delta, ...
40 template <typename T>
RangeInit(const Eigen::GpuDevice & d,const T start,const T delta,const T size,T * out)41 Status RangeInit(const Eigen::GpuDevice& d, const T start, const T delta,
42 const T size, T* out) {
43 if (size == 0) return OkStatus();
44 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
45 return GpuLaunchKernel(RangeInitKernel<T>, config.block_count,
46 config.thread_per_block, 0, d.stream(), start, delta,
47 size, out);
48 }
49
50 // Computes keys_out = sorted(keys_in), and indices_out = argsort(keys_in).
51 // If keys_out is not required, it can be set to nullptr.
52 // If indices_in is nullptr, the range of input indices [0, size) will be used.
53 template <bool Descending, typename Tkey, typename Tindex>
54 Status GpuRadixSortImpl(OpKernelContext* context, int size, const Tkey* keys_in,
55 Tkey* keys_out, // Optional
56 const Tindex* indices_in, // Optional
57 Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
58 if (size == 0) return OkStatus();
59 if (num_bits == 0) {
60 // Workaround for CUB failing when begin_bit = end_bit = 0 (e.g., when all
61 // keys are 0, so no sorting is needed).
62 se::Stream* stream = context->op_device_context()->stream();
63 if (keys_out) {
64 // Copy keys_in to keys_out.
65 size_t num_bytes = size * sizeof(Tkey);
66 se::DeviceMemoryBase src(const_cast<Tkey*>(keys_in), num_bytes);
67 se::DeviceMemoryBase dst(keys_out, num_bytes);
68 if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
69 return errors::Internal("Failed to copy keys_in to keys_out");
70 }
71 }
72 if (indices_in) {
73 // Copy indices_in to indices_out.
74 size_t num_bytes = size * sizeof(Tindex);
75 se::DeviceMemoryBase src(const_cast<Tindex*>(indices_in), num_bytes);
76 se::DeviceMemoryBase dst(indices_out, num_bytes);
77 if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
78 return errors::Internal("Failed to copy indices_in to indices_out");
79 }
80 } else {
81 // Set output indices to range.
82 const Eigen::GpuDevice& device =
83 context->eigen_device<Eigen::GpuDevice>();
84 TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
85 Tindex(size), indices_out));
86 }
87 return OkStatus();
88 }
89 // Allocate temporary inputs/outputs if necessary.
90 Tensor tmp_indices_in;
91 if (!indices_in) {
92 TF_RETURN_IF_ERROR(context->allocate_temp(
93 DataTypeToEnum<Tindex>::value, TensorShape({size}), &tmp_indices_in));
94 Tindex* mutable_indices_in = tmp_indices_in.flat<Tindex>().data();
95 indices_in = mutable_indices_in;
96 const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
97 // Initialize indices_in to the input index range.
98 TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
99 Tindex(size), mutable_indices_in));
100 }
101 Tensor tmp_keys_out;
102 if (!keys_out) {
103 TF_RETURN_IF_ERROR(context->allocate_temp(
104 DataTypeToEnum<Tkey>::value, TensorShape({size}), &tmp_keys_out));
105 keys_out = tmp_keys_out.flat<Tkey>().data();
106 }
107 // Determine temporary device storage requirements.
108 Tensor temp_storage;
109 size_t temp_storage_bytes = 0;
110 const auto& cu_stream = GetGpuStream(context);
111 gpuError_t err;
112 if constexpr (Descending) {
113 err = gpuprim::DeviceRadixSort::SortPairsDescending(
114 nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out,
115 size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream);
116 } else {
117 err = gpuprim::DeviceRadixSort::SortPairs(
118 nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out,
119 size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream);
120 }
121 if (err != 0) {
122 return errors::Internal(
123 "Failed to launch gpuprim::DeviceRadixSort::SortPairs to calculate "
124 "temp_storage_bytes, status: ",
125 cudaGetErrorString(err));
126 }
127 // Allocate temporary storage.
128 TF_RETURN_IF_ERROR(context->allocate_temp(
129 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
130 &temp_storage));
131 // Sort indices by keys.
132 if constexpr (Descending) {
133 err = gpuprim::DeviceRadixSort::SortPairsDescending(
134 temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in, keys_out,
135 indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits,
136 cu_stream);
137 } else {
138 err = gpuprim::DeviceRadixSort::SortPairs(
139 temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in, keys_out,
140 indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits,
141 cu_stream);
142 }
143 if (err != 0) {
144 return errors::Internal(
145 "Failed to launch gpuprim::DeviceRadixSort::SortPairs, "
146 "temp_storage_bytes: ",
147 temp_storage_bytes, "status: ", cudaGetErrorString(err));
148 }
149 return OkStatus();
150 }
151
152 } // namespace detail
153
154 template <typename Tkey, typename Tindex>
155 Status GpuRadixSort(OpKernelContext* context, int size, const Tkey* keys_in,
156 Tkey* keys_out, // Optional
157 const Tindex* indices_in, // Optional
158 Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
159 return detail::GpuRadixSortImpl</*Descending=*/false>(
160 context, size, keys_in, keys_out, indices_in, indices_out, num_bits);
161 }
162
163 template <typename Tkey, typename Tindex>
164 Status GpuRadixSortDescending(OpKernelContext* context, int size,
165 const Tkey* keys_in,
166 Tkey* keys_out, // Optional
167 const Tindex* indices_in, // Optional
168 Tindex* indices_out,
169 int num_bits = sizeof(Tkey) * 8) {
170 return detail::GpuRadixSortImpl</*Descending=*/true>(
171 context, size, keys_in, keys_out, indices_in, indices_out, num_bits);
172 }
173
174 template <typename InputIteratorT, typename OutputIteratorT>
GpuInclusivePrefixSum(OpKernelContext * context,int size,InputIteratorT input,OutputIteratorT output)175 Status GpuInclusivePrefixSum(OpKernelContext* context, int size,
176 InputIteratorT input, OutputIteratorT output) {
177 static_assert(
178 !std::is_same<typename std::remove_reference<decltype(*input)>::type,
179 bool>::value,
180 "GpuInclusivePrefixSum does not work correct with booleans, please use "
181 "TransformInputIterator to explicitly cast to an integer.");
182 if (size == 0) return OkStatus();
183 const auto& cu_stream = GetGpuStream(context);
184 size_t temp_storage_bytes;
185 auto err = gpuprim::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes,
186 input, output, size, cu_stream);
187 if (err != 0) {
188 return errors::Internal(
189 "Failed to launch gpuprim::DeviceScan::InclusiveSum to calculate "
190 "temp_storage_bytes, status: ",
191 cudaGetErrorString(err));
192 }
193 Tensor temp_storage;
194 TF_RETURN_IF_ERROR(context->allocate_temp(
195 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
196 &temp_storage));
197 err = gpuprim::DeviceScan::InclusiveSum(temp_storage.flat<int8>().data(),
198 temp_storage_bytes, input, output,
199 size, cu_stream);
200 if (err != 0) {
201 return errors::Internal(
202 "Failed to launch gpuprim::DeviceScan::InclusiveSum, "
203 "temp_storage_bytes: ",
204 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
205 }
206 return OkStatus();
207 }
208
209 // Note that this behaves deterministically for repeat calls on the same device.
210 template <typename InputIteratorT, typename OutputIteratorT,
211 typename OffsetIteratorT, typename ReduceOp, typename T>
GpuSegmentedReduce(OpKernelContext * context,int num_segments,ReduceOp reduce_op,const T & initial_value,InputIteratorT input,OffsetIteratorT segment_offsets,OutputIteratorT output)212 Status GpuSegmentedReduce(
213 OpKernelContext* context, int num_segments, ReduceOp reduce_op,
214 const T& initial_value,
215 InputIteratorT input, // [any]
216 OffsetIteratorT segment_offsets, // [num_segments + 1]
217 OutputIteratorT output) { // [num_segments]
218 if (num_segments == 0) return OkStatus();
219 const auto& cu_stream = GetGpuStream(context);
220 size_t temp_storage_bytes;
221 auto err = gpuprim::DeviceSegmentedReduce::Reduce(
222 nullptr, temp_storage_bytes, input, output, num_segments, segment_offsets,
223 segment_offsets + 1, reduce_op, initial_value, cu_stream);
224 if (err != 0) {
225 return errors::Internal(
226 "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce to calculate "
227 "temp_storage_bytes, status: ",
228 cudaGetErrorString(err));
229 }
230 Tensor temp_storage;
231 TF_RETURN_IF_ERROR(context->allocate_temp(
232 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
233 &temp_storage));
234 err = gpuprim::DeviceSegmentedReduce::Reduce(
235 temp_storage.flat<int8>().data(), temp_storage_bytes, input, output,
236 num_segments, segment_offsets, segment_offsets + 1, reduce_op,
237 initial_value, cu_stream);
238 if (err != 0) {
239 return errors::Internal(
240 "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce"
241 ", temp_storage_bytes: ",
242 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
243 }
244 return OkStatus();
245 }
246
247 template <typename InputIteratorT, typename FlagIteratorT,
248 typename OutputIteratorT, typename NumSelectedT = int>
249 Status GpuSelectFlagged(OpKernelContext* context, int size,
250 InputIteratorT input, FlagIteratorT flags,
251 OutputIteratorT output,
252 NumSelectedT* out_num_selected = nullptr) {
253 const auto& cu_stream = GetGpuStream(context);
254 Tensor out_num_selected_t;
255 if (!out_num_selected) {
256 TF_RETURN_IF_ERROR(
257 context->allocate_temp(DataTypeToEnum<NumSelectedT>::value,
258 TensorShape({}), &out_num_selected_t));
259 out_num_selected = out_num_selected_t.scalar<NumSelectedT>().data();
260 }
261 size_t temp_storage_bytes;
262 auto err =
263 gpuprim::DeviceSelect::Flagged(nullptr, temp_storage_bytes, input, flags,
264 output, out_num_selected, size, cu_stream);
265 if (err != 0) {
266 return errors::Internal(
267 "Failed to launch gpuprim::DeviceSelect::Flagged to calculate "
268 "temp_storage_bytes, status: ",
269 cudaGetErrorString(err));
270 }
271 Tensor temp_storage;
272 TF_RETURN_IF_ERROR(context->allocate_temp(
273 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
274 &temp_storage));
275 err = gpuprim::DeviceSelect::Flagged(temp_storage.flat<int8>().data(),
276 temp_storage_bytes, input, flags, output,
277 out_num_selected, size, cu_stream);
278 if (err != 0) {
279 return errors::Internal(
280 "Failed to launch gpuprim::DeviceSelect::Flagged, temp_storage_bytes: ",
281 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
282 }
283 return OkStatus();
284 }
285
286 } // namespace tensorflow
287
288 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
289
290 #endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
291