xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/symbol.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Export.h>
3 #include <cstdint>
4 #include <functional>  // For std::hash
5 #include <string>
6 
7 
8 namespace c10 {
9 
10 // 'prim' symbols are synthetic operators that occur only in the IR
11 // and don't have corresponding implementations in ATen.
12 
13 // 'onnx' symbols correspond to ONNX operators.  Their semantics
14 // are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
15 // The particular version we are targeting is specified by '_onnx_opset_version'
16 // in torch.onnx.symbolic_helper
17 //
18 // In general, most ONNX operators won't get an entry here, because they
19 // are handled from the Python end.  However, you may occasionally need
20 // to intern an ONNX symbol here so that you can conveniently write an
21 // optimization on ONNX operations.
22 
23 // 'attr' symbols are attribute keys.  They are shared between both ONNX and ATen
24 // operators (you disambiguate their meaning by looking at the operator itself).
25 // In general, you only need to define attribute keys that are used by
26 // onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
27 
28 // Note [Symbol allocation]
29 // ~~~~~~~~~~~~~~~~~~~~~~~~
30 //
31 //  1. Symbol namespace is split up into namespaces.
32 //
33 //  2. The intended access pattern for built-in symbols is onnx::MatMul
34 //  in the c10 namespace (this is a Symbol).
35 //
36 
37 // Built-in constant definition strategy:
38 // - Enum is the most convenient way to generate a contiguous sequence
39 //   of numbers for an identifier.
40 // - However, an enum gives you a fresh type.  We want onnx::MatMul to
41 //   be type Symbol, not some random enum type!
42 // - Therefore, after using enums to generate the sequence of integers,
43 //   we then declare constexpr Symbols to get everything the actual Symbol
44 //   type we want.  Symbols must be constexpr to be valid to be "case"ed on.
45 
46 using unique_t = uint32_t;
47 
48 const std::string& domain_prefix();
49 
50 // A Symbol is like an interned string, but with a little extra
51 // structure; it is namespaced via SymbolNamespace and the resulting
52 // intern pointers support efficient namespace testing.
53 struct TORCH_API Symbol {
SymbolSymbol54   explicit constexpr Symbol() : value(0) {};
SymbolSymbol55   explicit constexpr Symbol(unique_t uniq)
56   : value(uniq) {}
57 
58   // Get a Symbol for a qualified string like "attr::bar"
59   static Symbol fromQualString(const std::string & s);
60 
61   // Get a Symbol from a domain and an unqualified string like "org.pytorch.attr" and "bar"
62   static Symbol fromDomainAndUnqualString(const std::string & d, const std::string & s);
63 
64   // Constructors for our various namespaced strings.  This will construct
65   // the appropriate namespaced string, e.g., "attr::foo" for the
66   // argument "foo", and then attempt to intern it.  DO NOT USE THIS
67   // with a string literal; attr::foo should be available in that case
68   // (and if it's not, you should add it to the built-ins list above.)
69   static Symbol attr(const std::string & s);
70   static Symbol aten(const std::string & s);
71   static Symbol cuda(const std::string & s);
72   static Symbol onnx(const std::string & s);
73   static Symbol prim(const std::string & s);
74   static Symbol user(const std::string & s);
75   static Symbol caffe2(const std::string & s);
76   static Symbol dimname(const std::string & s);
77   // TODO: eliminate me
78   static Symbol scope(const std::string & s);
79 
80   bool is_attr() const;
81   bool is_aten() const;
82   bool is_cuda() const;
83   bool is_prim() const;
84   bool is_prims() const;
85   bool is_nvprims() const;
86   bool is_onnx() const;
87   bool is_user() const;
88   bool is_caffe2() const;
89   bool is_dimname() const;
90 
91   // So we can switch on this
unique_tSymbol92   constexpr operator unique_t() const {
93     return value;
94   }
95 
96   Symbol ns() const;
97 
98   // Give a string corresponding to the unqualified version of this name, e.g.,
99   // "mm". Use this in a context where the intended namespace of the string is
100   // obvious; this is a *lossy* conversion.
101   const char * toUnqualString() const;
102 
103   // Give a string corresponding to the qualified version of this name,
104   // e.g., "aten::mm".  This string format is made available to Python bindings
105   // (so we know how to parse it.)
106   const char * toQualString() const;
107 
108   // This describes a symbol in a case where humans read it.  At the moment it's
109   // the same as toQualString.  This has to be a const char* returned because
110   // a lot of printf style macros use it.
111   const char * toDisplayString() const;
112 
113   // Give a string corresponding to the domain name for the symbol,
114   // e.g., "org.pytorch.aten".
115   std::string domainString() const;
116 
117 private:
118 
119   explicit Symbol(Symbol ns, const std::string & s);
120   unique_t value;
121 };
122 
123 static inline bool operator==(Symbol lhs, Symbol rhs) {
124   return static_cast<unique_t>(lhs) == static_cast<unique_t>(rhs);
125 }
126 
attr(const std::string & s)127 inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
aten(const std::string & s)128 inline Symbol Symbol::aten(const std::string & s)  { return Symbol::fromQualString("aten::" + s); }
cuda(const std::string & s)129 inline Symbol Symbol::cuda(const std::string & s)  { return Symbol::fromQualString("cuda::" + s); }
onnx(const std::string & s)130 inline Symbol Symbol::onnx(const std::string & s)  { return Symbol::fromQualString("onnx::" + s); }
prim(const std::string & s)131 inline Symbol Symbol::prim(const std::string & s)  { return Symbol::fromQualString("prim::" + s); }
scope(const std::string & s)132 inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
user(const std::string & s)133 inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
caffe2(const std::string & s)134 inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); }
dimname(const std::string & s)135 inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
136 
137 } // namespace c10
138 
139 // make symbol behave like an integer in hash tables
140 namespace std {
141 template <>
142 struct hash<c10::Symbol> {
143   size_t operator()(c10::Symbol s) const {
144     return std::hash<uint32_t>()(static_cast<uint32_t>(s));
145   }
146 };
147 }
148