1 #include <torch/library.h>
2
3 #include <ATen/core/dispatch/Dispatcher.h>
4
5 namespace torch {
6
7 namespace {
8 // TODO: Consider representing debug info as a struct instead so you
9 // don't have to allocate strings all the time
debugString(const char * file,uint32_t line)10 std::string debugString(const char* file, uint32_t line) {
11 #ifdef STRIP_ERROR_MESSAGES
12 return std::string();
13 #else
14 return c10::str("registered at ", file, ":", line);
15 #endif
16 }
17
debugString(std::string debug,const char * file,uint32_t line)18 std::string debugString(std::string debug, const char* file, uint32_t line) {
19 #ifdef STRIP_ERROR_MESSAGES
20 return std::string();
21 #else
22 if (debug.empty()) {
23 return debugString(file, line);
24 } else {
25 return debug;
26 }
27 #endif
28 }
29
30 #ifndef STRIP_ERROR_MESSAGES
toString(Library::Kind kind)31 const char* toString(Library::Kind kind) {
32 switch (kind) {
33 case Library::DEF:
34 return "TORCH_LIBRARY";
35 case Library::IMPL:
36 return "TORCH_LIBRARY_IMPL";
37 case Library::FRAGMENT:
38 return "TORCH_LIBRARY_FRAGMENT";
39 }
40 return "(unknown)";
41 }
42 #endif
43
44 constexpr auto CatchAll = c10::DispatchKey::CatchAll;
45 } // anonymous namespace
46
CppFunction(c10::KernelFunction func,std::optional<c10::impl::CppSignature> cpp_signature,std::unique_ptr<c10::FunctionSchema> schema)47 CppFunction::CppFunction(c10::KernelFunction func, std::optional<c10::impl::CppSignature> cpp_signature, std::unique_ptr<c10::FunctionSchema> schema)
48 : func_(std::move(func))
49 , cpp_signature_(cpp_signature)
50 , schema_(std::move(schema))
51 , debug_()
52 {}
53
54 CppFunction::~CppFunction() = default;
55
reset()56 void Library::reset() {
57 registrars_.clear();
58 }
59
60 #define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"
61
Library(Kind kind,std::string ns,std::optional<c10::DispatchKey> k,const char * file,uint32_t line)62 Library::Library(Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line)
63 : kind_(kind)
64 , ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns)))
65 , dispatch_key_(k.value_or(CatchAll) == CatchAll ? std::optional<c10::DispatchKey>() : k)
66 , file_(file)
67 , line_(line)
68 {
69 switch (kind_) {
70 case DEF:
71 // Only DEFs require library uniqueness; fragments
72 // don't register a library
73 registrars_.emplace_back(
74 c10::Dispatcher::singleton().registerLibrary(
75 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
76 *ns_, debugString(file_, line_)
77 )
78 );
79 [[fallthrough]];
80 case FRAGMENT:
81 TORCH_CHECK(
82 ns_.has_value(),
83 toString(kind_), ": cannot define ", toString(kind_), " with the wildcard namespace _ "
84 "(every ", toString(kind_), " defines operators for a distinct namespace!) "
85 "Did you mean to use TORCH_LIBRARY_IMPL instead? "
86 ERROR_CONTEXT
87 );
88 TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
89 break;
90 case IMPL:
91 // Nothing to do, everything is OK
92 break;
93 }
94 }
95
96 // TODO: Error if an operator is def'ed multiple times. Right now we just
97 // merge everything
98
99 #define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
_def(c10::FunctionSchema && schema,c10::OperatorName * out_name,const std::vector<at::Tag> & tags,_RegisterOrVerify rv)100 Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags, _RegisterOrVerify rv) & {
101 TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
102 DEF_PRELUDE,
103 "Cannot define an operator inside of a ", toString(kind_), " block. "
104 "All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
105 ERROR_CONTEXT
106 );
107 TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
108 TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
109 auto ns_opt = schema.getNamespace();
110 if (ns_opt.has_value()) {
111 // Note [Redundancy in registration code is OK]
112 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
113 // In an earlier version of this code, I made it an error to explicitly
114 // specify the namespace, even when the namespaces match. I've decided
115 // to relax this constraint because sometimes we code generate registrations
116 // and you cannot conveniently tell what the enclosing context will be;
117 // in these cases, it is simpler (and less error prone) to place all
118 // of the information in the registration site, which will be cross-checked
119 // in the end in any case (and if it turns out you DON'T have the right
120 // information at the site, as is the case with backend specific
121 // per-op registrations, you will get the right behavior!)
122 TORCH_CHECK(*ns_opt == *ns_,
123 "Explicitly provided namespace (", *ns_opt, ") in schema string "
124 "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
125 "Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
126 "(and consider deleting the namespace from your schema string.) ",
127 ERROR_CONTEXT
128 );
129 } else {
130 bool b = schema.setNamespaceIfNotSet(ns_->c_str());
131 TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
132 }
133 if (out_name) {
134 *out_name = schema.operator_name(); // copy!
135 }
136 switch (rv) {
137 case _RegisterOrVerify::REGISTER:
138 if (python_module_.has_value()) {
139 registrars_.emplace_back(
140 c10::Dispatcher::singleton().registerPythonModule(
141 schema.operator_name(),
142 python_module_->first,
143 python_module_->second)
144 );
145 }
146 registrars_.emplace_back(
147 c10::Dispatcher::singleton().registerDef(
148 std::move(schema),
149 debugString(file_, line_),
150 tags
151 )
152 );
153 break;
154 case _RegisterOrVerify::VERIFY:
155 c10::Dispatcher::singleton().waitForDef(schema);
156 break;
157 }
158 return *this;
159 }
160 #undef DEF_PRELUDE
161
162 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_def(std::variant<c10::OperatorName,c10::FunctionSchema> && name_or_schema,CppFunction && f,const std::vector<at::Tag> & tags)163 Library& Library::_def(std::variant<c10::OperatorName, c10::FunctionSchema>&& name_or_schema, CppFunction&& f, const std::vector<at::Tag>& tags) & {
164 c10::FunctionSchema schema = [&] {
165 if (std::holds_alternative<c10::FunctionSchema>(name_or_schema)){
166 return std::get<c10::FunctionSchema>(std::move(name_or_schema));
167 } else {
168 // it's a name; use the inferred schema
169 c10::OperatorName name = std::get<c10::OperatorName>(std::move(name_or_schema));
170 TORCH_CHECK(f.schema_,
171 "def(\"", name, "\"): "
172 "Full schema string was not specified, and we couldn't infer schema either. ",
173 "Please explicitly provide a schema string. ",
174 ERROR_CONTEXT
175 );
176 c10::FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
177 s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
178 return s;
179 }
180 }();
181 c10::OperatorName name("", ""); // Get the namespaced name for the impl call
182 // First define the schema...
183 _def(std::move(schema), &name, tags);
184 // Then register the implementation...
185 auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
186 registrars_.emplace_back(
187 c10::Dispatcher::singleton().registerImpl(
188 std::move(name),
189 dispatch_key,
190 std::move(f.func_),
191 f.cpp_signature_,
192 std::move(f.schema_),
193 debugString(std::move(f.debug_), file_, line_)
194 )
195 );
196 return *this;
197 }
198
199 #define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
_parseNameForLib(const char * name_str) const200 at::OperatorName Library::_parseNameForLib(const char* name_str) const {
201 auto name = torch::jit::parseName(name_str);
202 auto ns_opt = name.getNamespace();
203 // This is a copy paste of Library::_impl
204 if (ns_opt.has_value()) {
205 // See Note [Redundancy in registration code is OK]
206 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
207 TORCH_CHECK(*ns_opt == *ns_,
208 IMPL_PRELUDE,
209 "Explicitly provided namespace (", *ns_opt, ") in operator name "
210 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
211 "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
212 "Move this definition to the ", toString(kind_), " block corresponding to this namespace "
213 "(and consider deleting the namespace from your schema string.) ",
214 ERROR_CONTEXT
215 );
216 } else {
217 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
218 bool b = name.setNamespaceIfNotSet(ns_->c_str());
219 TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
220 }
221 return name;
222 }
223
224 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_impl(const char * name_str,CppFunction && f,_RegisterOrVerify rv)225 Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
226 at::OperatorName name = _parseNameForLib(name_str);
227 // See Note [Redundancy in registration code is OK]
228 TORCH_CHECK(!(f.dispatch_key_.has_value() &&
229 dispatch_key_.has_value() &&
230 *f.dispatch_key_ != *dispatch_key_),
231 IMPL_PRELUDE,
232 "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
233 "with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, "). "
234 "Please declare a separate ", toString(kind_), " block for this dispatch key and "
235 "move your impl() there. "
236 ERROR_CONTEXT
237 );
238 auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
239 switch (rv) {
240 case _RegisterOrVerify::REGISTER:
241 registrars_.emplace_back(
242 c10::Dispatcher::singleton().registerImpl(
243 std::move(name),
244 dispatch_key,
245 std::move(f.func_),
246 f.cpp_signature_,
247 std::move(f.schema_),
248 debugString(std::move(f.debug_), file_, line_)
249 )
250 );
251 break;
252 case _RegisterOrVerify::VERIFY:
253 c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
254 break;
255 }
256 return *this;
257 }
258
_resolve(const char * name_str) const259 c10::OperatorName Library::_resolve(const char* name_str) const {
260 return _parseNameForLib(name_str);
261 }
262 #undef IMPL_PRELUDE
263
264 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
_fallback(CppFunction && f)265 Library& Library::_fallback(CppFunction&& f) & {
266 TORCH_CHECK(kind_ == IMPL,
267 "fallback(...): Cannot define an operator inside of a ", toString(kind_), " block. "
268 "Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ",
269 ERROR_CONTEXT);
270 auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
271 TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
272 TORCH_CHECK(!ns_.has_value(),
273 "fallback(...): Fallback functions which apply to only a single namespace ",
274 "(you specified ", *ns_, ") are not supported. If you intended to apply ",
275 "this fallback function globally, please define a separate block:\n\n",
276 " TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
277 ERROR_CONTEXT);
278 // Note if dispatch_key is DispatchKey::Undefined, it'll be ignored here since Undefined
279 // isn't a runtime key, you shouldn't register anything to it at all.
280 for (auto k : c10::getRuntimeDispatchKeySet(*dispatch_key)) {
281 // mobile doesn't use all dispatch keys, so skip any fallback registrations for the unused keys.
282 auto idx = getDispatchTableIndexForDispatchKey(k);
283 if (idx < 0) continue;
284 registrars_.emplace_back(
285 c10::Dispatcher::singleton().registerFallback(
286 k,
287 f.func_,
288 debugString(f.debug_, file_, line_)
289 )
290 );
291 }
292 return *this;
293 }
294
295
296 } // namespace torch
297