xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/CPUFallback.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <ATen/core/stack.h>
5 #include <ATen/core/boxing/KernelFunction.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <c10/util/Metaprogramming.h>
8 #include <torch/library.h>
9 
10 namespace at::native {
11 
12 // This function implements a boxed fallback to CPU.
13 // External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
14 TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
15                             c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
16 
17 // This is a helper function that backends can use to directly call their boxed CPU fallback
18 // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
19 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
20 struct _call_fallback_fn final {};
21 
22 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
23 struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
24     static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
25         auto op = c10::Dispatcher::singleton()
26             // TODO: figure out how to make compiler happy without dynamic casts
27             .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
28             //.findSchemaOrThrow("a", "b")
29             .typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
30         return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
31             c10::BoxedKernel::makeFromFunction<fallback_fn>(),
32             op,
33             c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
34             // TODO: get std::forward<> to work
35             args...
36             );
37     }
38 };
39 
40 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
41 using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
42 
43 template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
44 using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
45 
46 } // namespace at::native
47