1 #include <ATen/functorch/Interpreter.h>
2 #include <ATen/functorch/BatchedTensorImpl.h>
3 #include <ATen/functorch/TensorWrapper.h>
4 #include <ATen/functorch/VmapInterpreter.h>
5 #include <ATen/functorch/FunctionalizeInterpreter.h>
6 #include <ATen/functorch/ADInterpreters.h>
7 #include <ATen/functorch/DynamicLayer.h>
8
9 namespace at::functorch {
10
get_all_dynlayer_keyset()11 static DispatchKeySet get_all_dynlayer_keyset() {
12 // NB: FULL_AFTER does not include the dispatch key
13
14 // "all dispatch keys between DynamicLayer{Front, Back}Mode, inclusive"
15 auto result =
16 DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::FuncTorchDynamicLayerFrontMode) -
17 DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::FuncTorchDynamicLayerBackMode);
18 result = result | DispatchKeySet({DispatchKey::FuncTorchDynamicLayerFrontMode});
19
20 // Hack: don't handle the autocast dispatch keys. Their interaction with functorch
21 // is weird.
22 result = result - autocast_dispatch_keyset;
23
24 // Hack: don't handle DispatchKey::FuncTorchVmapMode. We need a better way of modeling this.
25 // In e.g. grad(vmap(f)), DispatchKey::FuncTorchVmapMode makes it so that all random operations,
26 // even after we are done handling the vmap layer, error out.
27 result = result.remove(DispatchKey::FuncTorchVmapMode);
28
29 return result;
30 }
31
32 // TODO: This should be constexpr, but there are some methods
33 // of DispatchKeySet that haven't been marked constexpr yet.
34 static DispatchKeySet all_dynlayer_keyset = get_all_dynlayer_keyset();
35
keysForEnteringDynamicLayer(TransformType key)36 static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
37 if (key == TransformType::Vmap) {
38 // NB: Does not include DispatchKey::FuncTorchVmapMode. We may modulate the key when
39 // constructing the DynamicLayer, but we don't control it when entering/exiting
40 // the DynamicLayer.
41 return DispatchKeySet({DispatchKey::FuncTorchBatched, DispatchKey::BatchedNestedTensor});
42 } else if (key == TransformType::Grad || key == TransformType::Jvp) {
43 return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
44 } else if (key == TransformType::Functionalize) {
45 return DispatchKeySet(DispatchKey::Functionalize);
46 } else {
47 TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
48 }
49 }
50
keysToExcludeWhenEnteringDynamicLayer(TransformType key)51 DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
52 DispatchKeySet exclude = all_dynlayer_keyset;
53 exclude = exclude.remove(DispatchKey::FuncTorchDynamicLayerBackMode);
54 exclude = exclude - keysForEnteringDynamicLayer(key);
55 return exclude;
56 }
57
setup_dispatch_key_tls(TransformType key,DispatchKeySet also_include)58 void setup_dispatch_key_tls(TransformType key, DispatchKeySet also_include) {
59 auto local_keyset = c10::impl::tls_local_dispatch_key_set();
60 auto to_exclude = local_keyset.excluded_;
61 to_exclude = to_exclude | keysToExcludeWhenEnteringDynamicLayer(key);
62 to_exclude = to_exclude - keysForEnteringDynamicLayer(key);
63 local_keyset.excluded_ = to_exclude;
64 local_keyset.included_ = local_keyset.included_ | also_include;
65 c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
66 }
67
operator <<(std::ostream & os,const TransformType & t)68 std::ostream& operator<<(std::ostream& os, const TransformType& t) {
69 switch (t) {
70 case TransformType::Torch:
71 os << "Torch";
72 break;
73 case TransformType::Vmap:
74 os << "Vmap";
75 break;
76 case TransformType::Grad:
77 os << "Grad";
78 break;
79 case TransformType::Jvp:
80 os << "Jvp";
81 break;
82 case TransformType::Functionalize:
83 os << "Functionalize";
84 break;
85 default:
86 TORCH_INTERNAL_ASSERT(false);
87 }
88 return os;
89 }
90
sanityCheckStack(const c10::OperatorHandle & op,torch::jit::Stack * stack)91 void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
92 auto num_args = op.schema().arguments().size();
93 foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - num_args), static_cast<int64_t>(stack->size()),
94 [](const Tensor& tensor) {
95 auto result = unwrapIfDead(tensor);
96 auto* wrapper = maybeGetTensorWrapper(result);
97 TORCH_INTERNAL_ASSERT(wrapper == nullptr);
98 auto* batched = maybeGetBatchedImpl(result);
99 TORCH_INTERNAL_ASSERT(batched == nullptr);
100 return tensor;
101 });
102 }
103
104 #define INTERPRETER_DISPATCH(type, method) \
105 switch (key()) { \
106 case TransformType::Vmap: \
107 TORCH_INTERNAL_ASSERT(std::holds_alternative<VmapInterpreterMeta>(this->meta()));\
108 return VmapInterpreterPtr(this). method; \
109 case TransformType::Grad: \
110 TORCH_INTERNAL_ASSERT(std::holds_alternative<GradInterpreterMeta>(this->meta()));\
111 return GradInterpreterPtr(this). method; \
112 case TransformType::Jvp: \
113 TORCH_INTERNAL_ASSERT(std::holds_alternative<JvpInterpreterMeta>(this->meta()));\
114 return JvpInterpreterPtr(this). method; \
115 case TransformType::Functionalize: \
116 TORCH_INTERNAL_ASSERT(std::holds_alternative<FunctionalizeInterpreterMeta>(this->meta()));\
117 return FunctionalizeInterpreterPtr(this). method; \
118 default: \
119 TORCH_INTERNAL_ASSERT(false, "Unrecognized transform"); \
120 }
121
process(const c10::OperatorHandle & op,torch::jit::Stack * stack)122 void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
123 INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
124 }
125
sendToNextInterpreter(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)126 void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
127 INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)));
128 }
129
130 } // namespace at::functorch
131