xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/NamedTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/align_as_native.h>
10 #include <ATen/ops/align_tensors_native.h>
11 #include <ATen/ops/align_to_native.h>
12 #include <ATen/ops/gather_native.h>
13 #include <ATen/ops/index_add_native.h>
14 #include <ATen/ops/index_copy_native.h>
15 #include <ATen/ops/index_fill.h>
16 #include <ATen/ops/index_fill_native.h>
17 #include <ATen/ops/index_select_native.h>
18 #include <ATen/ops/refine_names_native.h>
19 #include <ATen/ops/rename_native.h>
20 #include <ATen/ops/scatter_add_native.h>
21 #include <ATen/ops/scatter_native.h>
22 #include <ATen/ops/sort_native.h>
23 #include <ATen/ops/squeeze.h>
24 #include <ATen/ops/squeeze_native.h>
25 #include <ATen/ops/zeros_like_ops.h>
26 #endif
27 
28 #include <c10/util/irange.h>
29 
30 #include <bitset>
31 
32 namespace at::native {
33 
rename_(Tensor & self,std::optional<DimnameList> names)34 Tensor& rename_(Tensor& self, std::optional<DimnameList> names) {
35   at::internal_set_names_inplace(self, names);
36   return self;
37 }
38 
rename(const Tensor & self,std::optional<DimnameList> names)39 Tensor rename(const Tensor& self, std::optional<DimnameList> names) {
40   auto result = self.alias();
41   at::internal_set_names_inplace(result, names);
42   return result;
43 }
44 
report_moving_unnamed_dim_error(DimnameList names,DimnameList other,bool is_aligning_two_tensors)45 static void report_moving_unnamed_dim_error(
46     DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
47   if (is_aligning_two_tensors) {
48     TORCH_CHECK(false,
49         "Aligning Tensor", names, " and Tensor", other,
50         " would change the absolute position from the right of an unnamed dimension. ",
51         "Please name unnamed dimensions to avoid ambiguity.");
52   } else {
53     TORCH_CHECK(false,
54         "Aligning Tensor", names, " to `names` ", other,
55         " would change the absolute position from the right of an unnamed dimension. ",
56         "Please name unnamed dimensions to avoid ambiguity.");
57   }
58 }
59 
report_not_a_subsequence_error(DimnameList names,DimnameList other,bool is_aligning_two_tensors)60 static void report_not_a_subsequence_error(
61     DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
62   if (is_aligning_two_tensors) {
63 #ifndef STRIP_ERROR_MESSAGES
64     auto shorter = names.size() > other.size() ? other : names;
65     auto longer = names.size() > other.size() ? names : other;
66 #endif
67     TORCH_CHECK(false,
68         "Could not align Tensor", shorter, " and Tensor", longer,
69         " because ", shorter, " is not a subsequence of ", longer, ". ");
70   } else {
71     TORCH_CHECK(false,
72         "Could not align Tensor", names, " to `names` ", other,
73         " because ", names, " is not a subsequence of `names`.");
74   }
75 }
76 
77 
78 // Let tensor `t` have size `tensor_sizes` and `tensor_names`.
79 // This helper function computes the resulting size of `t` after aligning it
80 // to `aligned_names`. Enforces the alignment rules in Note [Alignment rules].
aligned_size(IntArrayRef tensor_sizes,DimnameList tensor_names,DimnameList aligned_names,bool is_aligning_two_tensors)81 static std::vector<int64_t> aligned_size(
82     IntArrayRef tensor_sizes,
83     DimnameList tensor_names,
84     DimnameList aligned_names,
85     bool is_aligning_two_tensors) {
86   std::vector<int64_t> expanded_sizes(aligned_names.size(), 1);
87   ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
88   ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
89   for (; idx >= 0 && dim >= 0; --idx) {
90     if (tensor_names[dim] != aligned_names[idx]) {
91       continue;
92     }
93     // We've found a None name in `shorter` and `longer`. If their absolute positions
94     // from the right are not equal, then aligning the two names would require
95     // changing the absolute position from right of one of the None names,
96     // violating condition 2 of our [Alignment rules].
97     //
98     // For example:
99     // *, c, a, b
100     //       *, a
101     // [*, a] is a subsequence of [*, c, a, b], but in order to align them,
102     // we'd have to move the * to create [*, c: 1, a, b: 1]
103     if (tensor_names[dim].isWildcard() &&
104         tensor_sizes.size() - dim != aligned_names.size() - idx) {
105       report_moving_unnamed_dim_error(
106           tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
107     }
108     expanded_sizes[idx] = tensor_sizes[dim];
109     --dim;
110   }
111   if (dim != -1) {
112     report_not_a_subsequence_error(
113         tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
114   }
115 
116   return expanded_sizes;
117 }
118 
refine_names(const Tensor & self,DimnameList names)119 Tensor refine_names(const Tensor& self, DimnameList names) {
120   const auto self_names = self.names();
121   TORCH_CHECK(self_names.size() == names.size(),
122       "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
123       " because they have a different number of dims (",
124       self_names.size(), " and ", names.size(), " respectively).");
125   check_names_valid_for(self, names);
126 
127   for (const auto idx : c10::irange(self_names.size())) {
128     const auto& self_name = self_names[idx];
129     const auto& out_name = names[idx];
130     if (self_name == out_name || self_name.isWildcard()) {
131       continue;
132     }
133     if (out_name.isWildcard()) {
134       TORCH_CHECK(false,
135           "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
136           " because ", self_name, " is more specific than ", out_name, " at index ",
137           idx);
138     }
139     TORCH_CHECK(false,
140         "refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
141         " because ", self_name, " is different from ", out_name, " at index ",
142         idx);
143     TORCH_INTERNAL_ASSERT(false); // done handling errors
144   }
145 
146   auto result = self.alias();
147   internal_set_names_inplace(result, names);
148   return result;
149 }
150 
151 // [Alignment rules]
152 // Aligns `tensor` to names with the following rules:
153 // 1) Check that tensor.names is a subsequence (not necessarily contiguous) of `names`.
154 // 2) Aligning tensor.names to names must not change the absolute position from the
155 //    right of any unnamed dimension.
156 //
157 // is_aligning_two_tensors tunes the error message to better match the following cases:
158 // 1) tensor.align_to(names)  (is_aligning_two_tensors=false)
159 // 2) torch.align_tensors([tensor, other])  (is_aligning_two_tensors=true)
align(const Tensor & tensor,DimnameList names,bool is_aligning_two_tensors)160 static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_two_tensors) {
161   std::vector<int64_t> expanded_sizes = aligned_size(
162         tensor.sizes(),
163         tensor.names(),
164         names,
165         is_aligning_two_tensors);
166   auto result = tensor.rename(std::nullopt).view(expanded_sizes);
167   at::internal_set_names_inplace(result, names);
168   return result;
169 }
170 
countUnset(std::bitset<kMaxNamedTensorDim> set,int64_t up_to_idx)171 static int64_t countUnset(std::bitset<kMaxNamedTensorDim> set, int64_t up_to_idx) {
172   int64_t result = 0;
173   for (const auto i : c10::irange(up_to_idx)) {
174     if (!set.test(i)) result++;
175   }
176   return result;
177 }
178 
179 // Handles `tensor.align_to(*order)` in the case where there is an ellipsis.
180 //
181 // Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')`
182 // We expand the `...` to "all unmentioned dimensions, in the order which they
183 // appear in the original tensor."
184 //
185 // `order` is passed in **without** the ellipsis name. This is because ellipsis
186 // is not a valid name in cpp right now. Future work should be done on making
187 // ellipsis a valid name.
188 //
189 // `ellipsis_idx` is where the ellipsis occurs in the Python call.
190 // In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and
191 // ellipsis_idx = 1.
align_to(const Tensor & tensor,DimnameList order,int64_t ellipsis_idx)192 Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) {
193   const auto tensor_names = tensor.names();
194   const auto tensor_sizes = tensor.sizes();
195   const auto tensor_strides = tensor.strides();
196   const auto tensor_dim = tensor.sizes().size();
197   constexpr int64_t not_found = -1;
198 
199   // General strategy.
200   //
201   // Step 1: We compute the following 3 things:
202   // 1. How many names the ellipsis should expand to
203   // 2. Which names in `tensor.names` are not mentioned in `order`.
204   // 3. Where names in `order` occur in tensor, if at all.
205   //
206   // Step 2: Compute the new sizes/strides/names.
207   // First, determine the ndim of the output tensor (this is not obvious)
208   // by counting the number of names in `tensor` that are not in `order`.
209   // Next, fill in output sizes/strides/names by using `order` and knowledge
210   // of which dimensions in `tensor` are unmentioned in `order`.
211 
212   std::bitset<kMaxNamedTensorDim> order_has_tensor_name;
213 
214   // tensor_idx_for[i] = j means that the ith name in `order`
215   // appears in the jth element of tensor.
216   std::vector<int64_t> tensor_idx_for(order.size(), not_found);
217 
218   for (const auto order_idx : c10::irange(order.size())) {
219     const auto name = order[order_idx];
220     TORCH_CHECK(name.isBasic(),
221         "align_to: the desired order of dimensions cannot contain a None name, got ",
222         order);
223     auto it = std::find(tensor_names.begin(), tensor_names.end(), name);
224     if (it == tensor_names.end()) {
225       continue;
226     }
227     auto idx_in_tensor = std::distance(tensor_names.begin(), it);
228     tensor_idx_for[order_idx] = idx_in_tensor;
229     order_has_tensor_name.set(idx_in_tensor);
230   }
231 
232   const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim);
233   const auto out_dim = num_ellipsis_names + order.size();
234 
235   // Step 2: Now that we know the size of the output tensor, we can use the
236   // metadata obtained from Step 1 to fill in the new sizes/strides/names
237   std::vector<int64_t> new_sizes(out_dim, 1);
238   std::vector<int64_t> new_strides(out_dim, 0);
239   std::vector<Dimname> new_names(out_dim, Dimname::wildcard());
240 
241   auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) {
242     new_sizes[out_dim] = tensor_sizes[tensor_dim];
243     new_strides[out_dim] = tensor_strides[tensor_dim];
244     new_names[out_dim] = tensor_names[tensor_dim];
245   };
246 
247   // Fill in the non-ellipsis dimensions
248   for (const auto order_idx : c10::irange(static_cast<int64_t>(order.size()))) {
249     auto out_idx = order_idx;
250     if (order_idx >= ellipsis_idx) {
251       out_idx = order_idx + num_ellipsis_names;
252     }
253     const auto tensor_idx = tensor_idx_for[order_idx];
254     if (tensor_idx == not_found) {
255       // We are adding a new size-one dimension
256       new_names[out_idx] = order[order_idx];
257       continue;
258     }
259     setNewSizesStridesNamesFor(out_idx, tensor_idx);
260   }
261 
262   // Fill in the ellipsis dimensions
263   for (const auto tensor_idx : c10::irange(tensor_dim)) {
264     if (order_has_tensor_name.test(tensor_idx)) {
265       continue;
266     }
267     setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx);
268     ellipsis_idx++;
269   }
270 
271   check_names_valid_for(out_dim, new_names);
272 
273   Tensor result;
274   {
275     NoNamesGuard guard;
276     result = tensor.as_strided(new_sizes, new_strides);
277   }
278   internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false);
279   return result;
280 }
281 
align_to(const Tensor & tensor,DimnameList names)282 Tensor align_to(const Tensor& tensor, DimnameList names) {
283   auto tensor_names = tensor.names();
284   auto tensor_sizes = tensor.sizes();
285   auto tensor_strides = tensor.strides();
286   std::vector<int64_t> new_sizes(names.size(), 1);
287   std::vector<int64_t> new_strides(names.size(), 0);
288 
289   for (const auto idx : c10::irange(tensor_names.size())) {
290     const auto& dim = tensor_names[idx];
291     TORCH_CHECK(dim.isBasic(),
292         "align_to: All input dims must be named. Found unnamed dim at index ",
293         idx, " of Tensor", tensor_names);
294     auto it = std::find(names.begin(), names.end(), dim);
295     TORCH_CHECK(it != names.end(),
296         "align_to: Cannot find dim ", dim, " from Tensor", names,
297         " in desired alignment ", names, ".");
298     int64_t new_idx = std::distance(names.begin(), it);
299     new_sizes[new_idx] = tensor_sizes[idx];
300     new_strides[new_idx] = tensor_strides[idx];
301   }
302   Tensor result;
303   {
304     NoNamesGuard guard;
305     result = tensor.as_strided(new_sizes, new_strides);
306   }
307   internal_set_names_inplace(result, names);
308   return result;
309 }
310 
align_as(const Tensor & tensor,const Tensor & other)311 Tensor align_as(const Tensor& tensor, const Tensor& other) {
312   return native::align_to(tensor, other.names());
313 }
314 
align_tensors_to(TensorList tensors,DimnameList names)315 static std::vector<Tensor> align_tensors_to(TensorList tensors, DimnameList names) {
316   std::vector<Tensor> result;
317   result.reserve(tensors.size());
318   for (const auto& tensor : tensors) {
319     result.emplace_back(align(tensor, names, /*is_aligning_two_tensors=*/true));
320   }
321   return result;
322 }
323 
align_tensors(TensorList tensors)324 std::vector<Tensor> align_tensors(TensorList tensors) {
325   auto longest_dim = std::max_element(
326       tensors.begin(), tensors.end(),
327       [](const Tensor& a, const Tensor& b) {
328         return a.dim() < b.dim();
329       });
330   return align_tensors_to(tensors, longest_dim->names());
331 }
332 
333 // Misc. Dimname overloads that don't have homes. Maybe we should move
334 // all of them here or autogenerate them because they look so similar.
gather(const Tensor & self,Dimname dim,const Tensor & index,bool sparse_grad)335 Tensor gather(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad) {
336   reportNYIDimnameOverload("gather");
337 }
gather_out(const Tensor & self,Dimname dim,const Tensor & index,bool sparse_grad,Tensor & result)338 Tensor& gather_out(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad, Tensor& result) {
339   reportNYIDimnameOverload("gather");
340 }
index_add(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source,const Scalar & alpha)341 Tensor index_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) {
342   reportNYIDimnameOverload("index_add");
343 }
index_fill(const Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)344 Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
345   return at::index_fill(self, dimname_to_position(self, dim), index, source);
346 }
index_fill_(Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)347 Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
348   return self.index_fill_(dimname_to_position(self, dim), index, source);
349 }
index_fill(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)350 Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
351   return at::index_fill(self, dimname_to_position(self, dim), index, source);
352 }
index_fill_(Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)353 Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
354   return self.index_fill_(dimname_to_position(self, dim), index, source);
355 }
index_copy(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)356 Tensor index_copy(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
357   reportNYIDimnameOverload("index_copy");
358 }
index_copy_(Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)359 Tensor& index_copy_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
360   reportNYIDimnameOverload("index_copy");
361 }
index_select_out(const Tensor & self,Dimname dim,const Tensor & index,Tensor & out)362 Tensor& index_select_out(const Tensor& self, Dimname dim, const Tensor& index, Tensor& out) {
363   reportNYIDimnameOverload("index_select");
364 }
index_select(const Tensor & self,Dimname dim,const Tensor & index)365 Tensor index_select(const Tensor& self, Dimname dim, const Tensor& index) {
366   reportNYIDimnameOverload("index_select");
367 }
scatter(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)368 Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
369   reportNYIDimnameOverload("scatter");
370 }
scatter(const Tensor & self,Dimname dim,const Tensor & index,const Scalar & source)371 Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
372   reportNYIDimnameOverload("scatter");
373 }
scatter_add(const Tensor & self,Dimname dim,const Tensor & index,const Tensor & source)374 Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
375   reportNYIDimnameOverload("scatter_add");
376 }
sort_out(const Tensor & self,std::optional<bool> stable,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)377 std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, std::optional<bool> stable, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
378   reportNYIDimnameOverload("sort");
379 }
sort_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)380 std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
381   reportNYIDimnameOverload("sort");
382 }
sort(const Tensor & self,std::optional<bool> stable,Dimname dim,bool keepdim)383 std::tuple<Tensor, Tensor> sort(const Tensor& self, std::optional<bool> stable, Dimname dim, bool keepdim) {
384   reportNYIDimnameOverload("sort");
385 }
sort(const Tensor & self,Dimname dim,bool keepdim)386 std::tuple<Tensor, Tensor> sort(const Tensor& self, Dimname dim, bool keepdim) {
387   reportNYIDimnameOverload("sort");
388 }
squeeze_(Tensor & self,Dimname dim)389 Tensor& squeeze_(Tensor& self, Dimname dim) {
390   reportNYIDimnameOverload("squeeze");
391 }
squeeze(const Tensor & self,Dimname dim)392 Tensor squeeze(const Tensor& self, Dimname dim) {
393   return at::squeeze(self, dimname_to_position(self, dim));
394 }
395 
396 
397 }  // namespace at::native
398