1 #pragma once
2
3 #include <ATen/cuda/ApplyGridUtils.cuh>
4 #include <ATen/cuda/detail/IndexUtils.cuh>
5 #include <ATen/core/TensorBase.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/Atomic.cuh>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <c10/macros/Macros.h>
10 #include <ATen/native/Copy.h>
11
12 #include <math.h>
13
14 //
15 // This file contains pointwise operation functions and kernels that
16 // work on both contiguous and non-contiguous tensor arguments of
17 // arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
18 // copying or temporary storage.
19 //
20
21 /*
22 NOTE [ CUDA_tensor_applyN helpers ]
23
24 The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
25 functions apply a pointwise operator to N tensor(s).
26
27 The calling convention is
28
29 1. The template arguments should be, sequentially,
30 - First N typename args specify the scalar types of each of the N tensors.
31 - (Optional) `int step` arg specifies the number of elements processed
32 together at the same time.
33 Default is 1.
34 - A usually omitted (i.e., inferred) typename arg specifies the type of the
35 function/functor applied on `N * step` values in each iteration of each
36 CUDA thread.
37 2. The arguments should be, sequentially,
38 - N tensors
39 - op: a function/functor that processes `N * step` values at the same time.
40 - If `step == 1`, it must have signature
41 `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
42 `scalar*_t`s are the first N typename template args, and the inputs
43 are the `N` values from the `N` tensors retrieved at a common index.
44 - Otherwise, it must must have signature
45 void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times
46 scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times
47 ...,
48 scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times
49 Different from `step == 1` case, it processes `N * step` values taken
50 from `step` common indices. Moreover, the first input `n` represents the
51 number of valid indices (it will always have `0 < n <= step`). It will
52 almost always be `step`, but at the boundary we may not have full `step`
53 elements and `n` can be a lesser value.
54
55 E.g., if `step == 4` and `N == 2`, `op` could be
56
57 [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
58 scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
59 // Only process u1, ..., un and v1, ..., vn.
60 // So if `n == 3`, `u4` and `v4` need not to be considered.
61 }
62
63 In both cases, the references can actually be const, but at least one of
64 them should be non-const in order to write the output.
65 - (Optional, but recommended) N TensorArgType args that specify for each
66 tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
67 or only reads (i.e., TensorArgType::ReadOnly).
68 Default is TensorArgType::ReadWrite for first Tensor, and
69 TensorArgType::ReadOnly for the rest.
70
71 E.g.,
72
73 to compute a = b^2 for a and b of same dtype, we can call
74
75 CUDA_tensor_apply2<scalar, scalar>(
76 a, b,
77 [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
78 );
79
80 to work on 2 values at the same time, we can call
81
82 CUDA_tensor_apply2<scalar1, scalar2, 2>(
83 a, b,
84 [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
85 const scalar2 &b_val1, const scalar2 &b_val2) {
86 // call special vectorized op here, or just do elementwise and enjoy unrolling...
87 // if n == 1, only process a_val1 and b_val1
88 }
89 );
90 */
91
92 namespace at::cuda {
93
94 // TODO: combine with TensorArg? So far that's been for debugging, and this is functional...
95 enum class TensorArgType { ReadWrite, ReadOnly };
96
97 namespace {
98
99 // Rearrange dimensions for pointwise operations so that strides are in
100 // decreasing order as much as possible, so that kernels have better memory
101 // access patterns.
102 //
103 // For example, consider a binary operation on two "transposed" 2-dim tensors:
104 // sizes: 256 512
105 // aInfo->strides: 1 256
106 // bInfo->strides: 1 256
107 //
108 // Given this, each concurrent memory access inside kernelPointwiseApply2() is
109 // exactly 256 elements apart, resulting in poor performance.
110 //
111 // This function exchanges dimensions so that memory access is contiguous:
112 // sizes: 512 256
113 // aInfo->strides: 256 1
114 // bInfo->strides: 256 1
115 //
116 // (Actually, it becomes even better because now collapseDims() can turn each
117 // input into one contiguous array.)
118 //
119 // In general, given M (<=4) TensorInfo's with N dimensions, we can view each
120 // strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
121 // strides[i] and [j] if
122 // (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
123 // (exchanging them will benefit input #k), and
124 // (2) strides[i][k] <= strieds[j][k] for all k
125 // (exchanging them will not make any input worse).
126 template <typename T1, typename IndexType,
127 typename T2 = void, typename T3 = void, typename T4 = void>
rearrangeDims(detail::TensorInfo<T1,IndexType> * aInfo,detail::TensorInfo<T2,IndexType> * bInfo=nullptr,detail::TensorInfo<T3,IndexType> * cInfo=nullptr,detail::TensorInfo<T4,IndexType> * dInfo=nullptr)128 inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
129 detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
130 detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
131 detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
132 int numInfos = 1;
133 int dims = aInfo->dims;
134 IndexType *sizes[4] = { aInfo->sizes, };
135 IndexType *strides[4] = { aInfo->strides, };
136
137 if (bInfo != nullptr) {
138 ++numInfos;
139 if (bInfo->dims != dims) return;
140 sizes[1] = bInfo->sizes;
141 strides[1] = bInfo->strides;
142 }
143
144 if (cInfo != nullptr) {
145 ++numInfos;
146 if (cInfo->dims != dims) return;
147 sizes[2] = cInfo->sizes;
148 strides[2] = cInfo->strides;
149 }
150
151 if (dInfo != nullptr) {
152 ++numInfos;
153 if (dInfo->dims != dims) return;
154 sizes[3] = dInfo->sizes;
155 strides[3] = dInfo->strides;
156 }
157
158 // Bail out if sizes do not match: we are using "deprecated pointwise
159 // behavior" among tensors of different shapes but same number of elements.
160 for (int i = 1; i < numInfos; ++i) {
161 for (int j = 0; j < dims; ++j) {
162 if (sizes[i][j] != sizes[0][j]) return;
163 }
164 }
165
166 for (int i = 0; i < dims - 1; ++i) {
167 // No need to consider dimensions of size 1.
168 if (sizes[0][i] == 1) continue;
169
170 for (int j = i + 1; j < dims; ++j) {
171 if (sizes[0][j] == 1) continue;
172
173 // Compare the relative sizes of strides between dim #i and dim #j.
174 bool hasIncreasingStrides = false;
175 bool hasDecreasingStrides = false;
176
177 for (int k = 0; k < numInfos; k++) {
178 IndexType stride_i = strides[k][i];
179 IndexType stride_j = strides[k][j];
180 if (stride_i < stride_j) {
181 hasIncreasingStrides = true;
182 } else if (stride_i > stride_j) {
183 hasDecreasingStrides = true;
184 }
185 }
186
187 if (hasIncreasingStrides && !hasDecreasingStrides) {
188 for (int k = 0; k < numInfos; k++) {
189 IndexType size = sizes[k][i];
190 sizes[k][i] = sizes[k][j];
191 sizes[k][j] = size;
192
193 IndexType stride = strides[k][i];
194 strides[k][i] = strides[k][j];
195 strides[k][j] = stride;
196 }
197 }
198 }
199 }
200 }
201
202 // The `remaining_steps` argument is used to support Op that operates on
203 // multiple elements at the same time. Generally, the strategy of ApplyOpN is to
204 // 1. Initialize `remaining_steps = step`, where `step` is the template arg of
205 // CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
206 // number of elements in bound for this call. It will almost always equal to
207 // `step` except at boundaries.
208 // 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
209 // bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
210 // 3. At `remaining_steps = 0`,
211 // if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
212 // if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
213 // tensor2_val1, tensor2_val2, ..., tesor2_valstep,
214 // ...
215 // tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
216 //
217 // See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
218
219 template <typename Op,
220 typename scalar,
221 typename IndexType,
222 int ADims,
223 int remaining_steps,
224 typename... Offsets>
225 struct ApplyOp1 {
226 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1227 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
228 IndexType linearIndex, Offsets... aOffsets) {
229 // Convert `linearIndex` into an offset of `a`
230 const IndexType aOffset = sizeof...(Offsets) < n ?
231 detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
232
233 ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
234 a, op, n, linearIndex + 1, aOffsets..., aOffset
235 );
236 }
237 };
238
239 // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
240 // We don't need to pass in how many elements need to processed in this case.
241 template <typename Op,
242 typename scalar,
243 typename IndexType,
244 int ADims,
245 typename Offset>
246 struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
247 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1248 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
249 int n, IndexType linearIndex, Offset offset) {
250 op(a.data[offset]);
251 }
252 };
253
254 template <typename Op,
255 typename scalar,
256 typename IndexType,
257 int ADims,
258 typename... Offsets>
259 struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
260 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp1261 static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
262 IndexType linearIndex, Offsets... offsets) {
263 op(n, a.data[offsets]...);
264 }
265 };
266
267 template <typename Op,
268 typename scalar,
269 typename IndexType,
270 int ADims,
271 int step>
272 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK,AT_APPLY_BLOCKS_PER_SM)273 C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
274 #endif
275 __global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
276 IndexType totalElements, const Op op) {
277 for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
278 linearIndex < totalElements;
279 linearIndex += gridDim.x * blockDim.x * step) {
280 ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
281 a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
282 }
283 }
284
285
286 template <typename Op,
287 typename scalar1,
288 typename scalar2,
289 typename IndexType,
290 int ADims,
291 int BDims,
292 int remaining_steps,
293 typename... Offsets>
294 struct ApplyOp2 {
295 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2296 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
297 detail::TensorInfo<scalar2, IndexType> &b,
298 const Op &op, int64_t n, IndexType linearIndex,
299 Offsets... aOffsets, Offsets... bOffsets) {
300 // Convert `linearIndex` into an offset of `a`
301 const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
302 detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
303
304 // Convert `linearIndex` into an offset of `b`
305 const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
306 detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
307
308 ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
309 a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
310 );
311 }
312 };
313
314 // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
315 // We don't need to pass in how many elements need to processed in this case.
316 template <typename Op,
317 typename scalar1,
318 typename scalar2,
319 typename IndexType,
320 int ADims,
321 int BDims,
322 typename Offset>
323 struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
324 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2325 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
326 detail::TensorInfo<scalar2, IndexType> &b,
327 const Op &op, int /*n*/, IndexType /*linearIndex*/,
328 Offset aOffset, Offset bOffset) {
329 op(a.data[aOffset], b.data[bOffset]);
330 }
331 };
332
333 template <typename Op,
334 typename scalar1,
335 typename scalar2,
336 typename IndexType,
337 int ADims,
338 int BDims,
339 typename... Offsets>
340 struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
341 __device__ __forceinline__
applyat::cuda::__anonfece6e520111::ApplyOp2342 static void apply(detail::TensorInfo<scalar1, IndexType> &a,
343 detail::TensorInfo<scalar2, IndexType> &b,
344 const Op &op, int n, IndexType linearIndex,
345 Offsets... aOffsets, Offsets... bOffsets) {
346 op(n, a.data[aOffsets]..., b.data[bOffsets]...);
347 }
348 };
349
350 template <typename Op,
351 typename scalar1,
352 typename scalar2,
353 typename IndexType,
354 int ADims, int BDims,
355 int step,
356 int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
357 int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
358 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(max_threads_per_block,min_blocks_per_sm)359 C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
360 #endif
361 __global__ void
362 kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
363 detail::TensorInfo<scalar2, IndexType> b,
364 IndexType totalElements,
365 const Op op) {
366 for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
367 linearIndex < totalElements;
368 linearIndex += gridDim.x * blockDim.x * step) {
369 ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
370 a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
371 linearIndex);
372 }
373 }
374
375 } // anonymous namespace
376
377 template <typename scalar1, typename scalar2, int step, typename Op,
378 int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
379 int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
CUDA_tensor_apply2(at::TensorBase a,at::TensorBase b,const Op op,TensorArgType aType=TensorArgType::ReadWrite,TensorArgType bType=TensorArgType::ReadOnly)380 inline bool CUDA_tensor_apply2(at::TensorBase a,
381 at::TensorBase b,
382 const Op op,
383 TensorArgType aType = TensorArgType::ReadWrite,
384 TensorArgType bType = TensorArgType::ReadOnly) {
385 TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
386 "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
387 "tensors with type ", a.device().type(), " and ", b.device().type());
388 int64_t totalElements = a.numel();
389
390 if (totalElements != b.numel()) {
391 return false;
392 }
393
394 if (a.dim() > MAX_TENSORINFO_DIMS ||
395 b.dim() > MAX_TENSORINFO_DIMS) {
396 return false;
397 }
398
399 if (a.numel() == 0) {
400 // Empty tensor; do nothing
401 return true;
402 }
403 const dim3 block = getApplyBlock(max_threads_per_block);
404
405 dim3 grid;
406 auto curDevice = current_device();
407 if (curDevice == -1) return false;
408 if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
409 return false;
410 }
411
412 /*
413 Expands readable/writable tensors whose indices may be "overlapped."
414 This ensures that each element of the tensor is operated on once and only
415 once.
416 */
417 TensorBase oldA;
418 TensorBase oldB;
419
420 if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
421 // Must perform in contiguous space
422 oldA = std::exchange(a, a.contiguous());
423 }
424 if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
425 // Must perform in contiguous space
426 oldB = std::exchange(b, b.contiguous());
427 }
428
429 // It is possible that the tensor dimensions are able to be collapsed,
430 // and thus we can reduce the actual code complexity of the copy by
431 // exploiting this knowledge statically, since the div/mod is the
432 // most expensive part of the operation, more so than memory accesses.
433 // For instance, when copying a non-contiguous to a contiguous tensor
434 // (or vice versa), the contiguous tensor can be collapsed to one
435 // dimension, and the loop to translate the linear index to the array
436 // index can be similarly collapsed. That is what this unrolling is for.
437
438 #define HANDLE_CASE(TYPE, A, B) \
439 kernelPointwiseApply2<Op, \
440 scalar1, \
441 scalar2, \
442 TYPE, A, B, step, \
443 max_threads_per_block, \
444 min_blocks_per_sm> \
445 <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>( \
446 aInfo, bInfo, static_cast<TYPE>(totalElements), op); \
447 C10_CUDA_KERNEL_LAUNCH_CHECK();
448
449 #define HANDLE_B_CASE(TYPE, A, B) { \
450 switch (B) { \
451 case 1: \
452 HANDLE_CASE(TYPE, A, 1); \
453 break; \
454 case 2: \
455 HANDLE_CASE(TYPE, A, 2); \
456 break; \
457 default: \
458 HANDLE_CASE(TYPE, A, -1); \
459 break; \
460 } \
461 }
462
463 #define HANDLE_A_CASE(TYPE, A, B) { \
464 switch (A) { \
465 case 1: \
466 HANDLE_B_CASE(TYPE, 1, B); \
467 break; \
468 case 2: \
469 HANDLE_B_CASE(TYPE, 2, B); \
470 break; \
471 default: \
472 HANDLE_B_CASE(TYPE, -1, B); \
473 break; \
474 } \
475 }
476
477 if (detail::canUse32BitIndexMath(a) &&
478 detail::canUse32BitIndexMath(b)) {
479 detail::TensorInfo<scalar1, unsigned int> aInfo =
480 detail::getTensorInfo<scalar1, unsigned int>(a);
481
482 detail::TensorInfo<scalar2, unsigned int> bInfo =
483 detail::getTensorInfo<scalar2, unsigned int>(b);
484 rearrangeDims(&aInfo, &bInfo);
485 aInfo.collapseDims();
486 bInfo.collapseDims();
487
488 HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
489 } else {
490 detail::TensorInfo<scalar1, uint64_t> aInfo =
491 detail::getTensorInfo<scalar1, uint64_t>(a);
492
493 detail::TensorInfo<scalar2, uint64_t> bInfo =
494 detail::getTensorInfo<scalar2, uint64_t>(b);
495 rearrangeDims(&aInfo, &bInfo);
496 aInfo.collapseDims();
497 bInfo.collapseDims();
498
499 /*
500 Only instantiates the all 1D special case and the fallback all nD case for
501 large (64-bit indexed) tensors to reduce compilation time.
502 */
503 if (aInfo.dims == 1 && bInfo.dims == 1) {
504 HANDLE_CASE(uint64_t, 1, 1);
505 } else {
506 HANDLE_CASE(uint64_t, -1, -1);
507 }
508 }
509 #undef HANDLE_CASE
510 #undef HANDLE_B_CASE
511 #undef HANDLE_A_CASE
512
513 if (oldA.defined()) {
514 at::native::copy_ignoring_overlaps(oldA, a);
515 }
516
517 if (oldB.defined()) {
518 at::native::copy_ignoring_overlaps(oldB, b);
519 }
520
521 return true;
522 }
523
524 /* Provides default step = 1 to CUDA_tensor_apply2. */
525 template <typename scalar1, typename scalar2, typename Op,
526 int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
527 int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
CUDA_tensor_apply2(const at::TensorBase & a,const at::TensorBase & b,const Op op,TensorArgType aType=TensorArgType::ReadWrite,TensorArgType bType=TensorArgType::ReadOnly)528 inline bool CUDA_tensor_apply2(const at::TensorBase &a,
529 const at::TensorBase &b,
530 const Op op,
531 TensorArgType aType = TensorArgType::ReadWrite,
532 TensorArgType bType = TensorArgType::ReadOnly) {
533 return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
534 max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
535 }
536
537 } // namespace at::cuda
538