xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ScanUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/NumericUtils.h>
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/cuda/cub.cuh>
5 #include <ATen/cuda/CUDAContext.h>
6 
7 #include <c10/util/Load.h>
8 #include <limits>
9 #include <cmath>
10 
11 namespace at {
12 namespace native {
13 
14 template <typename integer>
ceil_div(integer n,integer m)15 constexpr inline integer ceil_div(integer n, integer m) {
16   return (n + m - 1) / m;
17 }
18 
19 template <typename integer>
get_log_num_threads_x_inner_scan(integer num_rows,integer row_size)20 constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
21   integer log_num_threads_x = 0;
22   integer log_num_threads_y = 0;
23   while (((integer)1 << log_num_threads_x) < row_size) {
24     ++log_num_threads_x;
25   }
26   while (((integer)1 << log_num_threads_y) < num_rows) {
27     ++log_num_threads_y;
28   }
29   // we want to keep the ratio between the x-threads and y-threads about the same as
30   // the ratio between the row_size and num_rows, but the total number of threads in
31   // a block should be about 512
32   integer diff = log_num_threads_x - log_num_threads_y;
33   // 9 is from log2(512)
34   log_num_threads_x = ((integer)9 + diff) / (integer)2;
35   // I found that in having larger log_num_threads_x can give significant speed up in some cases,
36   // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
37   // similar to the previous implementation
38   // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
39   log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
40   return log_num_threads_x;
41 }
42 
43 template<typename scalar_t, typename idx_t, typename BinaryOperation>
binary_op_update(const scalar_t lhs,scalar_t & rhs,const idx_t lhs_idx,idx_t & rhs_idx,BinaryOperation binary_op)44 __device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
45   if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
46     rhs = lhs;
47     rhs_idx = lhs_idx;
48   }
49 }
50 /* Perform an inclusive scan along the innermost dimension of a tensor.
51  *
52  * - num_rows is the size of the flattened outer dimensions;
53  * - row_size is the size of the innermost dimension;
54  *
55  * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
56  * considered as having 'num_rows' rows of size 'row_size'.
57  * Each thread block processes one or more sets of contiguous rows (processing multiple rows
58  * per thread block is quicker than processing a single row, especially for short rows).
59  */
60 template<typename scalar_t, class BinaryFunction>
tensor_kernel_scan_innermost_dim_with_indices(const scalar_t * self_,scalar_t * values_,int64_t * indices_,int num_rows,int row_size,const uint32_t num_threads,const uint32_t log_num_threads_x,scalar_t init,BinaryFunction binary_op)61 __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
62                                                 int num_rows, int row_size,
63                                                 const uint32_t num_threads, const uint32_t log_num_threads_x,
64                                                 scalar_t init, BinaryFunction binary_op) {
65   // dynamic memory allocation for vbuf and ibuf
66   alignas(sizeof(double)) extern __shared__ char buf[];
67   scalar_t* vbuf = reinterpret_cast<scalar_t*>(buf); // the size is num_threads * 2
68   int64_t* ibuf = reinterpret_cast<int64_t*>(vbuf + num_threads * 2);
69   const uint32_t num_threads_x = 1 << log_num_threads_x;
70   scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y;
71   int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y;
72 
73   for (int block_row = blockIdx.x * blockDim.y;
74        block_row < num_rows;
75        block_row += blockDim.y * gridDim.x) {
76     int row = block_row + threadIdx.y;
77     const scalar_t *row_self = self_ + row * row_size;
78     scalar_t *row_values = values_ + row * row_size;
79     int64_t *row_indices = indices_ + row * row_size;
80     scalar_t block_total = init;
81     int64_t block_idx_final = 0;
82     const bool row_exists = row < num_rows;
83     // Perform scan on one block at a time, keeping track of the total value of
84     // all blocks processed so far.
85     for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
86       // Load data into shared memory (two values per thread).
87       int col1 = block_col + threadIdx.x;
88       int col2 = block_col + num_threads_x + threadIdx.x;
89       if (row_exists) {
90         if (col1 < row_size) {
91           row_buf[threadIdx.x] = c10::load(&row_self[col1]);
92           row_idx_buf[threadIdx.x] = col1;
93         } else {
94           row_buf[threadIdx.x] = init;
95           // No need to set the index here as the value in init will never be selected
96         }
97 
98         if (col2 < row_size) {
99           row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
100           row_idx_buf[num_threads_x + threadIdx.x] = col2;
101         } else {
102           row_buf[num_threads_x + threadIdx.x] = init;
103           // No need to set the index here as the value in init will never be selected
104         }
105 
106         // Add the total value of all previous blocks to the first value of this block.
107         if (threadIdx.x == 0) {
108           binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
109         }
110       }
111       __syncthreads();
112 
113       // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
114       // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
115       for (uint32_t s = 1; s <= num_threads_x; s <<= 1) {
116         if (row_exists) {
117           uint32_t a = (threadIdx.x / s) * (2 * s) + s;
118           uint32_t ti = a + (threadIdx.x % s);
119           uint32_t si = a - 1;
120           binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op);
121         }
122         __syncthreads();
123       }
124 
125       // Write back to output.
126       if (row_exists) {
127         if (col1 < row_size){
128           row_values[col1] = row_buf[threadIdx.x];
129           row_indices[col1] = row_idx_buf[threadIdx.x];
130         }
131         if (col2 < row_size) {
132           row_values[col2] = row_buf[num_threads_x + threadIdx.x];
133           row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
134         }
135       }
136       block_total = row_buf[2 * num_threads_x - 1];
137       block_idx_final = row_idx_buf[2 * num_threads_x - 1];
138       __syncthreads();
139     }
140   }
141 }
142 
143 /* Perform an inclusive scan along an outer dimension of a tensor.
144  *
145  * - num_orows is the size of the flattened outer dimensions;
146  * - num_irows is the size of the flattened inner dimensions;
147  * - row_size is the size of the dimension along which to compute the variance;
148  *
149  * The dimensions to the outside and inside of the specified dimension are considered as flattened.
150  * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
151  * outer dimensions, which contains several "inner rows").
152  * Each thread processes a single inner row at a time.
153  */
154 template<typename scalar_t, class BinaryFunction>
tensor_kernel_scan_outer_dim_with_indices(const scalar_t * self_,scalar_t * values_,int64_t * indices_,const uint32_t num_orows,const uint32_t num_irows,const uint32_t row_size,scalar_t init,BinaryFunction binary_op)155 __global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
156                   const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
157   for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
158     for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
159       const scalar_t *self = self_ + orow * row_size * num_irows + irow;
160       scalar_t *values = values_ + orow * row_size * num_irows + irow;
161       int64_t *indices = indices_ + orow * row_size * num_irows + irow;
162       scalar_t out = init;
163       int64_t out_idx = 0;
164 
165       for (auto col = decltype(row_size){0}; col < row_size; ++col) {
166         const auto val = c10::load(self);
167         if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
168           out = val;
169           out_idx = col;
170         }
171         *values = out;
172         *indices = out_idx;
173         self += num_irows;
174         values += num_irows;
175         indices += num_irows;
176       }
177     }
178   }
179 }
180 
check_fits_in_unsigned(int64_t val,const char * name)181 inline void check_fits_in_unsigned(int64_t val, const char* name) {
182   constexpr auto umax = std::numeric_limits<uint32_t>::max();
183   TORCH_CHECK(
184       val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
185 }
186 
187 
188 template<typename scalar_t, class BinaryFunction>
scan_outer_dim_with_indices(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int dim,scalar_t init,BinaryFunction binary_op)189 __host__ void scan_outer_dim_with_indices(
190     const TensorBase& self, const TensorBase& values, const TensorBase& indices,
191     int dim, scalar_t init, BinaryFunction binary_op) {
192   int64_t row_size = self.size(dim);
193   auto sizes = self.sizes();
194 
195   // Treat all outer dimensions (i.e. dim_ < dim) as one.
196   const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
197 
198   // Treat all inner dimensions (i.e. dim > dimension) as one.
199   const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
200   //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
201   //make sure that input is not bigger than supported by uint32_t
202   check_fits_in_unsigned(num_irows, "num_irows");
203   check_fits_in_unsigned(num_orows, "num_orows");
204   check_fits_in_unsigned(row_size, "row_size");
205 
206 
207   dim3 threads(std::min(512, int(num_irows)));
208   int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
209   dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
210   tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
211     self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
212     num_orows, num_irows, row_size, init, binary_op);
213   C10_CUDA_KERNEL_LAUNCH_CHECK();
214 }
215 
216 template <typename scalar_t, class BinaryFunction>
scan_innermost_dim_with_indices(const TensorBase & self,const TensorBase & values,const TensorBase & indices,scalar_t init,BinaryFunction binary_op)217 __host__ void scan_innermost_dim_with_indices(
218     const TensorBase& self, const TensorBase& values, const TensorBase& indices,
219     scalar_t init, BinaryFunction binary_op) {
220   int ndim = self.dim();
221   // Treat all outer dimensions as a single dimension.
222   int row_size = self.size(ndim - 1);
223   int num_rows = self.numel() / row_size;
224 
225   // assuming max_num_threads per block is 512
226   const uint32_t num_threads = 512;
227   const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
228   const uint32_t num_threads_x = (1 << log_num_threads_x);
229   const uint32_t num_threads_y = num_threads / num_threads_x;
230   dim3 threads(num_threads_x, num_threads_y);
231   dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
232 
233   const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t));
234   tensor_kernel_scan_innermost_dim_with_indices<scalar_t><<<grid, threads, mem_size,
235                                                             at::cuda::getCurrentCUDAStream()>>>(
236     self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
237     num_rows, row_size, num_threads, log_num_threads_x, init, binary_op);
238   C10_CUDA_KERNEL_LAUNCH_CHECK();
239 }
240 
241 template<typename scalar_t, typename BinaryFunction>
scan_dim_with_indices(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim,scalar_t init,BinaryFunction binary_op)242 void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
243      int64_t dim, scalar_t init, BinaryFunction binary_op) {
244   int ndim = self.dim();
245   auto self_ = self.expect_contiguous();
246   TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
247   if (dim == ndim - 1) {
248     scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op);
249   } else {
250     scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op);
251   }
252 }
253 
254 // TODO: The implementation of `tensor_kernel_scan_outer_dim` and
255 // `tensor_kernel_scan_innermost_dim` is similar to
256 // `tensor_kernel_scan_outer_dim_with_indices`
257 // `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
258 // remove the duplication.
259 
260 /* Perform an inclusive scan along an outer dimension of a tensor.
261  *
262  * - num_orows is the size of the flattened outer dimensions;
263  * - num_irows is the size of the flattened inner dimensions;
264  * - row_size is the size of the dimension along which to scan;
265  *
266  * The dimensions to the outside and inside of the specified dimension are considered as flattened.
267  * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
268  * outer dimensions, which contains several "inner rows").
269  * Each thread processes a single inner row at a time.
270  */
271 template<typename scalar_t, class BinaryOp>
tensor_kernel_scan_outer_dim(scalar_t * tgt_,const scalar_t * src_,const uint32_t num_orows,const uint32_t num_irows,const uint32_t row_size,const scalar_t init,BinaryOp binary_op)272 __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
273                                               const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
274                                               const scalar_t init, BinaryOp binary_op)
275 {
276   for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
277     for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
278       const scalar_t *src = src_ + orow * row_size * num_irows + irow;
279       scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
280       scalar_t acc = init;
281 
282       for (uint32_t col = 0; col < row_size; ++col) {
283         acc = binary_op(acc, c10::load(src));
284         *tgt = acc;
285 
286         src += num_irows;
287         tgt += num_irows;
288       }
289     }
290   }
291 }
292 
293 /* Perform an inclusive scan along the innermost dimension of a tensor.
294  *
295  * - num_rows is the size of the flattened outer dimensions;
296  * - row_size is the size of the innermost dimension;
297  *
298  * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
299  * considered as having 'num_rows' rows of size 'row_size'.
300  * Each thread block processes one or more sets of contiguous rows (processing multiple rows
301  * per thread block is quicker than processing a single row, especially for short rows).
302  */
303 template<typename T, class BinaryFunction>
tensor_kernel_scan_innermost_dim_impl(T * row_buf,T * tgt_,const T * src_,const uint32_t num_rows,const uint32_t row_size,const uint32_t log_num_threads_x,T init,BinaryFunction binary_op)304 __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_,
305                                       const uint32_t num_rows, const uint32_t row_size,
306                                       const uint32_t log_num_threads_x,
307                                       T init, BinaryFunction binary_op){
308   const uint32_t num_threads_x = 1 << log_num_threads_x;
309   for (uint32_t block_row = blockIdx.x * blockDim.y;
310        block_row < num_rows;
311        block_row += blockDim.y * gridDim.x) {
312     uint32_t row = block_row + threadIdx.y;
313     T block_total = init;
314 
315     const T *row_src = src_ + row * row_size;
316     T *row_tgt = tgt_ + row * row_size;
317     const bool row_exists = row < num_rows;
318 
319     // Perform scan on one block at a time, keeping track of the total value of
320     // all blocks processed so far.
321     for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
322       // Load data into shared memory (two values per thread).
323       uint32_t col1 = block_col + threadIdx.x;
324       uint32_t col2 = block_col + num_threads_x + threadIdx.x;
325       if (row_exists) {
326         if (col1 < row_size) {
327           row_buf[threadIdx.x] = row_src[col1];
328         } else {
329           row_buf[threadIdx.x] = init;
330         }
331 
332         if (col2 < row_size) {
333           row_buf[num_threads_x + threadIdx.x] = row_src[col2];
334         } else {
335           row_buf[num_threads_x + threadIdx.x] = init;
336         }
337 
338         // Add the total value of all previous blocks to the first value of this block.
339         if (threadIdx.x == 0) {
340           row_buf[0] = binary_op(row_buf[0], block_total);
341         }
342       }
343       __syncthreads();
344 
345       // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
346       // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
347       for (uint32_t m = 0; m <= log_num_threads_x; ++m) {
348         if (row_exists) {
349           uint32_t s = 1 << m; // s = 2 ^ m
350           uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s
351           uint32_t ti = a + (threadIdx.x % s);
352           uint32_t si = a - 1;
353           row_buf[ti] = binary_op(row_buf[ti], row_buf[si]);
354         }
355         __syncthreads();
356       }
357 
358       // Write back to output.
359       if (row_exists) {
360         if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
361         if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
362       }
363       block_total = row_buf[2 * num_threads_x - 1];
364       __syncthreads();
365     }
366   }
367 }
368 
369 template <
370     typename T,
371     class BinaryFunction>
tensor_kernel_scan_innermost_dim(T * tgt_,const T * src_,const uint32_t num_rows,const uint32_t row_size,const uint32_t log_num_threads_x,T init,BinaryFunction binary_op)372 __global__ void tensor_kernel_scan_innermost_dim(
373     T* tgt_,
374     const T* src_,
375     const uint32_t num_rows,
376     const uint32_t row_size,
377     const uint32_t log_num_threads_x,
378     T init,
379     BinaryFunction binary_op) {
380   alignas(sizeof(double)) extern __shared__ char sbuf[];
381   T* sbuf2 = reinterpret_cast<T*>(sbuf);
382   const uint32_t num_threads_x = 1 << log_num_threads_x;
383   T* row_buf = reinterpret_cast<T*>(sbuf2 + num_threads_x * 2 * threadIdx.y);
384 
385   tensor_kernel_scan_innermost_dim_impl<T>(
386       row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op);
387 }
388 
389 
390 template<typename scalar_t, class BinaryFunction>
scan_outer_dim(const TensorBase & self,const TensorBase & result,int dim,scalar_t init,BinaryFunction binary_op)391 __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
392                              int dim, scalar_t init, BinaryFunction binary_op) {
393   const int64_t row_size = self.size(dim);
394   auto sizes = self.sizes();
395 
396   // Treat all outer dimensions (i.e. dim_ < dim) as one.
397   const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
398 
399   // Treat all inner dimensions (i.e. dim > dimension) as one.
400   const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
401 
402   dim3 threads(std::min(512, int(num_irows)));
403   int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
404   dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
405 
406   check_fits_in_unsigned(num_irows, "num_irows");
407   check_fits_in_unsigned(num_orows, "num_orows");
408   check_fits_in_unsigned(row_size, "row_size");
409 
410   tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
411     result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
412     num_orows, num_irows, row_size, init, binary_op);
413   C10_CUDA_KERNEL_LAUNCH_CHECK();
414 }
415 
416 template <typename scalar_t, class BinaryFunction>
scan_innermost_dim(const TensorBase & self,const TensorBase & result,scalar_t init,BinaryFunction binary_op)417 void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
418                         scalar_t init, BinaryFunction binary_op) {
419   int64_t ndim = self.dim();
420   // Treat all outer dimensions as a single dimension.
421   int64_t row_size = self.size(ndim - 1);
422   int64_t num_rows = self.numel() / row_size;
423 
424   // assuming max_num_threads per block is 512
425   const uint32_t num_threads = 512;
426   const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
427   const uint32_t num_threads_x = (1 << log_num_threads_x);
428   const uint32_t num_threads_y = num_threads / num_threads_x;
429   dim3 threads(num_threads_x, num_threads_y);
430   int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
431   dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
432 
433   check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
434   check_fits_in_unsigned(row_size, "row_size");
435 
436   tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
437                                                at::cuda::getCurrentCUDAStream()>>>(
438     result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
439     num_rows, row_size, log_num_threads_x, init, binary_op);
440   C10_CUDA_KERNEL_LAUNCH_CHECK();
441 }
442 
443 template<typename scalar_t, typename BinaryFunction>
scan_dim(const TensorBase & self,const TensorBase & result,int64_t dim,scalar_t init,BinaryFunction binary_op)444 void scan_dim(const TensorBase& self, const TensorBase& result,
445      int64_t dim, scalar_t init, BinaryFunction binary_op) {
446   int ndim = self.dim();
447   auto self_ = self.expect_contiguous();
448   TORCH_INTERNAL_ASSERT(result.is_contiguous());
449 
450   if (self.numel() == self.size(dim)) {
451     cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
452   } else if (dim == ndim - 1) {
453     scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
454   } else {
455     scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
456   }
457 }
458 
459 }}  // namespace at::native
460