1 #include <ATen/NamedTensorUtils.h>
2 #include <ATen/TensorNames.h>
3 #include <ATen/WrapDimUtilsMulti.h>
4 #include <c10/util/irange.h>
5
6 #include <bitset>
7 #include <sstream>
8
9 namespace at {
10
11 #ifndef STRIP_ERROR_MESSAGES
12 // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
toDimnameRepr(const Tensor & tensor)13 static std::string toDimnameRepr(const Tensor& tensor) {
14 std::ostringstream os;
15 os << "Tensor" << tensor.names();
16 return os.str();
17 }
18 #endif
19
dimname_to_position(const Tensor & tensor,Dimname dim)20 int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
21 TORCH_CHECK(dim.type() != NameType::WILDCARD,
22 "Please look up dimensions by name, got: name = None.");
23 TORCH_CHECK(tensor.has_names(),
24 "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
25 const auto names = tensor.names();
26
27 const auto it = std::find(names.begin(), names.end(), dim);
28 TORCH_CHECK(it != names.end(),
29 "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
30
31 return std::distance(names.begin(), it);
32 }
33
dimnames_to_positions(const Tensor & tensor,DimnameList dims)34 std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims) {
35 std::vector<int64_t> result;
36 result.reserve(dims.size());
37 for (const auto& name : dims) {
38 result.push_back(dimname_to_position(tensor, name));
39 }
40 return result;
41 }
42
report_positional_error(const Dimname & name,const Dimname & other_name,DimnameList names,DimnameList other_names,const char * action)43 static void report_positional_error(
44 const Dimname& name,
45 const Dimname& other_name,
46 DimnameList names,
47 DimnameList other_names,
48 const char* action) {
49 // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
50 TORCH_CHECK(false,
51 "Error when attempting to ", action, " dims ", names, " and dims ",
52 other_names, ": dim ", name, " and dim ", other_name, " are at the same position "
53 "from the right but do not match.")
54 }
55
check_for_misalignment(const Dimname & name,DimnameList names,DimnameList other_names,const char * action)56 static void check_for_misalignment(
57 const Dimname& name,
58 DimnameList names,
59 DimnameList other_names,
60 const char* action) {
61 if (name.isWildcard()) {
62 return;
63 }
64 auto it = std::find(other_names.begin(), other_names.end(), name);
65 // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
66 TORCH_CHECK(it == other_names.end(),
67 "Misaligned dims when attempting to ", action, " dims ", names, " and dims ",
68 other_names, ": dim ", name, " appears in a different position from the right "
69 "across both lists.");
70 }
71
72 // Assumption: A DimnameList can have no duplicate full names with
73 // the exception of wildcards
unify_from_right(DimnameList names,DimnameList other_names,const char * action)74 std::vector<Dimname> unify_from_right(
75 DimnameList names,
76 DimnameList other_names,
77 const char* action) {
78 const auto wildcard = Dimname::wildcard();
79 const auto size = std::max(names.size(), other_names.size());
80 auto result = std::vector<Dimname>(size, wildcard);
81
82 auto names_it = names.rbegin();
83 auto other_it = other_names.rbegin();
84 auto result_it = result.rbegin();
85 while (names_it != names.rend() || other_it != other_names.rend()) {
86 const auto& name = names_it == names.rend() ? wildcard : *names_it;
87 const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it;
88
89 // Step 1: Check that the names match
90 const auto maybeName = name.unify(other_name);
91 if (!maybeName) {
92 report_positional_error(name, other_name, names, other_names, action);
93 }
94 *result_it = *maybeName;
95
96 // Step 2: Check that the names are not misaligned
97 if (!name.isBasic() || !other_name.isBasic()) {
98 // Let: N = max(len(names), len(other_names))
99 // K = # of special names among names and other_names.
100 // This search (including the outer loop) is O(N*K) but typically # of dims is small.
101 check_for_misalignment(name, names, other_names, action);
102 check_for_misalignment(other_name, other_names, names, action);
103 }
104
105 if (names_it != names.rend()) {
106 ++names_it;
107 }
108 if (other_it != other_names.rend()) {
109 ++other_it;
110 }
111 ++result_it;
112 }
113 return result;
114 }
115
116 namespace namedinference {
117
118 static std::bitset<dim_bitset_size>
compute_included_idxs(IntArrayRef excluded_idxs,int64_t ndims)119 compute_included_idxs(IntArrayRef excluded_idxs, int64_t ndims) {
120 auto result = dim_list_to_bitset(excluded_idxs, ndims);
121 result.flip();
122 return result;
123 }
124
assert_names_equal(DimnameList a,DimnameList b)125 static void assert_names_equal(DimnameList a, DimnameList b) {
126 TORCH_CHECK(a == b,
127 "Name mismatch: specified out tensor with names ", a,
128 " are not the same as the computed output names ", b,
129 ". Please rename the out tensor's dims with `Tensor.rename`.");
130 }
131
propagate_names_if_present_and_nonempty(const Tensor & result,std::optional<DimnameList> maybe_names,bool validate_names)132 const Tensor& propagate_names_if_present_and_nonempty(const Tensor& result,
133 std::optional<DimnameList> maybe_names,
134 bool validate_names) {
135 auto maybe_name_list = maybe_names.value_or(at::ArrayRef<Dimname>{});
136 propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_name_list, validate_names);
137 return result;
138 }
139
propagate_names_if_nonempty(const Tensor & result,DimnameList maybe_names,bool validate_names)140 const Tensor& propagate_names_if_nonempty(const Tensor& result,
141 DimnameList maybe_names,
142 bool validate_names) {
143 propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
144 return result;
145 }
146
propagate_names_if_nonempty(TensorImpl * result,DimnameList maybe_names,bool validate_names)147 TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
148 DimnameList maybe_names,
149 bool validate_names) {
150 if (maybe_names.empty()) {
151 return result;
152 }
153 return propagate_names(result, maybe_names, validate_names);
154 }
155
propagate_names(const Tensor & result,DimnameList names,bool validate_names)156 const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) {
157 propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
158 return result;
159 }
160
propagate_names(TensorImpl * result,DimnameList names,bool validate_names)161 TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) {
162 if (result->dim() > 0) {
163 TORCH_INTERNAL_ASSERT(
164 !names.empty(),
165 "propagate_names: passed in empty names to propagate to result with",
166 " shape ", result->sizes(), ". Empty names means that name inference did",
167 "not occur; use `propagate_names_if_nonempty` instead of `propagate_names`.");
168 }
169 if (!impl::has_names(result)) {
170 impl::internal_set_names_inplace(result, names, validate_names);
171 } else {
172 assert_names_equal(impl::get_names(result), names);
173 }
174 return result;
175 }
176
propagate_names_except(const Tensor & result,const Tensor & src,IntArrayRef excluded_idxs)177 void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
178 if (!result.has_names() && !src.has_names()) {
179 return;
180 }
181 const auto src_names = src.names();
182 const auto result_dim = static_cast<int64_t>(result.dim());
183 const auto src_dim = static_cast<int64_t>(src_names.size());
184 const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size());
185 TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim);
186
187 // fast path
188 if (excluded_idxs.size() == 1) {
189 std::vector<Dimname> outnames = src_names.vec();
190 outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim));
191 propagate_names(result, outnames);
192 return;
193 }
194
195 std::vector<Dimname> outnames;
196 outnames.reserve(result_dim);
197 auto included_idxs = compute_included_idxs(excluded_idxs, src_dim);
198 for (const auto dim : c10::irange(src_dim)) {
199 if (included_idxs[dim]) {
200 outnames.push_back(src_names[dim]);
201 }
202 }
203 propagate_names(result, outnames);
204 }
205
propagate_names_for_reduction(const Tensor & result,const Tensor & src,IntArrayRef reduced_dims,bool keepdim)206 void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
207 if (keepdim) {
208 propagate_names(result, src);
209 return;
210 }
211 // This actually means "full reduction"
212 if (reduced_dims.empty()) {
213 return;
214 }
215 propagate_names_except(result, src, reduced_dims);
216 }
217
propagate_names(const Tensor & result,const Tensor & src)218 void propagate_names(const Tensor& result, const Tensor& src) {
219 propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
220 }
221
propagate_names(TensorImpl * result,TensorImpl * src)222 void propagate_names(TensorImpl* result, TensorImpl* src) {
223 if (result == src) {
224 return;
225 }
226 if (!impl::has_names(result) && !impl::has_names(src)) {
227 return;
228 }
229 propagate_names(result, impl::get_names(src));
230 }
231
compute_squeeze_outnames(const Tensor & tensor)232 std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
233 if (!tensor.has_names()) {
234 return {};
235 }
236 std::vector<Dimname> outnames;
237 auto tensor_names = tensor.names();
238 for (const auto d : c10::irange(tensor.dim())) {
239 if (tensor.sym_sizes()[d] != 1) {
240 outnames.push_back(tensor_names[d]);
241 }
242 }
243 return outnames;
244 }
245
compute_squeeze_outnames(const Tensor & tensor,std::bitset<dim_bitset_size> dims)246 std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
247 if (!tensor.has_names()) {
248 return {};
249 }
250 std::vector<Dimname> outnames;
251 auto tensor_names = tensor.names();
252 for (const auto d : c10::irange(tensor.dim())) {
253 if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
254 outnames.push_back(tensor_names[d]);
255 }
256 }
257 return outnames;
258 }
259
compute_diagonal_outnames(const Tensor & tensor,int64_t dim1,int64_t dim2)260 std::vector<Dimname> compute_diagonal_outnames(
261 const Tensor& tensor,
262 int64_t dim1,
263 int64_t dim2) {
264 if (!tensor.has_names()) {
265 return {};
266 }
267 std::vector<Dimname> outnames;
268 auto tensor_names = tensor.names();
269 for (const auto d : c10::irange(tensor.dim())) {
270 if (d == dim1 || d == dim2) {
271 continue;
272 }
273 outnames.push_back(tensor_names[d]);
274 }
275 outnames.push_back(Dimname::wildcard());
276 return outnames;
277 }
278
check_feature_names_are_distinct(DimnameList self_names,DimnameList other_names,const DimnameList & outnames)279 static void check_feature_names_are_distinct(
280 DimnameList self_names,
281 DimnameList other_names,
282 const DimnameList& outnames) {
283 if (self_names.size() < 2 || other_names.size() < 2) {
284 // There are less than 2 feature dims in outnames so there is nothing to check
285 return;
286 }
287 auto feature0 = outnames[outnames.size() - 2];
288 auto feature1 = outnames[outnames.size() - 1];
289 TORCH_CHECK(
290 feature0 == Dimname::wildcard() || feature0 != feature1,
291 "Matrix multiplying Tensor", self_names,
292 " with Tensor", other_names,
293 " would produce output tensor with duplicate names ",
294 outnames,
295 ". Please rename the input tensors with `Tensor.rename` to prevent this.");
296 }
297
num_batch_dims(DimnameList names)298 static int64_t num_batch_dims(DimnameList names) {
299 if (names.size() <= 2) {
300 return 0;
301 }
302 return static_cast<int64_t>(names.size() - 2);
303 }
304
compute_matmul_outnames(DimnameList self_names,DimnameList other_names)305 static std::vector<Dimname> compute_matmul_outnames(
306 DimnameList self_names,
307 DimnameList other_names) {
308 TORCH_CHECK(!self_names.empty() && !other_names.empty(),
309 "both arguments to matmul need to be at least 1D, but they are ",
310 self_names.size(), "D and ", other_names.size(), "D");
311
312 // matmul performs a batch matrix multiply between self and other, each of which
313 // can either be:
314 // - a batches of matrices (if dim > 2)
315 // - a matrix (if dim == 2)
316 // - a vector (if dim == 1)
317 //
318 // To compute output names, we unify the batch dimensions because those are
319 // broadcastable to get the output batch dimensions.
320 //
321 // After that, we append some names that are equal to the result of the matmul
322 // without batch dimensions. Those names are computed by removing the names
323 // of the dimensions that were contracted away. We always contract the
324 // last dim of the first tensor with the first feature dimension of the second.
325
326 // Get the output's batch dimension names
327 auto wrapped_self_names = TensorNames(self_names, 0, num_batch_dims(self_names));
328 const auto wrapped_other_names = TensorNames(other_names, 0, num_batch_dims(other_names));
329 auto& working_names = wrapped_self_names.unifyFromRightInplace(wrapped_other_names, "matmul");
330
331 // Append the result of each individual (non-batched) matmul.
332 // If either of self or other have dim 1, that means they are a vector. Vectors get
333 // completely contracted away during matmul so we don't take any names from them.
334 if (self_names.size() >= 2) {
335 working_names.append(TensorName(self_names, -2));
336 }
337 if (other_names.size() >= 2) {
338 working_names.append(TensorName(other_names, -1));
339 }
340 auto result = working_names.toDimnameVec();
341
342 check_feature_names_are_distinct(self_names, other_names, result);
343 return result;
344 }
345
propagate_names_for_addmv(const Tensor & mat,const Tensor & vec,const Tensor & bias)346 std::vector<Dimname> propagate_names_for_addmv(
347 const Tensor& mat,
348 const Tensor& vec,
349 const Tensor& bias) {
350 if (!mat.has_names() &&
351 !vec.has_names() && !bias.has_names()) {
352 return std::vector<Dimname>{};
353 }
354 auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names());
355 return unify_from_right(mv_outnames, bias.names());
356 }
357
propagate_names_for_addmm(const Tensor & m1,const Tensor & m2,const Tensor & bias)358 std::vector<Dimname> propagate_names_for_addmm(
359 const Tensor& m1,
360 const Tensor& m2,
361 const Tensor& bias) {
362 if (!m1.has_names() && !m2.has_names() &&
363 !bias.has_names()) {
364 return std::vector<Dimname>{};
365 }
366
367 auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names());
368 return unify_from_right(mm_outnames, bias.names());
369 }
370
check_names_for_dot(TensorImpl * vec1,TensorImpl * vec2)371 void check_names_for_dot(
372 TensorImpl* vec1,
373 TensorImpl* vec2) {
374 if (!impl::has_names(vec1) && !impl::has_names(vec2)) {
375 return;
376 }
377 compute_matmul_outnames(impl::get_names(vec1), impl::get_names(vec2));
378 }
379
380 // expand adds new None dimensions. This is consistent with name inference
381 // rules for binary ops that expect the named dims to line up positionally
382 // from the right. i.e.,
383 // Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
propagate_names_for_expand(const Tensor & result,const Tensor & self)384 void propagate_names_for_expand(const Tensor& result, const Tensor& self) {
385 if (!self.has_names()) {
386 return;
387 }
388 auto result_dim = result.dim();
389 if (self.dim() == result_dim) {
390 propagate_names(result, self);
391 return;
392 }
393 std::vector<Dimname> outnames(result_dim, Dimname::wildcard());
394 std::copy(
395 self.opt_names()->begin(),
396 self.opt_names()->end(),
397 outnames.begin() + result_dim - self.dim());
398 propagate_names(result, outnames);
399 }
400
compute_broadcast_outnames(const Tensor & self,const Tensor & other)401 std::vector<Dimname> compute_broadcast_outnames(
402 const Tensor& self,
403 const Tensor& other) {
404 if (!self.has_names() && !other.has_names()) {
405 return {};
406 }
407 return unify_from_right(self.names(), other.names());
408 }
409
broadcast_to_outnames(const Tensor & tensor,const Tensor & reference_tensor,const char * op_name)410 std::vector<Dimname> broadcast_to_outnames(
411 const Tensor& tensor,
412 const Tensor& reference_tensor,
413 const char* op_name) {
414 if (!tensor.has_names() && !reference_tensor.has_names()) {
415 return {};
416 }
417 auto reference_names = reference_tensor.names();
418 auto tensor_names = tensor.names();
419 TORCH_CHECK(
420 reference_names.size() >= tensor_names.size(),
421 op_name, ": attempted to broadcast Tensor", tensor_names, " to Tensor",
422 reference_names, " but the number of dims (", tensor_names.size(),
423 ") must be less than or equal to the number of dims in the tensor (",
424 reference_names.size(), ")");
425 return unify_from_right(reference_names, tensor_names);
426 }
427
compute_cat_outnames(const MaterializedITensorListRef & tensors)428 std::vector<Dimname> compute_cat_outnames(const MaterializedITensorListRef& tensors) {
429 if (!at::has_names(tensors)) {
430 return {};
431 }
432 std::vector<Dimname> result;
433 for (const Tensor& tensor : tensors) {
434 const auto tensor_names = tensor.names();
435 TORCH_CHECK(!tensor_names.empty(), "zero-dimensional tensor cannot be concatenated");
436 TORCH_CHECK(result.empty() || tensor_names.size() == result.size(),
437 "Tensors must have same number of dimensions: got ", result.size(),
438 " and ", tensor_names.size());
439 result = unify_from_right(result, tensor_names, "cat");
440 }
441 return result;
442 }
443
compute_matmul_outnames(const Tensor & self,const Tensor & other)444 std::vector<Dimname> compute_matmul_outnames(
445 const Tensor& self,
446 const Tensor& other) {
447 if (!self.has_names() && !other.has_names()) {
448 return {};
449 }
450 return compute_matmul_outnames(self.names(), other.names());
451 }
452
compute_cdist_outnames(const Tensor & self,const Tensor & other)453 std::vector<Dimname> compute_cdist_outnames(
454 const Tensor& self,
455 const Tensor& other) {
456 if (!self.has_names() && !other.has_names()) {
457 return {};
458 }
459 const auto self_names = self.names();
460 const auto other_names = other.names();
461
462 auto self_batch = TensorNames(self_names, 0, num_batch_dims(self_names));
463 const auto other_batch = TensorNames(other_names, 0, num_batch_dims(other_names));
464
465 auto& result = self_batch.unifyFromRightInplace(other_batch, "cdist");
466
467 // cdist treats self and other like batches of M x D and N X D tensors, respectively.
468 // It computes the pairwise distance between each of the M vectors (of size D)
469 // in `self` and each of the N vectors in `other`, returning a batch of M x N
470 // distance values. We propagate the names of the dimension of size M (in self)
471 // and the dimension of size N (in other), both of which are second-from-last.
472 result.append(TensorName(self_names, -2));
473 result.append(TensorName(other_names, -2));
474 result.checkUnique("cdist");
475
476 return result.toDimnameVec();
477 }
478
compute_bmm_outnames(const Tensor & result,const Tensor & self,const Tensor & other)479 std::vector<Dimname> compute_bmm_outnames(
480 const Tensor& result,
481 const Tensor& self,
482 const Tensor& other) {
483 if (!result.has_names() && !self.has_names() && !other.has_names()) {
484 return {};
485 }
486 return compute_matmul_outnames(self.names(), other.names());
487 }
488
compute_baddbmm_outnames(const Tensor & result,const Tensor & self,const Tensor & other,const Tensor & bias)489 std::vector<Dimname> compute_baddbmm_outnames(
490 const Tensor& result,
491 const Tensor& self,
492 const Tensor& other,
493 const Tensor& bias) {
494 if (!result.has_names() && !self.has_names()
495 && !other.has_names() && !bias.has_names()) {
496 return {};
497 }
498 auto bmm_names = compute_matmul_outnames(self.names(), other.names());
499 auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
500 return baddbmm_names;
501 }
502
are_names_equal(TensorImpl * self,TensorImpl * other)503 bool are_names_equal(TensorImpl* self, TensorImpl* other) {
504 if (!impl::has_names(self) && !impl::has_names(other)) {
505 return true;
506 }
507 return impl::get_names(self) == impl::get_names(other);
508 }
509
510 } // namespace namedinference
511 } // namespace at
512