xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/Interpreter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/functorch/Macros.h>
4 #include <ATen/core/dispatch/Dispatcher.h>
5 #include <c10/core/impl/LocalDispatchKeySet.h>
6 #include <optional>
7 #include <bitset>
8 #include <utility>
9 #include <variant>
10 
11 namespace at::functorch {
12 
13 // NOTE: [functorch interpreter stack]
14 //
15 // functorch's dispatching system uses a stack of interpreters.
16 // Historically we've referred to this as the "DynamicLayerStack".
17 //
18 // An interpreter is something that reads in the code it is passed
19 // and then executes it. We have a different interpreter per-transform:
20 // the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
21 // and executing the batched version of it (the batching rule for aten::mv).
22 //
23 // Concretely, each interpreter is responsible for two things:
24 //
25 // 1) process(ophandle, stack)
26 // Given an operator handle and a stack of arguments, the interpreter is
27 // responsible for figuring out how to execute the operation under the semantics
28 // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
29 // the batching rule.
30 //
31 // The batching rules are stored as kernels on the FuncTorchBatched key, so the way
32 // VmapInterpreter calls the batching rule is roughly: (A) exclude all
33 // dispatch keys aside from the Batched key, (B) redispatch so we get to the
34 // Batched key.
35 //
36 // 2) sendToNextInterpreter(ophandle, stack)
37 // The VmapInterpreter, when it sees aten::mv, will process it into a call to
38 // aten::mm. It then needs to send the call to aten::mm to the next interpreter
39 // in the interpreter stack.
40 //
41 // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
42 // and most Interpreters will implement it this way.
43 
44 enum class RandomnessType {
45     Error,      // always errors when calling a random function
46     Same,       // randomness appears the same across batches
47     Different,  // randomness appears different across batches
48     END
49 };
50 
51 enum class TransformType {
52   Torch,  // Unused
53   Vmap,
54   Grad,  // reverse-mode AD, aka vjp
55   Jvp,  // forward-mode AD
56   Functionalize,
57 };
58 
59 std::ostream& operator<<(std::ostream& os, const TransformType& t);
60 
61 // NOTE: [Interpreter "subclassing" design]
62 //
63 // How are various Interpreters for different transforms (vmap, grad, ...)
64 // implemented?
65 //
66 // Accessing interpreters is in the hot-path of functorch so we have a constraint
67 // that this code must be as fast as possible.
68 //
69 // As a result, we stay away from virtual methods and this causes our code
70 // to look a little funny.
71 //
72 // `Interpreter` is the struct for Interpreters. It holds ALL of the
73 // relevant information (what type of interpreter it is and the metadata).
74 // Metadata for each interpreter is represented as a Union (std::variant)
75 // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
76 //
77 // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
78 // if you want to access the metadata fields (like batchSize and randomness).
79 //
80 // Each type of interpreter (e.g. Vmap) has a convenience struct
81 // (e.g. VmapInterpreterPtr) associated with it.
82 //
83 // Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
84 // and then one can access methods on VmapInterpreterPtr like so:
85 // >>> VmapInterpreterPtr(&interpreter).batchSize()
86 //
87 // Finally, Interpreter::process switches on the type of the interpreter
88 // and calls one of {Transform}Intepreter::processImpl under the hood.
89 // Same for Interpreter::sendToNextInterpreter :)
90 
91 struct VmapInterpreterMeta {
VmapInterpreterMetaVmapInterpreterMeta92   explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
93     batchSize_(std::move(batchSize)), randomness_(randomness) {}
94   c10::SymInt batchSize_;
95   RandomnessType randomness_;
96 };
97 
98 struct GradInterpreterMeta {
GradInterpreterMetaGradInterpreterMeta99   explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
100   bool prevGradMode_;
101 };
102 
103 struct JvpInterpreterMeta {
JvpInterpreterMetaJvpInterpreterMeta104   explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
105   bool prevFwdGradMode_;
106 };
107 
108 struct FunctionalizeInterpreterMeta {
FunctionalizeInterpreterMetaFunctionalizeInterpreterMeta109   explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
110     functionalizeAddBackViews_(functionalizeAddBackViews) {}
111   bool functionalizeAddBackViews_;
112 };
113 
114 typedef std::variant<
115   int64_t,
116   GradInterpreterMeta,
117   JvpInterpreterMeta,
118   VmapInterpreterMeta,
119   FunctionalizeInterpreterMeta
120 > InterpreterMeta;
121 
122 
123 struct Interpreter {
124   // factory functions
VmapInterpreter125   static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
126     return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
127   }
GradInterpreter128   static Interpreter Grad(int64_t level, bool prevGradMode) {
129     return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
130   }
JvpInterpreter131   static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
132     return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
133   }
FunctionalizeInterpreter134   static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
135     return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
136   }
137 
138   // methods
keyInterpreter139   TransformType key() const { return type_; }
levelInterpreter140   int64_t level() const { return level_; }
metaInterpreter141   const InterpreterMeta& meta() const { return meta_; }
142 
143   void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
144   void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
145 
saveLocalDispatchKeySetInterpreter146   void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
147     TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
148     savedLocalDispatchKeySet_ = keyset;
149   }
clearSavedLocalDispatchKeySetInterpreter150   void clearSavedLocalDispatchKeySet() {
151     TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
152     savedLocalDispatchKeySet_ = std::nullopt;
153   }
getSavedLocalDispatchKeySetInterpreter154   c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
155     TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
156     return *savedLocalDispatchKeySet_;
157   }
158 
159   // An Interpreter is alive if we are currently inside the ongoing transform
160   // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
161   // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
is_aliveInterpreter162   bool is_alive() const {
163     return *is_alive_;
164   }
is_alive_ptrInterpreter165   const std::shared_ptr<bool>& is_alive_ptr() const {
166     return is_alive_;
167   }
set_is_aliveInterpreter168   void set_is_alive(bool alive) {
169     *is_alive_ = alive;
170   }
171 
172   // Please don't use this
173   explicit Interpreter() = default;
174 
175  private:
InterpreterInterpreter176   explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
177     type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
178 
179   // fields
180   TransformType type_{};
181   int64_t level_{};
182   std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
183   std::shared_ptr<bool> is_alive_;
184   InterpreterMeta meta_;
185 };
186 
187 // Applies the following for-loop:
188 // for i in range(begin, end):
189 //   args[i] = func(args[i])
190 void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
191     std::function<Tensor(const Tensor&)> func);
192 
193 // Applies the following for-loop:
194 // for i in range(begin, end):
195 //   if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
196 //     args[i] = func(args[i], i - begin, true)
197 //   args[i] = func(args[i], i - begin)
198 void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
199     const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
200 
201 std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
202 
203 DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
204 
205 void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
206 
207 void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
208 
209 } // namespace at::functorch
210