xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gpu_prim_helpers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
16 #define TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #define EIGEN_USE_GPU
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/gpu_prim.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 #include "tensorflow/stream_executor/stream.h"
28 
29 namespace tensorflow {
30 
31 namespace detail {
32 
33 template <typename T>
RangeInitKernel(const T start,const T delta,const T size,T * out)34 __global__ void RangeInitKernel(const T start, const T delta, const T size,
35                                 T* out) {
36   GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
37 }
38 
39 // Initialize out with range start, start + delta, start + 2 * delta, ...
40 template <typename T>
RangeInit(const Eigen::GpuDevice & d,const T start,const T delta,const T size,T * out)41 Status RangeInit(const Eigen::GpuDevice& d, const T start, const T delta,
42                  const T size, T* out) {
43   if (size == 0) return OkStatus();
44   GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
45   return GpuLaunchKernel(RangeInitKernel<T>, config.block_count,
46                          config.thread_per_block, 0, d.stream(), start, delta,
47                          size, out);
48 }
49 
50 // Computes keys_out = sorted(keys_in), and indices_out = argsort(keys_in).
51 // If keys_out is not required, it can be set to nullptr.
52 // If indices_in is nullptr, the range of input indices [0, size) will be used.
53 template <bool Descending, typename Tkey, typename Tindex>
54 Status GpuRadixSortImpl(OpKernelContext* context, int size, const Tkey* keys_in,
55                         Tkey* keys_out,            // Optional
56                         const Tindex* indices_in,  // Optional
57                         Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
58   if (size == 0) return OkStatus();
59   if (num_bits == 0) {
60     // Workaround for CUB failing when begin_bit = end_bit = 0 (e.g., when all
61     // keys are 0, so no sorting is needed).
62     se::Stream* stream = context->op_device_context()->stream();
63     if (keys_out) {
64       // Copy keys_in to keys_out.
65       size_t num_bytes = size * sizeof(Tkey);
66       se::DeviceMemoryBase src(const_cast<Tkey*>(keys_in), num_bytes);
67       se::DeviceMemoryBase dst(keys_out, num_bytes);
68       if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
69         return errors::Internal("Failed to copy keys_in to keys_out");
70       }
71     }
72     if (indices_in) {
73       // Copy indices_in to indices_out.
74       size_t num_bytes = size * sizeof(Tindex);
75       se::DeviceMemoryBase src(const_cast<Tindex*>(indices_in), num_bytes);
76       se::DeviceMemoryBase dst(indices_out, num_bytes);
77       if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
78         return errors::Internal("Failed to copy indices_in to indices_out");
79       }
80     } else {
81       // Set output indices to range.
82       const Eigen::GpuDevice& device =
83           context->eigen_device<Eigen::GpuDevice>();
84       TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
85                                            Tindex(size), indices_out));
86     }
87     return OkStatus();
88   }
89   // Allocate temporary inputs/outputs if necessary.
90   Tensor tmp_indices_in;
91   if (!indices_in) {
92     TF_RETURN_IF_ERROR(context->allocate_temp(
93         DataTypeToEnum<Tindex>::value, TensorShape({size}), &tmp_indices_in));
94     Tindex* mutable_indices_in = tmp_indices_in.flat<Tindex>().data();
95     indices_in = mutable_indices_in;
96     const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
97     // Initialize indices_in to the input index range.
98     TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
99                                          Tindex(size), mutable_indices_in));
100   }
101   Tensor tmp_keys_out;
102   if (!keys_out) {
103     TF_RETURN_IF_ERROR(context->allocate_temp(
104         DataTypeToEnum<Tkey>::value, TensorShape({size}), &tmp_keys_out));
105     keys_out = tmp_keys_out.flat<Tkey>().data();
106   }
107   // Determine temporary device storage requirements.
108   Tensor temp_storage;
109   size_t temp_storage_bytes = 0;
110   const auto& cu_stream = GetGpuStream(context);
111   gpuError_t err;
112   if constexpr (Descending) {
113     err = gpuprim::DeviceRadixSort::SortPairsDescending(
114         nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out,
115         size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream);
116   } else {
117     err = gpuprim::DeviceRadixSort::SortPairs(
118         nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out,
119         size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream);
120   }
121   if (err != 0) {
122     return errors::Internal(
123         "Failed to launch gpuprim::DeviceRadixSort::SortPairs to calculate "
124         "temp_storage_bytes, status: ",
125         cudaGetErrorString(err));
126   }
127   // Allocate temporary storage.
128   TF_RETURN_IF_ERROR(context->allocate_temp(
129       DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
130       &temp_storage));
131   // Sort indices by keys.
132   if constexpr (Descending) {
133     err = gpuprim::DeviceRadixSort::SortPairsDescending(
134         temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in, keys_out,
135         indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits,
136         cu_stream);
137   } else {
138     err = gpuprim::DeviceRadixSort::SortPairs(
139         temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in, keys_out,
140         indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits,
141         cu_stream);
142   }
143   if (err != 0) {
144     return errors::Internal(
145         "Failed to launch gpuprim::DeviceRadixSort::SortPairs, "
146         "temp_storage_bytes: ",
147         temp_storage_bytes, "status: ", cudaGetErrorString(err));
148   }
149   return OkStatus();
150 }
151 
152 }  // namespace detail
153 
154 template <typename Tkey, typename Tindex>
155 Status GpuRadixSort(OpKernelContext* context, int size, const Tkey* keys_in,
156                     Tkey* keys_out,            // Optional
157                     const Tindex* indices_in,  // Optional
158                     Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
159   return detail::GpuRadixSortImpl</*Descending=*/false>(
160       context, size, keys_in, keys_out, indices_in, indices_out, num_bits);
161 }
162 
163 template <typename Tkey, typename Tindex>
164 Status GpuRadixSortDescending(OpKernelContext* context, int size,
165                               const Tkey* keys_in,
166                               Tkey* keys_out,            // Optional
167                               const Tindex* indices_in,  // Optional
168                               Tindex* indices_out,
169                               int num_bits = sizeof(Tkey) * 8) {
170   return detail::GpuRadixSortImpl</*Descending=*/true>(
171       context, size, keys_in, keys_out, indices_in, indices_out, num_bits);
172 }
173 
174 template <typename InputIteratorT, typename OutputIteratorT>
GpuInclusivePrefixSum(OpKernelContext * context,int size,InputIteratorT input,OutputIteratorT output)175 Status GpuInclusivePrefixSum(OpKernelContext* context, int size,
176                              InputIteratorT input, OutputIteratorT output) {
177   static_assert(
178       !std::is_same<typename std::remove_reference<decltype(*input)>::type,
179                     bool>::value,
180       "GpuInclusivePrefixSum does not work correct with booleans, please use "
181       "TransformInputIterator to explicitly cast to an integer.");
182   if (size == 0) return OkStatus();
183   const auto& cu_stream = GetGpuStream(context);
184   size_t temp_storage_bytes;
185   auto err = gpuprim::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes,
186                                                input, output, size, cu_stream);
187   if (err != 0) {
188     return errors::Internal(
189         "Failed to launch gpuprim::DeviceScan::InclusiveSum to calculate "
190         "temp_storage_bytes, status: ",
191         cudaGetErrorString(err));
192   }
193   Tensor temp_storage;
194   TF_RETURN_IF_ERROR(context->allocate_temp(
195       DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
196       &temp_storage));
197   err = gpuprim::DeviceScan::InclusiveSum(temp_storage.flat<int8>().data(),
198                                           temp_storage_bytes, input, output,
199                                           size, cu_stream);
200   if (err != 0) {
201     return errors::Internal(
202         "Failed to launch gpuprim::DeviceScan::InclusiveSum, "
203         "temp_storage_bytes: ",
204         temp_storage_bytes, ", status: ", cudaGetErrorString(err));
205   }
206   return OkStatus();
207 }
208 
209 // Note that this behaves deterministically for repeat calls on the same device.
210 template <typename InputIteratorT, typename OutputIteratorT,
211           typename OffsetIteratorT, typename ReduceOp, typename T>
GpuSegmentedReduce(OpKernelContext * context,int num_segments,ReduceOp reduce_op,const T & initial_value,InputIteratorT input,OffsetIteratorT segment_offsets,OutputIteratorT output)212 Status GpuSegmentedReduce(
213     OpKernelContext* context, int num_segments, ReduceOp reduce_op,
214     const T& initial_value,
215     InputIteratorT input,             // [any]
216     OffsetIteratorT segment_offsets,  // [num_segments + 1]
217     OutputIteratorT output) {         // [num_segments]
218   if (num_segments == 0) return OkStatus();
219   const auto& cu_stream = GetGpuStream(context);
220   size_t temp_storage_bytes;
221   auto err = gpuprim::DeviceSegmentedReduce::Reduce(
222       nullptr, temp_storage_bytes, input, output, num_segments, segment_offsets,
223       segment_offsets + 1, reduce_op, initial_value, cu_stream);
224   if (err != 0) {
225     return errors::Internal(
226         "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce to calculate "
227         "temp_storage_bytes, status: ",
228         cudaGetErrorString(err));
229   }
230   Tensor temp_storage;
231   TF_RETURN_IF_ERROR(context->allocate_temp(
232       DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
233       &temp_storage));
234   err = gpuprim::DeviceSegmentedReduce::Reduce(
235       temp_storage.flat<int8>().data(), temp_storage_bytes, input, output,
236       num_segments, segment_offsets, segment_offsets + 1, reduce_op,
237       initial_value, cu_stream);
238   if (err != 0) {
239     return errors::Internal(
240         "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce"
241         ", temp_storage_bytes: ",
242         temp_storage_bytes, ", status: ", cudaGetErrorString(err));
243   }
244   return OkStatus();
245 }
246 
247 template <typename InputIteratorT, typename FlagIteratorT,
248           typename OutputIteratorT, typename NumSelectedT = int>
249 Status GpuSelectFlagged(OpKernelContext* context, int size,
250                         InputIteratorT input, FlagIteratorT flags,
251                         OutputIteratorT output,
252                         NumSelectedT* out_num_selected = nullptr) {
253   const auto& cu_stream = GetGpuStream(context);
254   Tensor out_num_selected_t;
255   if (!out_num_selected) {
256     TF_RETURN_IF_ERROR(
257         context->allocate_temp(DataTypeToEnum<NumSelectedT>::value,
258                                TensorShape({}), &out_num_selected_t));
259     out_num_selected = out_num_selected_t.scalar<NumSelectedT>().data();
260   }
261   size_t temp_storage_bytes;
262   auto err =
263       gpuprim::DeviceSelect::Flagged(nullptr, temp_storage_bytes, input, flags,
264                                      output, out_num_selected, size, cu_stream);
265   if (err != 0) {
266     return errors::Internal(
267         "Failed to launch gpuprim::DeviceSelect::Flagged to calculate "
268         "temp_storage_bytes, status: ",
269         cudaGetErrorString(err));
270   }
271   Tensor temp_storage;
272   TF_RETURN_IF_ERROR(context->allocate_temp(
273       DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
274       &temp_storage));
275   err = gpuprim::DeviceSelect::Flagged(temp_storage.flat<int8>().data(),
276                                        temp_storage_bytes, input, flags, output,
277                                        out_num_selected, size, cu_stream);
278   if (err != 0) {
279     return errors::Internal(
280         "Failed to launch gpuprim::DeviceSelect::Flagged, temp_storage_bytes: ",
281         temp_storage_bytes, ", status: ", cudaGetErrorString(err));
282   }
283   return OkStatus();
284 }
285 
286 }  // namespace tensorflow
287 
288 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
289 
290 #endif  // TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_
291