1 #pragma once 2 3 #include <ATen/core/ATen_fwd.h> 4 #include <ATen/core/boxing/BoxedKernel.h> 5 #include <ATen/core/stack.h> 6 #include <c10/core/DispatchKeySet.h> 7 #include <c10/util/intrusive_ptr.h> 8 #include <c10/util/TypeList.h> 9 #include <type_traits> 10 11 namespace c10 { 12 13 using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. 14 15 class OperatorHandle; 16 struct OperatorKernel; 17 class KernelFunction; 18 19 template <typename T> 20 using has_symint = 21 std::disjunction< 22 std::is_same<c10::SymInt, T>, 23 std::is_same<c10::SymIntArrayRef, T>, 24 std::is_same<at::OptionalSymIntArrayRef, T>, 25 std::is_same<std::optional<c10::SymInt>, T> 26 >; 27 28 template <typename T> 29 struct remove_symint { 30 using type = T; 31 }; 32 33 template <> 34 struct remove_symint<c10::SymInt> { 35 using type = int64_t; 36 }; 37 38 template <> 39 struct remove_symint<at::OptionalSymIntArrayRef> { 40 using type = OptionalIntArrayRef; 41 }; 42 43 template <> 44 struct remove_symint<c10::SymIntArrayRef> { 45 using type = c10::IntArrayRef; 46 }; 47 48 template <> 49 struct remove_symint<std::optional<c10::SymInt>> { 50 using type = std::optional<int64_t>; 51 }; 52 53 54 template <bool symint, typename T> 55 struct maybe_keep_symint final {}; 56 57 template <typename T> 58 struct maybe_keep_symint<true, T> { using type = T; }; 59 60 template <typename T> 61 struct maybe_keep_symint<false, T> { using type = typename remove_symint<T>::type; }; 62 63 template <typename T> 64 using fn_has_symint = typename guts::typelist::true_for_any_type< 65 has_symint, 66 typename guts::infer_function_traits<T>::type::parameter_types 67 >; 68 69 template <typename T> 70 struct fn_remove_symint; 71 72 template <typename Ret, typename... Args> 73 struct fn_remove_symint<Ret(Args...)> { 74 using type = Ret(typename remove_symint<Args>::type...); 75 }; 76 77 /** 78 * KernelFunction is similar to std::function but stores a kernel function. 79 * You can create a KernelFunction from a boxed or unboxed function/functor/lambda 80 * and call it in a boxed or unboxed way. If the way it was created doesn't 81 * match the way it was called, it will do boxing or unboxing as necessary. 82 */ 83 class TORCH_API KernelFunction final { 84 public: 85 using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; 86 using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; 87 using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys; 88 89 KernelFunction(); 90 91 // Fast path for dispatch to allow not touching the boxed kernel in 92 // the common case where unboxed is available. 93 bool isValidUnboxed() const; 94 bool isValidSymUnboxed() const; 95 bool isValid() const; 96 bool isFallthrough() const; 97 98 /** 99 * Call the function in a boxed way. 100 * If the kernel function was created with an unboxed function, 101 * this will call an unboxing wrapper which then calls into that 102 * unboxed function. 103 * 104 * Example: 105 * 106 * > void boxed_func(OperatorKernel*, Stack* stack) {...} 107 * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); 108 * > Tensor result = func.callBoxed(stack); 109 * 110 * Or, with an unboxed implementation: 111 * 112 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( 113 * > [] (Tensor a, bool b) -> Tensor {...}); 114 * > Tensor result = func.callBoxed(stack); 115 */ 116 void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; 117 118 /** 119 * Call the function in an unboxed way. 120 * If the kernel function was created with a boxed function, 121 * this will box all inputs and then call into that boxed function. 122 * 123 * Note that this doesn't work for all types yet. 124 * 125 * Example: 126 * 127 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( 128 * > [] (Tensor a, bool b) -> Tensor {...}); 129 * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true); 130 * 131 * Or, with a boxed implementation: 132 * 133 * > void boxed_func(OperatorKernel*, Stack* stack) {...} 134 * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); 135 * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true); 136 */ 137 template<class Return, class... Args> 138 Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const; 139 140 /** 141 * Create a KernelFunction from a BoxedKernel. 142 */ 143 static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); 144 145 /** 146 * Create a KernelFunction from a boxed function. 147 * 148 * Example: 149 * 150 * > void boxed_func(OperatorKernel*, Stack* stack) {...} 151 * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>(); 152 */ 153 template<BoxedKernelFunction* func> 154 static KernelFunction makeFromBoxedFunction(); 155 156 /** 157 * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) 158 * See Note [Plumbing Keys Through The Dispatcher] for details. 159 */ 160 template<BoxedKernelFunction_withDispatchKeys* func> 161 static KernelFunction makeFromBoxedFunction(); 162 163 /** 164 * Create a KernelFunction from an unboxed functor. 165 * 166 * Example: 167 * 168 * > class MyFunctor final : public c10::OperatorKernel { 169 * > public: 170 * > Tensor operator()(Tensor a, Tensor b) {...} 171 * > }; 172 * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>()); 173 */ 174 template<bool AllowLegacyTypes = false, class KernelFunctor> 175 static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor); 176 177 /** 178 * Create a KernelFunction from a boxed functor. 179 * 180 * Example: 181 * 182 * > class MyFunctor final : public c10::OperatorKernel { 183 * > public: 184 * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} 185 * > }; 186 * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>()); 187 */ 188 template<class KernelFunctor> 189 static KernelFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor); 190 191 /** 192 * Create a KernelFunction from an unboxed function. 193 * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction 194 * because knowing the function pointer as a template argument (i.e. at 195 * compile time) allows the compiler to inline the function into its 196 * unboxing wrapper and yields better performance when calling the function. 197 * 198 * Example: 199 * 200 * > Tensor unboxed_func(Tensor a, Tensor b) {...} 201 * > KernelFunction func = KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func), &unboxed_func>(); 202 */ 203 template<class FuncPtr, bool AllowLegacyTypes = false> 204 static KernelFunction makeFromUnboxedFunction(FuncPtr); 205 206 /** 207 * Create a KernelFunction from an unboxed function. 208 * KernelFunction::makeFromUnboxedFunction is usually a better choice than 209 * this if you know the function pointer at compile time, see doc comment 210 * there for an explanation. 211 * 212 * Example: 213 * 214 * > Tensor unboxed_func(Tensor a, Tensor b) {...} 215 * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func); 216 */ 217 template<bool AllowLegacyTypes = false, class FuncType> 218 static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); 219 220 static KernelFunction makeFallthrough(); 221 static KernelFunction makeAmbiguousAutogradOther(); 222 static KernelFunction makeNamedNotSupported(); 223 224 /** 225 * Create a KernelFunction from an unboxed lambda. 226 * 227 * Example: 228 * 229 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( 230 * > [] (Tensor a, bool b) -> Tensor {...}); 231 */ 232 template<bool AllowLegacyTypes = false, class Lambda> 233 static std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); 234 template<bool AllowLegacyTypes = false, class Lambda> 235 static std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); 236 237 std::string dumpState() const; 238 // For testing internal invariants only 239 bool _equalsBoxedAndUnboxed(const KernelFunction&) const; 240 241 private: 242 243 explicit KernelFunction( 244 std::unique_ptr<OperatorKernel> functor, 245 InternalBoxedKernelFunction* boxed_kernel_func, 246 void* unboxed_kernel_func, 247 void* sym_unboxed_kernel_func); 248 explicit KernelFunction( 249 BoxedKernel boxed_fn, 250 void* unboxed_kernel_func, 251 void* sym_unboxed_kernel_func); 252 253 BoxedKernel boxed_kernel_func_; 254 void* unboxed_kernel_func_; 255 void* sym_unboxed_kernel_func_; 256 }; 257 258 } 259 260 #include <ATen/core/boxing/KernelFunction_impl.h> 261