1 #include <ATen/core/Dimname.h>
2 #include <c10/util/Exception.h>
3 #include <cctype>
4
5 namespace at {
6
7 static Symbol kWildcard = Symbol::dimname("*");
8
operator <<(std::ostream & out,const Dimname & dimname)9 std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
10 if (dimname.type() == NameType::WILDCARD) {
11 out << "None";
12 } else {
13 out << "'" << dimname.symbol().toUnqualString() << "'";
14 }
15 return out;
16 }
17
isValidName(const std::string & name)18 bool Dimname::isValidName(const std::string& name) {
19 // allow valid ASCII python identifiers: "uppercase and lowercase
20 // letters A through Z, the underscore _ and, except for the first
21 // character, the digits 0 through 9" (at least length 1)
22 // https://docs.python.org/3/reference/lexical_analysis.html#identifiers
23 if (name.empty()) {
24 return false;
25 }
26 for (auto it = name.begin(); it != name.end(); ++it) {
27 // NOLINTNEXTLINE(bugprone-branch-clone)
28 if (std::isalpha(*it) || *it == '_') {
29 continue;
30 } else if (it != name.begin() && std::isdigit(*it)) {
31 continue;
32 }
33 return false;
34 }
35 return true;
36 }
37
check_valid_identifier(const std::string & name)38 static void check_valid_identifier(const std::string& name) {
39 TORCH_CHECK(
40 Dimname::isValidName(name),
41 "Invalid name: a valid identifier contains only digits, alphabetical "
42 "characters, and/or underscore and starts with a non-digit. got: '",
43 name, "'.");
44 }
45
fromSymbol(Symbol name)46 Dimname Dimname::fromSymbol(Symbol name) {
47 TORCH_INTERNAL_ASSERT(name.is_dimname());
48 if (name == kWildcard) {
49 return Dimname::wildcard();
50 }
51 check_valid_identifier(name.toUnqualString());
52 return Dimname(name);
53 }
54
wildcard()55 Dimname Dimname::wildcard() {
56 static Dimname result(kWildcard, NameType::WILDCARD);
57 return result;
58 }
59
unify(Dimname other) const60 std::optional<Dimname> Dimname::unify(Dimname other) const {
61 if (other.type() == NameType::WILDCARD) {
62 return *this;
63 }
64 if (type_ == NameType::WILDCARD) {
65 return other;
66 }
67 if (name_ == other.symbol()) {
68 return *this;
69 }
70 return std::nullopt;
71 }
72
matches(Dimname other) const73 bool Dimname::matches(Dimname other) const {
74 return unify(other).has_value();
75 }
76
77 } // namespace at
78