xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #if GOOGLE_CUDA
14 
15 #define EIGEN_USE_GPU
16 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/kernels/tensor_to_hash_bucket_op.h"
19 #include "tensorflow/core/util/gpu_kernel_helper.h"
20 #include "third_party/farmhash_gpu/src/farmhash_gpu.h"
21 
22 namespace tensorflow {
23 
24 namespace {
25 
26 // We set the buffer size to 20 as it is sufficient to cover the number of
27 // digits in any integer type.
28 constexpr int kSharedMemBufferSizePerThread = 20;
29 
30 template <typename T>
FillDigits(T val,int num_digits,int * i,char * buf)31 __device__ __forceinline__ void FillDigits(T val, int num_digits, int* i,
32                                            char* buf) {
33   eigen_assert(num_digits <= kSharedMemBufferSizePerThread - (*i));
34 
35   int factor = (val < 0 ? -1 : 1);
36 
37   int num_digits_a = num_digits;
38   do {
39     int digit = static_cast<int>((val % 10) * factor);
40     buf[(*i) + num_digits - 1] = digit + '0';
41     val /= 10;
42     num_digits--;
43   } while (val != 0);
44 
45   (*i) += num_digits_a;
46 }
47 
48 template <typename T>
IntegerToString(T val,char * buf)49 __device__ __forceinline__ int IntegerToString(T val, char* buf) {
50   int num_digits = 0;
51   T val_a = val;
52   do {
53     val_a = val_a / 10;
54     num_digits++;
55   } while (val_a != 0);
56 
57   int i = 0;
58   if (val < 0) {
59     buf[i++] = '-';
60   }
61 
62   FillDigits(val, num_digits, &i, buf);
63 
64   return i;
65 }
66 
67 template <typename T>
ComputeHashes(const T * __restrict__ vals,int vals_size,int64 num_buckets,int64 * __restrict__ hashes)68 __global__ void ComputeHashes(const T* __restrict__ vals, int vals_size,
69                               int64 num_buckets, int64* __restrict__ hashes) {
70   extern __shared__ char s[];
71 
72   GPU_1D_KERNEL_LOOP(tid, vals_size) {
73     int size = IntegerToString(vals[tid],
74                                s + threadIdx.x * kSharedMemBufferSizePerThread);
75     uint64_t a_hash = ::util_gpu::Fingerprint64(
76         s + threadIdx.x * kSharedMemBufferSizePerThread, size);
77     int64 a_bucket = static_cast<int64_t>(a_hash % num_buckets);
78     hashes[tid] = a_bucket;
79   }
80 }
81 
82 }  // end namespace
83 
84 namespace functor {
85 
86 template <typename T>
operator ()(OpKernelContext * c,const int64 num_buckets,const T * input,const int num_elems,int64 * output)87 void LaunchTensorToHashBucket<Eigen::GpuDevice, T>::operator()(
88     OpKernelContext* c, const int64 num_buckets, const T* input,
89     const int num_elems, int64* output) {
90   auto* stream = c->op_device_context()->stream();
91   const Eigen::GpuDevice& d = c->eigen_gpu_device();
92   if (num_elems > 0) {
93     constexpr size_t kThreadsLimitInBlock = 1024;
94 
95     size_t smem_bytes_allowed =
96         stream->parent()->GetDeviceDescription().shared_memory_per_block();
97     auto smem_bytes_per_thread = kSharedMemBufferSizePerThread * sizeof(char);
98     size_t thread_per_block = std::min(
99         kThreadsLimitInBlock, smem_bytes_allowed / smem_bytes_per_thread);
100 
101     auto smem_bytes_per_block = thread_per_block * smem_bytes_per_thread;
102     GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
103         num_elems, d, ComputeHashes<T>, smem_bytes_per_block, thread_per_block);
104     OP_REQUIRES_OK(
105         c, GpuLaunchKernel(ComputeHashes<T>, config.block_count,
106                            config.thread_per_block, smem_bytes_per_block,
107                            d.stream(), input, num_elems, num_buckets, output));
108   }
109 }
110 
111 }  // namespace functor
112 
113 #define REGISTER_FUNCTORS(type) \
114   template struct functor::LaunchTensorToHashBucket<Eigen::GpuDevice, type>;
115 
116 TF_CALL_INTEGRAL_TYPES(REGISTER_FUNCTORS);
117 
118 #undef REGISTER_FUNCTORS
119 
120 }  // namespace tensorflow
121 #endif  // GOOGLE_CUDA
122