xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/gpu_kernel_helper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #include <type_traits>
22 
23 #if GOOGLE_CUDA
24 #include "third_party/gpus/cuda/include/cuda_fp16.h"
25 #endif
26 #include "tensorflow/core/util/gpu_cuda_alias.h"
27 #include "tensorflow/core/util/gpu_device_functions.h"
28 #include "tensorflow/core/util/gpu_launch_config.h"
29 
30 #if GOOGLE_CUDA
31 #define TF_RED_WARPSIZE 32
32 #elif TENSORFLOW_USE_ROCM
33 #define TF_RED_WARPSIZE 64
34 #endif
35 
36 // Deprecated, use 'for(int i : GpuGridRangeX(n))' instead.
37 #define GPU_1D_KERNEL_LOOP(i, n) \
38   for (int i : ::tensorflow::GpuGridRangeX<int>(n))
39 #define CUDA_1D_KERNEL_LOOP(i, n) \
40   for (int i : ::tensorflow::GpuGridRangeX<int>(n))
41 
42 // Deprecated, use 'for(int i : GpuGridRange?(n))' instead.
43 #define GPU_AXIS_KERNEL_LOOP(i, n, axis) \
44   for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
45 #define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
46   for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
47 
48 #if GOOGLE_CUDA
49 #define gpuSuccess cudaSuccess
50 using gpuStream_t = cudaStream_t;
51 using gpuError_t = cudaError_t;
52 #elif TENSORFLOW_USE_ROCM
53 #define gpuSuccess hipSuccess
54 using gpuStream_t = hipStream_t;
55 using gpuError_t = hipError_t;
56 #endif
57 
58 // macro wrapper to declare dynamic shared memory
59 #if GOOGLE_CUDA
60 
61 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
62   extern __shared__ __align__(ALIGN) TYPE NAME[]
63 
64 #elif TENSORFLOW_USE_ROCM
65 
66 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
67   HIP_DYNAMIC_SHARED(TYPE, NAME)
68 
69 #endif
70 
71 namespace tensorflow {
72 
73 #if GOOGLE_CUDA
74 // cudaGetErrorString is available to both host and device
GpuGetErrorString(cudaError_t error)75 __host__ __device__ inline const char* GpuGetErrorString(cudaError_t error) {
76   return cudaGetErrorString(error);
77 }
78 #elif TENSORFLOW_USE_ROCM
79 // hipGetErrorString is available on host side only
80 inline const char* GpuGetErrorString(hipError_t error) {
81   return hipGetErrorString(error);
82 }
83 #endif
84 
85 // Returns a raw reference to the current cuda stream. Required by a
86 // number of kernel calls (for which StreamInterface* does not work),
87 // i.e. CUB and certain cublas primitives.
GetGpuStream(OpKernelContext * context)88 inline const gpuStream_t& GetGpuStream(OpKernelContext* context) {
89   const gpuStream_t* ptr = CHECK_NOTNULL(
90       reinterpret_cast<const gpuStream_t*>(context->op_device_context()
91                                                ->stream()
92                                                ->implementation()
93                                                ->GpuStreamMemberHack()));
94   return *ptr;
95 }
96 
97 // Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or
98 // hipLaunchKernel in ROCm environment with the given arguments.
99 //
100 // The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
101 template <typename... Ts, typename... Args>
GpuLaunchKernel(void (* function)(Ts...),dim3 grid_dim,dim3 block_dim,size_t shared_memory_size_bytes,gpuStream_t stream,Args...arguments)102 Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
103                        size_t shared_memory_size_bytes, gpuStream_t stream,
104                        Args... arguments) {
105   static_assert(detail::NoneIsReference<Ts...>(),
106                 "Kernels with reference arguments have undefined behaviour.");
107 #if GOOGLE_CUDA
108   auto func_ptr = absl::bit_cast<const void*>(function);
109   // Cast arguments and forward them as an array of pointers.
110   auto args_tuple = std::tuple<Ts...>(arguments...);
111   auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
112   auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
113                                  shared_memory_size_bytes, stream);
114   if (result != cudaSuccess) {
115     return errors::Internal(cudaGetErrorString(result));
116   }
117 #elif TENSORFLOW_USE_ROCM
118   hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes,
119                      stream, std::forward<Args>(arguments)...);
120 #endif
121   return OkStatus();
122 }
123 
124 // Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA
125 // builds
126 template <typename... Args>
127 auto CudaLaunchKernel(Args&&... args)
128     -> decltype(GpuLaunchKernel(std::forward<Args>(args)...)) {
129   return GpuLaunchKernel(std::forward<Args>(args)...);
130 }
131 
GpuLdg(const tensorflow::bfloat16 * address)132 __host__ __device__ inline tensorflow::bfloat16 GpuLdg(
133     const tensorflow::bfloat16* address) {
134   return Eigen::numext::bit_cast<tensorflow::bfloat16>(
135       GpuLdg(reinterpret_cast<const uint16_t*>(address)));
136 }
137 // Already aliased in gpu_device_functions.h
138 
139 template <typename T>
ldg(const T * ptr)140 __host__ __device__ inline T ldg(const T* ptr) {
141   return GpuLdg(ptr);
142 }
143 
144 template <typename T>
tf_min(const T & x,const T & y)145 __host__ __device__ inline const T& tf_min(const T& x, const T& y) {
146   return x < y ? x : y;
147 }
148 
149 template <typename T>
tf_max(const T & x,const T & y)150 __host__ __device__ inline const T& tf_max(const T& x, const T& y) {
151   return x < y ? y : x;
152 }
153 
154 // Overloads of the above functions for float and double.
tf_min(float x,float y)155 __host__ __device__ inline float tf_min(float x, float y) {
156   return fminf(x, y);
157 }
tf_min(double x,double y)158 __host__ __device__ inline double tf_min(double x, double y) {
159   return fmin(x, y);
160 }
tf_max(float x,float y)161 __host__ __device__ inline float tf_max(float x, float y) {
162   return fmaxf(x, y);
163 }
tf_max(double x,double y)164 __host__ __device__ inline double tf_max(double x, double y) {
165   return fmax(x, y);
166 }
167 
168 // ROCM TODO re-enable them after adding fp16 support logic
169 #if GOOGLE_CUDA
170 __device__ inline Eigen::half GpuShuffleSync(unsigned mask, Eigen::half value,
171                                              int src_lane,
172                                              int width = warpSize) {
173   return Eigen::half(
174       GpuShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
175 }
176 // Aliased in gpu_device_functions.h
177 
178 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleUpSync(
179     unsigned mask, Eigen::half value, int delta, int width = warpSize) {
180   return Eigen::half(
181       GpuShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
182 }
183 // Aliased in gpu_device_functions.h
184 
185 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleDownSync(
186     unsigned mask, Eigen::half value, int delta, int width = warpSize) {
187   return Eigen::half(
188       GpuShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
189 }
190 // Aliased in gpu_device_functions.h
191 
192 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleXorSync(
193     unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
194   return Eigen::half(
195       GpuShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
196 }
197 // Aliased in gpu_device_functions.h
198 #endif
199 
200 #ifdef __CUDA_ARCH__
201 #define UNROLL_ON_DEVICE _Pragma("unroll")
202 #else
203 #define UNROLL_ON_DEVICE
204 #endif
205 
206 // Represents an aligned array of N elements of T. Data pointers can be
207 // reinterpreted as this type to generate vectorized loads/stores in a kernel.
208 template <typename T, int N>
209 class alignas(alignof(T) * N) AlignedVector {
210  public:
211   typedef T value_type;
212   static constexpr const int kSize = N;
213 
214   AlignedVector() = default;
215 
216   // Uniform initialization.
AlignedVector(value_type uniform)217   __host__ __device__ explicit AlignedVector(value_type uniform) {
218     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; }
219   }
220   // Uniform initialization with explicit conversion.
221   // Note: This is required for T=Eigen::half because it only supports explicit
222   // conversions from other types and its template constructor is too relaxed
223   // to be able to use std::is_constructible.
224   template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
225                                                 int>::type = 0>
AlignedVector(U uniform_u)226   __host__ __device__ explicit AlignedVector(U uniform_u) {
227     value_type uniform(uniform_u);
228     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; }
229   }
230   // Implicit conversion.
231   template <typename U, typename std::enable_if<
232                             std::is_convertible<U, T>::value, int>::type = 0>
AlignedVector(const AlignedVector<U,N> & other)233   __host__ __device__ AlignedVector(const AlignedVector<U, N>& other) {
234     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = other[i]; }
235   }
236   // Explicit conversion.
237   template <typename U,
238             typename std::enable_if<!std::is_convertible<U, T>::value &&
239                                         std::is_constructible<T, U>::value,
240                                     int>::type = 0>
AlignedVector(const AlignedVector<U,N> & other)241   __host__ __device__ explicit AlignedVector(const AlignedVector<U, N>& other) {
242     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) {
243       values_[i] = T(other[i]);
244     }
245   }
246 
247   __host__ __device__ value_type& operator[](int i) { return values_[i]; }
248   __host__ __device__ const value_type& operator[](int i) const {
249     return values_[i];
250   }
251 
252 #define DEFINE_BINARY_UPDATE_OPERATOR(op)                                      \
253   __host__ __device__ AlignedVector& operator op(const AlignedVector& rhs) {   \
254     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] op rhs[i]; } \
255     return *this;                                                              \
256   }
257   DEFINE_BINARY_UPDATE_OPERATOR(+=)
258   DEFINE_BINARY_UPDATE_OPERATOR(-=)
259   DEFINE_BINARY_UPDATE_OPERATOR(*=)
260   DEFINE_BINARY_UPDATE_OPERATOR(/=)
261 #undef DEFINE_BINARY_UPDATE_OPERATOR
262 
263 #define DEFINE_BINARY_OPERATOR(op)                          \
264   friend __host__ __device__ AlignedVector operator op(     \
265       const AlignedVector& lhs, const AlignedVector& rhs) { \
266     AlignedVector ret;                                      \
267     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) {      \
268       ret[i] = lhs[i] op rhs[i];                            \
269     }                                                       \
270     return ret;                                             \
271   }
272   DEFINE_BINARY_OPERATOR(+)
273   DEFINE_BINARY_OPERATOR(-)
274   DEFINE_BINARY_OPERATOR(*)
275   DEFINE_BINARY_OPERATOR(/)
276 #undef DEFINE_BINARY_OPERATOR
277 
278 #define DEFINE_BINARY_FUNCTION(func)                                        \
279   friend __host__ __device__ AlignedVector func(const AlignedVector& lhs,   \
280                                                 const AlignedVector& rhs) { \
281     AlignedVector ret;                                                      \
282     UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) {                      \
283       ret[i] = func(lhs[i], rhs[i]);                                        \
284     }                                                                       \
285     return ret;                                                             \
286   }
287   DEFINE_BINARY_FUNCTION(min)
288   DEFINE_BINARY_FUNCTION(max)
289 #undef DEFINE_BINARY_FUNCTION
290 
291  private:
292   value_type values_[N];
293 };
294 
295 #undef UNROLL_ON_DEVICE
296 
297 // Returns the maximum power-of-two alignment (in units of elements, not bytes)
298 // of a stride or pointer value.
alignment_of(int64_t element_stride)299 inline int64_t alignment_of(int64_t element_stride) {
300   // A zero/nullptr value means that the stride/pointer is not used, so it
301   // effectively has infinite alignment.
302   constexpr int64_t kMaxAlignment = 512;
303   if (element_stride == 0) return kMaxAlignment;
304   return element_stride & -element_stride;
305 }
306 
307 template <typename T>
alignment_of(T * ptr)308 inline int64_t alignment_of(T* ptr) {
309   const intptr_t ptr_val = reinterpret_cast<std::uintptr_t>(ptr);
310   // Pointers should always be aligned to sizeof(T) bytes.
311   DCHECK_EQ(ptr_val % sizeof(T), 0);
312   // Note that we want the alignment in elements, not bytes.
313   return alignment_of(ptr_val / sizeof(T));
314 }
315 
316 template <typename... Args>
MinAlignmentOf(Args...args)317 int64_t MinAlignmentOf(Args... args) {
318   return std::min({alignment_of(args)...});
319 }
320 
321 namespace detail {
322 
323 template <int64_t VecSize, template <int vec_size> class Functor>
324 struct DispatchToVectorizedHelper {
325   template <typename... Args>
operatorDispatchToVectorizedHelper326   Status operator()(int64_t max_vec_size, Args&&... args) const {
327     if (max_vec_size >= VecSize) {
328       return Functor<VecSize>()(std::forward<Args>(args)...);
329     }
330     return DispatchToVectorizedHelper<VecSize / 2, Functor>()(
331         max_vec_size, std::forward<Args>(args)...);
332   }
333 };
334 template <template <int vec_size> class Functor>
335 struct DispatchToVectorizedHelper<1, Functor> {
336   template <typename... Args>
337   Status operator()(int64_t max_vec_size, Args&&... args) const {
338     return Functor<1>()(std::forward<Args>(args)...);
339   }
340 };
341 
342 }  // namespace detail
343 
344 // Calls Functor<vec_size>()(args...) with vec_size set to the optimal GPU
345 // vector instruction size for type T that is <= max_vec_size. The max_vec_size
346 // argument should be set to the minimum alignment of all relevant parameters.
347 // Requires sizeof(T) to be a power of 2.
348 template <typename T, template <int vec_size> class Functor, typename... Args>
349 Status DispatchToVectorized(int64_t max_vec_size, Args&&... args) {
350   static_assert((sizeof(T) & (sizeof(T) - 1)) == 0,
351                 "sizeof(T) must be a power of 2");
352   if (max_vec_size <= 0) {
353     return errors::InvalidArgument("DispatchToVectorized: max_vec_size (",
354                                    max_vec_size,
355                                    ") must be greater than zero.");
356   }
357   constexpr const int kOptimalVecSizeBytes = 16;
358   // The optimal number of (aligned) elements of T to load/store in a
359   // single instruction inside a kernel.
360   constexpr const int optimal_vec_size =
361       (kOptimalVecSizeBytes - 1) / sizeof(T) + 1;
362   return detail::DispatchToVectorizedHelper<optimal_vec_size, Functor>()(
363       max_vec_size, std::forward<Args>(args)...);
364 }
365 
366 // Similar to std::upper_bound, this returns the index of the first element in
367 // [first, first + count) that is greater than `val`, or `count` if no such
368 // element is found. Assumes [first, first + count) is sorted.
369 namespace gpu_helper {
370 template <typename T, typename OutType = int32, typename Iterator = const T*>
371 __device__ OutType upper_bound(Iterator first, OutType count, T val) {
372   Iterator orig = first;
373   OutType step = 0;
374   while (count > 0) {
375     Iterator it = first;
376     step = count / 2;
377     it += step;
378     if (!(val < *it)) {
379       first = ++it;
380       count -= step + 1;
381     } else {
382       count = step;
383     }
384   }
385 
386   return first - orig;
387 }
388 
389 // Similar to std::lower_bound, this returns the index of the first element in
390 // [first, first + count) that is not less than `val`, or `count` if no such
391 // element is found. Assumes [first, first + count) is sorted.
392 template <typename T, typename OutType = int32, typename Iterator = const T*>
393 __device__ OutType lower_bound(Iterator first, OutType count, T val) {
394   Iterator orig = first;
395   OutType step = 0;
396   while (count > 0) {
397     Iterator it = first;
398     step = count / 2;
399     it += step;
400     if (*it < val) {
401       first = ++it;
402       count -= step + 1;
403     } else {
404       count = step;
405     }
406   }
407 
408   return first - orig;
409 }
410 
411 }  // namespace gpu_helper
412 
413 #ifndef TENSORFLOW_USE_ROCM
414 namespace cuda_helper = gpu_helper;
415 #endif
416 
417 // For int division, we can substitute the fast multiplication for slow
418 // division. For detailed information see:
419 //   https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
420 //
421 // Warning: This implementation only works when the divisor is [1, INT32_MAX]
422 //          and the numerator has to be [0, INT32_MAX]. This is enough for our
423 //          purpose for computing integer indices.
424 // Basics: the typical int division can be written as:
425 //   n / d = (m * n) / 2^(32 + s)
426 // where 'n' is the numerator and 'd' is the divisor. For a given 'd', we
427 // need to find a magic number 'm' and a shift 's'. See update_magic().
428 struct FastDividerUint32 {
429   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC FastDividerUint32(uint32_t d)
430       : divisor(d) {
431     assert(divisor >= 1 && divisor <= INT32_MAX);
432     update_magic();
433   }
434 
435   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void update_magic() {
436     // (1). The shift 's' is calculated by log2ceil(d).
437 #if defined(__CUDA_ARCH__)
438     shift = 32 - __clz(divisor - 1);
439 #else
440     for (shift = 0; shift < 32; shift++) {
441       if ((1U << shift) >= divisor) break;
442     }
443 #endif
444 
445     // (2). The magic number 'm' is calculated by:
446     //   m = 2^(32 + s) / d + 1
447     // Note, the digit '1' is to round up 'm * n', which will be rounded down
448     // later by dividing two. In practice, 'm' is a 33-bit value. To fit the
449     // 32-bit range, we introduce:
450     //   magic = m - 2^32
451     //         = 2^(32 + s) / d - 2^32 + 1
452     //         = 2^32 * 2^s / d - 2^32 * d / d + 1
453     //         = (2^32 * (2^s - d)) / d + 1, where 'magic' will be in 32-bit.
454     uint64_t m = (0x100000000ull * ((0x1ull << shift) - divisor)) / divisor + 1;
455     magic = static_cast<uint32_t>(m);
456   }
457 
458   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC FastDividerUint32& operator=(
459       uint32_t d) {
460     assert(divisor >= 1 && divisor <= INT32_MAX);
461     this->divisor = d;
462     update_magic();
463     return *this;
464   }
465 
466   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC operator uint32_t() const {
467     return divisor;
468   }
469 
470   uint32_t divisor;
471   uint32_t magic;
472   uint32_t shift;
473 };
474 
475 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
476 operator/(const uint32_t n, const FastDividerUint32& fdiv) {
477   // (3). We use the 32-bit 'magic' instead of 'm' in the formula:
478   //   n / d = (m * n) / 2^(32 + s)
479   //         = (magic + 2^32) * n / 2^(32 + s)
480   //         = (magic * n) / 2^(32 + s) + n / 2^s
481   //         = (magic * n) / 2^32 / 2^s + n / 2^s
482   //         = (magic * n / 2^32 + n) / 2^s
483 #if defined(__CUDA_ARCH__)
484   uint32_t q = __umulhi(n, fdiv.magic);
485 #else
486   uint32_t q =
487       static_cast<uint32_t>((static_cast<uint64_t>(n) * fdiv.magic) >> 32);
488 #endif
489   return (n + q) >> fdiv.shift;
490 }
491 
492 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
493 operator%(const uint32_t n, const FastDividerUint32& fdiv) {
494   return n - (n / fdiv) * fdiv.divisor;
495 }
496 
497 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
498 operator/(const int n, const FastDividerUint32& fdiv) {
499   return static_cast<uint32_t>(n) / fdiv;
500 }
501 
502 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
503 operator%(const int n, const FastDividerUint32& fdiv) {
504   return static_cast<uint32_t>(n) % fdiv;
505 }
506 
507 }  // namespace tensorflow
508 
509 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
510 #endif  // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
511