xref: /aosp_15_r20/external/pytorch/torch/library.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /// \file
4 ///
5 /// This header provides an API for extending PyTorch's core library
6 /// of operators with user defined operators and data types.  This
7 /// API can be used in a few ways:
8 ///
9 /// * You can define new custom operators and classes with TORCH_LIBRARY(),
10 ///   making them available for use in both eager Python as well as in
11 ///   TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
12 ///   macro, as the provided functionality is similar (pybind11 lets you bind
13 ///   C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
14 ///   Python and TorchScript).
15 ///
16 /// * You can override existing operators with TORCH_LIBRARY_IMPL(),
17 ///   providing a new implementation for these operators for a custom
18 ///   backend (e.g., XLA).  When you pass operators with tensors of your custom
19 ///   backend, your overridden implementations will be called instead
20 ///   of the standard implementations.
21 ///
22 /// * You can use both capabilities at the same time, allowing you
23 ///   to write custom operators that register CPU/CUDA/Autograd
24 ///   implementations without having to write the boilerplate
25 ///   conditionals yourself.
26 ///
27 /// For a tutorial style introduction to the library API, check
28 /// out the [Extending TorchScript with Custom C++
29 /// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)
30 /// tutorial.
31 ///
32 /// ```
33 /// // Define a library whose operators live in the namespace 'myops'.
34 /// // You must define all of the operators for this library in
35 /// // this namespace.
36 /// TORCH_LIBRARY(myops, m) {
37 ///   // Define a operator with exactly one implementation for all backends.
38 ///   m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
39 ///
40 ///   // Define a schema for an operator, but provide no implementation
41 ///   // (use this syntax if you want to use the dispatcher)
42 ///   m.def("mul(Tensor self, Tensor other) -> Tensor");
43 ///
44 ///   // Provide an implementation for a defined operator (you can
45 ///   // provide multiple; one per backend).  The dispatcher takes care of
46 ///   // calling the correct implementation depending on if we get a CPU
47 ///   // tensor or a CUDA tensor
48 ///   m.impl("mul", torch::kCPU, &mul_cpu_impl);
49 ///   m.impl("mul", torch::kCUDA, &mul_cuda_impl);
50 /// }
51 ///
52 /// // Define implementations for operators for a non-standard backend,
53 /// // e.g., XLA (valid values are entries of DispatchKey).  This can
54 /// // be used to define operators in a different file than the initial
55 /// // TORCH_LIBRARY definition (e.g., if it is in an external library)
56 /// TORCH_LIBRARY_IMPL(myops, XLA, m) {
57 ///   m.impl("mul", &mul_xla_impl);
58 /// }
59 /// ```
60 
61 #include <ATen/core/op_registration/infer_schema.h>
62 #include <ATen/core/op_registration/op_allowlist.h>
63 #include <ATen/core/dispatch/Dispatcher.h>
64 #include <c10/core/DispatchKey.h>
65 #include <torch/csrc/jit/frontend/function_schema_parser.h>
66 
67 // Just for inferFunctionSchemaFromFunctor
68 #include <ATen/core/enum_tag.h>
69 #include <ATen/core/op_registration/op_registration.h>
70 
71 namespace torch {
72 
73 #if defined C10_MOBILE
74 /**
75  * The NoInferSchemaTag is a type name used to indicate that this call to the
76  * CppFunction constructor should not trigger schema inference from functor.
77  * Schema inference from functor utilizes template meta-programming, and is
78  * costly from a size perspective. Ideally, one would expect that the schema
79  * inference would require very little binary size since most of the
80  * computation can be done by the compiler at build time, but that isn't
81  * necessarily the case.
82  *
83  * Schema inference is elided only for mobile use-cases where we don't need
84  * the additional runtime cost or size overhead on client devices.
85  *
86  */
87 struct NoInferSchemaTag {};
88 #endif
89 
90 #define HAS_PT2_COMPLIANT_TAG
91 
92 // For multipy/torchdeploy use case
93 enum class _RegisterOrVerify { REGISTER, VERIFY };
94 
95 template <class CurClass>
96 class class_;
97 
98 #define HAS_IMPL_ABSTRACT_PYSTUB
99 
100 /// Represents a C++ function that implements an operator.  Most users won't
101 /// interact directly with this class, except via error messages: the
102 /// constructors this function define the set of permissible "function"-like
103 /// things you can bind via the interface.
104 ///
105 /// This class erases the type of the passed in function, but durably records
106 /// the type via an inferred schema for the function.
107 class TORCH_API CppFunction final {
108   // TODO: This is morally the same thing as KernelRegistrationConfig, but it's
109   // opaque to the user.
110 
111  public:
112   /// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
113   template <typename Func>
114   explicit CppFunction(
115       Func* f,
116       std::enable_if_t<
117           c10::guts::is_function_type<Func>::value,
118           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction (f))119       : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
120         cpp_signature_(c10::impl::CppSignature::make<Func>()),
121         schema_(
122             c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
123         debug_() {}
124 
125   /// This overload accepts compile time function pointers, e.g.,
126   /// `CppFunction(TORCH_FN(add_impl))`
127   template <typename FuncPtr>
128   explicit CppFunction(
129       FuncPtr f,
130       std::enable_if_t<
131           c10::is_compile_time_function_pointer<FuncPtr>::value,
132           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedFunction (f))133       : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
134         cpp_signature_(
135             c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
136         schema_(c10::detail::inferFunctionSchemaFromFunctor<
137                 typename FuncPtr::FuncType>()),
138         debug_() {}
139 
140   /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
141   /// ... })`
142   template <typename Lambda>
143   explicit CppFunction(
144       Lambda&& f,
145       std::enable_if_t<
146           c10::guts::is_functor<std::decay_t<Lambda>>::value,
147           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedLambda (std::forward<Lambda> (f)))148       : func_(c10::KernelFunction::makeFromUnboxedLambda(
149             std::forward<Lambda>(f))),
150         cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
151         schema_(c10::detail::inferFunctionSchemaFromFunctor<
152                 std::decay_t<Lambda>>()),
153         debug_() {}
154 
155 #if defined C10_MOBILE
156   /// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
157   /// NoInferSchemaTag())`
158   template <typename Func>
159   explicit CppFunction(
160       Func* f,
161       NoInferSchemaTag,
162       std::enable_if_t<
163           c10::guts::is_function_type<Func>::value,
164           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction (f))165       : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
166         cpp_signature_(c10::impl::CppSignature::make<Func>())
167         // TODO: Don't go through WrapRuntimeKernelFunctor
168         ,
169         schema_(nullptr),
170         debug_() {}
171 
172   /// This overload accepts compile time function pointers, e.g.,
173   /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
174   template <typename FuncPtr>
175   explicit CppFunction(
176       FuncPtr f,
177       NoInferSchemaTag,
178       std::enable_if_t<
179           c10::is_compile_time_function_pointer<FuncPtr>::value,
180           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedFunction (f))181       : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
182         cpp_signature_(
183             c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
184         // TODO: Don't go through WrapRuntimeKernelFunctor
185         ,
186         schema_(nullptr),
187         debug_() {}
188 
189   /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
190   /// ... }. NoInferSchemaTag())`
191   template <typename Lambda>
192   explicit CppFunction(
193       Lambda&& f,
194       NoInferSchemaTag,
195       std::enable_if_t<
196           c10::guts::is_functor<std::decay_t<Lambda>>::value,
197           std::nullptr_t> = nullptr)
func_(c10::KernelFunction::makeFromUnboxedLambda (std::forward<Lambda> (f)))198       : func_(c10::KernelFunction::makeFromUnboxedLambda(
199             std::forward<Lambda>(f))),
200         cpp_signature_(c10::impl::CppSignature::make<Lambda>())
201         // TODO: Don't go through WrapRuntimeKernelFunctor
202         ,
203         schema_(nullptr),
204         debug_() {}
205 #endif
206 
207   ~CppFunction();
208 
209   CppFunction(CppFunction&&) noexcept = default;
210 
211   CppFunction& operator=(CppFunction&&) = default;
212 
213   /// \private
214   /// Creates a function from a type-erased boxed kernel.
makeFromBoxedKernel(c10::BoxedKernel kernel)215   static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
216     return CppFunction(
217         c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)),
218         /* cpp_signature */ std::nullopt, // not known for boxed functions
219         /* schema */ nullptr);
220   }
221 
222   /// This creates a fallthrough function.  Fallthrough functions
223   /// immediately redispatch to the next available dispatch key,
224   /// but are implemented more efficiently than a hand written
225   /// function done in the same way.
makeFallthrough()226   static CppFunction makeFallthrough() {
227     return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());
228   }
229 
230   /// \private
231   ///
232   /// Creates a function that raises an error saying that named tensors
233   /// are not supported when called.
makeNamedNotSupported()234   static CppFunction makeNamedNotSupported() {
235     return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());
236   }
237 
238   /// Create a function from a boxed kernel function with signature
239   /// `void(const OperatorHandle&, Stack*)`; i.e., they receive a
240   /// stack of arguments in a boxed calling convention, rather than
241   /// in the native C++ calling convention.  Boxed functions are
242   /// typically only used to register backend fallbacks via
243   /// torch::Library::fallback().
244   template <c10::BoxedKernel::BoxedKernelFunction* func>
makeFromBoxedFunction()245   static CppFunction makeFromBoxedFunction() {
246     return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());
247   }
248 
249   // Variant that takes in a boxed kernel function with a plumbed
250   // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
251   // details.
252   template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
makeFromBoxedFunction()253   static CppFunction makeFromBoxedFunction() {
254     return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());
255   }
256 
257   /// Create a function from a boxed kernel functor which defines
258   /// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)`
259   /// (receiving arguments from boxed calling convention) and inherits
260   /// from `c10::OperatorKernel`.  Unlike makeFromBoxedFunction, functions
261   /// registered in this way can also carry additional state which
262   /// is managed by the functor; this is useful if you're writing an
263   /// adapter to some other implementation, e.g., a Python callable, which
264   /// is dynamically associated with the registered kernel.
265   template <class KernelFunctor>
makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)266   static CppFunction makeFromBoxedFunctor(
267       std::unique_ptr<KernelFunctor> kernelFunctor) {
268     return makeFromBoxedKernel(
269         c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
270   }
271 
272   /// Create a function from an unboxed kernel function.
273   /// This is typically used to register common operators.
274   template <
275       typename FuncPtr,
276       std::enable_if_t<
277           c10::guts::is_function_type<FuncPtr>::value,
278           std::nullptr_t> = nullptr>
makeFromUnboxedFunction(FuncPtr * f)279   static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
280     return CppFunction(f);
281   }
282 
283   /// Create a function from a compile time unboxed kernel function pointer.
284   /// This is typically used to register common operators.
285   /// Compile time function pointers can be used to allow the compiler
286   /// to optimize (e.g. inline) calls to it.
287   template <
288       typename FuncPtr,
289       std::enable_if_t<
290           c10::is_compile_time_function_pointer<FuncPtr>::value,
291           std::nullptr_t> = nullptr>
makeFromUnboxedFunction(FuncPtr f)292   static CppFunction makeFromUnboxedFunction(FuncPtr f) {
293     return CppFunction(f);
294   }
295 
debug(std::string d)296   CppFunction&& debug(std::string d) && {
297     debug_ = std::move(d);
298     return std::move(*this);
299   }
300 
301  private:
302   std::optional<c10::DispatchKey> dispatch_key_;
303   c10::KernelFunction func_;
304   std::optional<c10::impl::CppSignature> cpp_signature_;
305   std::unique_ptr<c10::FunctionSchema> schema_;
306   std::string debug_;
307 
308   // The "setter" for dispatch_key_
309   template <typename Func>
310   friend CppFunction dispatch(c10::DispatchKey, Func&&);
311 
312   // The only class which actually pulls out values from CppFunction (does so
313   // destructively, felt too lazy to write accessors that I don't even
314   // want users to use)
315   friend class Library;
316 
317   CppFunction(
318       c10::KernelFunction func,
319       std::optional<c10::impl::CppSignature> cpp_signature,
320       std::unique_ptr<c10::FunctionSchema> schema);
321 };
322 
323 /// \defgroup torch-dispatch-overloads torch::dispatch overloads
324 
325 /// Create a torch::CppFunction which is associated with a specific
326 /// dispatch key.  torch::CppFunctions that are tagged with a
327 /// c10::DispatchKey don't get invoked unless the dispatcher determines
328 /// that this particular c10::DispatchKey is the one that should be
329 /// dispatched to.
330 ///
331 /// This function is generally not used directly, instead, prefer using
332 /// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey
333 /// for all registration calls inside of its body.
334 ///
335 /// \ingroup torch-dispatch-overloads
336 template <typename Func>
dispatch(c10::DispatchKey k,Func && raw_f)337 inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
338   CppFunction f(std::forward<Func>(raw_f));
339   if (k == c10::DispatchKey::CatchAll) {
340     f.dispatch_key_ = std::nullopt;
341   } else {
342     f.dispatch_key_ = k;
343   }
344   return f;
345 }
346 
347 /// Convenience overload of dispatch() which accepts c10::DeviceType
348 ///
349 /// \ingroup torch-dispatch-overloads
350 template <typename Func>
dispatch(c10::DeviceType type,Func && raw_f)351 inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
352   auto deviceTypeToDispatchKey = [](c10::DeviceType t) {
353     switch (t) {
354       // This list is synchronized with the k-constants in c10/core/DeviceType.h
355       case c10::DeviceType::CPU:
356         return c10::DispatchKey::CPU;
357       case c10::DeviceType::CUDA:
358         return c10::DispatchKey::CUDA;
359       case c10::DeviceType::IPU:
360         return c10::DispatchKey::IPU;
361       case c10::DeviceType::XLA:
362         return c10::DispatchKey::XLA;
363       case c10::DeviceType::Lazy:
364         return c10::DispatchKey::Lazy;
365       case c10::DeviceType::XPU:
366         return c10::DispatchKey::XPU;
367       case c10::DeviceType::MPS:
368         return c10::DispatchKey::MPS;
369       case c10::DeviceType::Meta:
370         return c10::DispatchKey::Meta;
371       case c10::DeviceType::HIP:
372         return c10::DispatchKey::HIP;
373       case c10::DeviceType::MAIA:
374         return c10::DispatchKey::MAIA;
375       case c10::DeviceType::HPU:
376         return c10::DispatchKey::HPU;
377       case c10::DeviceType::MTIA:
378         return c10::DispatchKey::MTIA;
379       case c10::DeviceType::PrivateUse1:
380         return c10::DispatchKey::PrivateUse1;
381       default:
382         TORCH_CHECK(
383             false,
384             "Device type ",
385             t,
386             " cannot be overloaded at dispatch time, "
387             "please file a bug report explaining what you were trying to do.");
388     }
389   };
390   return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
391 }
392 
393 /// \defgroup torch-schema-overloads torch::schema overloads
394 
395 /// Construct a c10::FunctionSchema from a string, with an explicitly
396 /// specified c10::AliasAnalysisKind.  Ordinarily, schemas are simply
397 /// passed in as strings, but if you need to specify a custom alias
398 /// analysis, you can replace the string with a call to this function.
399 ///
400 /// ```
401 /// // Default alias analysis (FROM_SCHEMA)
402 /// m.def("def3(Tensor self) -> Tensor");
403 /// // Pure function alias analysis
404 /// m.def(torch::schema("def3(Tensor self) -> Tensor",
405 /// c10::AliasAnalysisKind::PURE_FUNCTION));
406 /// ```
407 ///
408 /// \ingroup torch-schema-overloads
409 inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) {
410   c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars);
411   s.setAliasAnalysis(k);
412   return s;
413 }
414 
415 /// Function schemas can be directly constructed from string literals.
416 ///
417 /// \ingroup torch-schema-overloads
418 inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) {
419   return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars);
420 }
421 
422 /// \private
423 ///
424 /// Already constructed function schemas are accepted if they are
425 /// rvalues.
426 ///
427 /// \ingroup torch-schema-overloads
schema(c10::FunctionSchema && s)428 inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) {
429   return std::move(s);
430 }
431 
432 namespace detail {
433 
constructSchemaOrName(c10::FunctionSchema && s)434 inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
435     c10::FunctionSchema&& s) {
436   return std::move(s);
437 }
constructSchemaOrName(c10::OperatorName && n)438 inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
439     c10::OperatorName&& n) {
440   return std::move(n);
441 }
442 inline std::variant<c10::OperatorName, c10::FunctionSchema>
constructSchemaOrName(const char * str)443 constructSchemaOrName(const char* str) {
444   auto s = torch::jit::parseSchemaOrName(str);
445   if (std::holds_alternative<c10::FunctionSchema>(s)) {
446     std::get<c10::FunctionSchema>(s).setAliasAnalysis(
447         c10::AliasAnalysisKind::FROM_SCHEMA);
448   }
449   return s;
450 }
451 
452 class TorchLibraryInit;
453 
454 } // namespace detail
455 
456 // Note [Selective build]
457 // ~~~~~~~~~~~~~~~~~~~~~~
458 // In some settings, especially mobile, it is important to avoid compiling any
459 // references to functions that you aren't actually going to use, so that they
460 // can be eliminated by the linker.  We call this capability "selective build".
461 //
462 // A very easy way to implement selective build which results in a lot of
463 // boilerplate is to just add ifdef's around every registration call, but this
464 // means you have to write a lot of extra lines of code at every registration
465 // site, and it also means you have to define some munging scheme to map
466 // operators to macros.
467 //
468 // Instead of doing this, we have a different mechanism centered around the
469 // concept of a SelectiveStr.  A selective name is like a const char* string,
470 // except it also carries at compile time a boolean saying whether or not a
471 // registration should actually happen or not.  We then have extra overloads
472 // which bypass registration entirely if a selective name is disabled.  We do a
473 // constexpr test to see if a operator should be enabled or not; this is
474 // currently implemented in ATen/core/op_registration/op_allowlist.h
475 
476 namespace detail {
477 
478 // dummy class for non selected custom torchbind classes
479 class ClassNotSelected {
480  public:
def_pickle(...)481   ClassNotSelected& def_pickle(...) {
482     return *this;
483   }
def(...)484   ClassNotSelected& def(...) {
485     return *this;
486   }
487 };
488 
489 // A SelectiveStr is like a const char*, except that it also comes
490 // with a type brand that says whether or not the name is enabled or
491 // not.  If the string is disabled, then (at compile time) we DON'T generate
492 // a registration call for it.  This class is not intended to be called
493 // directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
494 // to create it.
495 template <bool enabled>
496 class SelectiveStr {
497  public:
SelectiveStr(const char * name)498   constexpr explicit SelectiveStr(const char* name) : name_(name) {}
499   constexpr operator const char*() {
500     return name_;
501   }
502 
503  private:
504   const char* name_;
505 };
506 
507 #define TORCH_SELECTIVE_CLASS(n) \
508   torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n)
509 #define TORCH_SELECTIVE_NAME(n) \
510   torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n)
511 #define TORCH_SELECTIVE_SCHEMA(n) \
512   torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n)
513 
514 } // namespace detail
515 
516 /// This object provides the API for defining operators and providing
517 /// implementations at dispatch keys.  Typically, a torch::Library
518 /// is not allocated directly; instead it is created by the
519 /// TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() macros.
520 ///
521 /// Most methods on torch::Library return a reference to itself,
522 /// supporting method chaining.
523 ///
524 /// ```
525 /// // Examples:
526 ///
527 /// TORCH_LIBRARY(torchvision, m) {
528 ///    // m is a torch::Library
529 ///    m.def("roi_align", ...);
530 ///    ...
531 /// }
532 ///
533 /// TORCH_LIBRARY_IMPL(aten, XLA, m) {
534 ///    // m is a torch::Library
535 ///    m.impl("add", ...);
536 ///    ...
537 /// }
538 /// ```
539 ///
540 class TORCH_API Library final {
541  public:
542   /// \private
543   ///
544   /// Which type of macro produced this Library
545   enum Kind {
546     DEF, // from TORCH_LIBRARY (no qualifier)
547     IMPL,
548     FRAGMENT,
549   };
550 
551   /// \private
552   ///
553   /// Use TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() instead of using these
554   /// constructors directly
555   Library(
556       Kind kind,
557       std::string ns,
558       std::optional<c10::DispatchKey> k,
559       const char* file,
560       uint32_t line);
561 
562   Library(const Library&) = delete;
563   Library& operator=(const Library&) = delete;
564   Library(Library&&) = default;
565   Library& operator=(Library&&) = default;
566 
567   // Some notes about the API design here.  We had the following constraints:
568   //
569   //  - We need to support multiple "types" of arguments for schema and
570   //    functions (e.g., unnamed lambda types, regular functions, const char*,
571   //    fully instantiated schemas)
572   //  - We don't want to write exponentially many overloads
573   //  - We don't want to rely on implicit conversion to a common type,
574   //    because the C++ compiler will only be willing to do a single
575   //    implicit conversion (reducing the set of valid types which you
576   //    can invoke with); also error messages are worse when an implicit
577   //    conversion is not selected (as the compiler will not explain
578   //    why it didn't select an implicit conversion; this is different
579   //    from overloads where it will explain each candidate overload and
580   //    why it didn't apply)
581   //
582   // To solve all of these constraints at the same time, we use a trick taken
583   // from the pybind11 library: template over the argument in the user visible
584   // API, and inside of the templated function explicitly call an overloaded
585   // function to resolve the argument to a real type.  You get the good error
586   // messages from overloads, but at the same time you only need to write the
587   // overload for any given argument type once.
588 
589   /// Declare an operator with a schema, but don't provide any implementations
590   /// for it.  You're expected to then provide implementations using the
591   /// impl() method.  All template arguments are inferred.
592   ///
593   /// \param raw_schema The schema of the operator to be defined.
594   ///     Typically, this is a `const char*` string literal, but any type
595   ///     accepted by torch::schema() is accepted here.
596   ///
597   /// ```
598   /// // Example:
599   /// TORCH_LIBRARY(myops, m) {
600   ///   m.def("add(Tensor self, Tensor other) -> Tensor");
601   /// }
602   /// ```
603 
604   template <typename Schema>
605   Library& def(
606       Schema&& raw_schema,
607       const std::vector<at::Tag>& tags = {},
608       _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
609     c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
610     return _def(std::move(s), nullptr, tags, rv);
611   }
612 
613   /// Declares that for all operators that are subsequently def'ed, their
614   /// fake impls may be found in the given Python module (pymodule).
615   /// This registers some help text that is used if the fake impl
616   /// cannot be found.
617   ///
618   /// Args:
619   /// - pymodule: the python module
620   /// - context: We may include this in the error message.
621   Library& set_python_module(const char* pymodule, const char* context = "") {
622     python_module_ = {pymodule, context};
623     return *this;
624   }
625 
626   /// Deprecated; use set_python_module instead
627   Library& impl_abstract_pystub(const char* pymodule, const char* context = "") {
628     return set_python_module(pymodule, context);
629   }
630 
631   /// Define an operator for a schema and then register an implementation for
632   /// it.  This is typically what you would use if you aren't planning
633   /// on making use of the dispatcher to structure your operator
634   /// implementation.  It's roughly equivalent to calling def() and
635   /// then impl(), but if you omit the schema of the operator, we will
636   /// infer it from the type of your C++ function.  All template
637   /// arguments are inferred.
638   ///
639   /// \param raw_name_or_schema The schema of the operator to be
640   ///   defined, or just the name of the operator if the schema is to be
641   ///   inferred from `raw_f`.  Typically a `const char*` literal.
642   /// \param raw_f The C++ function that implements this operator.
643   ///   Any valid constructor of torch::CppFunction is accepted here;
644   ///   typically you provide a function pointer or lambda.
645   ///
646   /// ```
647   /// // Example:
648   /// TORCH_LIBRARY(myops, m) {
649   ///   m.def("add", add_fn);
650   /// }
651   /// ```
652   template <typename NameOrSchema, typename Func>
653   Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f,
654       const std::vector<at::Tag>& tags = {}) & {
655     CppFunction f(std::forward<Func>(raw_f));
656     return _def(
657         detail::constructSchemaOrName(
658             ::std::forward<NameOrSchema>(raw_name_or_schema)),
659         ::std::move(f), tags);
660   }
661 
662   /// Register an implementation for an operator.  You may register multiple
663   /// implementations for a single operator at different dispatch keys
664   /// (see torch::dispatch()).  Implementations must have a corresponding
665   /// declaration (from def()), otherwise they are invalid.  If you plan
666   /// to register multiple implementations, DO NOT provide a function
667   /// implementation when you def() the operator.
668   ///
669   /// \param name The name of the operator to implement.  Do NOT provide
670   ///   schema here.
671   /// \param raw_f The C++ function that implements this operator.  Any
672   ///   valid constructor of torch::CppFunction is accepted here;
673   ///   typically you provide a function pointer or lambda.
674   ///
675   /// ```
676   /// // Example:
677   /// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
678   ///   m.impl("add", add_cuda);
679   /// }
680   /// ```
681   template <typename Name, typename Func>
682   Library& impl(
683       Name name,
684       Func&& raw_f,
685       _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
686     // TODO: need to raise an error when you impl a function that has a
687     // catch all def
688 #if defined C10_MOBILE
689     CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
690 #else
691     CppFunction f(std::forward<Func>(raw_f));
692 #endif
693     return _impl(name, std::move(f), rv);
694   }
695 
696 #if defined C10_MOBILE
697   // Note: This overload is needed only for C10_MOBILE, since the automatically
698   // defined copy constructor for the CppFunction doesn't have the additional
699   // NoInferSchemaTag argument. We define the overload for the impl() function
700   // to accept a CppFunction&& argument. The already constructed CppFunction
701   // object may or may not have the inferred schema, but it doesn't matter
702   // for our purposes since if it already has the inferred schema, then we
703   // might as well just pass it through directly.
704   //
705   template <typename Name>
impl(Name name,CppFunction && raw_f)706   Library& impl(Name name, CppFunction&& raw_f) & {
707     // TODO: need to raise an error when you impl a function that has a
708     // catch all def
709     CppFunction f(std::forward<CppFunction>(raw_f));
710     return _impl(name, std::move(f));
711   }
712 #endif
713 
714   // Helper for getting an OperatorName for a const char*.  You probably
715   // don't need this.
716   c10::OperatorName _resolve(const char* name) const;
717 
718   /// \private
719   ///
720   /// Convenience overload for directly specifying the dispatch key when
721   /// impl().  You probably don't need this; instead, prefer specifying
722   /// the dispatch key for the entire block in TORCH_LIBRARY_IMPL()
723   template <typename Name, typename Dispatch, typename Func>
impl(Name name,Dispatch && key,Func && raw_f)724   Library& impl(Name name, Dispatch&& key, Func&& raw_f) & {
725     return impl(
726         name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
727   }
728 
729   template <typename Name, typename Func>
impl_UNBOXED(Name,Func *)730   Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & {
731     static_assert(
732         c10::guts::false_t<Func>(),
733         ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
734     return *this;
735   }
736 
737   // These overloads cover cases when a SelectiveStr (see Note [Selective
738   // build]) has been disabled at compile time.  In that case, don't generate
739   // any code referencing the passed in functions at all.
740   Library& def(detail::SelectiveStr<false>, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
741     return *this;
742   }
743   Library& def(detail::SelectiveStr<true> raw_schema, const std::vector<at::Tag>& tags = {}) & {
744     return def(raw_schema.operator const char*(), tags);
745   }
746   template <typename Func>
747   Library& def(detail::SelectiveStr<false>, Func&& /*raw_f*/, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
748     return *this;
749   }
750   template <typename Func>
751   Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f, const std::vector<at::Tag>& tags = {}) & {
752     return def(
753         raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f), tags);
754   }
755 
756   template <typename Func>
757   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
impl(detail::SelectiveStr<false>,Func &&)758   Library& impl(detail::SelectiveStr<false>, Func&& /*raw_f*/) & {
759     return *this;
760   }
761   template <typename Dispatch, typename Func>
impl(detail::SelectiveStr<false>,Dispatch &&,Func &&)762   Library& impl(
763       detail::SelectiveStr<false>,
764       // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
765       Dispatch&& /*key*/,
766       // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
767       Func&& /*raw_f*/) & {
768     return *this;
769   }
770   template <typename Func>
impl_UNBOXED(detail::SelectiveStr<false>,Func *)771   Library& impl_UNBOXED(
772       detail::SelectiveStr<false> /*name*/,
773       Func* /*raw_f*/) & {
774     static_assert(
775         c10::guts::false_t<Func>(),
776         ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
777     return *this;
778   }
779 
780   template <typename Func>
impl(detail::SelectiveStr<true> name,Func && raw_f)781   Library& impl(detail::SelectiveStr<true> name, Func&& raw_f) & {
782     return impl(name.operator const char*(), std::forward<Func>(raw_f));
783   }
784   template <typename Dispatch, typename Func>
impl(detail::SelectiveStr<true> name,Dispatch && key,Func && raw_f)785   Library& impl(
786       detail::SelectiveStr<true> name,
787       Dispatch&& key,
788       Func&& raw_f) & {
789     return impl(
790         name.operator const char*(),
791         std::forward<Dispatch>(key),
792         std::forward<Func>(raw_f));
793   }
794   template <typename Func>
impl_UNBOXED(detail::SelectiveStr<true>,Func *)795   Library& impl_UNBOXED(
796       detail::SelectiveStr<true> /*name*/,
797       Func* /*raw_f*/) & {
798     static_assert(
799         c10::guts::false_t<Func>(),
800         ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
801     return *this;
802   }
803 
804   /// Register a fallback implementation for all operators which will be used
805   /// if there is not a specific implementation for an operator available.
806   /// There MUST be a DispatchKey associated with a fallback; e.g.,
807   /// only call this from TORCH_LIBRARY_IMPL() with namespace `_`.
808   ///
809   /// \param raw_f The function that implements the fallback.  Unboxed
810   ///   functions typically do not work as fallback functions, as
811   ///   fallback functions must work for every operator (even though
812   ///   they have varying type signatures).  Typical arguments are
813   ///   CppFunction::makeFallthrough() or
814   ///   CppFunction::makeFromBoxedFunction()
815   ///
816   /// ```
817   /// // Example:
818   ///
819   /// TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
820   ///   // If there is not a kernel explicitly registered
821   ///   // for AutogradXLA, fallthrough to the next
822   ///   // available kernel
823   ///   m.fallback(torch::CppFunction::makeFallthrough());
824   /// }
825   ///
826   /// // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp
827   /// // for a full example of boxed fallback
828   /// ```
829   template <typename Func>
fallback(Func && raw_f)830   Library& fallback(Func&& raw_f) & {
831     CppFunction f((std::forward<Func>(raw_f)));
832     return _fallback(std::move(f));
833   }
834 
835   template <class CurClass>
836   inline torch::class_<CurClass> class_(const std::string& className);
837 
838   // These overloads enable the use of selective build on classes registered
839   // within a library. The API is the same as before with 1 minor change.
840   // Instead of m.class_<foo>("foo") you instead do
841   // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
842   template <class CurClass>
843   inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className);
844 
845   template <class CurClass>
846   inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
847 
848   // De-registers all registrations created with this Library
849   void reset();
850 
851  private:
852   Kind kind_;
853   std::optional<std::string> ns_;
854   std::optional<c10::DispatchKey> dispatch_key_;
855   std::optional<std::pair<const char*, const char*>> python_module_;
856   const char* file_;
857   uint32_t line_;
858 
859   std::vector<c10::RegistrationHandleRAII> registrars_;
860 
861   friend class detail::TorchLibraryInit;
862 
863   // Non-user visible actual implementations of functions.  These aren't
864   // public because we only implement & qualifier and not && qualifier
865   Library& _def(
866       c10::FunctionSchema&& schema,
867       c10::OperatorName* out_name = nullptr,
868       const std::vector<at::Tag>& tags = {},
869       _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
870   Library& _def(
871       std::variant<c10::OperatorName, c10::FunctionSchema>&&,
872       CppFunction&& f,
873       const std::vector<at::Tag>& tags = {}) &;
874   Library& _impl(
875       const char* name,
876       CppFunction&& f,
877       _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
878   Library& _fallback(CppFunction&& f) &;
879 
880   at::OperatorName _parseNameForLib(const char* name_str) const;
881 };
882 
883 namespace detail {
884 
885 class TorchLibraryInit final {
886  private:
887   using InitFn = void(Library&);
888   Library lib_;
889 
890  public:
TorchLibraryInit(Library::Kind kind,InitFn * fn,const char * ns,std::optional<c10::DispatchKey> k,const char * file,uint32_t line)891   TorchLibraryInit(
892       Library::Kind kind,
893       InitFn* fn,
894       const char* ns,
895       std::optional<c10::DispatchKey> k,
896       const char* file,
897       uint32_t line)
898       : lib_(kind, ns, k, file, line) {
899     fn(lib_);
900   }
901 };
902 
903 } // namespace detail
904 
905 } // namespace torch
906 
907 // NB: The EXACT NAMING of the initializer functions (e.g.,
908 // TORCH_LIBRARY_init_aten) matters for the code analyzer;
909 // see the regexes at tools/code_analyzer/run_analyzer.sh
910 
911 /// Macro for defining a function that will be run at static
912 /// initialization time to define a library of operators in the
913 /// namespace `ns` (must be a valid C++ identifier, no quotes).
914 /// Use this macro when you want to define a new set of custom operators
915 /// that do not already exist in PyTorch.
916 ///
917 /// Example usage:
918 ///
919 /// ```
920 /// TORCH_LIBRARY(myops, m) {
921 ///   // m is a torch::Library; methods on it will define
922 ///   // operators in the myops namespace
923 ///   m.def("add", add_impl);
924 /// }
925 /// ```
926 ///
927 /// The `m` argument is bound to a torch::Library that is used to
928 /// register operators.  There may only be one TORCH_LIBRARY()
929 /// for any given namespace.
930 #define TORCH_LIBRARY(ns, m)                                                   \
931   static void TORCH_LIBRARY_init_##ns(torch::Library&);                        \
932   static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \
933       torch::Library::DEF,                                                     \
934       &TORCH_LIBRARY_init_##ns,                                                \
935       #ns,                                                                     \
936       std::nullopt,                                                            \
937       __FILE__,                                                                \
938       __LINE__);                                                               \
939   void TORCH_LIBRARY_init_##ns(torch::Library& m)
940 
941 /// \private
942 ///
943 /// This macro is a version of TORCH_LIBRARY() that doesn't enforce that there
944 /// is only one library (it is a "fragment").  This is used inside the
945 /// PerOpRegistration.cpp file, as well as in places where all op registrations
946 /// within the same namespace cannot be easily put into one macro block
947 /// (this is mostly the case for custom ops in fbcode that were ported from
948 /// the old API)
949 #define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID)
950 
951 /// \private
952 ///
953 /// The above macro requires an extra unique identifier (uid) to prevent
954 /// variable name collisions This can happen if TORCH_LIBRARY_FRAGMENT is called
955 /// multiple times with the same namespace in the same translation unit. Note
956 /// that the TORCH_LIBRARY variant doesn't run into this problem, because it
957 /// enforces that it can only be called once for a given namespace.
958 #define _TORCH_LIBRARY_FRAGMENT(ns, m, uid)                       \
959   static void C10_CONCATENATE(                                    \
960       TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \
961   static const torch::detail::TorchLibraryInit C10_CONCATENATE(   \
962       TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)(           \
963       torch::Library::FRAGMENT,                                   \
964       &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
965       #ns,                                                        \
966       std::nullopt,                                               \
967       __FILE__,                                                   \
968       __LINE__);                                                  \
969   void C10_CONCATENATE(                                           \
970       TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m)
971 
972 /// Macro for defining a function that will be run at static
973 /// initialization time to define operator overrides for dispatch key
974 /// `k` (must be an unqualified enum member of c10::DispatchKey) in
975 /// namespace `ns` (must be a valid C++ identifer, no quotes).  Use this
976 /// macro when you want to implement a preexisting set of custom
977 /// operators on a new dispatch key (e.g., you want to provide CUDA
978 /// implementations of already existing operators).  One common usage
979 /// pattern is to use TORCH_LIBRARY() to define schema for all new
980 /// operators you want to define, and then use several
981 /// TORCH_LIBRARY_IMPL() blocks to provide implementations of the
982 /// operator for CPU, CUDA and Autograd.
983 ///
984 /// In some cases, you need to define something that applies to all namespaces,
985 /// not just one namespace (usually a fallback).  In that case, use the reserved
986 /// namespace _, e.g.,
987 ///
988 /// ```
989 /// TORCH_LIBRARY_IMPL(_, XLA, m) {
990 ///    m.fallback(xla_fallback);
991 /// }
992 /// ```
993 ///
994 /// Example usage:
995 ///
996 /// ```
997 /// TORCH_LIBRARY_IMPL(myops, CPU, m) {
998 ///   // m is a torch::Library; methods on it will define
999 ///   // CPU implementations of operators in the myops namespace.
1000 ///   // It is NOT valid to call torch::Library::def()
1001 ///   // in this context.
1002 ///   m.impl("add", add_cpu_impl);
1003 /// }
1004 /// ```
1005 ///
1006 /// If ``add_cpu_impl`` is an overloaded function, use a
1007 /// ``static_cast`` to specify which overload you want
1008 /// (by providing the full type).
1009 ///
1010 // NB: if the dispatch key is not whitelisted, we simply omit the Library
1011 // call entirely
1012 #define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)
1013 
1014 /// \private
1015 ///
1016 /// The above macro requires an extra unique identifier (uid) to prevent
1017 /// variable name collisions. This can happen if TORCH_LIBRARY_IMPL is called
1018 /// multiple times with the same namespace and dispatch key in the same
1019 /// translation unit.
1020 #define _TORCH_LIBRARY_IMPL(ns, k, m, uid)                                \
1021   static void C10_CONCATENATE(                                            \
1022       TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);       \
1023   static const torch::detail::TorchLibraryInit C10_CONCATENATE(           \
1024       TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(                 \
1025       torch::Library::IMPL,                                               \
1026       (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::k)       \
1027            ? &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid) \
1028            : [](torch::Library&) -> void {}),                             \
1029       #ns,                                                                \
1030       std::make_optional(c10::DispatchKey::k),                            \
1031       __FILE__,                                                           \
1032       __LINE__);                                                          \
1033   void C10_CONCATENATE(                                                   \
1034       TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)
1035 
1036 // These are variants of the macros above which are to be used for testing (they
1037 // don't setup the static initializer, so you can control the visibility of
1038 // the allocated library yourself).
1039 //
1040 // DO NOT use these in production code, they are NOT understood by the
1041 // code analyzer and will be incorrectly analyzed in those situations.
1042 
1043 /// \private
1044 #define MAKE_TORCH_LIBRARY(ns) \
1045   torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__)
1046 /// \private
1047 #define MAKE_TORCH_LIBRARY_IMPL(ns, k)         \
1048   torch::Library(                              \
1049       torch::Library::IMPL,                    \
1050       #ns,                                     \
1051       std::make_optional(c10::DispatchKey::k), \
1052       __FILE__,                                \
1053       __LINE__)
1054 
1055 // Make the custom class API visible, so it is available from
1056 // torch::Library.
1057 
1058 #include <torch/custom_class.h>
1059