1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/align_as_native.h>
10 #include <ATen/ops/align_tensors_native.h>
11 #include <ATen/ops/align_to_native.h>
12 #include <ATen/ops/gather_native.h>
13 #include <ATen/ops/index_add_native.h>
14 #include <ATen/ops/index_copy_native.h>
15 #include <ATen/ops/index_fill.h>
16 #include <ATen/ops/index_fill_native.h>
17 #include <ATen/ops/index_select_native.h>
18 #include <ATen/ops/refine_names_native.h>
19 #include <ATen/ops/rename_native.h>
20 #include <ATen/ops/scatter_add_native.h>
21 #include <ATen/ops/scatter_native.h>
22 #include <ATen/ops/sort_native.h>
23 #include <ATen/ops/squeeze.h>
24 #include <ATen/ops/squeeze_native.h>
25 #include <ATen/ops/zeros_like_ops.h>
26 #endif
27
28 #include <c10/util/irange.h>
29
30 #include <bitset>
31
32 namespace at::native {
33
rename_(Tensor & self,std::optional<DimnameList> names)34 Tensor& rename_(Tensor& self, std::optional<DimnameList> names) {
35 at::internal_set_names_inplace(self, names);
36 return self;
37 }
38
rename(const Tensor & self,std::optional<DimnameList> names)39 Tensor rename(const Tensor& self, std::optional<DimnameList> names) {
40 auto result = self.alias();
41 at::internal_set_names_inplace(result, names);
42 return result;
43 }
44
report_moving_unnamed_dim_error(DimnameList names,DimnameList other,bool is_aligning_two_tensors)45 static void report_moving_unnamed_dim_error(
46 DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
47 if (is_aligning_two_tensors) {
48 TORCH_CHECK(false,
49 "Aligning Tensor", names, " and Tensor", other,
50 " would change the absolute position from the right of an unnamed dimension. ",
51 "Please name unnamed dimensions to avoid ambiguity.");
52 } else {
53 TORCH_CHECK(false,
54 "Aligning Tensor", names, " to `names` ", other,
55 " would change the absolute position from the right of an unnamed dimension. ",
56 "Please name unnamed dimensions to avoid ambiguity.");
57 }
58 }
59
report_not_a_subsequence_error(DimnameList names,DimnameList other,bool is_aligning_two_tensors)60 static void report_not_a_subsequence_error(
61 DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
62 if (is_aligning_two_tensors) {
63 #ifndef STRIP_ERROR_MESSAGES
64 auto shorter = names.size() > other.size() ? other : names;
65 auto longer = names.size() > other.size() ? names : other;
66 #endif
67 TORCH_CHECK(false,
68 "Could not align Tensor", shorter, " and Tensor", longer,
69 " because ", shorter, " is not a subsequence of ", longer, ". ");
70 } else {
71 TORCH_CHECK(false,
72 "Could not align Tensor", names, " to `names` ", other,
73 " because ", names, " is not a subsequence of `names`.");
74 }
75 }
76
77
78 // Let tensor `t` have size `tensor_sizes` and `tensor_names`.
79 // This helper function computes the resulting size of `t` after aligning it
80 // to `aligned_names`. Enforces the alignment rules in Note [Alignment rules].
aligned_size(IntArrayRef tensor_sizes,DimnameList tensor_names,DimnameList aligned_names,bool is_aligning_two_tensors)81 static std::vector<int64_t> aligned_size(
82 IntArrayRef tensor_sizes,
83 DimnameList tensor_names,
84 DimnameList aligned_names,
85 bool is_aligning_two_tensors) {
86 std::vector<int64_t> expanded_sizes(aligned_names.size(), 1);
87 ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
88 ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
89 for (; idx >= 0 && dim >= 0; --idx) {
90 if (tensor_names[dim] != aligned_names[idx]) {
91 continue;
92 }
93 // We've found a None name in `shorter` and `longer`. If their absolute positions
94 // from the right are not equal, then aligning the two names would require
95 // changing the absolute position from right of one of the None names,
96 // violating condition 2 of our [Alignment rules].
97 //
98 // For example:
99 // *, c, a, b
100 // *, a
101 // [*, a] is a subsequence of [*, c, a, b], but in order to align them,
102 // we'd have to move the * to create [*, c: 1, a, b: 1]
103 if (tensor_names[dim].isWildcard() &&
104 tensor_sizes.size() - dim != aligned_names.size() - idx) {
105 report_moving_unnamed_dim_error(
106 tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
107 }
108 expanded_sizes[idx] = tensor_sizes[dim];
109 --dim;
110 }
111 if (dim != -1) {
112 report_not_a_subsequence_error(
113 tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
114 }
115
116 return expanded_sizes;
117 }
118
refine_names(const Tensor & self,DimnameList names)119 Tensor refine_names(const Tensor& self, DimnameList names) {
120 const auto self_names = self.names();
121 TORCH_CHECK(self_names.size() == names.size(),
122 "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
123 " because they have a different number of dims (",
124 self_names.size(), " and ", names.size(), " respectively).");
125 check_names_valid_for(self, names);
126
127 for (const auto idx : c10::irange(self_names.size())) {
128 const auto& self_name = self_names[idx];
129 const auto& out_name = names[idx];
130 if (self_name == out_name || self_name.isWildcard()) {
131 continue;
132 }
133 if (out_name.isWildcard()) {
134 TORCH_CHECK(false,
135 "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
136 " because ", self_name, " is more specific than ", out_name, " at index ",
137 idx);
138 }
139 TORCH_CHECK(false,
140 "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
141 " because ", self_name, " is different from ", out_name, " at index ",
142 idx);
143 TORCH_INTERNAL_ASSERT(false); // done handling errors
144 }
145
146 auto result = self.alias();
147 internal_set_names_inplace(result, names);
148 return result;
149 }
150
151 // [Alignment rules]
152 // Aligns `tensor` to names with the following rules:
153 // 1) Check that tensor.names is a subsequence (not necessarily contiguous) of `names`.
154 // 2) Aligning tensor.names to names must not change the absolute position from the
155 // right of any unnamed dimension.
156 //
157 // is_aligning_two_tensors tunes the error message to better match the following cases:
158 // 1) tensor.align_to(names) (is_aligning_two_tensors=false)
159 // 2) torch.align_tensors([tensor, other]) (is_aligning_two_tensors=true)
align(const Tensor & tensor,DimnameList names,bool is_aligning_two_tensors)160 static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_two_tensors) {
161 std::vector<int64_t> expanded_sizes = aligned_size(
162 tensor.sizes(),
163 tensor.names(),
164 names,
165 is_aligning_two_tensors);
166 auto result = tensor.rename(std::nullopt).view(expanded_sizes);
167 at::internal_set_names_inplace(result, names);
168 return result;
169 }
170
countUnset(std::bitset<kMaxNamedTensorDim> set,int64_t up_to_idx)171 static int64_t countUnset(std::bitset<kMaxNamedTensorDim> set, int64_t up_to_idx) {
172 int64_t result = 0;
173 for (const auto i : c10::irange(up_to_idx)) {
174 if (!set.test(i)) result++;
175 }
176 return result;
177 }
178
179 // Handles `tensor.align_to(*order)` in the case where there is an ellipsis.
180 //
181 // Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')`
182 // We expand the `...` to "all unmentioned dimensions, in the order which they
183 // appear in the original tensor."
184 //
185 // `order` is passed in **without** the ellipsis name. This is because ellipsis
186 // is not a valid name in cpp right now. Future work should be done on making
187 // ellipsis a valid name.
188 //
189 // `ellipsis_idx` is where the ellipsis occurs in the Python call.
190 // In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and
191 // ellipsis_idx = 1.
align_to(const Tensor & tensor,DimnameList order,int64_t ellipsis_idx)192 Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) {
193 const auto tensor_names = tensor.names();
194 const auto tensor_sizes = tensor.sizes();
195 const auto tensor_strides = tensor.strides();
196 const auto tensor_dim = tensor.sizes().size();
197 constexpr int64_t not_found = -1;
198
199 // General strategy.
200 //
201 // Step 1: We compute the following 3 things:
202 // 1. How many names the ellipsis should expand to
203 // 2. Which names in `tensor.names` are not mentioned in `order`.
204 // 3. Where names in `order` occur in tensor, if at all.
205 //
206 // Step 2: Compute the new sizes/strides/names.
207 // First, determine the ndim of the output tensor (this is not obvious)
208 // by counting the number of names in `tensor` that are not in `order`.
209 // Next, fill in output sizes/strides/names by using `order` and knowledge
210 // of which dimensions in `tensor` are unmentioned in `order`.
211
212 std::bitset<kMaxNamedTensorDim> order_has_tensor_name;
213
214 // tensor_idx_for[i] = j means that the ith name in `order`
215 // appears in the jth element of tensor.
216 std::vector<int64_t> tensor_idx_for(order.size(), not_found);
217
218 for (const auto order_idx : c10::irange(order.size())) {
219 const auto name = order[order_idx];
220 TORCH_CHECK(name.isBasic(),
221 "align_to: the desired order of dimensions cannot contain a None name, got ",
222 order);
223 auto it = std::find(tensor_names.begin(), tensor_names.end(), name);
224 if (it == tensor_names.end()) {
225 continue;
226 }
227 auto idx_in_tensor = std::distance(tensor_names.begin(), it);
228 tensor_idx_for[order_idx] = idx_in_tensor;
229 order_has_tensor_name.set(idx_in_tensor);
230 }
231
232 const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim);
233 const auto out_dim = num_ellipsis_names + order.size();
234
235 // Step 2: Now that we know the size of the output tensor, we can use the
236 // metadata obtained from Step 1 to fill in the new sizes/strides/names
237 std::vector<int64_t> new_sizes(out_dim, 1);
238 std::vector<int64_t> new_strides(out_dim, 0);
239 std::vector<Dimname> new_names(out_dim, Dimname::wildcard());
240
241 auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) {
242 new_sizes[out_dim] = tensor_sizes[tensor_dim];
243 new_strides[out_dim] = tensor_strides[tensor_dim];
244 new_names[out_dim] = tensor_names[tensor_dim];
245 };
246
247 // Fill in the non-ellipsis dimensions
248 for (const auto order_idx : c10::irange(static_cast<int64_t>(order.size()))) {
249 auto out_idx = order_idx;
250 if (order_idx >= ellipsis_idx) {
251 out_idx = order_idx + num_ellipsis_names;
252 }
253 const auto tensor_idx = tensor_idx_for[order_idx];
254 if (tensor_idx == not_found) {
255 // We are adding a new size-one dimension
256 new_names[out_idx] = order[order_idx];
257 continue;
258 }
259 setNewSizesStridesNamesFor(out_idx, tensor_idx);
260 }
261
262 // Fill in the ellipsis dimensions
263 for (const auto tensor_idx : c10::irange(tensor_dim)) {
264 if (order_has_tensor_name.test(tensor_idx)) {
265 continue;
266 }
267 setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx);
268 ellipsis_idx++;
269 }
270
271 check_names_valid_for(out_dim, new_names);
272
273 Tensor result;
274 {
275 NoNamesGuard guard;
276 result = tensor.as_strided(new_sizes, new_strides);
277 }
278 internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false);
279 return result;
280 }
281
align_to(const Tensor & tensor,DimnameList names)282 Tensor align_to(const Tensor& tensor, DimnameList names) {
283 auto tensor_names = tensor.names();
284 auto tensor_sizes = tensor.sizes();
285 auto tensor_strides = tensor.strides();
286 std::vector<int64_t> new_sizes(names.size(), 1);
287 std::vector<int64_t> new_strides(names.size(), 0);
288
289 for (const auto idx : c10::irange(tensor_names.size())) {
290 const auto& dim = tensor_names[idx];
291 TORCH_CHECK(dim.isBasic(),
292 "align_to: All input dims must be named. Found unnamed dim at index ",
293 idx, " of Tensor", tensor_names);
294 auto it = std::find(names.begin(), names.end(), dim);
295 TORCH_CHECK(it != names.end(),
296 "align_to: Cannot find dim ", dim, " from Tensor", names,
297 " in desired alignment ", names, ".");
298 int64_t new_idx = std::distance(names.begin(), it);
299 new_sizes[new_idx] = tensor_sizes[idx];
300 new_strides[new_idx] = tensor_strides[idx];
301 }
302 Tensor result;
303 {
304 NoNamesGuard guard;
305 result = tensor.as_strided(new_sizes, new_strides);
306 }
307 internal_set_names_inplace(result, names);
308 return result;
309 }
310
align_as(const Tensor & tensor,const Tensor & other)311 Tensor align_as(const Tensor& tensor, const Tensor& other) {
312 return native::align_to(tensor, other.names());
313 }
314
align_tensors_to(TensorList tensors,DimnameList names)315 static std::vector<Tensor> align_tensors_to(TensorList tensors, DimnameList names) {
316 std::vector<Tensor> result;
317 result.reserve(tensors.size());
318 for (const auto& tensor : tensors) {
319 result.emplace_back(align(tensor, names, /*is_aligning_two_tensors=*/true));
320 }
321 return result;
322 }
323
align_tensors(TensorList tensors)324 std::vector<Tensor> align_tensors(TensorList tensors) {
325 auto longest_dim = std::max_element(
326 tensors.begin(), tensors.end(),
327 [](const Tensor& a, const Tensor& b) {
328 return a.dim() < b.dim();
329 });
330 return align_tensors_to(tensors, longest_dim->names());
331 }
332
333 // Misc. Dimname overloads that don't have homes. Maybe we should move
334 // all of them here or autogenerate them because they look so similar.
gather(const Tensor & self,Dimname dim,const Tensor & index,bool sparse_grad)335 Tensor gather(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad) {
336 reportNYIDimnameOverload("gather");
337 }
gather_out(const Tensor & self,Dimname dim,const Tensor & index,bool sparse_grad,Tensor & result)338 Tensor& gather_out(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad, Tensor& result) {
339 reportNYIDimnameOverload("gather");
340 }
index_add(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source,const Scalar & alpha)341 Tensor index_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) {
342 reportNYIDimnameOverload("index_add");
343 }
index_fill(const Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)344 Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
345 return at::index_fill(self, dimname_to_position(self, dim), index, source);
346 }
index_fill_(Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)347 Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
348 return self.index_fill_(dimname_to_position(self, dim), index, source);
349 }
index_fill(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)350 Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
351 return at::index_fill(self, dimname_to_position(self, dim), index, source);
352 }
index_fill_(Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)353 Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
354 return self.index_fill_(dimname_to_position(self, dim), index, source);
355 }
index_copy(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)356 Tensor index_copy(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
357 reportNYIDimnameOverload("index_copy");
358 }
index_copy_(Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)359 Tensor& index_copy_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
360 reportNYIDimnameOverload("index_copy");
361 }
index_select_out(const Tensor & self,Dimname dim,const Tensor & index,Tensor & out)362 Tensor& index_select_out(const Tensor& self, Dimname dim, const Tensor& index, Tensor& out) {
363 reportNYIDimnameOverload("index_select");
364 }
index_select(const Tensor & self,Dimname dim,const Tensor & index)365 Tensor index_select(const Tensor& self, Dimname dim, const Tensor& index) {
366 reportNYIDimnameOverload("index_select");
367 }
scatter(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)368 Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
369 reportNYIDimnameOverload("scatter");
370 }
scatter(const Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)371 Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
372 reportNYIDimnameOverload("scatter");
373 }
scatter_add(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)374 Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
375 reportNYIDimnameOverload("scatter_add");
376 }
sort_out(const Tensor & self,std::optional<bool> stable,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)377 std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, std::optional<bool> stable, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
378 reportNYIDimnameOverload("sort");
379 }
sort_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)380 std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
381 reportNYIDimnameOverload("sort");
382 }
sort(const Tensor & self,std::optional<bool> stable,Dimname dim,bool keepdim)383 std::tuple<Tensor, Tensor> sort(const Tensor& self, std::optional<bool> stable, Dimname dim, bool keepdim) {
384 reportNYIDimnameOverload("sort");
385 }
sort(const Tensor & self,Dimname dim,bool keepdim)386 std::tuple<Tensor, Tensor> sort(const Tensor& self, Dimname dim, bool keepdim) {
387 reportNYIDimnameOverload("sort");
388 }
squeeze_(Tensor & self,Dimname dim)389 Tensor& squeeze_(Tensor& self, Dimname dim) {
390 reportNYIDimnameOverload("squeeze");
391 }
squeeze(const Tensor & self,Dimname dim)392 Tensor squeeze(const Tensor& self, Dimname dim) {
393 return at::squeeze(self, dimname_to_position(self, dim));
394 }
395
396
397 } // namespace at::native
398