xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/library.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/library.h>
2 
3 #include <ATen/core/dispatch/Dispatcher.h>
4 
5 namespace torch {
6 
7 namespace {
8   // TODO: Consider representing debug info as a struct instead so you
9   // don't have to allocate strings all the time
debugString(const char * file,uint32_t line)10   std::string debugString(const char* file, uint32_t line) {
11 #ifdef STRIP_ERROR_MESSAGES
12     return std::string();
13 #else
14     return c10::str("registered at ", file, ":", line);
15 #endif
16   }
17 
debugString(std::string debug,const char * file,uint32_t line)18   std::string debugString(std::string debug, const char* file, uint32_t line) {
19 #ifdef STRIP_ERROR_MESSAGES
20     return std::string();
21 #else
22     if (debug.empty()) {
23       return debugString(file, line);
24     } else {
25       return debug;
26     }
27 #endif
28   }
29 
30 #ifndef STRIP_ERROR_MESSAGES
toString(Library::Kind kind)31   const char* toString(Library::Kind kind) {
32     switch (kind) {
33       case Library::DEF:
34         return "TORCH_LIBRARY";
35       case Library::IMPL:
36         return "TORCH_LIBRARY_IMPL";
37       case Library::FRAGMENT:
38         return "TORCH_LIBRARY_FRAGMENT";
39     }
40     return "(unknown)";
41   }
42 #endif
43 
44   constexpr auto CatchAll = c10::DispatchKey::CatchAll;
45 } // anonymous namespace
46 
CppFunction(c10::KernelFunction func,std::optional<c10::impl::CppSignature> cpp_signature,std::unique_ptr<c10::FunctionSchema> schema)47 CppFunction::CppFunction(c10::KernelFunction func, std::optional<c10::impl::CppSignature> cpp_signature, std::unique_ptr<c10::FunctionSchema> schema)
48   : func_(std::move(func))
49   , cpp_signature_(cpp_signature)
50   , schema_(std::move(schema))
51   , debug_()
52   {}
53 
54 CppFunction::~CppFunction() = default;
55 
reset()56 void Library::reset() {
57   registrars_.clear();
58 }
59 
60 #define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"
61 
Library(Kind kind,std::string ns,std::optional<c10::DispatchKey> k,const char * file,uint32_t line)62 Library::Library(Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line)
63   : kind_(kind)
64   , ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns)))
65   , dispatch_key_(k.value_or(CatchAll) == CatchAll ? std::optional<c10::DispatchKey>() : k)
66   , file_(file)
67   , line_(line)
68   {
69     switch (kind_) {
70       case DEF:
71         // Only DEFs require library uniqueness; fragments
72         // don't register a library
73         registrars_.emplace_back(
74           c10::Dispatcher::singleton().registerLibrary(
75             // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
76             *ns_, debugString(file_, line_)
77           )
78         );
79         [[fallthrough]];
80       case FRAGMENT:
81         TORCH_CHECK(
82           ns_.has_value(),
83           toString(kind_), ": cannot define ", toString(kind_), " with the wildcard namespace _ "
84           "(every ", toString(kind_), " defines operators for a distinct namespace!) "
85           "Did you mean to use TORCH_LIBRARY_IMPL instead?  "
86           ERROR_CONTEXT
87         );
88         TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
89         break;
90       case IMPL:
91         // Nothing to do, everything is OK
92         break;
93     }
94   }
95 
96 // TODO: Error if an operator is def'ed multiple times.  Right now we just
97 // merge everything
98 
99 #define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
_def(c10::FunctionSchema && schema,c10::OperatorName * out_name,const std::vector<at::Tag> & tags,_RegisterOrVerify rv)100 Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags, _RegisterOrVerify rv) & {
101   TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
102     DEF_PRELUDE,
103     "Cannot define an operator inside of a ", toString(kind_), " block.  "
104     "All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  ",
105     ERROR_CONTEXT
106   );
107   TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
108   TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
109   auto ns_opt = schema.getNamespace();
110   if (ns_opt.has_value()) {
111     // Note [Redundancy in registration code is OK]
112     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
113     // In an earlier version of this code, I made it an error to explicitly
114     // specify the namespace, even when the namespaces match.  I've decided
115     // to relax this constraint because sometimes we code generate registrations
116     // and you cannot conveniently tell what the enclosing context will be;
117     // in these cases, it is simpler (and less error prone) to place all
118     // of the information in the registration site, which will be cross-checked
119     // in the end in any case (and if it turns out you DON'T have the right
120     // information at the site, as is the case with backend specific
121     // per-op registrations, you will get the right behavior!)
122     TORCH_CHECK(*ns_opt == *ns_,
123       "Explicitly provided namespace (", *ns_opt, ") in schema string "
124       "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, ").  "
125       "Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
126       "(and consider deleting the namespace from your schema string.)  ",
127       ERROR_CONTEXT
128     );
129   } else {
130     bool b = schema.setNamespaceIfNotSet(ns_->c_str());
131     TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
132   }
133   if (out_name) {
134     *out_name = schema.operator_name(); // copy!
135   }
136   switch (rv) {
137     case _RegisterOrVerify::REGISTER:
138       if (python_module_.has_value()) {
139         registrars_.emplace_back(
140           c10::Dispatcher::singleton().registerPythonModule(
141             schema.operator_name(),
142             python_module_->first,
143             python_module_->second)
144         );
145       }
146       registrars_.emplace_back(
147         c10::Dispatcher::singleton().registerDef(
148           std::move(schema),
149           debugString(file_, line_),
150           tags
151         )
152       );
153       break;
154     case _RegisterOrVerify::VERIFY:
155       c10::Dispatcher::singleton().waitForDef(schema);
156       break;
157   }
158   return *this;
159 }
160 #undef DEF_PRELUDE
161 
162 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_def(std::variant<c10::OperatorName,c10::FunctionSchema> && name_or_schema,CppFunction && f,const std::vector<at::Tag> & tags)163 Library& Library::_def(std::variant<c10::OperatorName, c10::FunctionSchema>&& name_or_schema, CppFunction&& f, const std::vector<at::Tag>& tags) & {
164   c10::FunctionSchema schema = [&] {
165     if (std::holds_alternative<c10::FunctionSchema>(name_or_schema)){
166       return std::get<c10::FunctionSchema>(std::move(name_or_schema));
167     } else {
168       // it's a name; use the inferred schema
169       c10::OperatorName name = std::get<c10::OperatorName>(std::move(name_or_schema));
170       TORCH_CHECK(f.schema_,
171         "def(\"", name, "\"): "
172         "Full schema string was not specified, and we couldn't infer schema either.  ",
173         "Please explicitly provide a schema string.  ",
174         ERROR_CONTEXT
175       );
176       c10::FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
177       s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
178       return s;
179     }
180   }();
181   c10::OperatorName name("", "");  // Get the namespaced name for the impl call
182   // First define the schema...
183   _def(std::move(schema), &name, tags);
184   // Then register the implementation...
185   auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
186   registrars_.emplace_back(
187     c10::Dispatcher::singleton().registerImpl(
188       std::move(name),
189       dispatch_key,
190       std::move(f.func_),
191       f.cpp_signature_,
192       std::move(f.schema_),
193       debugString(std::move(f.debug_), file_, line_)
194     )
195   );
196   return *this;
197 }
198 
199 #define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
_parseNameForLib(const char * name_str) const200 at::OperatorName Library::_parseNameForLib(const char* name_str) const {
201   auto name = torch::jit::parseName(name_str);
202   auto ns_opt = name.getNamespace();
203   // This is a copy paste of Library::_impl
204   if (ns_opt.has_value()) {
205     // See Note [Redundancy in registration code is OK]
206     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
207     TORCH_CHECK(*ns_opt == *ns_,
208       IMPL_PRELUDE,
209       "Explicitly provided namespace (", *ns_opt, ") in operator name "
210       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
211       "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, ").  "
212       "Move this definition to the ", toString(kind_), " block corresponding to this namespace "
213       "(and consider deleting the namespace from your schema string.)  ",
214       ERROR_CONTEXT
215     );
216   } else {
217     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
218     bool b = name.setNamespaceIfNotSet(ns_->c_str());
219     TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
220   }
221   return name;
222 }
223 
224 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_impl(const char * name_str,CppFunction && f,_RegisterOrVerify rv)225 Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
226   at::OperatorName name = _parseNameForLib(name_str);
227   // See Note [Redundancy in registration code is OK]
228   TORCH_CHECK(!(f.dispatch_key_.has_value() &&
229                 dispatch_key_.has_value() &&
230                 *f.dispatch_key_ != *dispatch_key_),
231     IMPL_PRELUDE,
232     "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
233     "with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, ").  "
234     "Please declare a separate ", toString(kind_), " block for this dispatch key and "
235     "move your impl() there.  "
236     ERROR_CONTEXT
237   );
238   auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
239   switch (rv) {
240     case _RegisterOrVerify::REGISTER:
241       registrars_.emplace_back(
242         c10::Dispatcher::singleton().registerImpl(
243           std::move(name),
244           dispatch_key,
245           std::move(f.func_),
246           f.cpp_signature_,
247           std::move(f.schema_),
248           debugString(std::move(f.debug_), file_, line_)
249         )
250       );
251       break;
252     case _RegisterOrVerify::VERIFY:
253       c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
254       break;
255   }
256   return *this;
257 }
258 
_resolve(const char * name_str) const259 c10::OperatorName Library::_resolve(const char* name_str) const {
260   return _parseNameForLib(name_str);
261 }
262 #undef IMPL_PRELUDE
263 
264 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_fallback(CppFunction && f)265 Library& Library::_fallback(CppFunction&& f) & {
266   TORCH_CHECK(kind_ == IMPL,
267     "fallback(...): Cannot define an operator inside of a ", toString(kind_), " block.  "
268     "Did you mean to call this function inside a TORCH_LIBRARY_IMPL block?  ",
269     ERROR_CONTEXT);
270   auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
271   TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
272   TORCH_CHECK(!ns_.has_value(),
273     "fallback(...): Fallback functions which apply to only a single namespace ",
274     "(you specified ", *ns_, ") are not supported.  If you intended to apply ",
275     "this fallback function globally, please define a separate block:\n\n",
276     "    TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
277     ERROR_CONTEXT);
278   // Note if dispatch_key is DispatchKey::Undefined, it'll be ignored here since Undefined
279   // isn't a runtime key, you shouldn't register anything to it at all.
280   for (auto k : c10::getRuntimeDispatchKeySet(*dispatch_key)) {
281     // mobile doesn't use all dispatch keys, so skip any fallback registrations for the unused keys.
282     auto idx = getDispatchTableIndexForDispatchKey(k);
283     if (idx < 0) continue;
284     registrars_.emplace_back(
285       c10::Dispatcher::singleton().registerFallback(
286         k,
287         f.func_,
288         debugString(f.debug_, file_, line_)
289       )
290     );
291   }
292   return *this;
293 }
294 
295 
296 } // namespace torch
297