xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/LegacyTypeDispatch.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <ATen/core/VariableHooksInterface.h>
4 #include <torch/library.h>
5 
6 /*
7  * This file implements a variable fallback kernel for custom operators.
8  * Since tensors always have the Autograd set, but custom operators
9  * usually don't have a kernel registered for Autograd, the dispatcher
10  * will call into this fallback kernel instead.
11  * Note that this is not a correct autograd implementation. It will just
12  * fallthrough to the custom operator implementation.
13  * If you want a custom operator to work with autograd, you need to use
14  * autograd::Function so that the custom operator implementation knows how to
15  * do autograd.
16  * Note also that ops from native_functions.yaml register their own variable
17  * kernels, so this is never called for them.
18  */
19 
20 // TODO This whole file should be deleted and replaced with the mechanism
21 //      described in https://github.com/pytorch/pytorch/issues/29548
22 
23 using c10::Stack;
24 
25 namespace {
26 
27 #ifdef C10_MOBILE
28 // NOTE [mobile/edge builds and the autograd fallback]
29 // To save on binary size, some of the mobile configs don't include the
30 // autograd kernels for built-in operators (VariableTypeEverything.cpp).
31 // For the mobile build:
32 // - we don't care about having a nice autograd fallback that warns if
33 // an operator has incorrect autograd support. If you're running
34 // a custom operator on mobile then it's already too late for us to warn
35 // or error on it.
36 // - for perf reasons, we do not want mobile to go through autograd_fallback
37 // for all operators (the boxing/unboxing adds overhead).
38 // As a result, on mobile we set the fallback to the fallthrough.
39 #define AUTOGRAD_FALLBACK torch::CppFunction::makeFallthrough()
40 #else
41 
42 // Register fallthrough for Autograd backends dispatch keys
43 // NB: But not the private use ones; maybe the extension wants
44 // to override it themselves!
45 void autograd_fallback(
46     const c10::OperatorHandle& op,
47     c10::DispatchKeySet dispatch_keys,
48     torch::jit::Stack* stack) {
49   // PyTorch has separate builds, some of which don't include autograd.
50   // So we define some behavior for when autograd isn't included and
51   // go through a layer of indirection (VariableHooksInterface) when it is.
52   // See aten/src/ATen/core/VariableHooksInterface.h for more details.
53   if (!at::impl::HasVariableHooks()) {
54     op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
55     return;
56   }
57   at::impl::GetVariableHooks()->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
58 }
59 
60 #define AUTOGRAD_FALLBACK torch::CppFunction::makeFromBoxedFunction<&autograd_fallback>()
61 #endif
62 
TORCH_LIBRARY_IMPL(_,AutogradOther,m)63 TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
64   m.fallback(AUTOGRAD_FALLBACK);
65 }
66 
TORCH_LIBRARY_IMPL(_,AutogradCPU,m)67 TORCH_LIBRARY_IMPL(_, AutogradCPU, m) {
68   m.fallback(AUTOGRAD_FALLBACK);
69 }
70 
TORCH_LIBRARY_IMPL(_,AutogradXPU,m)71 TORCH_LIBRARY_IMPL(_, AutogradXPU, m) {
72   m.fallback(AUTOGRAD_FALLBACK);
73 }
74 
TORCH_LIBRARY_IMPL(_,AutogradCUDA,m)75 TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) {
76   m.fallback(AUTOGRAD_FALLBACK);
77 }
78 
TORCH_LIBRARY_IMPL(_,AutogradXLA,m)79 TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
80   m.fallback(AUTOGRAD_FALLBACK);
81 }
82 
TORCH_LIBRARY_IMPL(_,AutogradLazy,m)83 TORCH_LIBRARY_IMPL(_, AutogradLazy, m) {
84   m.fallback(AUTOGRAD_FALLBACK);
85 }
86 
TORCH_LIBRARY_IMPL(_,AutogradMPS,m)87 TORCH_LIBRARY_IMPL(_, AutogradMPS, m) {
88   m.fallback(AUTOGRAD_FALLBACK);
89 }
90 
TORCH_LIBRARY_IMPL(_,AutogradMeta,m)91 TORCH_LIBRARY_IMPL(_, AutogradMeta, m) {
92   m.fallback(AUTOGRAD_FALLBACK);
93 }
94 
95 // see Note [ADInplaceOrView key]
TORCH_LIBRARY_IMPL(_,ADInplaceOrView,m)96 TORCH_LIBRARY_IMPL(_, ADInplaceOrView, m) {
97   m.fallback(torch::CppFunction::makeFallthrough());
98 }
99 
TORCH_LIBRARY_IMPL(_,AutogradHPU,m)100 TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
101   m.fallback(AUTOGRAD_FALLBACK);
102 }
103 
104 #undef AUTOGRAD_FALLBACK
105 
106 } // namespace
107