xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gpu_prim.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 To in writing unless required by applicable law or agreed,
9 distributed on an, software distributed under the license is "AS IS"
10 BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express
11 or implied. For the specific language governing permissions and
12 limitations under the license, the license you must see.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
15 #define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
16 
17 #include "tensorflow/core/platform/bfloat16.h"
18 
19 #if GOOGLE_CUDA
20 #include "cub/block/block_load.cuh"
21 #include "cub/block/block_scan.cuh"
22 #include "cub/block/block_store.cuh"
23 #include "cub/device/device_histogram.cuh"
24 #include "cub/device/device_radix_sort.cuh"
25 #include "cub/device/device_reduce.cuh"
26 #include "cub/device/device_scan.cuh"
27 #include "cub/device/device_segmented_radix_sort.cuh"
28 #include "cub/device/device_segmented_reduce.cuh"
29 #include "cub/device/device_select.cuh"
30 #include "cub/iterator/counting_input_iterator.cuh"
31 #include "cub/iterator/transform_input_iterator.cuh"
32 #include "cub/thread/thread_operators.cuh"
33 #include "cub/warp/warp_reduce.cuh"
34 #include "third_party/gpus/cuda/include/cusparse.h"
35 
36 namespace gpuprim = ::cub;
37 
38 // Required for sorting Eigen::half and bfloat16.
39 namespace cub {
40 template <>
41 __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
42     Eigen::half *ptr, Eigen::half val, Int2Type<true> /*is_primitive*/) {
43   *reinterpret_cast<volatile uint16_t *>(ptr) =
44       Eigen::numext::bit_cast<uint16_t>(val);
45 }
46 
47 template <>
48 __device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
49     Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
50   uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
51   return Eigen::numext::bit_cast<Eigen::half>(result);
52 }
53 
54 template <>
55 __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::bfloat16>(
56     Eigen::bfloat16 *ptr, Eigen::bfloat16 val,
57     Int2Type<true> /*is_primitive*/) {
58   *reinterpret_cast<volatile uint16_t *>(ptr) =
59       Eigen::numext::bit_cast<uint16_t>(val);
60 }
61 
62 template <>
63 __device__ __forceinline__ Eigen::bfloat16
64 ThreadLoadVolatilePointer<Eigen::bfloat16>(Eigen::bfloat16 *ptr,
65                                            Int2Type<true> /*is_primitive*/) {
66   uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
67   return Eigen::numext::bit_cast<Eigen::bfloat16>(result);
68 }
69 
70 template <>
71 struct NumericTraits<Eigen::half>
72     : BaseTraits</*_CATEGORY=*/FLOATING_POINT, /*_PRIMITIVE=*/true,
73                  /*_NULL_TYPE=*/false, /*_UnsignedBits=*/uint16_t,
74                  /*T=*/Eigen::half> {};
75 template <>
76 struct NumericTraits<tensorflow::bfloat16>
77     : BaseTraits</*_CATEGORY=*/FLOATING_POINT, /*_PRIMITIVE=*/true,
78                  /*_NULL_TYPE=*/false, /*_UnsignedBits=*/uint16_t,
79                  /*T=*/tensorflow::bfloat16> {};
80 }  // namespace cub
81 #elif TENSORFLOW_USE_ROCM
82 #include "rocm/include/hipcub/hipcub.hpp"
83 namespace gpuprim = ::hipcub;
84 
85 // Required for sorting Eigen::half and bfloat16.
86 namespace rocprim {
87 namespace detail {
88 template <>
89 struct radix_key_codec_base<Eigen::half>
90     : radix_key_codec_floating<Eigen::half, uint16_t> {};
91 template <>
92 struct radix_key_codec_base<tensorflow::bfloat16>
93     : radix_key_codec_floating<tensorflow::bfloat16, uint16_t> {};
94 };  // namespace detail
95 };  // namespace rocprim
96 
97 #endif  // TENSORFLOW_USE_ROCM
98 
99 #endif  // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
100