xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/TensorTopK.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/TensorTopK.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <ATen/native/cuda/Sort.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #include <ATen/CUDAFunctions.h>
14 #else
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/sort_cuda_dispatch.h>
17 #include <ATen/ops/topk_native.h>
18 #endif
19 
20 namespace at::native {
21 
22 // TODO: remove this when CUDA <11.6 is no longer supported
topk_out_with_sort(const Tensor & self,int64_t k,int64_t dim,bool largest,const Tensor & values,const Tensor & indices)23 void topk_out_with_sort(
24   const Tensor& self,
25   int64_t k, int64_t dim, bool largest,
26   const Tensor& values,
27   const Tensor& indices
28 ) {
29   auto [sorted_values, sorted_indices] = at::cuda::sort(self, /* stable= */false, dim, largest);
30   values.copy_(sorted_values.narrow(dim, 0, k));
31   indices.copy_(sorted_indices.narrow(dim, 0, k));
32 }
33 
34 // TODO: remove this when CUDA <11.6 is no longer supported
35 bool disable_sort_for_topk();
should_use_sort(const Tensor & self,int64_t dim)36 bool should_use_sort(const Tensor& self, int64_t dim) {
37   if (disable_sort_for_topk()) return false;
38   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
39   if (self.dim() == 0) return false;
40   if (self.dtype() == kBool) return false; // Bool is not support by topk
41   int64_t slice_size = self.size(dim);
42   if (slice_size == 0) return false;
43   int64_t num_slices = self.numel() / slice_size;
44   return num_slices <= 10 && slice_size >= 100000;
45 }
46 
TORCH_IMPL_FUNC(topk_out_cuda)47 TORCH_IMPL_FUNC(topk_out_cuda)
48   (const Tensor& self,
49    int64_t k, int64_t dim, bool largest, bool sorted,
50    const Tensor& values,
51    const Tensor& indices) {
52   TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
53   checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
54 
55   dim = at::maybe_wrap_dim(dim, self);
56 
57   if (should_use_sort(self, dim)) {
58     topk_out_with_sort(self, k, dim, largest, values, indices);
59     return;
60   }
61 
62   // If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
63   if (k == 0) {
64     return;
65   }
66 
67   launch_gather_topk_kernel(self, k, dim, largest, values, indices);
68 
69   // Sort the results if the user wants them sorted, since our
70   // selection routine does not ensure sorting
71   if (sorted && values.numel() > 1) {
72     if (should_use_small_sort(values, dim)) {
73       // This avoids any memory allocations and performs all sorting
74       // work inplace along the slice
75 
76       sortKeyValueInplace(values, indices, dim, largest);
77     } else {
78       // Depend upon the backup sort that returns indices, which we
79       // can use in conjunction with gather to produce the original
80       // indices.
81       // This is not the most efficient implementation, especially since
82       // there are memory allocations performed here. If the user desires
83       // greater performance, they should torch.gather() the results
84       // themselves using the reported indices, providing previously
85       // allocated tensors to receive the results.
86 
87       Tensor sortedIndices = at::empty_like(indices);
88       Tensor sortedValues = at::empty_like(values);
89       at::cuda::sort_outf(values, /* stable= */ false, dim, largest, sortedValues, sortedIndices);
90       indices.copy_(indices.gather(dim, sortedIndices));
91       values.copy_(sortedValues);
92     }
93   }
94 }
95 
96 } // namespace at::native
97