xref: /aosp_15_r20/external/pytorch/aten/src/ATen/NamedTensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/NamedTensorUtils.h>
2 #include <ATen/TensorNames.h>
3 #include <ATen/WrapDimUtilsMulti.h>
4 #include <c10/util/irange.h>
5 
6 #include <bitset>
7 #include <sstream>
8 
9 namespace at {
10 
11 #ifndef STRIP_ERROR_MESSAGES
12 // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
toDimnameRepr(const Tensor & tensor)13 static std::string toDimnameRepr(const Tensor& tensor) {
14   std::ostringstream os;
15   os << "Tensor" << tensor.names();
16   return os.str();
17 }
18 #endif
19 
dimname_to_position(const Tensor & tensor,Dimname dim)20 int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
21   TORCH_CHECK(dim.type() != NameType::WILDCARD,
22       "Please look up dimensions by name, got: name = None.");
23   TORCH_CHECK(tensor.has_names(),
24       "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
25   const auto names = tensor.names();
26 
27   const auto it = std::find(names.begin(), names.end(), dim);
28   TORCH_CHECK(it != names.end(),
29       "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
30 
31   return std::distance(names.begin(), it);
32 }
33 
dimnames_to_positions(const Tensor & tensor,DimnameList dims)34 std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims) {
35   std::vector<int64_t> result;
36   result.reserve(dims.size());
37   for (const auto& name : dims) {
38     result.push_back(dimname_to_position(tensor, name));
39   }
40   return result;
41 }
42 
report_positional_error(const Dimname & name,const Dimname & other_name,DimnameList names,DimnameList other_names,const char * action)43 static void report_positional_error(
44     const Dimname& name,
45     const Dimname& other_name,
46     DimnameList names,
47     DimnameList other_names,
48     const char* action) {
49   // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
50   TORCH_CHECK(false,
51       "Error when attempting to ", action, " dims ", names, " and dims ",
52       other_names, ": dim ", name, " and dim ", other_name, " are at the same position "
53       "from the right but do not match.")
54 }
55 
check_for_misalignment(const Dimname & name,DimnameList names,DimnameList other_names,const char * action)56 static void check_for_misalignment(
57     const Dimname& name,
58     DimnameList names,
59     DimnameList other_names,
60     const char* action) {
61   if (name.isWildcard()) {
62     return;
63   }
64   auto it = std::find(other_names.begin(), other_names.end(), name);
65   // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
66   TORCH_CHECK(it == other_names.end(),
67       "Misaligned dims when attempting to ", action, " dims ", names, " and dims ",
68       other_names, ": dim ", name, " appears in a different position from the right "
69       "across both lists.");
70 }
71 
72 // Assumption: A DimnameList can have no duplicate full names with
73 // the exception of wildcards
unify_from_right(DimnameList names,DimnameList other_names,const char * action)74 std::vector<Dimname> unify_from_right(
75     DimnameList names,
76     DimnameList other_names,
77     const char* action) {
78   const auto wildcard = Dimname::wildcard();
79   const auto size = std::max(names.size(), other_names.size());
80   auto result = std::vector<Dimname>(size, wildcard);
81 
82   auto names_it = names.rbegin();
83   auto other_it = other_names.rbegin();
84   auto result_it = result.rbegin();
85   while (names_it != names.rend() || other_it != other_names.rend()) {
86     const auto& name = names_it == names.rend() ? wildcard : *names_it;
87     const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it;
88 
89     // Step 1: Check that the names match
90     const auto maybeName = name.unify(other_name);
91     if (!maybeName) {
92       report_positional_error(name, other_name, names, other_names, action);
93     }
94     *result_it = *maybeName;
95 
96     // Step 2: Check that the names are not misaligned
97     if (!name.isBasic() || !other_name.isBasic()) {
98       // Let: N = max(len(names), len(other_names))
99       //      K = # of special names among names and other_names.
100       // This search (including the outer loop) is O(N*K) but typically # of dims is small.
101       check_for_misalignment(name, names, other_names, action);
102       check_for_misalignment(other_name, other_names, names, action);
103     }
104 
105     if (names_it != names.rend()) {
106       ++names_it;
107     }
108     if (other_it != other_names.rend()) {
109       ++other_it;
110     }
111     ++result_it;
112   }
113   return result;
114 }
115 
116 namespace namedinference {
117 
118 static std::bitset<dim_bitset_size>
compute_included_idxs(IntArrayRef excluded_idxs,int64_t ndims)119 compute_included_idxs(IntArrayRef excluded_idxs, int64_t ndims) {
120   auto result = dim_list_to_bitset(excluded_idxs, ndims);
121   result.flip();
122   return result;
123 }
124 
assert_names_equal(DimnameList a,DimnameList b)125 static void assert_names_equal(DimnameList a, DimnameList b) {
126   TORCH_CHECK(a == b,
127       "Name mismatch: specified out tensor with names ", a,
128       " are not the same as the computed output names ", b,
129       ". Please rename the out tensor's dims with `Tensor.rename`.");
130 }
131 
propagate_names_if_present_and_nonempty(const Tensor & result,std::optional<DimnameList> maybe_names,bool validate_names)132 const Tensor& propagate_names_if_present_and_nonempty(const Tensor& result,
133     std::optional<DimnameList> maybe_names,
134     bool validate_names) {
135   auto maybe_name_list = maybe_names.value_or(at::ArrayRef<Dimname>{});
136   propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_name_list, validate_names);
137   return result;
138 }
139 
propagate_names_if_nonempty(const Tensor & result,DimnameList maybe_names,bool validate_names)140 const Tensor& propagate_names_if_nonempty(const Tensor& result,
141     DimnameList maybe_names,
142     bool validate_names) {
143   propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
144   return result;
145 }
146 
propagate_names_if_nonempty(TensorImpl * result,DimnameList maybe_names,bool validate_names)147 TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
148     DimnameList maybe_names,
149     bool validate_names) {
150   if (maybe_names.empty()) {
151     return result;
152   }
153   return propagate_names(result, maybe_names, validate_names);
154 }
155 
propagate_names(const Tensor & result,DimnameList names,bool validate_names)156 const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) {
157   propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
158   return result;
159 }
160 
propagate_names(TensorImpl * result,DimnameList names,bool validate_names)161 TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) {
162   if (result->dim() > 0) {
163     TORCH_INTERNAL_ASSERT(
164         !names.empty(),
165         "propagate_names: passed in empty names to propagate to result with",
166         " shape ", result->sizes(), ". Empty names means that name inference did",
167         "not occur; use `propagate_names_if_nonempty` instead of `propagate_names`.");
168   }
169   if (!impl::has_names(result)) {
170     impl::internal_set_names_inplace(result, names, validate_names);
171   } else {
172     assert_names_equal(impl::get_names(result), names);
173   }
174   return result;
175 }
176 
propagate_names_except(const Tensor & result,const Tensor & src,IntArrayRef excluded_idxs)177 void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
178   if (!result.has_names() && !src.has_names()) {
179     return;
180   }
181   const auto src_names = src.names();
182   const auto result_dim = static_cast<int64_t>(result.dim());
183   const auto src_dim = static_cast<int64_t>(src_names.size());
184   const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size());
185   TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim);
186 
187   // fast path
188   if (excluded_idxs.size() == 1) {
189     std::vector<Dimname> outnames = src_names.vec();
190     outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim));
191     propagate_names(result, outnames);
192     return;
193   }
194 
195   std::vector<Dimname> outnames;
196   outnames.reserve(result_dim);
197   auto included_idxs = compute_included_idxs(excluded_idxs, src_dim);
198   for (const auto dim : c10::irange(src_dim)) {
199     if (included_idxs[dim]) {
200       outnames.push_back(src_names[dim]);
201     }
202   }
203   propagate_names(result, outnames);
204 }
205 
propagate_names_for_reduction(const Tensor & result,const Tensor & src,IntArrayRef reduced_dims,bool keepdim)206 void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
207   if (keepdim) {
208     propagate_names(result, src);
209     return;
210   }
211   // This actually means "full reduction"
212   if (reduced_dims.empty()) {
213     return;
214   }
215   propagate_names_except(result, src, reduced_dims);
216 }
217 
propagate_names(const Tensor & result,const Tensor & src)218 void propagate_names(const Tensor& result, const Tensor& src) {
219   propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
220 }
221 
propagate_names(TensorImpl * result,TensorImpl * src)222 void propagate_names(TensorImpl* result, TensorImpl* src) {
223   if (result == src) {
224     return;
225   }
226   if (!impl::has_names(result) && !impl::has_names(src)) {
227     return;
228   }
229   propagate_names(result, impl::get_names(src));
230 }
231 
compute_squeeze_outnames(const Tensor & tensor)232 std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
233   if (!tensor.has_names()) {
234     return {};
235   }
236   std::vector<Dimname> outnames;
237   auto tensor_names = tensor.names();
238   for (const auto d : c10::irange(tensor.dim())) {
239     if (tensor.sym_sizes()[d] != 1) {
240       outnames.push_back(tensor_names[d]);
241     }
242   }
243   return outnames;
244 }
245 
compute_squeeze_outnames(const Tensor & tensor,std::bitset<dim_bitset_size> dims)246 std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
247   if (!tensor.has_names()) {
248     return {};
249   }
250   std::vector<Dimname> outnames;
251   auto tensor_names = tensor.names();
252   for (const auto d : c10::irange(tensor.dim())) {
253     if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
254       outnames.push_back(tensor_names[d]);
255     }
256   }
257   return outnames;
258 }
259 
compute_diagonal_outnames(const Tensor & tensor,int64_t dim1,int64_t dim2)260 std::vector<Dimname> compute_diagonal_outnames(
261     const Tensor& tensor,
262     int64_t dim1,
263     int64_t dim2) {
264   if (!tensor.has_names()) {
265     return {};
266   }
267   std::vector<Dimname> outnames;
268   auto tensor_names = tensor.names();
269   for (const auto d : c10::irange(tensor.dim())) {
270     if (d == dim1 || d == dim2) {
271       continue;
272     }
273     outnames.push_back(tensor_names[d]);
274   }
275   outnames.push_back(Dimname::wildcard());
276   return outnames;
277 }
278 
check_feature_names_are_distinct(DimnameList self_names,DimnameList other_names,const DimnameList & outnames)279 static void check_feature_names_are_distinct(
280     DimnameList self_names,
281     DimnameList other_names,
282     const DimnameList& outnames) {
283   if (self_names.size() < 2 || other_names.size() < 2) {
284     // There are less than 2 feature dims in outnames so there is nothing to check
285     return;
286   }
287   auto feature0 = outnames[outnames.size() - 2];
288   auto feature1 = outnames[outnames.size() - 1];
289   TORCH_CHECK(
290     feature0 == Dimname::wildcard() || feature0 != feature1,
291     "Matrix multiplying Tensor", self_names,
292     " with Tensor", other_names,
293     " would produce output tensor with duplicate names ",
294     outnames,
295     ". Please rename the input tensors with `Tensor.rename` to prevent this.");
296 }
297 
num_batch_dims(DimnameList names)298 static int64_t num_batch_dims(DimnameList names) {
299   if (names.size() <= 2) {
300     return 0;
301   }
302   return static_cast<int64_t>(names.size() - 2);
303 }
304 
compute_matmul_outnames(DimnameList self_names,DimnameList other_names)305 static std::vector<Dimname> compute_matmul_outnames(
306     DimnameList self_names,
307     DimnameList other_names) {
308   TORCH_CHECK(!self_names.empty() && !other_names.empty(),
309       "both arguments to matmul need to be at least 1D, but they are ",
310       self_names.size(), "D and ", other_names.size(), "D");
311 
312   // matmul performs a batch matrix multiply between self and other, each of which
313   // can either be:
314   // - a batches of matrices (if dim > 2)
315   // - a matrix (if dim == 2)
316   // - a vector (if dim == 1)
317   //
318   // To compute output names, we unify the batch dimensions because those are
319   // broadcastable to get the output batch dimensions.
320   //
321   // After that, we append some names that are equal to the result of the matmul
322   // without batch dimensions. Those names are computed by removing the names
323   // of the dimensions that were contracted away. We always contract the
324   // last dim of the first tensor with the first feature dimension of the second.
325 
326   // Get the output's batch dimension names
327   auto wrapped_self_names = TensorNames(self_names, 0, num_batch_dims(self_names));
328   const auto wrapped_other_names = TensorNames(other_names, 0, num_batch_dims(other_names));
329   auto& working_names = wrapped_self_names.unifyFromRightInplace(wrapped_other_names, "matmul");
330 
331   // Append the result of each individual (non-batched) matmul.
332   // If either of self or other have dim 1, that means they are a vector. Vectors get
333   // completely contracted away during matmul so we don't take any names from them.
334   if (self_names.size() >= 2) {
335     working_names.append(TensorName(self_names, -2));
336   }
337   if (other_names.size() >= 2) {
338     working_names.append(TensorName(other_names, -1));
339   }
340   auto result = working_names.toDimnameVec();
341 
342   check_feature_names_are_distinct(self_names, other_names, result);
343   return result;
344 }
345 
propagate_names_for_addmv(const Tensor & mat,const Tensor & vec,const Tensor & bias)346 std::vector<Dimname> propagate_names_for_addmv(
347     const Tensor& mat,
348     const Tensor& vec,
349     const Tensor& bias) {
350   if (!mat.has_names() &&
351       !vec.has_names() && !bias.has_names()) {
352     return std::vector<Dimname>{};
353   }
354   auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names());
355   return unify_from_right(mv_outnames, bias.names());
356 }
357 
propagate_names_for_addmm(const Tensor & m1,const Tensor & m2,const Tensor & bias)358 std::vector<Dimname> propagate_names_for_addmm(
359     const Tensor& m1,
360     const Tensor& m2,
361     const Tensor& bias) {
362   if (!m1.has_names() && !m2.has_names() &&
363       !bias.has_names()) {
364     return std::vector<Dimname>{};
365   }
366 
367   auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names());
368   return unify_from_right(mm_outnames, bias.names());
369 }
370 
check_names_for_dot(TensorImpl * vec1,TensorImpl * vec2)371 void check_names_for_dot(
372     TensorImpl* vec1,
373     TensorImpl* vec2) {
374   if (!impl::has_names(vec1) && !impl::has_names(vec2)) {
375     return;
376   }
377   compute_matmul_outnames(impl::get_names(vec1), impl::get_names(vec2));
378 }
379 
380 // expand adds new None dimensions. This is consistent with name inference
381 // rules for binary ops that expect the named dims to line up positionally
382 // from the right. i.e.,
383 // Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
propagate_names_for_expand(const Tensor & result,const Tensor & self)384 void propagate_names_for_expand(const Tensor& result, const Tensor& self) {
385   if (!self.has_names()) {
386     return;
387   }
388   auto result_dim = result.dim();
389   if (self.dim() == result_dim) {
390     propagate_names(result, self);
391     return;
392   }
393   std::vector<Dimname> outnames(result_dim, Dimname::wildcard());
394   std::copy(
395       self.opt_names()->begin(),
396       self.opt_names()->end(),
397       outnames.begin() + result_dim - self.dim());
398   propagate_names(result, outnames);
399 }
400 
compute_broadcast_outnames(const Tensor & self,const Tensor & other)401 std::vector<Dimname> compute_broadcast_outnames(
402     const Tensor& self,
403     const Tensor& other) {
404   if (!self.has_names() && !other.has_names()) {
405     return {};
406   }
407   return unify_from_right(self.names(), other.names());
408 }
409 
broadcast_to_outnames(const Tensor & tensor,const Tensor & reference_tensor,const char * op_name)410 std::vector<Dimname> broadcast_to_outnames(
411     const Tensor& tensor,
412     const Tensor& reference_tensor,
413     const char* op_name) {
414   if (!tensor.has_names() && !reference_tensor.has_names()) {
415     return {};
416   }
417   auto reference_names = reference_tensor.names();
418   auto tensor_names = tensor.names();
419   TORCH_CHECK(
420       reference_names.size() >= tensor_names.size(),
421       op_name, ": attempted to broadcast Tensor", tensor_names, " to Tensor",
422       reference_names, " but the number of dims (", tensor_names.size(),
423       ") must be less than or equal to the number of dims in the tensor (",
424       reference_names.size(), ")");
425   return unify_from_right(reference_names, tensor_names);
426 }
427 
compute_cat_outnames(const MaterializedITensorListRef & tensors)428 std::vector<Dimname> compute_cat_outnames(const MaterializedITensorListRef& tensors) {
429   if (!at::has_names(tensors)) {
430     return {};
431   }
432   std::vector<Dimname> result;
433   for (const Tensor& tensor : tensors) {
434     const auto tensor_names = tensor.names();
435     TORCH_CHECK(!tensor_names.empty(), "zero-dimensional tensor cannot be concatenated");
436     TORCH_CHECK(result.empty() || tensor_names.size() == result.size(),
437         "Tensors must have same number of dimensions: got ", result.size(),
438         " and ", tensor_names.size());
439     result = unify_from_right(result, tensor_names, "cat");
440   }
441   return result;
442 }
443 
compute_matmul_outnames(const Tensor & self,const Tensor & other)444 std::vector<Dimname> compute_matmul_outnames(
445     const Tensor& self,
446     const Tensor& other) {
447   if (!self.has_names() && !other.has_names()) {
448     return {};
449   }
450   return compute_matmul_outnames(self.names(), other.names());
451 }
452 
compute_cdist_outnames(const Tensor & self,const Tensor & other)453 std::vector<Dimname> compute_cdist_outnames(
454     const Tensor& self,
455     const Tensor& other) {
456   if (!self.has_names() && !other.has_names()) {
457     return {};
458   }
459   const auto self_names = self.names();
460   const auto other_names = other.names();
461 
462   auto self_batch = TensorNames(self_names, 0, num_batch_dims(self_names));
463   const auto other_batch = TensorNames(other_names, 0, num_batch_dims(other_names));
464 
465   auto& result = self_batch.unifyFromRightInplace(other_batch, "cdist");
466 
467   // cdist treats self and other like batches of M x D and N X D tensors, respectively.
468   // It computes the pairwise distance between each of the M vectors (of size D)
469   // in `self` and each of the N vectors in `other`, returning a batch of M x N
470   // distance values. We propagate the names of the dimension of size M (in self)
471   // and the dimension of size N (in other), both of which are second-from-last.
472   result.append(TensorName(self_names, -2));
473   result.append(TensorName(other_names, -2));
474   result.checkUnique("cdist");
475 
476   return result.toDimnameVec();
477 }
478 
compute_bmm_outnames(const Tensor & result,const Tensor & self,const Tensor & other)479 std::vector<Dimname> compute_bmm_outnames(
480     const Tensor& result,
481     const Tensor& self,
482     const Tensor& other) {
483   if (!result.has_names() && !self.has_names() && !other.has_names()) {
484     return {};
485   }
486   return compute_matmul_outnames(self.names(), other.names());
487 }
488 
compute_baddbmm_outnames(const Tensor & result,const Tensor & self,const Tensor & other,const Tensor & bias)489 std::vector<Dimname> compute_baddbmm_outnames(
490     const Tensor& result,
491     const Tensor& self,
492     const Tensor& other,
493     const Tensor& bias) {
494   if (!result.has_names() && !self.has_names()
495     && !other.has_names() && !bias.has_names()) {
496     return {};
497   }
498   auto bmm_names = compute_matmul_outnames(self.names(), other.names());
499   auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
500   return baddbmm_names;
501 }
502 
are_names_equal(TensorImpl * self,TensorImpl * other)503 bool are_names_equal(TensorImpl* self, TensorImpl* other) {
504   if (!impl::has_names(self) && !impl::has_names(other)) {
505     return true;
506   }
507   return impl::get_names(self) == impl::get_names(other);
508 }
509 
510 } // namespace namedinference
511 } // namespace at
512