xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/Interpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/functorch/Interpreter.h>
2 #include <ATen/functorch/BatchedTensorImpl.h>
3 #include <ATen/functorch/TensorWrapper.h>
4 #include <ATen/functorch/VmapInterpreter.h>
5 #include <ATen/functorch/FunctionalizeInterpreter.h>
6 #include <ATen/functorch/ADInterpreters.h>
7 #include <ATen/functorch/DynamicLayer.h>
8 
9 namespace at::functorch {
10 
get_all_dynlayer_keyset()11 static DispatchKeySet get_all_dynlayer_keyset() {
12   // NB: FULL_AFTER does not include the dispatch key
13 
14   // "all dispatch keys between DynamicLayer{Front, Back}Mode, inclusive"
15   auto result =
16     DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::FuncTorchDynamicLayerFrontMode) -
17     DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::FuncTorchDynamicLayerBackMode);
18   result = result | DispatchKeySet({DispatchKey::FuncTorchDynamicLayerFrontMode});
19 
20   // Hack: don't handle the autocast dispatch keys. Their interaction with functorch
21   // is weird.
22   result = result - autocast_dispatch_keyset;
23 
24   // Hack: don't handle DispatchKey::FuncTorchVmapMode. We need a better way of modeling this.
25   // In e.g. grad(vmap(f)), DispatchKey::FuncTorchVmapMode makes it so that all random operations,
26   // even after we are done handling the vmap layer, error out.
27   result = result.remove(DispatchKey::FuncTorchVmapMode);
28 
29   return result;
30 }
31 
32 // TODO: This should be constexpr, but there are some methods
33 // of DispatchKeySet that haven't been marked constexpr yet.
34 static DispatchKeySet all_dynlayer_keyset = get_all_dynlayer_keyset();
35 
keysForEnteringDynamicLayer(TransformType key)36 static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
37   if (key == TransformType::Vmap) {
38     // NB: Does not include DispatchKey::FuncTorchVmapMode. We may modulate the key when
39     // constructing the DynamicLayer, but we don't control it when entering/exiting
40     // the DynamicLayer.
41     return DispatchKeySet({DispatchKey::FuncTorchBatched, DispatchKey::BatchedNestedTensor});
42   } else if (key == TransformType::Grad || key == TransformType::Jvp) {
43     return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
44   } else if (key == TransformType::Functionalize) {
45     return DispatchKeySet(DispatchKey::Functionalize);
46   } else {
47     TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
48   }
49 }
50 
keysToExcludeWhenEnteringDynamicLayer(TransformType key)51 DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
52   DispatchKeySet exclude = all_dynlayer_keyset;
53   exclude = exclude.remove(DispatchKey::FuncTorchDynamicLayerBackMode);
54   exclude = exclude - keysForEnteringDynamicLayer(key);
55   return exclude;
56 }
57 
setup_dispatch_key_tls(TransformType key,DispatchKeySet also_include)58 void setup_dispatch_key_tls(TransformType key, DispatchKeySet also_include) {
59   auto local_keyset = c10::impl::tls_local_dispatch_key_set();
60   auto to_exclude = local_keyset.excluded_;
61   to_exclude = to_exclude | keysToExcludeWhenEnteringDynamicLayer(key);
62   to_exclude = to_exclude - keysForEnteringDynamicLayer(key);
63   local_keyset.excluded_ = to_exclude;
64   local_keyset.included_ = local_keyset.included_ | also_include;
65   c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
66 }
67 
operator <<(std::ostream & os,const TransformType & t)68 std::ostream& operator<<(std::ostream& os, const TransformType& t) {
69   switch (t) {
70     case TransformType::Torch:
71       os << "Torch";
72       break;
73     case TransformType::Vmap:
74       os << "Vmap";
75       break;
76     case TransformType::Grad:
77       os << "Grad";
78       break;
79     case TransformType::Jvp:
80       os << "Jvp";
81       break;
82     case TransformType::Functionalize:
83       os << "Functionalize";
84       break;
85     default:
86       TORCH_INTERNAL_ASSERT(false);
87   }
88   return os;
89 }
90 
sanityCheckStack(const c10::OperatorHandle & op,torch::jit::Stack * stack)91 void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
92   auto num_args = op.schema().arguments().size();
93   foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - num_args), static_cast<int64_t>(stack->size()),
94       [](const Tensor& tensor) {
95         auto result = unwrapIfDead(tensor);
96         auto* wrapper = maybeGetTensorWrapper(result);
97         TORCH_INTERNAL_ASSERT(wrapper == nullptr);
98         auto* batched = maybeGetBatchedImpl(result);
99         TORCH_INTERNAL_ASSERT(batched == nullptr);
100         return tensor;
101       });
102 }
103 
104 #define INTERPRETER_DISPATCH(type, method) \
105   switch (key()) { \
106     case TransformType::Vmap: \
107       TORCH_INTERNAL_ASSERT(std::holds_alternative<VmapInterpreterMeta>(this->meta()));\
108       return VmapInterpreterPtr(this). method; \
109     case TransformType::Grad: \
110       TORCH_INTERNAL_ASSERT(std::holds_alternative<GradInterpreterMeta>(this->meta()));\
111       return GradInterpreterPtr(this). method; \
112     case TransformType::Jvp: \
113       TORCH_INTERNAL_ASSERT(std::holds_alternative<JvpInterpreterMeta>(this->meta()));\
114       return JvpInterpreterPtr(this). method; \
115     case TransformType::Functionalize: \
116       TORCH_INTERNAL_ASSERT(std::holds_alternative<FunctionalizeInterpreterMeta>(this->meta()));\
117       return FunctionalizeInterpreterPtr(this). method; \
118     default: \
119       TORCH_INTERNAL_ASSERT(false, "Unrecognized transform"); \
120   }
121 
process(const c10::OperatorHandle & op,torch::jit::Stack * stack)122 void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
123   INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
124 }
125 
sendToNextInterpreter(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)126 void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
127   INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)));
128 }
129 
130 } // namespace at::functorch
131