xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/scan_ops_gpu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #if CUDA_VERSION >= 9000
24 #define CUB_USE_COOPERATIVE_GROUPS
25 #endif  // CUDA_VERSION >= 9000
26 
27 #include "tensorflow/core/framework/numeric_types.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/kernels/gpu_prim.h"
30 #include "tensorflow/core/kernels/scan_ops.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 #include "tensorflow/core/util/gpu_launch_config.h"
33 #include "tensorflow/core/util/permutation_input_iterator.h"
34 #include "tensorflow/core/util/permutation_output_iterator.h"
35 
36 namespace tensorflow {
37 
38 typedef Eigen::GpuDevice GPUDevice;
39 typedef Eigen::Index Index;
40 
41 namespace functor {
42 
43 // Map a contiguous range to the actual memory locations depending on which
44 // axis the scan is taking place over and whether or not reversed.
45 struct MapIndexToLocation {
46   __host__ __device__ MapIndexToLocation(int dimx, int dimy, int dimz,
47                                          bool reverse = false)
dimx_MapIndexToLocation48       : dimx_(dimx), dimy_(dimy), dimz_(dimz), reverse_(reverse) {}
49 
operatorMapIndexToLocation50   __host__ __device__ int operator()(int id) const {
51     if (dimx_ == 1) {
52       int row = id % dimy_;
53       int col = id / dimy_;
54 
55       if (reverse_) return (dimy_ - row - 1) * dimz_ + col;
56 
57       return row * dimz_ + col;
58     } else if (dimz_ == 1) {
59       if (reverse_) {
60         int row = id / dimy_;
61         int col = id % dimy_;
62         return row * dimy_ + (dimy_ - col - 1);
63       }
64       return id;
65     } else {
66       int col = id % dimy_;
67       int tmp = id / dimy_;
68 
69       int row1 = id / (dimy_ * dimz_);
70       int col1 = tmp % dimz_;
71 
72       if (reverse_)
73         return row1 * dimy_ * dimz_ + (dimy_ - col - 1) * dimz_ + col1;
74 
75       return row1 * dimy_ * dimz_ + col * dimz_ + col1;
76     }
77   }
78 
79   int dimx_;
80   int dimy_;
81   int dimz_;
82   bool reverse_;
83 };
84 
85 template <typename T, typename Op>
86 struct BlockPrefixCallbackOp {
87   // Running prefix
88   T running_total_;
89   Op op_;
90 
BlockPrefixCallbackOpBlockPrefixCallbackOp91   __device__ BlockPrefixCallbackOp(T running_total, Op op)
92       : running_total_(running_total), op_(op) {}
93 
94   // Callback operator to be entered by the first warp of threads in the block.
95   // tid 0 is responsible for returning a value for seeding the block-wide scan.
operatorBlockPrefixCallbackOp96   __device__ T operator()(T block_aggregate) {
97     T old_prefix = running_total_;
98     running_total_ = op_(old_prefix, block_aggregate);
99     return old_prefix;
100   }
101 };
102 
103 template <typename T>
104 struct Sum {
operatorSum105   __host__ __device__ T operator()(const T& a, const T& b) const {
106     return a + b;
107   }
108 };
109 
110 template <typename T>
111 struct Prod {
operatorProd112   __host__ __device__ T operator()(const T& a, const T& b) const {
113     return a * b;
114   }
115 };
116 
117 template <typename T, typename Op>
118 struct IsSum {
119   constexpr static bool value =
120       (std::is_same<Op, Sum<T>>::value ||
121        std::is_same<Op, Eigen::internal::SumReducer<T>>::value);
122 };
123 
124 template <typename T, typename Op>
125 struct IsProd {
126   constexpr static bool value =
127       (std::is_same<Op, Prod<T>>::value ||
128        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
129 };
130 
131 template <typename T, typename Op>
132 struct IsLogSumExp {
133   constexpr static bool value = (std::is_same<Op, LogSumExp<T>>::value ||
134                                  std::is_same<Op, LogSumExpReducer<T>>::value);
135 };
136 
137 template <typename T, typename Op>
138 struct IdentityValue {
139   static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value ||
140                     IsLogSumExp<T, Op>::value,
141                 "IdentityValue not yet defined for this type.");
142 
143   template <typename U = T, typename OpCopy = Op>
operatorIdentityValue144   __host__ __device__ U operator()(
145       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
146     return t;
147   }
148 
149   template <typename U = T, typename OpCopy = Op>
operatorIdentityValue150   __host__ __device__ U operator()(
151       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
152     return t;
153   }
154 
155   template <typename U = T, typename OpCopy = Op>
156   __host__ __device__ U
operatorIdentityValue157   operator()(typename std::enable_if<IsLogSumExp<U, OpCopy>::value, U>::type t =
158                  U(Eigen::NumTraits<U>::lowest())) {
159     return t;
160   }
161 };
162 
163 // Each block is mapped to one sequence.  A contiguous range is mapped to the
164 // appropriate locations in memory by the permutation iterators.  This is
165 // ideal for 1-D and row based scans.  Column scans would be better if they
166 // did a block load and then locally transposed.  CUB's device wide scan is not
167 // used in the large 1D case, even though it would be more efficient, because
168 // it is not deterministic.
169 template <typename T, typename Op, int BlockDim = 128, int ItemsPerThread = 4>
__launch_bounds__(BlockDim)170 __launch_bounds__(BlockDim) __global__
171     void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz,
172                      bool exclusive, bool reverse, Op op) {
173   typedef gpuprim::BlockLoad<T, BlockDim, ItemsPerThread,
174                              gpuprim::BLOCK_LOAD_TRANSPOSE>
175       BlockLoad;
176   typedef gpuprim::BlockStore<T, BlockDim, ItemsPerThread,
177                               gpuprim::BLOCK_STORE_TRANSPOSE>
178       BlockStore;
179   typedef gpuprim::BlockScan<T, BlockDim> BlockScan;
180 
181   // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan
182   __shared__ union {
183     typename BlockLoad::TempStorage load;
184     typename BlockScan::TempStorage scan;
185     typename BlockStore::TempStorage store;
186   } temp_storage;
187 
188   int problem_length = dimy;
189 
190   // Initialize running total
191   BlockPrefixCallbackOp<T, Op> prefix_op(IdentityValue<T, Op>()(), op);
192 
193   MapIndexToLocation map_op(dimx, dimy, dimz, reverse);
194   int block_start = problem_length * blockIdx.x;
195   // Have the block iterate over segments of items
196   for (int block_offset = block_start;
197        block_offset < block_start + problem_length;
198        block_offset += BlockDim * ItemsPerThread) {
199     int valid_items = min(BlockDim * ItemsPerThread,
200                           problem_length - (block_offset % problem_length));
201 
202     // first construct a counting iterator that has the desired start point
203     typedef gpuprim::TransformInputIterator<int, MapIndexToLocation,
204                                             gpuprim::CountingInputIterator<int>>
205         MapIterType;
206 
207     gpuprim::CountingInputIterator<int> counting_iter(block_offset);
208 
209     // Next map the iterator to the actual locations in memory
210     MapIterType map_iter(counting_iter, map_op);
211 
212     PermutationInputIterator<T, const T*, MapIterType> permutein_iter(in,
213                                                                       map_iter);
214     PermutationOutputIterator<T, T*, MapIterType> permuteout_iter(out,
215                                                                   map_iter);
216 
217     // Load a segment of consecutive items that are blocked across threads
218     T thread_data[ItemsPerThread];
219     BlockLoad(temp_storage.load).Load(permutein_iter, thread_data, valid_items);
220     __syncthreads();
221 
222     // Collectively compute the block-wide scan
223     if (exclusive) {
224       BlockScan(temp_storage.scan)
225           .ExclusiveScan(thread_data, thread_data, op, prefix_op);
226     } else {
227       BlockScan(temp_storage.scan)
228           .InclusiveScan(thread_data, thread_data, op, prefix_op);
229     }
230     __syncthreads();
231 
232     // Store scanned items to output segment
233     BlockStore(temp_storage.store)
234         .Store(permuteout_iter, thread_data, valid_items);
235     __syncthreads();
236   }
237 }
238 
239 template <typename T, typename Op>
LaunchScan(const GPUDevice & d,typename TTypes<T,3>::ConstTensor in,typename TTypes<T,3>::Tensor out,Op op,const bool reverse,const bool exclusive)240 void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
241                 typename TTypes<T, 3>::Tensor out, Op op, const bool reverse,
242                 const bool exclusive) {
243   const int items_per_thread = 4;
244 
245   int dimx = in.dimension(0);
246   int dimy = in.dimension(1);
247   int dimz = in.dimension(2);
248   int num_blocks = dimx * dimz;
249 
250   int ideal_block_size = dimy / items_per_thread;
251   const int rocm_threads_per_warp = 64;
252   ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp);
253 
254   // There seems to be a bug when the type is not float and block_size 1024.
255   // Launch on the smallest power of 2 block size that we can.
256   if (ideal_block_size >= 1024 && std::is_same<T, float>::value) {
257     const int block_size = 1024;
258     TF_CHECK_OK(
259         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
260                         num_blocks, block_size, 0, d.stream(), in.data(),
261                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
262   } else if (ideal_block_size >= 512) {
263     const int block_size = 512;
264     TF_CHECK_OK(
265         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
266                         num_blocks, block_size, 0, d.stream(), in.data(),
267                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
268   } else if (ideal_block_size >= 256) {
269     const int block_size = 256;
270     TF_CHECK_OK(
271         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
272                         num_blocks, block_size, 0, d.stream(), in.data(),
273                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
274   } else if (ideal_block_size >= 128) {
275     const int block_size = 128;
276     TF_CHECK_OK(
277         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
278                         num_blocks, block_size, 0, d.stream(), in.data(),
279                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
280 #if TENSORFLOW_COMPILER_IS_HIP_CLANG
281     // HIP-CLANG has some kind of problem here with 32 threads (possibly because
282     // the warpsize is 64). Reenable when working properly
283   } else if (true) {
284 #else
285   } else if (ideal_block_size >= 64) {
286 #endif
287     const int block_size = 64;
288     TF_CHECK_OK(
289         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
290                         num_blocks, block_size, 0, d.stream(), in.data(),
291                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
292   } else {
293     const int block_size = 32;
294     TF_CHECK_OK(
295         GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
296                         num_blocks, block_size, 0, d.stream(), in.data(),
297                         out.data(), dimx, dimy, dimz, exclusive, reverse, op));
298   }
299 }
300 
301 template <typename T>
302 struct Scan<GPUDevice, Eigen::internal::SumReducer<T>, T> {
303   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
304                   typename TTypes<T, 3>::Tensor out,
305                   const Eigen::internal::SumReducer<T>& reducer,
306                   const bool reverse, const bool exclusive) {
307     LaunchScan<T, Sum<T>>(d, in, out, Sum<T>(), reverse, exclusive);
308   }
309 };
310 
311 template <typename T>
312 struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
313   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
314                   typename TTypes<T, 3>::Tensor out,
315                   const Eigen::internal::ProdReducer<T>& reducer,
316                   const bool reverse, const bool exclusive) {
317     LaunchScan<T, Prod<T>>(d, in, out, Prod<T>(), reverse, exclusive);
318   }
319 };
320 
321 template <typename T>
322 struct Scan<GPUDevice, LogSumExpReducer<T>, T> {
323   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
324                   typename TTypes<T, 3>::Tensor out,
325                   const LogSumExpReducer<T>& reducer, const bool reverse,
326                   const bool exclusive) {
327     LaunchScan<T, LogSumExp<T>>(d, in, out, LogSumExp<T>(), reverse, exclusive);
328   }
329 };
330 
331 }  // namespace functor
332 }  // end namespace tensorflow
333 
334 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
335 
336 #endif  // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
337