xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAApplyUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/ApplyGridUtils.cuh>
4 #include <ATen/cuda/detail/IndexUtils.cuh>
5 #include <ATen/core/TensorBase.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/Atomic.cuh>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <c10/macros/Macros.h>
10 #include <ATen/native/Copy.h>
11 
12 #include <math.h>
13 
14 //
15 // This file contains pointwise operation functions and kernels that
16 // work on both contiguous and non-contiguous tensor arguments of
17 // arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
18 // copying or temporary storage.
19 //
20 
21 /*
22   NOTE [ CUDA_tensor_applyN helpers ]
23 
24   The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
25   functions apply a pointwise operator to N tensor(s).
26 
27   The calling convention is
28 
29   1. The template arguments should be, sequentially,
30     - First N typename args specify the scalar types of each of the N tensors.
31     - (Optional) `int step` arg specifies the number of elements processed
32       together at the same time.
33       Default is 1.
34     - A usually omitted (i.e., inferred) typename arg specifies the type of the
35       function/functor applied on `N * step` values  in each iteration of each
36       CUDA thread.
37   2. The arguments should be, sequentially,
38     - N tensors
39     - op: a function/functor that processes `N * step` values at the same time.
40       - If `step == 1`, it must have signature
41         `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
42         `scalar*_t`s are the first N typename template args, and the inputs
43         are the `N` values from the `N` tensors retrieved at a common index.
44       - Otherwise, it must must have signature
45           void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&,  // repeat `step` times
46                          scalar2_t&, scalar2_t&, ..., scalar2_t&,  // repeat `step` times
47                          ...,
48                          scalarN_t&, scalarN_t&, ..., scalarN_t&)  // repeat `step` times
49         Different from `step == 1` case, it processes `N * step` values taken
50         from `step` common indices. Moreover, the first input `n` represents the
51         number of valid indices (it will always have `0 < n <= step`). It will
52         almost always be `step`, but at the boundary we may not have full `step`
53         elements and `n` can be a lesser value.
54 
55         E.g., if `step == 4` and `N == 2`, `op` could be
56 
57           [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
58                     scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
59             // Only process u1, ..., un and v1, ..., vn.
60             // So if `n == 3`, `u4` and `v4` need not to be considered.
61           }
62 
63       In both cases, the references can actually be const, but at least one of
64       them should be non-const in order to write the output.
65     - (Optional, but recommended) N TensorArgType args that specify for each
66       tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
67       or only reads (i.e., TensorArgType::ReadOnly).
68       Default is TensorArgType::ReadWrite for first Tensor, and
69                  TensorArgType::ReadOnly  for the rest.
70 
71   E.g.,
72 
73   to compute a = b^2 for a and b of same dtype, we can call
74 
75   CUDA_tensor_apply2<scalar, scalar>(
76     a, b,
77     [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
78   );
79 
80   to work on 2 values at the same time, we can call
81 
82   CUDA_tensor_apply2<scalar1, scalar2, 2>(
83     a, b,
84     [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
85                           const scalar2 &b_val1, const scalar2 &b_val2) {
86       // call special vectorized op here, or just do elementwise and enjoy unrolling...
87       // if n == 1, only process a_val1 and b_val1
88     }
89   );
90 */
91 
92 namespace at::cuda {
93 
94 // TODO: combine with TensorArg?  So far that's been for debugging, and this is functional...
95 enum class TensorArgType { ReadWrite, ReadOnly };
96 
97 namespace {
98 
99 // Rearrange dimensions for pointwise operations so that strides are in
100 // decreasing order as much as possible, so that kernels have better memory
101 // access patterns.
102 //
103 // For example, consider a binary operation on two "transposed" 2-dim tensors:
104 //    sizes:          256 512
105 //    aInfo->strides:   1 256
106 //    bInfo->strides:   1 256
107 //
108 // Given this, each concurrent memory access inside kernelPointwiseApply2() is
109 // exactly 256 elements apart, resulting in poor performance.
110 //
111 // This function exchanges dimensions so that memory access is contiguous:
112 //    sizes:          512 256
113 //    aInfo->strides: 256   1
114 //    bInfo->strides: 256   1
115 //
116 // (Actually, it becomes even better because now collapseDims() can turn each
117 // input into one contiguous array.)
118 //
119 // In general, given M (<=4) TensorInfo's with N dimensions, we can view each
120 // strides[i] (0 <= i < N) as an M-tuple.  Given each pair i < j, we exchange
121 // strides[i] and [j] if
122 //    (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
123 //        (exchanging them will benefit input #k), and
124 //    (2) strides[i][k] <= strieds[j][k] for all k
125 //        (exchanging them will not make any input worse).
126 template <typename T1, typename IndexType,
127           typename T2 = void, typename T3 = void, typename T4 = void>
rearrangeDims(detail::TensorInfo<T1,IndexType> * aInfo,detail::TensorInfo<T2,IndexType> * bInfo=nullptr,detail::TensorInfo<T3,IndexType> * cInfo=nullptr,detail::TensorInfo<T4,IndexType> * dInfo=nullptr)128 inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
129                           detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
130                           detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
131                           detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
132   int numInfos = 1;
133   int dims = aInfo->dims;
134   IndexType *sizes[4] = { aInfo->sizes, };
135   IndexType *strides[4] = { aInfo->strides, };
136 
137   if (bInfo != nullptr) {
138     ++numInfos;
139     if (bInfo->dims != dims) return;
140     sizes[1] = bInfo->sizes;
141     strides[1] = bInfo->strides;
142   }
143 
144   if (cInfo != nullptr) {
145     ++numInfos;
146     if (cInfo->dims != dims) return;
147     sizes[2] = cInfo->sizes;
148     strides[2] = cInfo->strides;
149   }
150 
151   if (dInfo != nullptr) {
152     ++numInfos;
153     if (dInfo->dims != dims) return;
154     sizes[3] = dInfo->sizes;
155     strides[3] = dInfo->strides;
156   }
157 
158   // Bail out if sizes do not match: we are using "deprecated pointwise
159   // behavior" among tensors of different shapes but same number of elements.
160   for (int i = 1; i < numInfos; ++i) {
161     for (int j = 0; j < dims; ++j) {
162       if (sizes[i][j] != sizes[0][j]) return;
163     }
164   }
165 
166   for (int i = 0; i < dims - 1; ++i) {
167     // No need to consider dimensions of size 1.
168     if (sizes[0][i] == 1) continue;
169 
170     for (int j = i + 1; j < dims; ++j) {
171       if (sizes[0][j] == 1) continue;
172 
173       // Compare the relative sizes of strides between dim #i and dim #j.
174       bool hasIncreasingStrides = false;
175       bool hasDecreasingStrides = false;
176 
177       for (int k = 0; k < numInfos; k++) {
178         IndexType stride_i = strides[k][i];
179         IndexType stride_j = strides[k][j];
180         if (stride_i < stride_j) {
181           hasIncreasingStrides = true;
182         } else if (stride_i > stride_j) {
183           hasDecreasingStrides = true;
184         }
185       }
186 
187       if (hasIncreasingStrides && !hasDecreasingStrides) {
188         for (int k = 0; k < numInfos; k++) {
189           IndexType size = sizes[k][i];
190           sizes[k][i] = sizes[k][j];
191           sizes[k][j] = size;
192 
193           IndexType stride = strides[k][i];
194           strides[k][i] = strides[k][j];
195           strides[k][j] = stride;
196         }
197       }
198     }
199   }
200 }
201 
202 // The `remaining_steps` argument is used to support Op that operates on
203 // multiple elements at the same time. Generally, the strategy of ApplyOpN is to
204 //  1. Initialize `remaining_steps = step`, where `step` is the template arg of
205 //     CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
206 //     number of elements in bound for this call. It will almost always equal to
207 //     `step` except at boundaries.
208 //  2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
209 //     bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
210 //  3. At `remaining_steps = 0`,
211 //       if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
212 //       if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
213 //                                  tensor2_val1, tensor2_val2, ..., tesor2_valstep,
214 //                                       ...
215 //                                  tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
216 //
217 // See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
218 
219 template <typename Op,
220           typename scalar,
221           typename IndexType,
222           int ADims,
223           int remaining_steps,
224           typename... Offsets>
225 struct ApplyOp1 {
226 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1227 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
228                   IndexType linearIndex, Offsets... aOffsets) {
229   // Convert `linearIndex` into an offset of `a`
230   const IndexType aOffset = sizeof...(Offsets) < n ?
231     detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
232 
233   ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
234     a, op, n, linearIndex + 1, aOffsets..., aOffset
235   );
236 }
237 };
238 
239 // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
240 // We don't need to pass in how many elements need to processed in this case.
241 template <typename Op,
242           typename scalar,
243           typename IndexType,
244           int ADims,
245           typename Offset>
246 struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
247 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1248 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
249                   int n, IndexType linearIndex, Offset offset) {
250   op(a.data[offset]);
251 }
252 };
253 
254 template <typename Op,
255           typename scalar,
256           typename IndexType,
257           int ADims,
258           typename... Offsets>
259 struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
260 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1261 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
262                  IndexType linearIndex, Offsets... offsets) {
263   op(n, a.data[offsets]...);
264 }
265 };
266 
267 template <typename Op,
268           typename scalar,
269           typename IndexType,
270           int ADims,
271           int step>
272 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK,AT_APPLY_BLOCKS_PER_SM)273 C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
274 #endif
275 __global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
276                                       IndexType totalElements, const Op op) {
277   for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
278        linearIndex < totalElements;
279        linearIndex += gridDim.x * blockDim.x * step) {
280     ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
281       a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
282   }
283 }
284 
285 
286 template <typename Op,
287           typename scalar1,
288           typename scalar2,
289           typename IndexType,
290           int ADims,
291           int BDims,
292           int remaining_steps,
293           typename... Offsets>
294 struct ApplyOp2 {
295 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2296 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
297                   detail::TensorInfo<scalar2, IndexType> &b,
298                   const Op &op, int64_t n, IndexType linearIndex,
299                   Offsets... aOffsets, Offsets... bOffsets) {
300   // Convert `linearIndex` into an offset of `a`
301   const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
302     detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
303 
304   // Convert `linearIndex` into an offset of `b`
305   const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
306     detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
307 
308   ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
309     a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
310   );
311 }
312 };
313 
314 // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
315 // We don't need to pass in how many elements need to processed in this case.
316 template <typename Op,
317           typename scalar1,
318           typename scalar2,
319           typename IndexType,
320           int ADims,
321           int BDims,
322           typename Offset>
323 struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
324 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2325 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
326                   detail::TensorInfo<scalar2, IndexType> &b,
327                   const Op &op, int /*n*/, IndexType /*linearIndex*/,
328                   Offset aOffset, Offset bOffset) {
329   op(a.data[aOffset], b.data[bOffset]);
330 }
331 };
332 
333 template <typename Op,
334           typename scalar1,
335           typename scalar2,
336           typename IndexType,
337           int ADims,
338           int BDims,
339           typename... Offsets>
340 struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
341 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2342 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
343                   detail::TensorInfo<scalar2, IndexType> &b,
344                   const Op &op, int n, IndexType linearIndex,
345                   Offsets... aOffsets, Offsets... bOffsets) {
346   op(n, a.data[aOffsets]..., b.data[bOffsets]...);
347 }
348 };
349 
350 template <typename Op,
351           typename scalar1,
352           typename scalar2,
353           typename IndexType,
354           int ADims, int BDims,
355           int step,
356           int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
357           int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
358 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(max_threads_per_block,min_blocks_per_sm)359 C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
360 #endif
361 __global__ void
362 kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
363                       detail::TensorInfo<scalar2, IndexType> b,
364                       IndexType totalElements,
365                       const Op op) {
366   for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
367        linearIndex < totalElements;
368        linearIndex += gridDim.x * blockDim.x * step) {
369     ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
370       a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
371       linearIndex);
372   }
373 }
374 
375 } // anonymous namespace
376 
377 template <typename scalar1, typename scalar2, int step, typename Op,
378           int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
379           int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
CUDA_tensor_apply2(at::TensorBase a,at::TensorBase b,const Op op,TensorArgType aType=TensorArgType::ReadWrite,TensorArgType bType=TensorArgType::ReadOnly)380 inline bool CUDA_tensor_apply2(at::TensorBase a,
381                                at::TensorBase b,
382                                const Op op,
383                                TensorArgType aType = TensorArgType::ReadWrite,
384                                TensorArgType bType = TensorArgType::ReadOnly) {
385   TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
386               "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
387               "tensors with type ", a.device().type(), " and ", b.device().type());
388   int64_t totalElements = a.numel();
389 
390   if (totalElements != b.numel()) {
391     return false;
392   }
393 
394   if (a.dim() > MAX_TENSORINFO_DIMS ||
395       b.dim() > MAX_TENSORINFO_DIMS) {
396     return false;
397   }
398 
399   if (a.numel() == 0) {
400     // Empty tensor; do nothing
401     return true;
402   }
403   const dim3 block = getApplyBlock(max_threads_per_block);
404 
405   dim3 grid;
406   auto curDevice = current_device();
407   if (curDevice == -1) return false;
408   if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
409     return false;
410   }
411 
412   /*
413   Expands readable/writable tensors whose indices may be "overlapped."
414   This ensures that each element of the tensor is operated on once and only
415   once.
416   */
417   TensorBase oldA;
418   TensorBase oldB;
419 
420   if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
421     // Must perform in contiguous space
422     oldA = std::exchange(a, a.contiguous());
423   }
424   if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
425     // Must perform in contiguous space
426     oldB = std::exchange(b, b.contiguous());
427   }
428 
429   // It is possible that the tensor dimensions are able to be collapsed,
430   // and thus we can reduce the actual code complexity of the copy by
431   // exploiting this knowledge statically, since the div/mod is the
432   // most expensive part of the operation, more so than memory accesses.
433   // For instance, when copying a non-contiguous to a contiguous tensor
434   // (or vice versa), the contiguous tensor can be collapsed to one
435   // dimension, and the loop to translate the linear index to the array
436   // index can be similarly collapsed. That is what this unrolling is for.
437 
438 #define HANDLE_CASE(TYPE, A, B)                                        \
439   kernelPointwiseApply2<Op,                                            \
440                         scalar1,                                       \
441                         scalar2,                                       \
442                         TYPE, A, B, step,                              \
443                         max_threads_per_block,                         \
444                         min_blocks_per_sm>                             \
445    <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(    \
446        aInfo, bInfo, static_cast<TYPE>(totalElements), op);            \
447   C10_CUDA_KERNEL_LAUNCH_CHECK();
448 
449 #define HANDLE_B_CASE(TYPE, A, B) {         \
450   switch (B) {                              \
451     case 1:                                 \
452       HANDLE_CASE(TYPE, A, 1);              \
453       break;                                \
454     case 2:                                 \
455       HANDLE_CASE(TYPE, A, 2);              \
456       break;                                \
457     default:                                \
458       HANDLE_CASE(TYPE, A, -1);             \
459       break;                                \
460   }                                         \
461 }
462 
463 #define HANDLE_A_CASE(TYPE, A, B) {         \
464   switch (A) {                              \
465     case 1:                                 \
466       HANDLE_B_CASE(TYPE, 1, B);            \
467       break;                                \
468     case 2:                                 \
469       HANDLE_B_CASE(TYPE, 2, B);            \
470       break;                                \
471     default:                                \
472       HANDLE_B_CASE(TYPE, -1, B);           \
473       break;                                \
474   }                                         \
475 }
476 
477   if (detail::canUse32BitIndexMath(a) &&
478       detail::canUse32BitIndexMath(b)) {
479     detail::TensorInfo<scalar1, unsigned int> aInfo =
480       detail::getTensorInfo<scalar1, unsigned int>(a);
481 
482     detail::TensorInfo<scalar2, unsigned int> bInfo =
483       detail::getTensorInfo<scalar2, unsigned int>(b);
484     rearrangeDims(&aInfo, &bInfo);
485     aInfo.collapseDims();
486     bInfo.collapseDims();
487 
488     HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
489   } else {
490     detail::TensorInfo<scalar1, uint64_t> aInfo =
491       detail::getTensorInfo<scalar1, uint64_t>(a);
492 
493     detail::TensorInfo<scalar2, uint64_t> bInfo =
494       detail::getTensorInfo<scalar2, uint64_t>(b);
495     rearrangeDims(&aInfo, &bInfo);
496     aInfo.collapseDims();
497     bInfo.collapseDims();
498 
499     /*
500     Only instantiates the all 1D special case and the fallback all nD case for
501     large (64-bit indexed) tensors to reduce compilation time.
502     */
503     if (aInfo.dims == 1 && bInfo.dims == 1) {
504       HANDLE_CASE(uint64_t, 1, 1);
505     } else {
506       HANDLE_CASE(uint64_t, -1, -1);
507     }
508   }
509 #undef HANDLE_CASE
510 #undef HANDLE_B_CASE
511 #undef HANDLE_A_CASE
512 
513   if (oldA.defined()) {
514     at::native::copy_ignoring_overlaps(oldA, a);
515   }
516 
517   if (oldB.defined()) {
518     at::native::copy_ignoring_overlaps(oldB, b);
519   }
520 
521   return true;
522 }
523 
524 /* Provides default step = 1 to CUDA_tensor_apply2. */
525 template <typename scalar1, typename scalar2, typename Op,
526           int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
527           int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
CUDA_tensor_apply2(const at::TensorBase & a,const at::TensorBase & b,const Op op,TensorArgType aType=TensorArgType::ReadWrite,TensorArgType bType=TensorArgType::ReadOnly)528 inline bool CUDA_tensor_apply2(const at::TensorBase &a,
529                                const at::TensorBase &b,
530                                const Op op,
531                                TensorArgType aType = TensorArgType::ReadWrite,
532                                TensorArgType bType = TensorArgType::ReadOnly) {
533   return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
534                             max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
535 }
536 
537 } // namespace at::cuda
538