1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <ATen/core/stack.h> 5 #include <ATen/core/boxing/KernelFunction.h> 6 #include <ATen/core/dispatch/Dispatcher.h> 7 #include <c10/util/Metaprogramming.h> 8 #include <torch/library.h> 9 10 namespace at::native { 11 12 // This function implements a boxed fallback to CPU. 13 // External backends can add their own custom logging on top if it to customize their own CPU fallbacks. 14 TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false, 15 c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU); 16 17 // This is a helper function that backends can use to directly call their boxed CPU fallback 18 // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. 19 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes> 20 struct _call_fallback_fn final {}; 21 22 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes> 23 struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final { 24 static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) { 25 auto op = c10::Dispatcher::singleton() 26 // TODO: figure out how to make compiler happy without dynamic casts 27 .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name) 28 //.findSchemaOrThrow("a", "b") 29 .typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>(); 30 return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call( 31 c10::BoxedKernel::makeFromFunction<fallback_fn>(), 32 op, 33 c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset. 34 // TODO: get std::forward<> to work 35 args... 36 ); 37 } 38 }; 39 40 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op> 41 using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>; 42 43 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op> 44 using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>; 45 46 } // namespace at::native 47