xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ScatterGatherChecks.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <vector>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/native/ReduceOpsUtils.h>
6 #include <c10/util/irange.h>
7 
8 namespace at::native {
9 
10 namespace {
11 
12 // checks whether index.dtype == int64
13 // and self.dtype == src.dtype if src is a Tensor
14 inline void scatter_gather_dtype_check(
15   const std::string& method_name,
16   const Tensor& self,
17   const Tensor& index,
18   const std::optional<Tensor>& src_opt = std::nullopt
19 ) {
20   if (index.numel() != 0) {
21     TORCH_CHECK(
22       index.scalar_type() == at::ScalarType::Long,
23       method_name, "(): Expected dtype int64 for index"
24     );
25   }
26 
27   if (src_opt.has_value()) {
28     const auto& src = src_opt.value();
29     TORCH_CHECK(
30       self.scalar_type() == src.scalar_type(),
31       method_name, "(): Expected self.dtype to be equal to src.dtype"
32     );
33   }
34 }
35 
36 // Used for `gather`-like methods
37 // Note: self means the input tensor here
38 // Test:
39 // 1. index.size(d) <= self.size(d) for all d != dim
40 // 2. index.dim() == self.dim()
gather_shape_check(const Tensor & self,int64_t dim,const Tensor & index)41 inline void gather_shape_check(const Tensor& self, int64_t dim,
42   const Tensor& index
43 ) {
44   auto self_dims = ensure_nonempty_dim(self.dim());
45   TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
46     "Index tensor must have the same number of dimensions as input tensor"
47   );
48 
49   for (const auto i : c10::irange(self_dims)) {
50     if (i != dim) {
51       TORCH_CHECK(
52         ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
53         "Size does not match at dimension ", i,
54         " expected index ", index.sizes(),
55         " to be smaller than self ", self.sizes(),
56         " apart from dimension ", dim
57       );
58     }
59   }
60 }
61 
62 // Used for `scatter` and `scatter_add`
63 // Tests:
64 //  1. index.size(d) <= self.size(d) for all d != dim
65 //  2. index.size(d) <= src.size(d) for all d if src is a Tensor
66 //  3. index.dim() == self.dim() == src.dim()
67 inline void scatter_shape_check(
68   const Tensor& self, int64_t dim, const Tensor& index,
69   const std::optional<Tensor>& src_opt = std::nullopt
70 ) {
71   if (index.numel() == 0) return;
72   TORCH_CHECK(
73     ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
74     "Index tensor must have the same number of dimensions as self tensor"
75   );
76 
77   bool is_wrong_shape = false;
78   int64_t self_dims = ensure_nonempty_dim(self.dim());
79 
80   //  Check: index.size(d) <= self.size(d) for all d != dim
81   for (const auto d : c10::irange(self_dims)) {
82     int64_t index_d_size = ensure_nonempty_size(index, d);
83     if (d == dim) continue;
84     if (index_d_size > ensure_nonempty_size(self, d)) {
85       is_wrong_shape = true;
86       break;
87     }
88   }
89 
90   //  Check: index.size(d) <= src.size(d) for all d if src is Tensor
91   if (!is_wrong_shape && src_opt.has_value()) {
92     const auto& src = src_opt.value();
93     for (const auto d : c10::irange(self_dims)) {
94       int64_t index_d_size = ensure_nonempty_size(index, d);
95       if (index_d_size > ensure_nonempty_size(src, d)) {
96         is_wrong_shape = true;
97         break;
98       }
99     }
100   }
101 
102   if (src_opt.has_value()) {
103     const auto& src = src_opt.value();
104 
105     TORCH_CHECK(
106       ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
107       "Index tensor must have the same number of dimensions as src tensor"
108     );
109 
110     TORCH_CHECK(!is_wrong_shape,
111       "Expected index ", index.sizes(),
112       " to be smaller than self ", self.sizes(),
113       " apart from dimension ", dim,
114       " and to be smaller size than src ", src.sizes()
115     );
116   }
117   else {
118     TORCH_CHECK(!is_wrong_shape,
119       "Expected index ", index.sizes(),
120       " to be smaller than self ", self.sizes(),
121       " apart from dimension ", dim
122     );
123   }
124 }
125 
126 } // anonymous namespace
127 
128 } // namespace at::native
129