1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/TensorTopK.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <ATen/native/cuda/Sort.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #include <ATen/CUDAFunctions.h>
14 #else
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/sort_cuda_dispatch.h>
17 #include <ATen/ops/topk_native.h>
18 #endif
19
20 namespace at::native {
21
22 // TODO: remove this when CUDA <11.6 is no longer supported
topk_out_with_sort(const Tensor & self,int64_t k,int64_t dim,bool largest,const Tensor & values,const Tensor & indices)23 void topk_out_with_sort(
24 const Tensor& self,
25 int64_t k, int64_t dim, bool largest,
26 const Tensor& values,
27 const Tensor& indices
28 ) {
29 auto [sorted_values, sorted_indices] = at::cuda::sort(self, /* stable= */false, dim, largest);
30 values.copy_(sorted_values.narrow(dim, 0, k));
31 indices.copy_(sorted_indices.narrow(dim, 0, k));
32 }
33
34 // TODO: remove this when CUDA <11.6 is no longer supported
35 bool disable_sort_for_topk();
should_use_sort(const Tensor & self,int64_t dim)36 bool should_use_sort(const Tensor& self, int64_t dim) {
37 if (disable_sort_for_topk()) return false;
38 // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
39 if (self.dim() == 0) return false;
40 if (self.dtype() == kBool) return false; // Bool is not support by topk
41 int64_t slice_size = self.size(dim);
42 if (slice_size == 0) return false;
43 int64_t num_slices = self.numel() / slice_size;
44 return num_slices <= 10 && slice_size >= 100000;
45 }
46
TORCH_IMPL_FUNC(topk_out_cuda)47 TORCH_IMPL_FUNC(topk_out_cuda)
48 (const Tensor& self,
49 int64_t k, int64_t dim, bool largest, bool sorted,
50 const Tensor& values,
51 const Tensor& indices) {
52 TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
53 checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
54
55 dim = at::maybe_wrap_dim(dim, self);
56
57 if (should_use_sort(self, dim)) {
58 topk_out_with_sort(self, k, dim, largest, values, indices);
59 return;
60 }
61
62 // If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
63 if (k == 0) {
64 return;
65 }
66
67 launch_gather_topk_kernel(self, k, dim, largest, values, indices);
68
69 // Sort the results if the user wants them sorted, since our
70 // selection routine does not ensure sorting
71 if (sorted && values.numel() > 1) {
72 if (should_use_small_sort(values, dim)) {
73 // This avoids any memory allocations and performs all sorting
74 // work inplace along the slice
75
76 sortKeyValueInplace(values, indices, dim, largest);
77 } else {
78 // Depend upon the backup sort that returns indices, which we
79 // can use in conjunction with gather to produce the original
80 // indices.
81 // This is not the most efficient implementation, especially since
82 // there are memory allocations performed here. If the user desires
83 // greater performance, they should torch.gather() the results
84 // themselves using the reported indices, providing previously
85 // allocated tensors to receive the results.
86
87 Tensor sortedIndices = at::empty_like(indices);
88 Tensor sortedValues = at::empty_like(values);
89 at::cuda::sort_outf(values, /* stable= */ false, dim, largest, sortedValues, sortedIndices);
90 indices.copy_(indices.gather(dim, sortedIndices));
91 values.copy_(sortedValues);
92 }
93 }
94 }
95
96 } // namespace at::native
97