xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/BoxedKernel_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace c10 {
4 
BoxedKernel()5 inline BoxedKernel::BoxedKernel()
6     : functor_()
7 , boxed_kernel_func_(nullptr)
8 {}
9 
BoxedKernel(std::unique_ptr<OperatorKernel> functor,InternalBoxedKernelFunction * boxed_kernel_func)10 inline BoxedKernel::BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func)
11 : functor_(std::move(functor))
12 , boxed_kernel_func_(boxed_kernel_func)
13 {}
14 
15 template<BoxedKernel::BoxedKernelFunction* func>
make_boxed_function(OperatorKernel *,const OperatorHandle & opHandle,DispatchKeySet,Stack * stack)16 inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) {
17     // Note that we're dropping the DispatchKeySet argument.
18     // See Note [Plumbing Keys Through The Dispatcher 2] for details.
19     func(opHandle, stack);
20 }
21 
22 template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
make_boxed_function(OperatorKernel *,const OperatorHandle & opHandle,DispatchKeySet ks,Stack * stack)23 inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) {
24     // See Note [Plumbing Keys Through The Dispatcher 2] for details.
25     func(opHandle, ks, stack);
26 }
27 
isValid()28 inline bool BoxedKernel::isValid() const {
29     return boxed_kernel_func_ != nullptr;
30 }
31 
isFallthrough()32 inline bool BoxedKernel::isFallthrough() const {
33     return boxed_kernel_func_ == &fallthrough_kernel;
34 }
35 
callBoxed(const OperatorHandle & opHandle,DispatchKeySet dispatchKeySet,Stack * stack)36 inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
37     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38         boxed_kernel_func_ != nullptr,
39         "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."
40     );
41     (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
42 }
43 
44 template<BoxedKernel::BoxedKernelFunction* func>
makeFromFunction()45 inline BoxedKernel BoxedKernel::makeFromFunction() {
46     return BoxedKernel(
47         nullptr,  // no functor_ object
48         &make_boxed_function<func>
49     );
50 }
51 
52 template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
makeFromFunction()53 inline BoxedKernel BoxedKernel::makeFromFunction() {
54     return BoxedKernel(
55         nullptr,  // no functor_ object
56         &make_boxed_function<func>
57     );
58 }
59 
makeFallthrough()60 inline BoxedKernel BoxedKernel::makeFallthrough() {
61     return BoxedKernel(
62         nullptr,  // no functor_ object
63         &fallthrough_kernel
64     );
65 }
66 
makeAmbiguousAutogradOther()67 inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
68     return BoxedKernel(
69         nullptr,  // no functor_ object
70         &ambiguous_autogradother_kernel
71     );
72 }
73 
makeNamedNotSupported()74 inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
75     return BoxedKernel(
76         nullptr,  // no functor_ object
77         &named_not_supported_kernel
78     );
79 }
80 
81 template<class KernelFunctor>
makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)82 inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
83     static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
84     return BoxedKernel(
85         std::move(kernelFunctor),
86         [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) {
87           (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
88         }
89     );
90 }
91 
getFunctor()92 inline OperatorKernel* BoxedKernel::getFunctor() const {
93   return functor_.get();
94 }
getFnPtr()95 inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
96   return boxed_kernel_func_;
97 }
98 
99 }  // namespace c10
100