1 // aten_interned_strings.h includes the names of all operators
2 #undef TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4 #include <ATen/core/interned_strings.h>
5 #include <cstdint>
6 #include <cstring>
7 #include <mutex>
8 #include <sstream>
9 #include <string>
10 #include <c10/util/Exception.h>
11 #include <ATen/core/interned_strings_class.h>
12
13 namespace c10 {
14
domain_prefix()15 const std::string& domain_prefix() {
16 static const std::string _domain_prefix = "org.pytorch.";
17 return _domain_prefix;
18 }
19
symbol(const std::string & s)20 Symbol InternedStrings::symbol(const std::string& s) {
21 std::lock_guard<std::mutex> guard(mutex_);
22 return _symbol(s);
23 }
24
string(Symbol sym)25 std::pair<const char*, const char*> InternedStrings::string(Symbol sym) {
26 // Builtin Symbols are also in the maps, but
27 // we can bypass the need to acquire a lock
28 // to read the map for Builtins because we already
29 // know their string value
30 #if defined C10_MOBILE
31 return customString(sym);
32 #else
33 switch (sym) {
34 #define DEFINE_CASE(ns, s) \
35 case static_cast<unique_t>(ns::s): \
36 return {#ns "::" #s, #s};
37 FORALL_NS_SYMBOLS(DEFINE_CASE)
38 #undef DEFINE_CASE
39 default:
40 return customString(sym);
41 }
42 #endif
43 }
44
ns(Symbol sym)45 Symbol InternedStrings::ns(Symbol sym) {
46 #if defined C10_MOBILE
47 std::lock_guard<std::mutex> guard(mutex_);
48 return sym_to_info_.at(sym).ns;
49 #else
50 switch (sym) {
51 #define DEFINE_CASE(ns, s) \
52 case static_cast<unique_t>(ns::s): \
53 return namespaces::ns;
54 // NOLINTNEXTLINE(bugprone-branch-clone)
55 FORALL_NS_SYMBOLS(DEFINE_CASE)
56 #undef DEFINE_CASE
57 default: {
58 std::lock_guard<std::mutex> guard(mutex_);
59 return sym_to_info_.at(sym).ns;
60 }
61 }
62 #endif
63 }
64
_symbol(const std::string & s)65 Symbol InternedStrings::_symbol(const std::string& s) {
66 auto it = string_to_sym_.find(s);
67 if (it != string_to_sym_.end())
68 return it->second;
69
70 auto pos = s.find("::");
71 if (pos == std::string::npos) {
72 std::stringstream ss;
73 ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
74 throw std::runtime_error(ss.str());
75 }
76 Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
77
78 Symbol sym(sym_to_info_.size());
79 string_to_sym_[s] = sym;
80 sym_to_info_.push_back({ns, s, s.substr(pos + strlen("::"))});
81 return sym;
82 }
83
customString(Symbol sym)84 std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) {
85 std::lock_guard<std::mutex> guard(mutex_);
86 SymbolInfo& s = sym_to_info_.at(sym);
87 return {s.qual_name.c_str(), s.unqual_name.c_str()};
88 }
89
globalStrings()90 static InternedStrings & globalStrings() {
91 static InternedStrings s;
92 return s;
93 }
94
fromQualString(const std::string & s)95 Symbol Symbol::fromQualString(const std::string & s) {
96 return globalStrings().symbol(s);
97 }
98
toUnqualString() const99 const char * Symbol::toUnqualString() const {
100 return globalStrings().string(*this).second;
101 }
102
toQualString() const103 const char * Symbol::toQualString() const {
104 return globalStrings().string(*this).first;
105 }
106
toDisplayString() const107 const char * Symbol::toDisplayString() const {
108 // TODO: Make this actually return something that's "user friendly".
109 // The trouble is that, for this to be usable in printf-style assert
110 // statements, this has to return a const char* (whose lifetime is
111 // global), so we can't actually assemble a string on the fly.
112 return toQualString();
113 }
114
ns() const115 Symbol Symbol::ns() const {
116 return globalStrings().ns(*this);
117 }
118
domainString() const119 std::string Symbol::domainString() const {
120 return domain_prefix() + ns().toUnqualString();
121 }
122
fromDomainAndUnqualString(const std::string & d,const std::string & s)123 Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
124 if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
125 std::ostringstream ss;
126 ss << "Symbol: domain string is expected to be prefixed with '"
127 << domain_prefix() << "', e.g. 'org.pytorch.aten'";
128 throw std::runtime_error(ss.str());
129 }
130 std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
131 return fromQualString(qualString);
132 }
133
is_attr() const134 bool Symbol::is_attr() const { return ns() == namespaces::attr; }
is_aten() const135 bool Symbol::is_aten() const { return ns() == namespaces::aten; }
is_cuda() const136 bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
is_prim() const137 bool Symbol::is_prim() const { return ns() == namespaces::prim; }
is_prims() const138 bool Symbol::is_prims() const { return ns() == namespaces::prims; }
is_nvprims() const139 bool Symbol::is_nvprims() const { return ns() == namespaces::nvprims; }
is_onnx() const140 bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
is_user() const141 bool Symbol::is_user() const { return ns() == namespaces::user; }
is_caffe2() const142 bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
is_dimname() const143 bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }
144
145 } // namespace c10
146