xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ScanKernels.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cstdint>
3 
4 namespace at {
5 class TensorBase;
6 
7 namespace native {
8 
9 // NOTE: these functions require output tensors to be contiguous
10 void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values,
11                                const TensorBase& indices, int64_t dim);
12 void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values,
13                                const TensorBase& indices, int64_t dim);
14 void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
15 void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
16 void launch_cumprod_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
17 
18 }}  // namespace at::native
19