1 #pragma once
2
3 /// \file
4 ///
5 /// This header provides an API for extending PyTorch's core library
6 /// of operators with user defined operators and data types. This
7 /// API can be used in a few ways:
8 ///
9 /// * You can define new custom operators and classes with TORCH_LIBRARY(),
10 /// making them available for use in both eager Python as well as in
11 /// TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
12 /// macro, as the provided functionality is similar (pybind11 lets you bind
13 /// C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
14 /// Python and TorchScript).
15 ///
16 /// * You can override existing operators with TORCH_LIBRARY_IMPL(),
17 /// providing a new implementation for these operators for a custom
18 /// backend (e.g., XLA). When you pass operators with tensors of your custom
19 /// backend, your overridden implementations will be called instead
20 /// of the standard implementations.
21 ///
22 /// * You can use both capabilities at the same time, allowing you
23 /// to write custom operators that register CPU/CUDA/Autograd
24 /// implementations without having to write the boilerplate
25 /// conditionals yourself.
26 ///
27 /// For a tutorial style introduction to the library API, check
28 /// out the [Extending TorchScript with Custom C++
29 /// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)
30 /// tutorial.
31 ///
32 /// ```
33 /// // Define a library whose operators live in the namespace 'myops'.
34 /// // You must define all of the operators for this library in
35 /// // this namespace.
36 /// TORCH_LIBRARY(myops, m) {
37 /// // Define a operator with exactly one implementation for all backends.
38 /// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
39 ///
40 /// // Define a schema for an operator, but provide no implementation
41 /// // (use this syntax if you want to use the dispatcher)
42 /// m.def("mul(Tensor self, Tensor other) -> Tensor");
43 ///
44 /// // Provide an implementation for a defined operator (you can
45 /// // provide multiple; one per backend). The dispatcher takes care of
46 /// // calling the correct implementation depending on if we get a CPU
47 /// // tensor or a CUDA tensor
48 /// m.impl("mul", torch::kCPU, &mul_cpu_impl);
49 /// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
50 /// }
51 ///
52 /// // Define implementations for operators for a non-standard backend,
53 /// // e.g., XLA (valid values are entries of DispatchKey). This can
54 /// // be used to define operators in a different file than the initial
55 /// // TORCH_LIBRARY definition (e.g., if it is in an external library)
56 /// TORCH_LIBRARY_IMPL(myops, XLA, m) {
57 /// m.impl("mul", &mul_xla_impl);
58 /// }
59 /// ```
60
61 #include <ATen/core/op_registration/infer_schema.h>
62 #include <ATen/core/op_registration/op_allowlist.h>
63 #include <ATen/core/dispatch/Dispatcher.h>
64 #include <c10/core/DispatchKey.h>
65 #include <torch/csrc/jit/frontend/function_schema_parser.h>
66
67 // Just for inferFunctionSchemaFromFunctor
68 #include <ATen/core/enum_tag.h>
69 #include <ATen/core/op_registration/op_registration.h>
70
71 namespace torch {
72
73 #if defined C10_MOBILE
74 /**
75 * The NoInferSchemaTag is a type name used to indicate that this call to the
76 * CppFunction constructor should not trigger schema inference from functor.
77 * Schema inference from functor utilizes template meta-programming, and is
78 * costly from a size perspective. Ideally, one would expect that the schema
79 * inference would require very little binary size since most of the
80 * computation can be done by the compiler at build time, but that isn't
81 * necessarily the case.
82 *
83 * Schema inference is elided only for mobile use-cases where we don't need
84 * the additional runtime cost or size overhead on client devices.
85 *
86 */
87 struct NoInferSchemaTag {};
88 #endif
89
90 #define HAS_PT2_COMPLIANT_TAG
91
92 // For multipy/torchdeploy use case
93 enum class _RegisterOrVerify { REGISTER, VERIFY };
94
95 template <class CurClass>
96 class class_;
97
98 #define HAS_IMPL_ABSTRACT_PYSTUB
99
100 /// Represents a C++ function that implements an operator. Most users won't
101 /// interact directly with this class, except via error messages: the
102 /// constructors this function define the set of permissible "function"-like
103 /// things you can bind via the interface.
104 ///
105 /// This class erases the type of the passed in function, but durably records
106 /// the type via an inferred schema for the function.
107 class TORCH_API CppFunction final {
108 // TODO: This is morally the same thing as KernelRegistrationConfig, but it's
109 // opaque to the user.
110
111 public:
112 /// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
113 template <typename Func>
114 explicit CppFunction(
115 Func* f,
116 std::enable_if_t<
117 c10::guts::is_function_type<Func>::value,
118 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction (f))119 : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
120 cpp_signature_(c10::impl::CppSignature::make<Func>()),
121 schema_(
122 c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
123 debug_() {}
124
125 /// This overload accepts compile time function pointers, e.g.,
126 /// `CppFunction(TORCH_FN(add_impl))`
127 template <typename FuncPtr>
128 explicit CppFunction(
129 FuncPtr f,
130 std::enable_if_t<
131 c10::is_compile_time_function_pointer<FuncPtr>::value,
132 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedFunction (f))133 : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
134 cpp_signature_(
135 c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
136 schema_(c10::detail::inferFunctionSchemaFromFunctor<
137 typename FuncPtr::FuncType>()),
138 debug_() {}
139
140 /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
141 /// ... })`
142 template <typename Lambda>
143 explicit CppFunction(
144 Lambda&& f,
145 std::enable_if_t<
146 c10::guts::is_functor<std::decay_t<Lambda>>::value,
147 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedLambda (std::forward<Lambda> (f)))148 : func_(c10::KernelFunction::makeFromUnboxedLambda(
149 std::forward<Lambda>(f))),
150 cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
151 schema_(c10::detail::inferFunctionSchemaFromFunctor<
152 std::decay_t<Lambda>>()),
153 debug_() {}
154
155 #if defined C10_MOBILE
156 /// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
157 /// NoInferSchemaTag())`
158 template <typename Func>
159 explicit CppFunction(
160 Func* f,
161 NoInferSchemaTag,
162 std::enable_if_t<
163 c10::guts::is_function_type<Func>::value,
164 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction (f))165 : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
166 cpp_signature_(c10::impl::CppSignature::make<Func>())
167 // TODO: Don't go through WrapRuntimeKernelFunctor
168 ,
169 schema_(nullptr),
170 debug_() {}
171
172 /// This overload accepts compile time function pointers, e.g.,
173 /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
174 template <typename FuncPtr>
175 explicit CppFunction(
176 FuncPtr f,
177 NoInferSchemaTag,
178 std::enable_if_t<
179 c10::is_compile_time_function_pointer<FuncPtr>::value,
180 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedFunction (f))181 : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
182 cpp_signature_(
183 c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
184 // TODO: Don't go through WrapRuntimeKernelFunctor
185 ,
186 schema_(nullptr),
187 debug_() {}
188
189 /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
190 /// ... }. NoInferSchemaTag())`
191 template <typename Lambda>
192 explicit CppFunction(
193 Lambda&& f,
194 NoInferSchemaTag,
195 std::enable_if_t<
196 c10::guts::is_functor<std::decay_t<Lambda>>::value,
197 std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedLambda (std::forward<Lambda> (f)))198 : func_(c10::KernelFunction::makeFromUnboxedLambda(
199 std::forward<Lambda>(f))),
200 cpp_signature_(c10::impl::CppSignature::make<Lambda>())
201 // TODO: Don't go through WrapRuntimeKernelFunctor
202 ,
203 schema_(nullptr),
204 debug_() {}
205 #endif
206
207 ~CppFunction();
208
209 CppFunction(CppFunction&&) noexcept = default;
210
211 CppFunction& operator=(CppFunction&&) = default;
212
213 /// \private
214 /// Creates a function from a type-erased boxed kernel.
makeFromBoxedKernel(c10::BoxedKernel kernel)215 static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
216 return CppFunction(
217 c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)),
218 /* cpp_signature */ std::nullopt, // not known for boxed functions
219 /* schema */ nullptr);
220 }
221
222 /// This creates a fallthrough function. Fallthrough functions
223 /// immediately redispatch to the next available dispatch key,
224 /// but are implemented more efficiently than a hand written
225 /// function done in the same way.
makeFallthrough()226 static CppFunction makeFallthrough() {
227 return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());
228 }
229
230 /// \private
231 ///
232 /// Creates a function that raises an error saying that named tensors
233 /// are not supported when called.
makeNamedNotSupported()234 static CppFunction makeNamedNotSupported() {
235 return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());
236 }
237
238 /// Create a function from a boxed kernel function with signature
239 /// `void(const OperatorHandle&, Stack*)`; i.e., they receive a
240 /// stack of arguments in a boxed calling convention, rather than
241 /// in the native C++ calling convention. Boxed functions are
242 /// typically only used to register backend fallbacks via
243 /// torch::Library::fallback().
244 template <c10::BoxedKernel::BoxedKernelFunction* func>
makeFromBoxedFunction()245 static CppFunction makeFromBoxedFunction() {
246 return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());
247 }
248
249 // Variant that takes in a boxed kernel function with a plumbed
250 // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
251 // details.
252 template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
makeFromBoxedFunction()253 static CppFunction makeFromBoxedFunction() {
254 return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());
255 }
256
257 /// Create a function from a boxed kernel functor which defines
258 /// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)`
259 /// (receiving arguments from boxed calling convention) and inherits
260 /// from `c10::OperatorKernel`. Unlike makeFromBoxedFunction, functions
261 /// registered in this way can also carry additional state which
262 /// is managed by the functor; this is useful if you're writing an
263 /// adapter to some other implementation, e.g., a Python callable, which
264 /// is dynamically associated with the registered kernel.
265 template <class KernelFunctor>
makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)266 static CppFunction makeFromBoxedFunctor(
267 std::unique_ptr<KernelFunctor> kernelFunctor) {
268 return makeFromBoxedKernel(
269 c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
270 }
271
272 /// Create a function from an unboxed kernel function.
273 /// This is typically used to register common operators.
274 template <
275 typename FuncPtr,
276 std::enable_if_t<
277 c10::guts::is_function_type<FuncPtr>::value,
278 std::nullptr_t> = nullptr>
makeFromUnboxedFunction(FuncPtr * f)279 static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
280 return CppFunction(f);
281 }
282
283 /// Create a function from a compile time unboxed kernel function pointer.
284 /// This is typically used to register common operators.
285 /// Compile time function pointers can be used to allow the compiler
286 /// to optimize (e.g. inline) calls to it.
287 template <
288 typename FuncPtr,
289 std::enable_if_t<
290 c10::is_compile_time_function_pointer<FuncPtr>::value,
291 std::nullptr_t> = nullptr>
makeFromUnboxedFunction(FuncPtr f)292 static CppFunction makeFromUnboxedFunction(FuncPtr f) {
293 return CppFunction(f);
294 }
295
debug(std::string d)296 CppFunction&& debug(std::string d) && {
297 debug_ = std::move(d);
298 return std::move(*this);
299 }
300
301 private:
302 std::optional<c10::DispatchKey> dispatch_key_;
303 c10::KernelFunction func_;
304 std::optional<c10::impl::CppSignature> cpp_signature_;
305 std::unique_ptr<c10::FunctionSchema> schema_;
306 std::string debug_;
307
308 // The "setter" for dispatch_key_
309 template <typename Func>
310 friend CppFunction dispatch(c10::DispatchKey, Func&&);
311
312 // The only class which actually pulls out values from CppFunction (does so
313 // destructively, felt too lazy to write accessors that I don't even
314 // want users to use)
315 friend class Library;
316
317 CppFunction(
318 c10::KernelFunction func,
319 std::optional<c10::impl::CppSignature> cpp_signature,
320 std::unique_ptr<c10::FunctionSchema> schema);
321 };
322
323 /// \defgroup torch-dispatch-overloads torch::dispatch overloads
324
325 /// Create a torch::CppFunction which is associated with a specific
326 /// dispatch key. torch::CppFunctions that are tagged with a
327 /// c10::DispatchKey don't get invoked unless the dispatcher determines
328 /// that this particular c10::DispatchKey is the one that should be
329 /// dispatched to.
330 ///
331 /// This function is generally not used directly, instead, prefer using
332 /// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey
333 /// for all registration calls inside of its body.
334 ///
335 /// \ingroup torch-dispatch-overloads
336 template <typename Func>
dispatch(c10::DispatchKey k,Func && raw_f)337 inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
338 CppFunction f(std::forward<Func>(raw_f));
339 if (k == c10::DispatchKey::CatchAll) {
340 f.dispatch_key_ = std::nullopt;
341 } else {
342 f.dispatch_key_ = k;
343 }
344 return f;
345 }
346
347 /// Convenience overload of dispatch() which accepts c10::DeviceType
348 ///
349 /// \ingroup torch-dispatch-overloads
350 template <typename Func>
dispatch(c10::DeviceType type,Func && raw_f)351 inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
352 auto deviceTypeToDispatchKey = [](c10::DeviceType t) {
353 switch (t) {
354 // This list is synchronized with the k-constants in c10/core/DeviceType.h
355 case c10::DeviceType::CPU:
356 return c10::DispatchKey::CPU;
357 case c10::DeviceType::CUDA:
358 return c10::DispatchKey::CUDA;
359 case c10::DeviceType::IPU:
360 return c10::DispatchKey::IPU;
361 case c10::DeviceType::XLA:
362 return c10::DispatchKey::XLA;
363 case c10::DeviceType::Lazy:
364 return c10::DispatchKey::Lazy;
365 case c10::DeviceType::XPU:
366 return c10::DispatchKey::XPU;
367 case c10::DeviceType::MPS:
368 return c10::DispatchKey::MPS;
369 case c10::DeviceType::Meta:
370 return c10::DispatchKey::Meta;
371 case c10::DeviceType::HIP:
372 return c10::DispatchKey::HIP;
373 case c10::DeviceType::MAIA:
374 return c10::DispatchKey::MAIA;
375 case c10::DeviceType::HPU:
376 return c10::DispatchKey::HPU;
377 case c10::DeviceType::MTIA:
378 return c10::DispatchKey::MTIA;
379 case c10::DeviceType::PrivateUse1:
380 return c10::DispatchKey::PrivateUse1;
381 default:
382 TORCH_CHECK(
383 false,
384 "Device type ",
385 t,
386 " cannot be overloaded at dispatch time, "
387 "please file a bug report explaining what you were trying to do.");
388 }
389 };
390 return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
391 }
392
393 /// \defgroup torch-schema-overloads torch::schema overloads
394
395 /// Construct a c10::FunctionSchema from a string, with an explicitly
396 /// specified c10::AliasAnalysisKind. Ordinarily, schemas are simply
397 /// passed in as strings, but if you need to specify a custom alias
398 /// analysis, you can replace the string with a call to this function.
399 ///
400 /// ```
401 /// // Default alias analysis (FROM_SCHEMA)
402 /// m.def("def3(Tensor self) -> Tensor");
403 /// // Pure function alias analysis
404 /// m.def(torch::schema("def3(Tensor self) -> Tensor",
405 /// c10::AliasAnalysisKind::PURE_FUNCTION));
406 /// ```
407 ///
408 /// \ingroup torch-schema-overloads
409 inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) {
410 c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars);
411 s.setAliasAnalysis(k);
412 return s;
413 }
414
415 /// Function schemas can be directly constructed from string literals.
416 ///
417 /// \ingroup torch-schema-overloads
418 inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) {
419 return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars);
420 }
421
422 /// \private
423 ///
424 /// Already constructed function schemas are accepted if they are
425 /// rvalues.
426 ///
427 /// \ingroup torch-schema-overloads
schema(c10::FunctionSchema && s)428 inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) {
429 return std::move(s);
430 }
431
432 namespace detail {
433
constructSchemaOrName(c10::FunctionSchema && s)434 inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
435 c10::FunctionSchema&& s) {
436 return std::move(s);
437 }
constructSchemaOrName(c10::OperatorName && n)438 inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
439 c10::OperatorName&& n) {
440 return std::move(n);
441 }
442 inline std::variant<c10::OperatorName, c10::FunctionSchema>
constructSchemaOrName(const char * str)443 constructSchemaOrName(const char* str) {
444 auto s = torch::jit::parseSchemaOrName(str);
445 if (std::holds_alternative<c10::FunctionSchema>(s)) {
446 std::get<c10::FunctionSchema>(s).setAliasAnalysis(
447 c10::AliasAnalysisKind::FROM_SCHEMA);
448 }
449 return s;
450 }
451
452 class TorchLibraryInit;
453
454 } // namespace detail
455
456 // Note [Selective build]
457 // ~~~~~~~~~~~~~~~~~~~~~~
458 // In some settings, especially mobile, it is important to avoid compiling any
459 // references to functions that you aren't actually going to use, so that they
460 // can be eliminated by the linker. We call this capability "selective build".
461 //
462 // A very easy way to implement selective build which results in a lot of
463 // boilerplate is to just add ifdef's around every registration call, but this
464 // means you have to write a lot of extra lines of code at every registration
465 // site, and it also means you have to define some munging scheme to map
466 // operators to macros.
467 //
468 // Instead of doing this, we have a different mechanism centered around the
469 // concept of a SelectiveStr. A selective name is like a const char* string,
470 // except it also carries at compile time a boolean saying whether or not a
471 // registration should actually happen or not. We then have extra overloads
472 // which bypass registration entirely if a selective name is disabled. We do a
473 // constexpr test to see if a operator should be enabled or not; this is
474 // currently implemented in ATen/core/op_registration/op_allowlist.h
475
476 namespace detail {
477
478 // dummy class for non selected custom torchbind classes
479 class ClassNotSelected {
480 public:
def_pickle(...)481 ClassNotSelected& def_pickle(...) {
482 return *this;
483 }
def(...)484 ClassNotSelected& def(...) {
485 return *this;
486 }
487 };
488
489 // A SelectiveStr is like a const char*, except that it also comes
490 // with a type brand that says whether or not the name is enabled or
491 // not. If the string is disabled, then (at compile time) we DON'T generate
492 // a registration call for it. This class is not intended to be called
493 // directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
494 // to create it.
495 template <bool enabled>
496 class SelectiveStr {
497 public:
SelectiveStr(const char * name)498 constexpr explicit SelectiveStr(const char* name) : name_(name) {}
499 constexpr operator const char*() {
500 return name_;
501 }
502
503 private:
504 const char* name_;
505 };
506
507 #define TORCH_SELECTIVE_CLASS(n) \
508 torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n)
509 #define TORCH_SELECTIVE_NAME(n) \
510 torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n)
511 #define TORCH_SELECTIVE_SCHEMA(n) \
512 torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n)
513
514 } // namespace detail
515
516 /// This object provides the API for defining operators and providing
517 /// implementations at dispatch keys. Typically, a torch::Library
518 /// is not allocated directly; instead it is created by the
519 /// TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() macros.
520 ///
521 /// Most methods on torch::Library return a reference to itself,
522 /// supporting method chaining.
523 ///
524 /// ```
525 /// // Examples:
526 ///
527 /// TORCH_LIBRARY(torchvision, m) {
528 /// // m is a torch::Library
529 /// m.def("roi_align", ...);
530 /// ...
531 /// }
532 ///
533 /// TORCH_LIBRARY_IMPL(aten, XLA, m) {
534 /// // m is a torch::Library
535 /// m.impl("add", ...);
536 /// ...
537 /// }
538 /// ```
539 ///
540 class TORCH_API Library final {
541 public:
542 /// \private
543 ///
544 /// Which type of macro produced this Library
545 enum Kind {
546 DEF, // from TORCH_LIBRARY (no qualifier)
547 IMPL,
548 FRAGMENT,
549 };
550
551 /// \private
552 ///
553 /// Use TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() instead of using these
554 /// constructors directly
555 Library(
556 Kind kind,
557 std::string ns,
558 std::optional<c10::DispatchKey> k,
559 const char* file,
560 uint32_t line);
561
562 Library(const Library&) = delete;
563 Library& operator=(const Library&) = delete;
564 Library(Library&&) = default;
565 Library& operator=(Library&&) = default;
566
567 // Some notes about the API design here. We had the following constraints:
568 //
569 // - We need to support multiple "types" of arguments for schema and
570 // functions (e.g., unnamed lambda types, regular functions, const char*,
571 // fully instantiated schemas)
572 // - We don't want to write exponentially many overloads
573 // - We don't want to rely on implicit conversion to a common type,
574 // because the C++ compiler will only be willing to do a single
575 // implicit conversion (reducing the set of valid types which you
576 // can invoke with); also error messages are worse when an implicit
577 // conversion is not selected (as the compiler will not explain
578 // why it didn't select an implicit conversion; this is different
579 // from overloads where it will explain each candidate overload and
580 // why it didn't apply)
581 //
582 // To solve all of these constraints at the same time, we use a trick taken
583 // from the pybind11 library: template over the argument in the user visible
584 // API, and inside of the templated function explicitly call an overloaded
585 // function to resolve the argument to a real type. You get the good error
586 // messages from overloads, but at the same time you only need to write the
587 // overload for any given argument type once.
588
589 /// Declare an operator with a schema, but don't provide any implementations
590 /// for it. You're expected to then provide implementations using the
591 /// impl() method. All template arguments are inferred.
592 ///
593 /// \param raw_schema The schema of the operator to be defined.
594 /// Typically, this is a `const char*` string literal, but any type
595 /// accepted by torch::schema() is accepted here.
596 ///
597 /// ```
598 /// // Example:
599 /// TORCH_LIBRARY(myops, m) {
600 /// m.def("add(Tensor self, Tensor other) -> Tensor");
601 /// }
602 /// ```
603
604 template <typename Schema>
605 Library& def(
606 Schema&& raw_schema,
607 const std::vector<at::Tag>& tags = {},
608 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
609 c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
610 return _def(std::move(s), nullptr, tags, rv);
611 }
612
613 /// Declares that for all operators that are subsequently def'ed, their
614 /// fake impls may be found in the given Python module (pymodule).
615 /// This registers some help text that is used if the fake impl
616 /// cannot be found.
617 ///
618 /// Args:
619 /// - pymodule: the python module
620 /// - context: We may include this in the error message.
621 Library& set_python_module(const char* pymodule, const char* context = "") {
622 python_module_ = {pymodule, context};
623 return *this;
624 }
625
626 /// Deprecated; use set_python_module instead
627 Library& impl_abstract_pystub(const char* pymodule, const char* context = "") {
628 return set_python_module(pymodule, context);
629 }
630
631 /// Define an operator for a schema and then register an implementation for
632 /// it. This is typically what you would use if you aren't planning
633 /// on making use of the dispatcher to structure your operator
634 /// implementation. It's roughly equivalent to calling def() and
635 /// then impl(), but if you omit the schema of the operator, we will
636 /// infer it from the type of your C++ function. All template
637 /// arguments are inferred.
638 ///
639 /// \param raw_name_or_schema The schema of the operator to be
640 /// defined, or just the name of the operator if the schema is to be
641 /// inferred from `raw_f`. Typically a `const char*` literal.
642 /// \param raw_f The C++ function that implements this operator.
643 /// Any valid constructor of torch::CppFunction is accepted here;
644 /// typically you provide a function pointer or lambda.
645 ///
646 /// ```
647 /// // Example:
648 /// TORCH_LIBRARY(myops, m) {
649 /// m.def("add", add_fn);
650 /// }
651 /// ```
652 template <typename NameOrSchema, typename Func>
653 Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f,
654 const std::vector<at::Tag>& tags = {}) & {
655 CppFunction f(std::forward<Func>(raw_f));
656 return _def(
657 detail::constructSchemaOrName(
658 ::std::forward<NameOrSchema>(raw_name_or_schema)),
659 ::std::move(f), tags);
660 }
661
662 /// Register an implementation for an operator. You may register multiple
663 /// implementations for a single operator at different dispatch keys
664 /// (see torch::dispatch()). Implementations must have a corresponding
665 /// declaration (from def()), otherwise they are invalid. If you plan
666 /// to register multiple implementations, DO NOT provide a function
667 /// implementation when you def() the operator.
668 ///
669 /// \param name The name of the operator to implement. Do NOT provide
670 /// schema here.
671 /// \param raw_f The C++ function that implements this operator. Any
672 /// valid constructor of torch::CppFunction is accepted here;
673 /// typically you provide a function pointer or lambda.
674 ///
675 /// ```
676 /// // Example:
677 /// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
678 /// m.impl("add", add_cuda);
679 /// }
680 /// ```
681 template <typename Name, typename Func>
682 Library& impl(
683 Name name,
684 Func&& raw_f,
685 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
686 // TODO: need to raise an error when you impl a function that has a
687 // catch all def
688 #if defined C10_MOBILE
689 CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
690 #else
691 CppFunction f(std::forward<Func>(raw_f));
692 #endif
693 return _impl(name, std::move(f), rv);
694 }
695
696 #if defined C10_MOBILE
697 // Note: This overload is needed only for C10_MOBILE, since the automatically
698 // defined copy constructor for the CppFunction doesn't have the additional
699 // NoInferSchemaTag argument. We define the overload for the impl() function
700 // to accept a CppFunction&& argument. The already constructed CppFunction
701 // object may or may not have the inferred schema, but it doesn't matter
702 // for our purposes since if it already has the inferred schema, then we
703 // might as well just pass it through directly.
704 //
705 template <typename Name>
impl(Name name,CppFunction && raw_f)706 Library& impl(Name name, CppFunction&& raw_f) & {
707 // TODO: need to raise an error when you impl a function that has a
708 // catch all def
709 CppFunction f(std::forward<CppFunction>(raw_f));
710 return _impl(name, std::move(f));
711 }
712 #endif
713
714 // Helper for getting an OperatorName for a const char*. You probably
715 // don't need this.
716 c10::OperatorName _resolve(const char* name) const;
717
718 /// \private
719 ///
720 /// Convenience overload for directly specifying the dispatch key when
721 /// impl(). You probably don't need this; instead, prefer specifying
722 /// the dispatch key for the entire block in TORCH_LIBRARY_IMPL()
723 template <typename Name, typename Dispatch, typename Func>
impl(Name name,Dispatch && key,Func && raw_f)724 Library& impl(Name name, Dispatch&& key, Func&& raw_f) & {
725 return impl(
726 name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
727 }
728
729 template <typename Name, typename Func>
impl_UNBOXED(Name,Func *)730 Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & {
731 static_assert(
732 c10::guts::false_t<Func>(),
733 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
734 return *this;
735 }
736
737 // These overloads cover cases when a SelectiveStr (see Note [Selective
738 // build]) has been disabled at compile time. In that case, don't generate
739 // any code referencing the passed in functions at all.
740 Library& def(detail::SelectiveStr<false>, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
741 return *this;
742 }
743 Library& def(detail::SelectiveStr<true> raw_schema, const std::vector<at::Tag>& tags = {}) & {
744 return def(raw_schema.operator const char*(), tags);
745 }
746 template <typename Func>
747 Library& def(detail::SelectiveStr<false>, Func&& /*raw_f*/, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
748 return *this;
749 }
750 template <typename Func>
751 Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f, const std::vector<at::Tag>& tags = {}) & {
752 return def(
753 raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f), tags);
754 }
755
756 template <typename Func>
757 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
impl(detail::SelectiveStr<false>,Func &&)758 Library& impl(detail::SelectiveStr<false>, Func&& /*raw_f*/) & {
759 return *this;
760 }
761 template <typename Dispatch, typename Func>
impl(detail::SelectiveStr<false>,Dispatch &&,Func &&)762 Library& impl(
763 detail::SelectiveStr<false>,
764 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
765 Dispatch&& /*key*/,
766 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
767 Func&& /*raw_f*/) & {
768 return *this;
769 }
770 template <typename Func>
impl_UNBOXED(detail::SelectiveStr<false>,Func *)771 Library& impl_UNBOXED(
772 detail::SelectiveStr<false> /*name*/,
773 Func* /*raw_f*/) & {
774 static_assert(
775 c10::guts::false_t<Func>(),
776 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
777 return *this;
778 }
779
780 template <typename Func>
impl(detail::SelectiveStr<true> name,Func && raw_f)781 Library& impl(detail::SelectiveStr<true> name, Func&& raw_f) & {
782 return impl(name.operator const char*(), std::forward<Func>(raw_f));
783 }
784 template <typename Dispatch, typename Func>
impl(detail::SelectiveStr<true> name,Dispatch && key,Func && raw_f)785 Library& impl(
786 detail::SelectiveStr<true> name,
787 Dispatch&& key,
788 Func&& raw_f) & {
789 return impl(
790 name.operator const char*(),
791 std::forward<Dispatch>(key),
792 std::forward<Func>(raw_f));
793 }
794 template <typename Func>
impl_UNBOXED(detail::SelectiveStr<true>,Func *)795 Library& impl_UNBOXED(
796 detail::SelectiveStr<true> /*name*/,
797 Func* /*raw_f*/) & {
798 static_assert(
799 c10::guts::false_t<Func>(),
800 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
801 return *this;
802 }
803
804 /// Register a fallback implementation for all operators which will be used
805 /// if there is not a specific implementation for an operator available.
806 /// There MUST be a DispatchKey associated with a fallback; e.g.,
807 /// only call this from TORCH_LIBRARY_IMPL() with namespace `_`.
808 ///
809 /// \param raw_f The function that implements the fallback. Unboxed
810 /// functions typically do not work as fallback functions, as
811 /// fallback functions must work for every operator (even though
812 /// they have varying type signatures). Typical arguments are
813 /// CppFunction::makeFallthrough() or
814 /// CppFunction::makeFromBoxedFunction()
815 ///
816 /// ```
817 /// // Example:
818 ///
819 /// TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
820 /// // If there is not a kernel explicitly registered
821 /// // for AutogradXLA, fallthrough to the next
822 /// // available kernel
823 /// m.fallback(torch::CppFunction::makeFallthrough());
824 /// }
825 ///
826 /// // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp
827 /// // for a full example of boxed fallback
828 /// ```
829 template <typename Func>
fallback(Func && raw_f)830 Library& fallback(Func&& raw_f) & {
831 CppFunction f((std::forward<Func>(raw_f)));
832 return _fallback(std::move(f));
833 }
834
835 template <class CurClass>
836 inline torch::class_<CurClass> class_(const std::string& className);
837
838 // These overloads enable the use of selective build on classes registered
839 // within a library. The API is the same as before with 1 minor change.
840 // Instead of m.class_<foo>("foo") you instead do
841 // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
842 template <class CurClass>
843 inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className);
844
845 template <class CurClass>
846 inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
847
848 // De-registers all registrations created with this Library
849 void reset();
850
851 private:
852 Kind kind_;
853 std::optional<std::string> ns_;
854 std::optional<c10::DispatchKey> dispatch_key_;
855 std::optional<std::pair<const char*, const char*>> python_module_;
856 const char* file_;
857 uint32_t line_;
858
859 std::vector<c10::RegistrationHandleRAII> registrars_;
860
861 friend class detail::TorchLibraryInit;
862
863 // Non-user visible actual implementations of functions. These aren't
864 // public because we only implement & qualifier and not && qualifier
865 Library& _def(
866 c10::FunctionSchema&& schema,
867 c10::OperatorName* out_name = nullptr,
868 const std::vector<at::Tag>& tags = {},
869 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
870 Library& _def(
871 std::variant<c10::OperatorName, c10::FunctionSchema>&&,
872 CppFunction&& f,
873 const std::vector<at::Tag>& tags = {}) &;
874 Library& _impl(
875 const char* name,
876 CppFunction&& f,
877 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
878 Library& _fallback(CppFunction&& f) &;
879
880 at::OperatorName _parseNameForLib(const char* name_str) const;
881 };
882
883 namespace detail {
884
885 class TorchLibraryInit final {
886 private:
887 using InitFn = void(Library&);
888 Library lib_;
889
890 public:
TorchLibraryInit(Library::Kind kind,InitFn * fn,const char * ns,std::optional<c10::DispatchKey> k,const char * file,uint32_t line)891 TorchLibraryInit(
892 Library::Kind kind,
893 InitFn* fn,
894 const char* ns,
895 std::optional<c10::DispatchKey> k,
896 const char* file,
897 uint32_t line)
898 : lib_(kind, ns, k, file, line) {
899 fn(lib_);
900 }
901 };
902
903 } // namespace detail
904
905 } // namespace torch
906
907 // NB: The EXACT NAMING of the initializer functions (e.g.,
908 // TORCH_LIBRARY_init_aten) matters for the code analyzer;
909 // see the regexes at tools/code_analyzer/run_analyzer.sh
910
911 /// Macro for defining a function that will be run at static
912 /// initialization time to define a library of operators in the
913 /// namespace `ns` (must be a valid C++ identifier, no quotes).
914 /// Use this macro when you want to define a new set of custom operators
915 /// that do not already exist in PyTorch.
916 ///
917 /// Example usage:
918 ///
919 /// ```
920 /// TORCH_LIBRARY(myops, m) {
921 /// // m is a torch::Library; methods on it will define
922 /// // operators in the myops namespace
923 /// m.def("add", add_impl);
924 /// }
925 /// ```
926 ///
927 /// The `m` argument is bound to a torch::Library that is used to
928 /// register operators. There may only be one TORCH_LIBRARY()
929 /// for any given namespace.
930 #define TORCH_LIBRARY(ns, m) \
931 static void TORCH_LIBRARY_init_##ns(torch::Library&); \
932 static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \
933 torch::Library::DEF, \
934 &TORCH_LIBRARY_init_##ns, \
935 #ns, \
936 std::nullopt, \
937 __FILE__, \
938 __LINE__); \
939 void TORCH_LIBRARY_init_##ns(torch::Library& m)
940
941 /// \private
942 ///
943 /// This macro is a version of TORCH_LIBRARY() that doesn't enforce that there
944 /// is only one library (it is a "fragment"). This is used inside the
945 /// PerOpRegistration.cpp file, as well as in places where all op registrations
946 /// within the same namespace cannot be easily put into one macro block
947 /// (this is mostly the case for custom ops in fbcode that were ported from
948 /// the old API)
949 #define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID)
950
951 /// \private
952 ///
953 /// The above macro requires an extra unique identifier (uid) to prevent
954 /// variable name collisions This can happen if TORCH_LIBRARY_FRAGMENT is called
955 /// multiple times with the same namespace in the same translation unit. Note
956 /// that the TORCH_LIBRARY variant doesn't run into this problem, because it
957 /// enforces that it can only be called once for a given namespace.
958 #define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
959 static void C10_CONCATENATE( \
960 TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \
961 static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
962 TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
963 torch::Library::FRAGMENT, \
964 &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
965 #ns, \
966 std::nullopt, \
967 __FILE__, \
968 __LINE__); \
969 void C10_CONCATENATE( \
970 TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m)
971
972 /// Macro for defining a function that will be run at static
973 /// initialization time to define operator overrides for dispatch key
974 /// `k` (must be an unqualified enum member of c10::DispatchKey) in
975 /// namespace `ns` (must be a valid C++ identifer, no quotes). Use this
976 /// macro when you want to implement a preexisting set of custom
977 /// operators on a new dispatch key (e.g., you want to provide CUDA
978 /// implementations of already existing operators). One common usage
979 /// pattern is to use TORCH_LIBRARY() to define schema for all new
980 /// operators you want to define, and then use several
981 /// TORCH_LIBRARY_IMPL() blocks to provide implementations of the
982 /// operator for CPU, CUDA and Autograd.
983 ///
984 /// In some cases, you need to define something that applies to all namespaces,
985 /// not just one namespace (usually a fallback). In that case, use the reserved
986 /// namespace _, e.g.,
987 ///
988 /// ```
989 /// TORCH_LIBRARY_IMPL(_, XLA, m) {
990 /// m.fallback(xla_fallback);
991 /// }
992 /// ```
993 ///
994 /// Example usage:
995 ///
996 /// ```
997 /// TORCH_LIBRARY_IMPL(myops, CPU, m) {
998 /// // m is a torch::Library; methods on it will define
999 /// // CPU implementations of operators in the myops namespace.
1000 /// // It is NOT valid to call torch::Library::def()
1001 /// // in this context.
1002 /// m.impl("add", add_cpu_impl);
1003 /// }
1004 /// ```
1005 ///
1006 /// If ``add_cpu_impl`` is an overloaded function, use a
1007 /// ``static_cast`` to specify which overload you want
1008 /// (by providing the full type).
1009 ///
1010 // NB: if the dispatch key is not whitelisted, we simply omit the Library
1011 // call entirely
1012 #define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)
1013
1014 /// \private
1015 ///
1016 /// The above macro requires an extra unique identifier (uid) to prevent
1017 /// variable name collisions. This can happen if TORCH_LIBRARY_IMPL is called
1018 /// multiple times with the same namespace and dispatch key in the same
1019 /// translation unit.
1020 #define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \
1021 static void C10_CONCATENATE( \
1022 TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \
1023 static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
1024 TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
1025 torch::Library::IMPL, \
1026 (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::k) \
1027 ? &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid) \
1028 : [](torch::Library&) -> void {}), \
1029 #ns, \
1030 std::make_optional(c10::DispatchKey::k), \
1031 __FILE__, \
1032 __LINE__); \
1033 void C10_CONCATENATE( \
1034 TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)
1035
1036 // These are variants of the macros above which are to be used for testing (they
1037 // don't setup the static initializer, so you can control the visibility of
1038 // the allocated library yourself).
1039 //
1040 // DO NOT use these in production code, they are NOT understood by the
1041 // code analyzer and will be incorrectly analyzed in those situations.
1042
1043 /// \private
1044 #define MAKE_TORCH_LIBRARY(ns) \
1045 torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__)
1046 /// \private
1047 #define MAKE_TORCH_LIBRARY_IMPL(ns, k) \
1048 torch::Library( \
1049 torch::Library::IMPL, \
1050 #ns, \
1051 std::make_optional(c10::DispatchKey::k), \
1052 __FILE__, \
1053 __LINE__)
1054
1055 // Make the custom class API visible, so it is available from
1056 // torch::Library.
1057
1058 #include <torch/custom_class.h>
1059