xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Dimname.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Dimname.h>
2 #include <c10/util/Exception.h>
3 #include <cctype>
4 
5 namespace at {
6 
7 static Symbol kWildcard = Symbol::dimname("*");
8 
operator <<(std::ostream & out,const Dimname & dimname)9 std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
10   if (dimname.type() == NameType::WILDCARD) {
11     out << "None";
12   } else {
13     out << "'" << dimname.symbol().toUnqualString() << "'";
14   }
15   return out;
16 }
17 
isValidName(const std::string & name)18 bool Dimname::isValidName(const std::string& name) {
19   // allow valid ASCII python identifiers: "uppercase and lowercase
20   // letters A through Z, the underscore _ and, except for the first
21   // character, the digits 0 through 9" (at least length 1)
22   // https://docs.python.org/3/reference/lexical_analysis.html#identifiers
23   if (name.empty()) {
24     return false;
25   }
26   for (auto it = name.begin(); it != name.end(); ++it) {
27     // NOLINTNEXTLINE(bugprone-branch-clone)
28     if (std::isalpha(*it) || *it == '_') {
29       continue;
30     } else if (it != name.begin() && std::isdigit(*it)) {
31       continue;
32     }
33     return false;
34   }
35   return true;
36 }
37 
check_valid_identifier(const std::string & name)38 static void check_valid_identifier(const std::string& name) {
39   TORCH_CHECK(
40       Dimname::isValidName(name),
41       "Invalid name: a valid identifier contains only digits, alphabetical "
42       "characters, and/or underscore and starts with a non-digit. got: '",
43       name, "'.");
44 }
45 
fromSymbol(Symbol name)46 Dimname Dimname::fromSymbol(Symbol name) {
47   TORCH_INTERNAL_ASSERT(name.is_dimname());
48   if (name == kWildcard) {
49     return Dimname::wildcard();
50   }
51   check_valid_identifier(name.toUnqualString());
52   return Dimname(name);
53 }
54 
wildcard()55 Dimname Dimname::wildcard() {
56   static Dimname result(kWildcard, NameType::WILDCARD);
57   return result;
58 }
59 
unify(Dimname other) const60 std::optional<Dimname> Dimname::unify(Dimname other) const {
61   if (other.type() == NameType::WILDCARD) {
62     return *this;
63   }
64   if (type_ == NameType::WILDCARD) {
65     return other;
66   }
67   if (name_ == other.symbol()) {
68     return *this;
69   }
70   return std::nullopt;
71 }
72 
matches(Dimname other) const73 bool Dimname::matches(Dimname other) const {
74   return unify(other).has_value();
75 }
76 
77 } // namespace at
78