xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SortingUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/NumericUtils.h>
4 #include <ATen/native/Resize.h>
5 #include <c10/util/irange.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #endif
12 
13 namespace at::native {
14 
15 // ensure we get good values and indices for kthvalue, mode
16 // this will always be with the reducing dim as 1-d
_reduction_with_indices_allocate_or_resize_output(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim_,bool keepdim)17 inline void _reduction_with_indices_allocate_or_resize_output(
18     Tensor& values,
19     Tensor& indices,
20     const Tensor& self,
21     int64_t dim_,
22     bool keepdim) {
23   int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
24   auto result_sizes = self.sizes().vec();
25   if (!result_sizes.empty()) {
26     result_sizes[dim] = 1;
27   }
28   if (values.defined()) {
29     TORCH_CHECK(
30         self.options().type_equal(values.options()),
31         "output values must be of same type as input");
32     if (!keepdim && values.dim() == self.dim() - 1) {
33       // unsqueeze to preserve passed in noncontiguous tensor in resize
34       values.unsqueeze_(dim);
35     }
36     resize_output(values, result_sizes);
37   } else {
38     values = at::empty(result_sizes, self.options());
39   }
40   if (indices.defined()) {
41     TORCH_CHECK(
42         indices.dtype() == kLong, "output indices must be of scalar type Long");
43     TORCH_CHECK(
44         indices.device() == self.device(),
45         "output indices must be on same device as input");
46     if (!keepdim && indices.dim() == self.dim() - 1) {
47       // unsqueeze to preserve passed in noncontiguous tensor in resize
48       indices.unsqueeze_(dim);
49     }
50     resize_output(indices, result_sizes);
51   } else {
52     indices = at::empty(result_sizes, self.options().dtype(kLong));
53   }
54 }
55 
56 // ensure we get good values and indices for topk
_allocate_or_resize_output_with_indices(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim_,int64_t k)57 inline void _allocate_or_resize_output_with_indices(
58     Tensor& values,
59     Tensor& indices,
60     const Tensor& self,
61     int64_t dim_,
62     int64_t k) {
63   int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
64   auto result_sizes = self.sizes().vec();
65   if (!result_sizes.empty()) {
66     result_sizes[dim] = k;
67   }
68   if (values.defined()) {
69     TORCH_CHECK(
70         self.options().type_equal(values.options()),
71         "output values must be of same type as input");
72     values.resize_(result_sizes);
73   } else {
74     values = at::empty(result_sizes, self.options());
75   }
76   if (indices.defined()) {
77     TORCH_CHECK(
78         indices.dtype() == kLong, "output indices must be of scalar type Long");
79     TORCH_CHECK(
80         indices.device() == self.device(),
81         "output indices must be on same device as input");
82     indices.resize_(result_sizes);
83   } else {
84     indices = at::empty(result_sizes, self.options().dtype(kLong));
85   }
86 }
87 
88 } // namespace at::native
89