1 #pragma once
2
3 #include <ATen/NumericUtils.h>
4 #include <ATen/native/Resize.h>
5 #include <c10/util/irange.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #endif
12
13 namespace at::native {
14
15 // ensure we get good values and indices for kthvalue, mode
16 // this will always be with the reducing dim as 1-d
_reduction_with_indices_allocate_or_resize_output(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim_,bool keepdim)17 inline void _reduction_with_indices_allocate_or_resize_output(
18 Tensor& values,
19 Tensor& indices,
20 const Tensor& self,
21 int64_t dim_,
22 bool keepdim) {
23 int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
24 auto result_sizes = self.sizes().vec();
25 if (!result_sizes.empty()) {
26 result_sizes[dim] = 1;
27 }
28 if (values.defined()) {
29 TORCH_CHECK(
30 self.options().type_equal(values.options()),
31 "output values must be of same type as input");
32 if (!keepdim && values.dim() == self.dim() - 1) {
33 // unsqueeze to preserve passed in noncontiguous tensor in resize
34 values.unsqueeze_(dim);
35 }
36 resize_output(values, result_sizes);
37 } else {
38 values = at::empty(result_sizes, self.options());
39 }
40 if (indices.defined()) {
41 TORCH_CHECK(
42 indices.dtype() == kLong, "output indices must be of scalar type Long");
43 TORCH_CHECK(
44 indices.device() == self.device(),
45 "output indices must be on same device as input");
46 if (!keepdim && indices.dim() == self.dim() - 1) {
47 // unsqueeze to preserve passed in noncontiguous tensor in resize
48 indices.unsqueeze_(dim);
49 }
50 resize_output(indices, result_sizes);
51 } else {
52 indices = at::empty(result_sizes, self.options().dtype(kLong));
53 }
54 }
55
56 // ensure we get good values and indices for topk
_allocate_or_resize_output_with_indices(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim_,int64_t k)57 inline void _allocate_or_resize_output_with_indices(
58 Tensor& values,
59 Tensor& indices,
60 const Tensor& self,
61 int64_t dim_,
62 int64_t k) {
63 int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
64 auto result_sizes = self.sizes().vec();
65 if (!result_sizes.empty()) {
66 result_sizes[dim] = k;
67 }
68 if (values.defined()) {
69 TORCH_CHECK(
70 self.options().type_equal(values.options()),
71 "output values must be of same type as input");
72 values.resize_(result_sizes);
73 } else {
74 values = at::empty(result_sizes, self.options());
75 }
76 if (indices.defined()) {
77 TORCH_CHECK(
78 indices.dtype() == kLong, "output indices must be of scalar type Long");
79 TORCH_CHECK(
80 indices.device() == self.device(),
81 "output indices must be on same device as input");
82 indices.resize_(result_sizes);
83 } else {
84 indices = at::empty(result_sizes, self.options().dtype(kLong));
85 }
86 }
87
88 } // namespace at::native
89