1 #pragma once
2
3 namespace c10 {
4
BoxedKernel()5 inline BoxedKernel::BoxedKernel()
6 : functor_()
7 , boxed_kernel_func_(nullptr)
8 {}
9
BoxedKernel(std::unique_ptr<OperatorKernel> functor,InternalBoxedKernelFunction * boxed_kernel_func)10 inline BoxedKernel::BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func)
11 : functor_(std::move(functor))
12 , boxed_kernel_func_(boxed_kernel_func)
13 {}
14
15 template<BoxedKernel::BoxedKernelFunction* func>
make_boxed_function(OperatorKernel *,const OperatorHandle & opHandle,DispatchKeySet,Stack * stack)16 inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) {
17 // Note that we're dropping the DispatchKeySet argument.
18 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
19 func(opHandle, stack);
20 }
21
22 template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
make_boxed_function(OperatorKernel *,const OperatorHandle & opHandle,DispatchKeySet ks,Stack * stack)23 inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) {
24 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
25 func(opHandle, ks, stack);
26 }
27
isValid()28 inline bool BoxedKernel::isValid() const {
29 return boxed_kernel_func_ != nullptr;
30 }
31
isFallthrough()32 inline bool BoxedKernel::isFallthrough() const {
33 return boxed_kernel_func_ == &fallthrough_kernel;
34 }
35
callBoxed(const OperatorHandle & opHandle,DispatchKeySet dispatchKeySet,Stack * stack)36 inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
37 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38 boxed_kernel_func_ != nullptr,
39 "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."
40 );
41 (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
42 }
43
44 template<BoxedKernel::BoxedKernelFunction* func>
makeFromFunction()45 inline BoxedKernel BoxedKernel::makeFromFunction() {
46 return BoxedKernel(
47 nullptr, // no functor_ object
48 &make_boxed_function<func>
49 );
50 }
51
52 template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
makeFromFunction()53 inline BoxedKernel BoxedKernel::makeFromFunction() {
54 return BoxedKernel(
55 nullptr, // no functor_ object
56 &make_boxed_function<func>
57 );
58 }
59
makeFallthrough()60 inline BoxedKernel BoxedKernel::makeFallthrough() {
61 return BoxedKernel(
62 nullptr, // no functor_ object
63 &fallthrough_kernel
64 );
65 }
66
makeAmbiguousAutogradOther()67 inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
68 return BoxedKernel(
69 nullptr, // no functor_ object
70 &ambiguous_autogradother_kernel
71 );
72 }
73
makeNamedNotSupported()74 inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
75 return BoxedKernel(
76 nullptr, // no functor_ object
77 &named_not_supported_kernel
78 );
79 }
80
81 template<class KernelFunctor>
makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)82 inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
83 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
84 return BoxedKernel(
85 std::move(kernelFunctor),
86 [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) {
87 (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
88 }
89 );
90 }
91
getFunctor()92 inline OperatorKernel* BoxedKernel::getFunctor() const {
93 return functor_.get();
94 }
getFnPtr()95 inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
96 return boxed_kernel_func_;
97 }
98
99 } // namespace c10
100