xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorNames.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/TensorNames.h>
2 #include <ATen/WrapDimUtils.h>
3 #include <c10/util/irange.h>
4 
5 namespace at::namedinference {
6 
7 
toDimname() const8 Dimname TensorName::toDimname() const {
9   return name_;
10 }
11 
unify(const TensorName & other,const char * op_name) const12 const TensorName& TensorName::unify(const TensorName& other, const char* op_name) const {
13   // unify(None, None)
14   if (name_.isWildcard() && other.name_.isWildcard()) {
15     return *this;
16   }
17 
18   // unify(A, A)
19   if (name_ == other.name_) {
20     return *this;
21   }
22 
23   // unify(A, None)
24   if (other.name_.isWildcard()) {
25     const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_);
26     TORCH_CHECK(it == other.origin_.end(),
27         op_name, ":",
28         " Cannot match ", *this, " with ", other,
29         " because the latter names already have ", name_, ".",
30         " Are your tensors misaligned?");
31     return *this;
32   }
33 
34   // unify(None, A)
35   if (name_.isWildcard()) {
36     return other.unify(*this, op_name);
37   }
38 
39   // unify(A, B)
40   TORCH_CHECK(name_ == other.name_,
41       op_name, ":",
42       " Expected ", *this,
43       " to match ", other,
44       " but they do not match.");
45   return *this;
46 }
47 
TensorNames(ArrayRef<Dimname> names)48 TensorNames::TensorNames(ArrayRef<Dimname> names) {
49   names_.reserve(names.size());
50   for (const auto idx : c10::irange(names.size())) {
51     names_.emplace_back(names, idx);
52   }
53 }
54 
TensorNames(ArrayRef<Dimname> names,int64_t start,int64_t end)55 TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
56   int64_t names_size = static_cast<int64_t>(names.size());
57   start = maybe_wrap_dim(start, names_size);
58   end = maybe_wrap_dim(end, names_size);
59   names_.reserve(end - start);
60   for (const auto idx : c10::irange(start, end)) {
61     names_.emplace_back(names, idx);
62   }
63 }
64 
unifyFromRightInplace(const TensorNames & other,const char * op_name)65 TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
66 
67   if (names_.size() > other.names_.size()) {
68     const auto size_diff = names_.size() - other.names_.size();
69     for (const auto idx : c10::irange(size_diff, names_.size())) {
70       names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
71     }
72   } else {
73     const auto size_diff = other.names_.size() - names_.size();
74     // pad names_ to the same length as other.names_ before unification
75     names_.insert(
76         names_.begin(),
77         other.names_.begin(),
78         other.names_.begin() + size_diff);
79     for (const auto idx : c10::irange(size_diff, names_.size())) {
80       names_[idx] = names_[idx].unify(other.names_[idx], op_name);
81     }
82   }
83 
84   return *this;
85 }
86 
append(TensorName name)87 void TensorNames::append(TensorName name) {
88   names_.emplace_back(name);
89 }
90 
checkUnique(const char * op_name) const91 void TensorNames::checkUnique(const char* op_name) const {
92   // O(N^2), but named tensors can have at most N = 64 dimensions, so this
93   // doesn't matter unless benchmarking tells us it does. The alternative is
94   // to create some sort of set data structure but the overhead of that
95   // might dominate for small sizes.
96   for (auto it = names_.begin(); it != names_.end(); ++it) {
97     const auto name = it->toDimname();
98     if (name.isWildcard()) continue;
99 
100     auto dup = std::find_if(it + 1, names_.end(),
101         [&](const TensorName& other) { return other.toDimname() == name; });
102     TORCH_CHECK(dup == names_.end(),
103         op_name, ": ",
104         "Attempted to propagate dims ", *it, " and ", *dup, " to the output, ",
105         "but that would create a tensor with duplicate names [", toDimnameVec(),
106         "]. Please rename your inputs with Tensor.rename to prevent this.");
107   }
108 }
109 
110 // Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W'].
111 // It should print like:
112 // 'C' (index 1 of ['N', 'C', 'H', 'W'])
operator <<(std::ostream & out,const TensorName & tensorname)113 std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
114   out << tensorname.name_ << " (index ";
115   out << tensorname.origin_idx_ << " of ";
116   out << tensorname.origin_ << ")";
117   return out;
118 }
119 
toDimnameVec() const120 std::vector<Dimname> TensorNames::toDimnameVec() const {
121   std::vector<Dimname> result;
122   result.reserve(names_.size());
123   for (const auto& tensor_name : names_) {
124     result.emplace_back(tensor_name.toDimname());
125   }
126   return result;
127 }
128 
129 
130 } // namespace at::namedinference
131