xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/BoxedKernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/boxing/OperatorKernel.h>
4 #include <c10/core/DispatchKeySet.h>
5 #include <c10/util/intrusive_ptr.h>
6 
7 namespace c10 {
8 
9 struct IValue;
10 using Stack = std::vector<IValue>;
11 
12 class OperatorHandle;
13 class KernelFunction;
14 
15 // This kernel implements the behavior of falling through to the next available
16 // registered dispatch key.  The implementation of this function is FAST; it is
17 // no overhead to fallthrough to the next key.  See cpp file for some more
18 // implementation notes; notably, this does NOT actually go through the
19 // boxing/unboxing codepath.
20 TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
21 
22 // Note [Ambiguity in AutogradOther kernel]
23 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24 // This error-reporting kernel is registered to the AutogradOther entry in the
25 // dispatch table when there is both a CompositeImplicitAutograd kernel and a
26 // backend kernel for ANY backend that maps to AutogradOther.  To see why
27 // this is necessary in the AutogradOther case, it's helpful to first see
28 // why everything works out fine for a backend that has a reserved Autograd
29 // entry (see rule 2.2 in [Note] DispatchTable computation):
30 //
31 //    CPU   AutogradCPU
32 //    reg?  registers with...
33 //    -------------------------------------------------
34 //    y     Autograd registration takes precedence
35 //          over CompositeImplicitAutograd.
36 //          This is good, because the CPU specific backend
37 //          implementation is more specialized and typically better;
38 //          if we used the composite, we would bypass it.
39 //          (NB: the Autograd key is guaranteed to exist because
40 //          the autograd codegen requires it!)
41 //
42 //    n     CompositeImplicitAutograd takes precedence.
43 //          This is also good, because the Autograd
44 //          registration (if it exists) would try to redispatch
45 //          to the (non-existent) CPU implementation; by
46 //          using the composite, we ensure the operator
47 //          actually works.
48 //
49 // As you can see, when we have a specific Autograd key (AutogradCPU), we can
50 // decide whether or not to use the CompositeImplicitAutograd kernel or the
51 // Autograd kernel based on whether or not the backend kernel exists.
52 //
53 // However, for AutogradOther (which is the catchall autograd kernel for
54 // everything that doesn't have a specific Autograd key), we can't do this
55 // trick because there isn't any unique backend to peek at to disambiguate;
56 // if there are some backends that have implementations they prefer Autograd,
57 // but unimplemented backends would prefer CompositeImplicitAutograd.  Rather
58 // than arbitrarily pick one or the other, we just register a kernel that raises
59 // an error and let the user decide how to proceed.
60 TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
61 
62 // Note [named_not_supported_kernel]
63 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64 // This kernel implements reporting an error message saying that named tensor is
65 // not supported.  This kernel doesn't rely on the Stack, and so it is special
66 // cased in the dispatcher to be triggered before we attempt boxing (so we can
67 // give a good error message in cases when boxing is not supported).  When
68 // boxing is universally supported this can be removed.
69 [[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
70 
71 /**
72  * BoxedKernel is similar to a std::function storing a boxed kernel.
73  */
74 class TORCH_API BoxedKernel final {
75 public:
76   // This is how boxed kernels are actually stored
77   //
78   // Note [Plumbing Keys Through The Dispatcher]
79   // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS)
80   // upon every dispatch call into order to compute which kernel to dispatch to.
81   //
82   // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores
83   // to have a first argument of type DispatchKeySet.
84   //
85   // What are the invariants of the DispatchKeySet when it gets passed to a kernel?
86   // - All keys to the left of the current dispatch key have been masked out.
87   //   (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer)
88   // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments
89   //   are still in the set.
90   //
91   // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches:
92   // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will
93   // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher
94   // upon redispatching.
95   //
96   // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature
97   // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples.
98   //
99   // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h.
100   // See Note [Plumbing Keys Through The Dispatcher 2] for details.
101   using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
102   // This is the public API for how boxed kernels are defined
103   using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
104   using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*);
105 
106   BoxedKernel();
107 
108   // Fast path for dispatch to allow not touching the boxed kernel in
109   // the common case where unboxed is available.
110   bool isValid() const;
111   bool isFallthrough() const;
112 
113   /**
114    * Call the function with boxed arguments.
115    */
116   void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
117 
118   /**
119    * Create a KernelFunction from a boxed function.
120    *
121    * Example:
122    *
123    * > void boxed_func(OperatorKernel*, Stack* stack) {...}
124    * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
125    */
126   template<BoxedKernelFunction* func>
127   static BoxedKernel makeFromFunction();
128 
129   /**
130    * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
131    * See Note [Plumbing Keys Through The Dispatcher] for details.
132    */
133   template<BoxedKernelFunction_withDispatchKeys* func>
134   static BoxedKernel makeFromFunction();
135 
136   /**
137    * Create a KernelFunction from a boxed functor.
138    *
139    * Example:
140    *
141    * > class MyFunctor final : public c10::OperatorKernel {
142    * >   public:
143    * >     void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
144    * > };
145    * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
146    */
147   template<class KernelFunctor>
148   static BoxedKernel makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor);
149 
150 
151   static BoxedKernel makeFallthrough();
152   static BoxedKernel makeAmbiguousAutogradOther();
153   static BoxedKernel makeNamedNotSupported();
154 
155 private:
156 
157   friend class KernelFunction;
158 
159   template<BoxedKernelFunction* func>
160   static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
161 
162   template<BoxedKernelFunction_withDispatchKeys* func>
163   static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
164 
165   explicit BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func);
166 
167   OperatorKernel* getFunctor() const;
168   InternalBoxedKernelFunction* getFnPtr() const;
169 
170   c10::intrusive_ptr<OperatorKernel> functor_;
171   InternalBoxedKernelFunction* boxed_kernel_func_;
172 };
173 
174 }  // namespace c10
175 
176 #include <ATen/core/boxing/BoxedKernel_impl.h>
177