xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/interned_strings.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // aten_interned_strings.h includes the names of all operators
2 #undef TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/core/interned_strings.h>
5 #include <cstdint>
6 #include <cstring>
7 #include <mutex>
8 #include <sstream>
9 #include <string>
10 #include <c10/util/Exception.h>
11 #include <ATen/core/interned_strings_class.h>
12 
13 namespace c10 {
14 
domain_prefix()15 const std::string& domain_prefix() {
16   static const std::string _domain_prefix = "org.pytorch.";
17   return _domain_prefix;
18 }
19 
symbol(const std::string & s)20 Symbol InternedStrings::symbol(const std::string& s) {
21   std::lock_guard<std::mutex> guard(mutex_);
22   return _symbol(s);
23 }
24 
string(Symbol sym)25 std::pair<const char*, const char*> InternedStrings::string(Symbol sym) {
26   // Builtin Symbols are also in the maps, but
27   // we can bypass the need to acquire a lock
28   // to read the map for Builtins because we already
29   // know their string value
30 #if defined C10_MOBILE
31   return customString(sym);
32 #else
33   switch (sym) {
34 #define DEFINE_CASE(ns, s) \
35   case static_cast<unique_t>(ns::s): \
36     return {#ns "::" #s, #s};
37     FORALL_NS_SYMBOLS(DEFINE_CASE)
38 #undef DEFINE_CASE
39     default:
40       return customString(sym);
41   }
42 #endif
43 }
44 
ns(Symbol sym)45 Symbol InternedStrings::ns(Symbol sym) {
46 #if defined C10_MOBILE
47   std::lock_guard<std::mutex> guard(mutex_);
48   return sym_to_info_.at(sym).ns;
49 #else
50   switch (sym) {
51 #define DEFINE_CASE(ns, s) \
52   case static_cast<unique_t>(ns::s): \
53     return namespaces::ns;
54     // NOLINTNEXTLINE(bugprone-branch-clone)
55     FORALL_NS_SYMBOLS(DEFINE_CASE)
56 #undef DEFINE_CASE
57     default: {
58       std::lock_guard<std::mutex> guard(mutex_);
59       return sym_to_info_.at(sym).ns;
60     }
61   }
62 #endif
63 }
64 
_symbol(const std::string & s)65 Symbol InternedStrings::_symbol(const std::string& s) {
66   auto it = string_to_sym_.find(s);
67   if (it != string_to_sym_.end())
68     return it->second;
69 
70   auto pos = s.find("::");
71   if (pos == std::string::npos) {
72     std::stringstream ss;
73     ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
74     throw std::runtime_error(ss.str());
75   }
76   Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
77 
78   Symbol sym(sym_to_info_.size());
79   string_to_sym_[s] = sym;
80   sym_to_info_.push_back({ns, s, s.substr(pos + strlen("::"))});
81   return sym;
82 }
83 
customString(Symbol sym)84 std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) {
85   std::lock_guard<std::mutex> guard(mutex_);
86   SymbolInfo& s = sym_to_info_.at(sym);
87   return {s.qual_name.c_str(), s.unqual_name.c_str()};
88 }
89 
globalStrings()90 static InternedStrings & globalStrings() {
91   static InternedStrings s;
92   return s;
93 }
94 
fromQualString(const std::string & s)95 Symbol Symbol::fromQualString(const std::string & s) {
96   return globalStrings().symbol(s);
97 }
98 
toUnqualString() const99 const char * Symbol::toUnqualString() const {
100   return globalStrings().string(*this).second;
101 }
102 
toQualString() const103 const char * Symbol::toQualString() const {
104   return globalStrings().string(*this).first;
105 }
106 
toDisplayString() const107 const char * Symbol::toDisplayString() const {
108   // TODO: Make this actually return something that's "user friendly".
109   // The trouble is that, for this to be usable in printf-style assert
110   // statements, this has to return a const char* (whose lifetime is
111   // global), so we can't actually assemble a string on the fly.
112   return toQualString();
113 }
114 
ns() const115 Symbol Symbol::ns() const {
116   return globalStrings().ns(*this);
117 }
118 
domainString() const119 std::string Symbol::domainString() const {
120   return domain_prefix() + ns().toUnqualString();
121 }
122 
fromDomainAndUnqualString(const std::string & d,const std::string & s)123 Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
124   if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
125     std::ostringstream ss;
126     ss << "Symbol: domain string is expected to be prefixed with '"
127        << domain_prefix() << "', e.g. 'org.pytorch.aten'";
128     throw std::runtime_error(ss.str());
129   }
130   std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
131   return fromQualString(qualString);
132 }
133 
is_attr() const134 bool Symbol::is_attr() const { return ns() == namespaces::attr; }
is_aten() const135 bool Symbol::is_aten() const { return ns() == namespaces::aten; }
is_cuda() const136 bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
is_prim() const137 bool Symbol::is_prim() const { return ns() == namespaces::prim; }
is_prims() const138 bool Symbol::is_prims() const { return ns() == namespaces::prims; }
is_nvprims() const139 bool Symbol::is_nvprims() const { return ns() == namespaces::nvprims; }
is_onnx() const140 bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
is_user() const141 bool Symbol::is_user() const { return ns() == namespaces::user; }
is_caffe2() const142 bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
is_dimname() const143 bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }
144 
145 } // namespace c10
146