xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/bfloat16.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/concat_lib_gpu.h"
27 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
28 #include "tensorflow/core/util/gpu_kernel_helper.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::GpuDevice GPUDevice;
33 
34 namespace {
35 
36 template <typename T, typename IntType>
concat_fixed_kernel(GpuDeviceArrayStruct<const T * > input_ptr_data,int split_size,int total_rows,int total_cols,T * __restrict__ output)37 __global__ void concat_fixed_kernel(
38     GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
39     int total_rows, int total_cols, T* __restrict__ output) {
40   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
41   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
42 
43   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
44     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
45 
46     IntType split = gidx / split_size;
47     const T* input_ptr = input_ptrs[split];
48     IntType col_offset = gidx % split_size;
49 #pragma unroll
50     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) {
51       output[gidy * total_cols + gidx] =
52           input_ptr[gidy * split_size + col_offset];
53     }
54   }
55 }
56 
57 }  // end namespace
58 
59 // cannot be in anonymous namespace due to extern shared memory
60 template <typename T, typename IntType, bool useSmem>
concat_variable_kernel(GpuDeviceArrayStruct<const T * > input_ptr_data,GpuDeviceArrayStruct<IntType> output_scan,IntType total_rows,IntType total_cols,T * output)61 __global__ void concat_variable_kernel(
62     GpuDeviceArrayStruct<const T*> input_ptr_data,
63     GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
64     IntType total_cols, T* output) {
65   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
66   IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
67 
68   // do upper_bound on col to find which pointer we should be using
69   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
70   IntType num_inputs = input_ptr_data.size;
71 
72   // verbose declaration needed due to template
73   GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, smem);
74   IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
75 
76   if (useSmem) {
77     IntType lidx = threadIdx.y * blockDim.x + threadIdx.x;
78     IntType blockSize = blockDim.x * blockDim.y;
79 
80     for (IntType i = lidx; i < output_scan.size; i += blockSize) {
81       smem_col_scan[i] = col_scan[i];
82     }
83 
84     __syncthreads();
85 
86     col_scan = smem_col_scan;
87   }
88 
89   // do an initial binary search and then scan linearly from there
90   // works well when there are many small segments and when the
91   // segments are much longer
92   IntType segment =
93       gpu_helper::upper_bound<IntType>(col_scan, num_inputs, gidx) - 1;
94 
95   IntType curr_offset = col_scan[segment];
96   IntType curr_segment = segment;
97   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
98     IntType curr_col_offset;
99     while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) {
100       curr_offset = curr_col_offset;
101       ++curr_segment;
102     }
103 
104     IntType local_col = gidx - curr_offset;
105     IntType segment_width = curr_col_offset - curr_offset;
106     const T* input_ptr = input_ptrs[curr_segment];
107 
108     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
109     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y)
110       output[gidy * total_cols + gidx] =
111           input_ptr[gidy * segment_width + local_col];
112   }
113 }
114 
115 template <typename T, typename IntType>
ConcatGPUSlice(const Eigen::GpuDevice & gpu_device,const std::vector<std::unique_ptr<typename TTypes<T,2>::ConstMatrix>> & inputs_flat,typename TTypes<T,2>::Matrix * output)116 void ConcatGPUSlice(
117     const Eigen::GpuDevice& gpu_device,
118     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
119         inputs_flat,
120     typename TTypes<T, 2>::Matrix* output) {
121   Eigen::array<IntType, 2> offset{0, 0};
122   for (int i = 0; i < inputs_flat.size(); ++i) {
123     Eigen::array<IntType, 2> size;
124     size[0] = inputs_flat[i]->dimension(0);
125     size[1] = inputs_flat[i]->dimension(1);
126     if (std::is_same<IntType, int32>::value) {
127       To32Bit(*output).slice(offset, size).device(gpu_device) =
128           To32Bit(*inputs_flat[i]);
129     } else {
130       output->slice(offset, size).device(gpu_device) = *inputs_flat[i];
131     }
132 
133     offset[1] += size[1];
134   }
135 }
136 
137 template <typename T, typename IntType>
ConcatGPUImpl(const Eigen::GpuDevice & gpu_device,const GpuDeviceArrayStruct<const T * > & input_ptrs,const GpuDeviceArrayStruct<IntType> & output_scan,bool fixed_size,int split_size,typename TTypes<T,2>::Matrix * output)138 void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
139                    const GpuDeviceArrayStruct<const T*>& input_ptrs,
140                    const GpuDeviceArrayStruct<IntType>& output_scan,
141                    bool fixed_size, int split_size,
142                    typename TTypes<T, 2>::Matrix* output) {
143   auto config = GetGpu2DLaunchConfig(output->dimension(1), output->dimension(0),
144                                      gpu_device);
145 
146   if (fixed_size) {
147     TF_CHECK_OK(GpuLaunchKernel(
148         concat_fixed_kernel<T, IntType>, config.block_count,
149         config.thread_per_block, 0, gpu_device.stream(), input_ptrs, split_size,
150         static_cast<int>(output->dimension(0)),
151         static_cast<int>(output->dimension(1)), output->data()));
152   } else {
153     IntType smem_max = gpu_device.sharedMemPerBlock();
154     IntType smem_usage = output_scan.size * sizeof(IntType);
155     // performance crossover is less than using maximum available shared memory
156     // on most processors
157     // possibly due to decreasing occupancy
158     // 4096 inputs is a lot, most code will take the smem path
159     const int32 kMaxSmemBytesPerformance = 16384;
160     if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) {
161       TF_CHECK_OK(GpuLaunchKernel(
162           concat_variable_kernel<T, IntType, true>, config.block_count,
163           config.thread_per_block, smem_usage, gpu_device.stream(), input_ptrs,
164           output_scan, static_cast<IntType>(output->dimension(0)),
165           static_cast<IntType>(output->dimension(1)), output->data()));
166     } else {
167       TF_CHECK_OK(GpuLaunchKernel(
168           concat_variable_kernel<T, IntType, false>, config.block_count,
169           config.thread_per_block, 0, gpu_device.stream(), input_ptrs,
170           output_scan, static_cast<IntType>(output->dimension(0)),
171           static_cast<IntType>(output->dimension(1)), output->data()));
172     }
173   }
174 }
175 
176 #define REGISTER_GPUCONCAT32(T)                                               \
177   template void ConcatGPUSlice<T, int32>(                                     \
178       const Eigen::GpuDevice& gpu_device,                                     \
179       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
180           inputs_flat,                                                        \
181       typename TTypes<T, 2>::Matrix* output);
182 
183 #define REGISTER_GPUCONCAT64(T)                                               \
184   template void ConcatGPUSlice<T, int64>(                                     \
185       const Eigen::GpuDevice& gpu_device,                                     \
186       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
187           inputs_flat,                                                        \
188       typename TTypes<T, 2>::Matrix* output);
189 
190 #define REGISTER_GPU32(T)                                              \
191   template void ConcatGPUImpl<T, int32>(                               \
192       const Eigen::GpuDevice& d,                                       \
193       const GpuDeviceArrayStruct<const T*>& input_ptrs,                \
194       const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
195       int split_size, typename TTypes<T, 2>::Matrix* output);
196 
197 #define REGISTER_GPU64(T)                                                \
198   template void ConcatGPUImpl<T, int64>(                                 \
199       const Eigen::GpuDevice& d,                                         \
200       const GpuDeviceArrayStruct<const T*>& input_ptrs,                  \
201       const GpuDeviceArrayStruct<int64_t>& ptr_offsets, bool fixed_size, \
202       int split_size, typename TTypes<T, 2>::Matrix* output);
203 
204 TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT32);  // int32 Needed for TensorLists.
205 TF_CALL_bfloat16(REGISTER_GPUCONCAT32);
206 TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT32);
207 
208 TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT64);  // int32 Needed for TensorLists.
209 TF_CALL_bfloat16(REGISTER_GPUCONCAT64);
210 TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT64);
211 
212 TF_CALL_INTEGRAL_TYPES(REGISTER_GPU32);  // int32 Needed for TensorLists.
213 TF_CALL_bfloat16(REGISTER_GPU32);
214 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU32);
215 
216 TF_CALL_INTEGRAL_TYPES(REGISTER_GPU64);  // int32 Needed for TensorLists.
217 TF_CALL_bfloat16(REGISTER_GPU64);
218 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU64);
219 
220 #undef REGISTER_GPUCONCAT32
221 #undef REGISTER_GPUCONCAT64
222 #undef REGISTER_GPU32
223 #undef REGISTER_GPU64
224 
225 }  // end namespace tensorflow
226 
227 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
228