1 #pragma once 2 3 #include <ATen/core/boxing/OperatorKernel.h> 4 #include <c10/core/DispatchKeySet.h> 5 #include <c10/util/intrusive_ptr.h> 6 7 namespace c10 { 8 9 struct IValue; 10 using Stack = std::vector<IValue>; 11 12 class OperatorHandle; 13 class KernelFunction; 14 15 // This kernel implements the behavior of falling through to the next available 16 // registered dispatch key. The implementation of this function is FAST; it is 17 // no overhead to fallthrough to the next key. See cpp file for some more 18 // implementation notes; notably, this does NOT actually go through the 19 // boxing/unboxing codepath. 20 TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); 21 22 // Note [Ambiguity in AutogradOther kernel] 23 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 // This error-reporting kernel is registered to the AutogradOther entry in the 25 // dispatch table when there is both a CompositeImplicitAutograd kernel and a 26 // backend kernel for ANY backend that maps to AutogradOther. To see why 27 // this is necessary in the AutogradOther case, it's helpful to first see 28 // why everything works out fine for a backend that has a reserved Autograd 29 // entry (see rule 2.2 in [Note] DispatchTable computation): 30 // 31 // CPU AutogradCPU 32 // reg? registers with... 33 // ------------------------------------------------- 34 // y Autograd registration takes precedence 35 // over CompositeImplicitAutograd. 36 // This is good, because the CPU specific backend 37 // implementation is more specialized and typically better; 38 // if we used the composite, we would bypass it. 39 // (NB: the Autograd key is guaranteed to exist because 40 // the autograd codegen requires it!) 41 // 42 // n CompositeImplicitAutograd takes precedence. 43 // This is also good, because the Autograd 44 // registration (if it exists) would try to redispatch 45 // to the (non-existent) CPU implementation; by 46 // using the composite, we ensure the operator 47 // actually works. 48 // 49 // As you can see, when we have a specific Autograd key (AutogradCPU), we can 50 // decide whether or not to use the CompositeImplicitAutograd kernel or the 51 // Autograd kernel based on whether or not the backend kernel exists. 52 // 53 // However, for AutogradOther (which is the catchall autograd kernel for 54 // everything that doesn't have a specific Autograd key), we can't do this 55 // trick because there isn't any unique backend to peek at to disambiguate; 56 // if there are some backends that have implementations they prefer Autograd, 57 // but unimplemented backends would prefer CompositeImplicitAutograd. Rather 58 // than arbitrarily pick one or the other, we just register a kernel that raises 59 // an error and let the user decide how to proceed. 60 TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); 61 62 // Note [named_not_supported_kernel] 63 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 64 // This kernel implements reporting an error message saying that named tensor is 65 // not supported. This kernel doesn't rely on the Stack, and so it is special 66 // cased in the dispatcher to be triggered before we attempt boxing (so we can 67 // give a good error message in cases when boxing is not supported). When 68 // boxing is universally supported this can be removed. 69 [[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); 70 71 /** 72 * BoxedKernel is similar to a std::function storing a boxed kernel. 73 */ 74 class TORCH_API BoxedKernel final { 75 public: 76 // This is how boxed kernels are actually stored 77 // 78 // Note [Plumbing Keys Through The Dispatcher] 79 // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS) 80 // upon every dispatch call into order to compute which kernel to dispatch to. 81 // 82 // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores 83 // to have a first argument of type DispatchKeySet. 84 // 85 // What are the invariants of the DispatchKeySet when it gets passed to a kernel? 86 // - All keys to the left of the current dispatch key have been masked out. 87 // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer) 88 // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments 89 // are still in the set. 90 // 91 // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches: 92 // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will 93 // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher 94 // upon redispatching. 95 // 96 // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature 97 // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples. 98 // 99 // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h. 100 // See Note [Plumbing Keys Through The Dispatcher 2] for details. 101 using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); 102 // This is the public API for how boxed kernels are defined 103 using BoxedKernelFunction = void(const OperatorHandle&, Stack*); 104 using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*); 105 106 BoxedKernel(); 107 108 // Fast path for dispatch to allow not touching the boxed kernel in 109 // the common case where unboxed is available. 110 bool isValid() const; 111 bool isFallthrough() const; 112 113 /** 114 * Call the function with boxed arguments. 115 */ 116 void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; 117 118 /** 119 * Create a KernelFunction from a boxed function. 120 * 121 * Example: 122 * 123 * > void boxed_func(OperatorKernel*, Stack* stack) {...} 124 * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>(); 125 */ 126 template<BoxedKernelFunction* func> 127 static BoxedKernel makeFromFunction(); 128 129 /** 130 * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) 131 * See Note [Plumbing Keys Through The Dispatcher] for details. 132 */ 133 template<BoxedKernelFunction_withDispatchKeys* func> 134 static BoxedKernel makeFromFunction(); 135 136 /** 137 * Create a KernelFunction from a boxed functor. 138 * 139 * Example: 140 * 141 * > class MyFunctor final : public c10::OperatorKernel { 142 * > public: 143 * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} 144 * > }; 145 * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>()); 146 */ 147 template<class KernelFunctor> 148 static BoxedKernel makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor); 149 150 151 static BoxedKernel makeFallthrough(); 152 static BoxedKernel makeAmbiguousAutogradOther(); 153 static BoxedKernel makeNamedNotSupported(); 154 155 private: 156 157 friend class KernelFunction; 158 159 template<BoxedKernelFunction* func> 160 static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); 161 162 template<BoxedKernelFunction_withDispatchKeys* func> 163 static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); 164 165 explicit BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func); 166 167 OperatorKernel* getFunctor() const; 168 InternalBoxedKernelFunction* getFnPtr() const; 169 170 c10::intrusive_ptr<OperatorKernel> functor_; 171 InternalBoxedKernelFunction* boxed_kernel_func_; 172 }; 173 174 } // namespace c10 175 176 #include <ATen/core/boxing/BoxedKernel_impl.h> 177