1 #pragma once
2
3 #include <vector>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/native/ReduceOpsUtils.h>
6 #include <c10/util/irange.h>
7
8 namespace at::native {
9
10 namespace {
11
12 // checks whether index.dtype == int64
13 // and self.dtype == src.dtype if src is a Tensor
14 inline void scatter_gather_dtype_check(
15 const std::string& method_name,
16 const Tensor& self,
17 const Tensor& index,
18 const std::optional<Tensor>& src_opt = std::nullopt
19 ) {
20 if (index.numel() != 0) {
21 TORCH_CHECK(
22 index.scalar_type() == at::ScalarType::Long,
23 method_name, "(): Expected dtype int64 for index"
24 );
25 }
26
27 if (src_opt.has_value()) {
28 const auto& src = src_opt.value();
29 TORCH_CHECK(
30 self.scalar_type() == src.scalar_type(),
31 method_name, "(): Expected self.dtype to be equal to src.dtype"
32 );
33 }
34 }
35
36 // Used for `gather`-like methods
37 // Note: self means the input tensor here
38 // Test:
39 // 1. index.size(d) <= self.size(d) for all d != dim
40 // 2. index.dim() == self.dim()
gather_shape_check(const Tensor & self,int64_t dim,const Tensor & index)41 inline void gather_shape_check(const Tensor& self, int64_t dim,
42 const Tensor& index
43 ) {
44 auto self_dims = ensure_nonempty_dim(self.dim());
45 TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
46 "Index tensor must have the same number of dimensions as input tensor"
47 );
48
49 for (const auto i : c10::irange(self_dims)) {
50 if (i != dim) {
51 TORCH_CHECK(
52 ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
53 "Size does not match at dimension ", i,
54 " expected index ", index.sizes(),
55 " to be smaller than self ", self.sizes(),
56 " apart from dimension ", dim
57 );
58 }
59 }
60 }
61
62 // Used for `scatter` and `scatter_add`
63 // Tests:
64 // 1. index.size(d) <= self.size(d) for all d != dim
65 // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
66 // 3. index.dim() == self.dim() == src.dim()
67 inline void scatter_shape_check(
68 const Tensor& self, int64_t dim, const Tensor& index,
69 const std::optional<Tensor>& src_opt = std::nullopt
70 ) {
71 if (index.numel() == 0) return;
72 TORCH_CHECK(
73 ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
74 "Index tensor must have the same number of dimensions as self tensor"
75 );
76
77 bool is_wrong_shape = false;
78 int64_t self_dims = ensure_nonempty_dim(self.dim());
79
80 // Check: index.size(d) <= self.size(d) for all d != dim
81 for (const auto d : c10::irange(self_dims)) {
82 int64_t index_d_size = ensure_nonempty_size(index, d);
83 if (d == dim) continue;
84 if (index_d_size > ensure_nonempty_size(self, d)) {
85 is_wrong_shape = true;
86 break;
87 }
88 }
89
90 // Check: index.size(d) <= src.size(d) for all d if src is Tensor
91 if (!is_wrong_shape && src_opt.has_value()) {
92 const auto& src = src_opt.value();
93 for (const auto d : c10::irange(self_dims)) {
94 int64_t index_d_size = ensure_nonempty_size(index, d);
95 if (index_d_size > ensure_nonempty_size(src, d)) {
96 is_wrong_shape = true;
97 break;
98 }
99 }
100 }
101
102 if (src_opt.has_value()) {
103 const auto& src = src_opt.value();
104
105 TORCH_CHECK(
106 ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
107 "Index tensor must have the same number of dimensions as src tensor"
108 );
109
110 TORCH_CHECK(!is_wrong_shape,
111 "Expected index ", index.sizes(),
112 " to be smaller than self ", self.sizes(),
113 " apart from dimension ", dim,
114 " and to be smaller size than src ", src.sizes()
115 );
116 }
117 else {
118 TORCH_CHECK(!is_wrong_shape,
119 "Expected index ", index.sizes(),
120 " to be smaller than self ", self.sizes(),
121 " apart from dimension ", dim
122 );
123 }
124 }
125
126 } // anonymous namespace
127
128 } // namespace at::native
129