1 #include <ATen/TensorNames.h>
2 #include <ATen/WrapDimUtils.h>
3 #include <c10/util/irange.h>
4
5 namespace at::namedinference {
6
7
toDimname() const8 Dimname TensorName::toDimname() const {
9 return name_;
10 }
11
unify(const TensorName & other,const char * op_name) const12 const TensorName& TensorName::unify(const TensorName& other, const char* op_name) const {
13 // unify(None, None)
14 if (name_.isWildcard() && other.name_.isWildcard()) {
15 return *this;
16 }
17
18 // unify(A, A)
19 if (name_ == other.name_) {
20 return *this;
21 }
22
23 // unify(A, None)
24 if (other.name_.isWildcard()) {
25 const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_);
26 TORCH_CHECK(it == other.origin_.end(),
27 op_name, ":",
28 " Cannot match ", *this, " with ", other,
29 " because the latter names already have ", name_, ".",
30 " Are your tensors misaligned?");
31 return *this;
32 }
33
34 // unify(None, A)
35 if (name_.isWildcard()) {
36 return other.unify(*this, op_name);
37 }
38
39 // unify(A, B)
40 TORCH_CHECK(name_ == other.name_,
41 op_name, ":",
42 " Expected ", *this,
43 " to match ", other,
44 " but they do not match.");
45 return *this;
46 }
47
TensorNames(ArrayRef<Dimname> names)48 TensorNames::TensorNames(ArrayRef<Dimname> names) {
49 names_.reserve(names.size());
50 for (const auto idx : c10::irange(names.size())) {
51 names_.emplace_back(names, idx);
52 }
53 }
54
TensorNames(ArrayRef<Dimname> names,int64_t start,int64_t end)55 TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
56 int64_t names_size = static_cast<int64_t>(names.size());
57 start = maybe_wrap_dim(start, names_size);
58 end = maybe_wrap_dim(end, names_size);
59 names_.reserve(end - start);
60 for (const auto idx : c10::irange(start, end)) {
61 names_.emplace_back(names, idx);
62 }
63 }
64
unifyFromRightInplace(const TensorNames & other,const char * op_name)65 TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
66
67 if (names_.size() > other.names_.size()) {
68 const auto size_diff = names_.size() - other.names_.size();
69 for (const auto idx : c10::irange(size_diff, names_.size())) {
70 names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
71 }
72 } else {
73 const auto size_diff = other.names_.size() - names_.size();
74 // pad names_ to the same length as other.names_ before unification
75 names_.insert(
76 names_.begin(),
77 other.names_.begin(),
78 other.names_.begin() + size_diff);
79 for (const auto idx : c10::irange(size_diff, names_.size())) {
80 names_[idx] = names_[idx].unify(other.names_[idx], op_name);
81 }
82 }
83
84 return *this;
85 }
86
append(TensorName name)87 void TensorNames::append(TensorName name) {
88 names_.emplace_back(name);
89 }
90
checkUnique(const char * op_name) const91 void TensorNames::checkUnique(const char* op_name) const {
92 // O(N^2), but named tensors can have at most N = 64 dimensions, so this
93 // doesn't matter unless benchmarking tells us it does. The alternative is
94 // to create some sort of set data structure but the overhead of that
95 // might dominate for small sizes.
96 for (auto it = names_.begin(); it != names_.end(); ++it) {
97 const auto name = it->toDimname();
98 if (name.isWildcard()) continue;
99
100 auto dup = std::find_if(it + 1, names_.end(),
101 [&](const TensorName& other) { return other.toDimname() == name; });
102 TORCH_CHECK(dup == names_.end(),
103 op_name, ": ",
104 "Attempted to propagate dims ", *it, " and ", *dup, " to the output, ",
105 "but that would create a tensor with duplicate names [", toDimnameVec(),
106 "]. Please rename your inputs with Tensor.rename to prevent this.");
107 }
108 }
109
110 // Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W'].
111 // It should print like:
112 // 'C' (index 1 of ['N', 'C', 'H', 'W'])
operator <<(std::ostream & out,const TensorName & tensorname)113 std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
114 out << tensorname.name_ << " (index ";
115 out << tensorname.origin_idx_ << " of ";
116 out << tensorname.origin_ << ")";
117 return out;
118 }
119
toDimnameVec() const120 std::vector<Dimname> TensorNames::toDimnameVec() const {
121 std::vector<Dimname> result;
122 result.reserve(names_.size());
123 for (const auto& tensor_name : names_) {
124 result.emplace_back(tensor_name.toDimname());
125 }
126 return result;
127 }
128
129
130 } // namespace at::namedinference
131