1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #include <type_traits>
22
23 #if GOOGLE_CUDA
24 #include "third_party/gpus/cuda/include/cuda_fp16.h"
25 #endif
26 #include "tensorflow/core/util/gpu_cuda_alias.h"
27 #include "tensorflow/core/util/gpu_device_functions.h"
28 #include "tensorflow/core/util/gpu_launch_config.h"
29
30 #if GOOGLE_CUDA
31 #define TF_RED_WARPSIZE 32
32 #elif TENSORFLOW_USE_ROCM
33 #define TF_RED_WARPSIZE 64
34 #endif
35
36 // Deprecated, use 'for(int i : GpuGridRangeX(n))' instead.
37 #define GPU_1D_KERNEL_LOOP(i, n) \
38 for (int i : ::tensorflow::GpuGridRangeX<int>(n))
39 #define CUDA_1D_KERNEL_LOOP(i, n) \
40 for (int i : ::tensorflow::GpuGridRangeX<int>(n))
41
42 // Deprecated, use 'for(int i : GpuGridRange?(n))' instead.
43 #define GPU_AXIS_KERNEL_LOOP(i, n, axis) \
44 for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
45 #define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
46 for (int i : ::tensorflow::GpuGridRange##axis<int>(n))
47
48 #if GOOGLE_CUDA
49 #define gpuSuccess cudaSuccess
50 using gpuStream_t = cudaStream_t;
51 using gpuError_t = cudaError_t;
52 #elif TENSORFLOW_USE_ROCM
53 #define gpuSuccess hipSuccess
54 using gpuStream_t = hipStream_t;
55 using gpuError_t = hipError_t;
56 #endif
57
58 // macro wrapper to declare dynamic shared memory
59 #if GOOGLE_CUDA
60
61 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
62 extern __shared__ __align__(ALIGN) TYPE NAME[]
63
64 #elif TENSORFLOW_USE_ROCM
65
66 #define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \
67 HIP_DYNAMIC_SHARED(TYPE, NAME)
68
69 #endif
70
71 namespace tensorflow {
72
73 #if GOOGLE_CUDA
74 // cudaGetErrorString is available to both host and device
GpuGetErrorString(cudaError_t error)75 __host__ __device__ inline const char* GpuGetErrorString(cudaError_t error) {
76 return cudaGetErrorString(error);
77 }
78 #elif TENSORFLOW_USE_ROCM
79 // hipGetErrorString is available on host side only
80 inline const char* GpuGetErrorString(hipError_t error) {
81 return hipGetErrorString(error);
82 }
83 #endif
84
85 // Returns a raw reference to the current cuda stream. Required by a
86 // number of kernel calls (for which StreamInterface* does not work),
87 // i.e. CUB and certain cublas primitives.
GetGpuStream(OpKernelContext * context)88 inline const gpuStream_t& GetGpuStream(OpKernelContext* context) {
89 const gpuStream_t* ptr = CHECK_NOTNULL(
90 reinterpret_cast<const gpuStream_t*>(context->op_device_context()
91 ->stream()
92 ->implementation()
93 ->GpuStreamMemberHack()));
94 return *ptr;
95 }
96
97 // Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or
98 // hipLaunchKernel in ROCm environment with the given arguments.
99 //
100 // The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
101 template <typename... Ts, typename... Args>
GpuLaunchKernel(void (* function)(Ts...),dim3 grid_dim,dim3 block_dim,size_t shared_memory_size_bytes,gpuStream_t stream,Args...arguments)102 Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
103 size_t shared_memory_size_bytes, gpuStream_t stream,
104 Args... arguments) {
105 static_assert(detail::NoneIsReference<Ts...>(),
106 "Kernels with reference arguments have undefined behaviour.");
107 #if GOOGLE_CUDA
108 auto func_ptr = absl::bit_cast<const void*>(function);
109 // Cast arguments and forward them as an array of pointers.
110 auto args_tuple = std::tuple<Ts...>(arguments...);
111 auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
112 auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
113 shared_memory_size_bytes, stream);
114 if (result != cudaSuccess) {
115 return errors::Internal(cudaGetErrorString(result));
116 }
117 #elif TENSORFLOW_USE_ROCM
118 hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes,
119 stream, std::forward<Args>(arguments)...);
120 #endif
121 return OkStatus();
122 }
123
124 // Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA
125 // builds
126 template <typename... Args>
127 auto CudaLaunchKernel(Args&&... args)
128 -> decltype(GpuLaunchKernel(std::forward<Args>(args)...)) {
129 return GpuLaunchKernel(std::forward<Args>(args)...);
130 }
131
GpuLdg(const tensorflow::bfloat16 * address)132 __host__ __device__ inline tensorflow::bfloat16 GpuLdg(
133 const tensorflow::bfloat16* address) {
134 return Eigen::numext::bit_cast<tensorflow::bfloat16>(
135 GpuLdg(reinterpret_cast<const uint16_t*>(address)));
136 }
137 // Already aliased in gpu_device_functions.h
138
139 template <typename T>
ldg(const T * ptr)140 __host__ __device__ inline T ldg(const T* ptr) {
141 return GpuLdg(ptr);
142 }
143
144 template <typename T>
tf_min(const T & x,const T & y)145 __host__ __device__ inline const T& tf_min(const T& x, const T& y) {
146 return x < y ? x : y;
147 }
148
149 template <typename T>
tf_max(const T & x,const T & y)150 __host__ __device__ inline const T& tf_max(const T& x, const T& y) {
151 return x < y ? y : x;
152 }
153
154 // Overloads of the above functions for float and double.
tf_min(float x,float y)155 __host__ __device__ inline float tf_min(float x, float y) {
156 return fminf(x, y);
157 }
tf_min(double x,double y)158 __host__ __device__ inline double tf_min(double x, double y) {
159 return fmin(x, y);
160 }
tf_max(float x,float y)161 __host__ __device__ inline float tf_max(float x, float y) {
162 return fmaxf(x, y);
163 }
tf_max(double x,double y)164 __host__ __device__ inline double tf_max(double x, double y) {
165 return fmax(x, y);
166 }
167
168 // ROCM TODO re-enable them after adding fp16 support logic
169 #if GOOGLE_CUDA
170 __device__ inline Eigen::half GpuShuffleSync(unsigned mask, Eigen::half value,
171 int src_lane,
172 int width = warpSize) {
173 return Eigen::half(
174 GpuShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
175 }
176 // Aliased in gpu_device_functions.h
177
178 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleUpSync(
179 unsigned mask, Eigen::half value, int delta, int width = warpSize) {
180 return Eigen::half(
181 GpuShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
182 }
183 // Aliased in gpu_device_functions.h
184
185 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleDownSync(
186 unsigned mask, Eigen::half value, int delta, int width = warpSize) {
187 return Eigen::half(
188 GpuShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
189 }
190 // Aliased in gpu_device_functions.h
191
192 __device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleXorSync(
193 unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
194 return Eigen::half(
195 GpuShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
196 }
197 // Aliased in gpu_device_functions.h
198 #endif
199
200 #ifdef __CUDA_ARCH__
201 #define UNROLL_ON_DEVICE _Pragma("unroll")
202 #else
203 #define UNROLL_ON_DEVICE
204 #endif
205
206 // Represents an aligned array of N elements of T. Data pointers can be
207 // reinterpreted as this type to generate vectorized loads/stores in a kernel.
208 template <typename T, int N>
209 class alignas(alignof(T) * N) AlignedVector {
210 public:
211 typedef T value_type;
212 static constexpr const int kSize = N;
213
214 AlignedVector() = default;
215
216 // Uniform initialization.
AlignedVector(value_type uniform)217 __host__ __device__ explicit AlignedVector(value_type uniform) {
218 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; }
219 }
220 // Uniform initialization with explicit conversion.
221 // Note: This is required for T=Eigen::half because it only supports explicit
222 // conversions from other types and its template constructor is too relaxed
223 // to be able to use std::is_constructible.
224 template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
225 int>::type = 0>
AlignedVector(U uniform_u)226 __host__ __device__ explicit AlignedVector(U uniform_u) {
227 value_type uniform(uniform_u);
228 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; }
229 }
230 // Implicit conversion.
231 template <typename U, typename std::enable_if<
232 std::is_convertible<U, T>::value, int>::type = 0>
AlignedVector(const AlignedVector<U,N> & other)233 __host__ __device__ AlignedVector(const AlignedVector<U, N>& other) {
234 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = other[i]; }
235 }
236 // Explicit conversion.
237 template <typename U,
238 typename std::enable_if<!std::is_convertible<U, T>::value &&
239 std::is_constructible<T, U>::value,
240 int>::type = 0>
AlignedVector(const AlignedVector<U,N> & other)241 __host__ __device__ explicit AlignedVector(const AlignedVector<U, N>& other) {
242 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) {
243 values_[i] = T(other[i]);
244 }
245 }
246
247 __host__ __device__ value_type& operator[](int i) { return values_[i]; }
248 __host__ __device__ const value_type& operator[](int i) const {
249 return values_[i];
250 }
251
252 #define DEFINE_BINARY_UPDATE_OPERATOR(op) \
253 __host__ __device__ AlignedVector& operator op(const AlignedVector& rhs) { \
254 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] op rhs[i]; } \
255 return *this; \
256 }
257 DEFINE_BINARY_UPDATE_OPERATOR(+=)
258 DEFINE_BINARY_UPDATE_OPERATOR(-=)
259 DEFINE_BINARY_UPDATE_OPERATOR(*=)
260 DEFINE_BINARY_UPDATE_OPERATOR(/=)
261 #undef DEFINE_BINARY_UPDATE_OPERATOR
262
263 #define DEFINE_BINARY_OPERATOR(op) \
264 friend __host__ __device__ AlignedVector operator op( \
265 const AlignedVector& lhs, const AlignedVector& rhs) { \
266 AlignedVector ret; \
267 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \
268 ret[i] = lhs[i] op rhs[i]; \
269 } \
270 return ret; \
271 }
272 DEFINE_BINARY_OPERATOR(+)
273 DEFINE_BINARY_OPERATOR(-)
274 DEFINE_BINARY_OPERATOR(*)
275 DEFINE_BINARY_OPERATOR(/)
276 #undef DEFINE_BINARY_OPERATOR
277
278 #define DEFINE_BINARY_FUNCTION(func) \
279 friend __host__ __device__ AlignedVector func(const AlignedVector& lhs, \
280 const AlignedVector& rhs) { \
281 AlignedVector ret; \
282 UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \
283 ret[i] = func(lhs[i], rhs[i]); \
284 } \
285 return ret; \
286 }
287 DEFINE_BINARY_FUNCTION(min)
288 DEFINE_BINARY_FUNCTION(max)
289 #undef DEFINE_BINARY_FUNCTION
290
291 private:
292 value_type values_[N];
293 };
294
295 #undef UNROLL_ON_DEVICE
296
297 // Returns the maximum power-of-two alignment (in units of elements, not bytes)
298 // of a stride or pointer value.
alignment_of(int64_t element_stride)299 inline int64_t alignment_of(int64_t element_stride) {
300 // A zero/nullptr value means that the stride/pointer is not used, so it
301 // effectively has infinite alignment.
302 constexpr int64_t kMaxAlignment = 512;
303 if (element_stride == 0) return kMaxAlignment;
304 return element_stride & -element_stride;
305 }
306
307 template <typename T>
alignment_of(T * ptr)308 inline int64_t alignment_of(T* ptr) {
309 const intptr_t ptr_val = reinterpret_cast<std::uintptr_t>(ptr);
310 // Pointers should always be aligned to sizeof(T) bytes.
311 DCHECK_EQ(ptr_val % sizeof(T), 0);
312 // Note that we want the alignment in elements, not bytes.
313 return alignment_of(ptr_val / sizeof(T));
314 }
315
316 template <typename... Args>
MinAlignmentOf(Args...args)317 int64_t MinAlignmentOf(Args... args) {
318 return std::min({alignment_of(args)...});
319 }
320
321 namespace detail {
322
323 template <int64_t VecSize, template <int vec_size> class Functor>
324 struct DispatchToVectorizedHelper {
325 template <typename... Args>
operatorDispatchToVectorizedHelper326 Status operator()(int64_t max_vec_size, Args&&... args) const {
327 if (max_vec_size >= VecSize) {
328 return Functor<VecSize>()(std::forward<Args>(args)...);
329 }
330 return DispatchToVectorizedHelper<VecSize / 2, Functor>()(
331 max_vec_size, std::forward<Args>(args)...);
332 }
333 };
334 template <template <int vec_size> class Functor>
335 struct DispatchToVectorizedHelper<1, Functor> {
336 template <typename... Args>
337 Status operator()(int64_t max_vec_size, Args&&... args) const {
338 return Functor<1>()(std::forward<Args>(args)...);
339 }
340 };
341
342 } // namespace detail
343
344 // Calls Functor<vec_size>()(args...) with vec_size set to the optimal GPU
345 // vector instruction size for type T that is <= max_vec_size. The max_vec_size
346 // argument should be set to the minimum alignment of all relevant parameters.
347 // Requires sizeof(T) to be a power of 2.
348 template <typename T, template <int vec_size> class Functor, typename... Args>
349 Status DispatchToVectorized(int64_t max_vec_size, Args&&... args) {
350 static_assert((sizeof(T) & (sizeof(T) - 1)) == 0,
351 "sizeof(T) must be a power of 2");
352 if (max_vec_size <= 0) {
353 return errors::InvalidArgument("DispatchToVectorized: max_vec_size (",
354 max_vec_size,
355 ") must be greater than zero.");
356 }
357 constexpr const int kOptimalVecSizeBytes = 16;
358 // The optimal number of (aligned) elements of T to load/store in a
359 // single instruction inside a kernel.
360 constexpr const int optimal_vec_size =
361 (kOptimalVecSizeBytes - 1) / sizeof(T) + 1;
362 return detail::DispatchToVectorizedHelper<optimal_vec_size, Functor>()(
363 max_vec_size, std::forward<Args>(args)...);
364 }
365
366 // Similar to std::upper_bound, this returns the index of the first element in
367 // [first, first + count) that is greater than `val`, or `count` if no such
368 // element is found. Assumes [first, first + count) is sorted.
369 namespace gpu_helper {
370 template <typename T, typename OutType = int32, typename Iterator = const T*>
371 __device__ OutType upper_bound(Iterator first, OutType count, T val) {
372 Iterator orig = first;
373 OutType step = 0;
374 while (count > 0) {
375 Iterator it = first;
376 step = count / 2;
377 it += step;
378 if (!(val < *it)) {
379 first = ++it;
380 count -= step + 1;
381 } else {
382 count = step;
383 }
384 }
385
386 return first - orig;
387 }
388
389 // Similar to std::lower_bound, this returns the index of the first element in
390 // [first, first + count) that is not less than `val`, or `count` if no such
391 // element is found. Assumes [first, first + count) is sorted.
392 template <typename T, typename OutType = int32, typename Iterator = const T*>
393 __device__ OutType lower_bound(Iterator first, OutType count, T val) {
394 Iterator orig = first;
395 OutType step = 0;
396 while (count > 0) {
397 Iterator it = first;
398 step = count / 2;
399 it += step;
400 if (*it < val) {
401 first = ++it;
402 count -= step + 1;
403 } else {
404 count = step;
405 }
406 }
407
408 return first - orig;
409 }
410
411 } // namespace gpu_helper
412
413 #ifndef TENSORFLOW_USE_ROCM
414 namespace cuda_helper = gpu_helper;
415 #endif
416
417 // For int division, we can substitute the fast multiplication for slow
418 // division. For detailed information see:
419 // https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
420 //
421 // Warning: This implementation only works when the divisor is [1, INT32_MAX]
422 // and the numerator has to be [0, INT32_MAX]. This is enough for our
423 // purpose for computing integer indices.
424 // Basics: the typical int division can be written as:
425 // n / d = (m * n) / 2^(32 + s)
426 // where 'n' is the numerator and 'd' is the divisor. For a given 'd', we
427 // need to find a magic number 'm' and a shift 's'. See update_magic().
428 struct FastDividerUint32 {
429 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC FastDividerUint32(uint32_t d)
430 : divisor(d) {
431 assert(divisor >= 1 && divisor <= INT32_MAX);
432 update_magic();
433 }
434
435 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void update_magic() {
436 // (1). The shift 's' is calculated by log2ceil(d).
437 #if defined(__CUDA_ARCH__)
438 shift = 32 - __clz(divisor - 1);
439 #else
440 for (shift = 0; shift < 32; shift++) {
441 if ((1U << shift) >= divisor) break;
442 }
443 #endif
444
445 // (2). The magic number 'm' is calculated by:
446 // m = 2^(32 + s) / d + 1
447 // Note, the digit '1' is to round up 'm * n', which will be rounded down
448 // later by dividing two. In practice, 'm' is a 33-bit value. To fit the
449 // 32-bit range, we introduce:
450 // magic = m - 2^32
451 // = 2^(32 + s) / d - 2^32 + 1
452 // = 2^32 * 2^s / d - 2^32 * d / d + 1
453 // = (2^32 * (2^s - d)) / d + 1, where 'magic' will be in 32-bit.
454 uint64_t m = (0x100000000ull * ((0x1ull << shift) - divisor)) / divisor + 1;
455 magic = static_cast<uint32_t>(m);
456 }
457
458 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC FastDividerUint32& operator=(
459 uint32_t d) {
460 assert(divisor >= 1 && divisor <= INT32_MAX);
461 this->divisor = d;
462 update_magic();
463 return *this;
464 }
465
466 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC operator uint32_t() const {
467 return divisor;
468 }
469
470 uint32_t divisor;
471 uint32_t magic;
472 uint32_t shift;
473 };
474
475 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
476 operator/(const uint32_t n, const FastDividerUint32& fdiv) {
477 // (3). We use the 32-bit 'magic' instead of 'm' in the formula:
478 // n / d = (m * n) / 2^(32 + s)
479 // = (magic + 2^32) * n / 2^(32 + s)
480 // = (magic * n) / 2^(32 + s) + n / 2^s
481 // = (magic * n) / 2^32 / 2^s + n / 2^s
482 // = (magic * n / 2^32 + n) / 2^s
483 #if defined(__CUDA_ARCH__)
484 uint32_t q = __umulhi(n, fdiv.magic);
485 #else
486 uint32_t q =
487 static_cast<uint32_t>((static_cast<uint64_t>(n) * fdiv.magic) >> 32);
488 #endif
489 return (n + q) >> fdiv.shift;
490 }
491
492 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
493 operator%(const uint32_t n, const FastDividerUint32& fdiv) {
494 return n - (n / fdiv) * fdiv.divisor;
495 }
496
497 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
498 operator/(const int n, const FastDividerUint32& fdiv) {
499 return static_cast<uint32_t>(n) / fdiv;
500 }
501
502 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint32_t
503 operator%(const int n, const FastDividerUint32& fdiv) {
504 return static_cast<uint32_t>(n) % fdiv;
505 }
506
507 } // namespace tensorflow
508
509 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
510 #endif // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
511