1 #pragma once 2 3 #include <ATen/functorch/Macros.h> 4 #include <ATen/core/dispatch/Dispatcher.h> 5 #include <c10/core/impl/LocalDispatchKeySet.h> 6 #include <optional> 7 #include <bitset> 8 #include <utility> 9 #include <variant> 10 11 namespace at::functorch { 12 13 // NOTE: [functorch interpreter stack] 14 // 15 // functorch's dispatching system uses a stack of interpreters. 16 // Historically we've referred to this as the "DynamicLayerStack". 17 // 18 // An interpreter is something that reads in the code it is passed 19 // and then executes it. We have a different interpreter per-transform: 20 // the "VmapInterpreter" is responsible for reading in operators (like aten::mv) 21 // and executing the batched version of it (the batching rule for aten::mv). 22 // 23 // Concretely, each interpreter is responsible for two things: 24 // 25 // 1) process(ophandle, stack) 26 // Given an operator handle and a stack of arguments, the interpreter is 27 // responsible for figuring out how to execute the operation under the semantics 28 // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call 29 // the batching rule. 30 // 31 // The batching rules are stored as kernels on the FuncTorchBatched key, so the way 32 // VmapInterpreter calls the batching rule is roughly: (A) exclude all 33 // dispatch keys aside from the Batched key, (B) redispatch so we get to the 34 // Batched key. 35 // 36 // 2) sendToNextInterpreter(ophandle, stack) 37 // The VmapInterpreter, when it sees aten::mv, will process it into a call to 38 // aten::mm. It then needs to send the call to aten::mm to the next interpreter 39 // in the interpreter stack. 40 // 41 // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack) 42 // and most Interpreters will implement it this way. 43 44 enum class RandomnessType { 45 Error, // always errors when calling a random function 46 Same, // randomness appears the same across batches 47 Different, // randomness appears different across batches 48 END 49 }; 50 51 enum class TransformType { 52 Torch, // Unused 53 Vmap, 54 Grad, // reverse-mode AD, aka vjp 55 Jvp, // forward-mode AD 56 Functionalize, 57 }; 58 59 std::ostream& operator<<(std::ostream& os, const TransformType& t); 60 61 // NOTE: [Interpreter "subclassing" design] 62 // 63 // How are various Interpreters for different transforms (vmap, grad, ...) 64 // implemented? 65 // 66 // Accessing interpreters is in the hot-path of functorch so we have a constraint 67 // that this code must be as fast as possible. 68 // 69 // As a result, we stay away from virtual methods and this causes our code 70 // to look a little funny. 71 // 72 // `Interpreter` is the struct for Interpreters. It holds ALL of the 73 // relevant information (what type of interpreter it is and the metadata). 74 // Metadata for each interpreter is represented as a Union (std::variant) 75 // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...). 76 // 77 // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this 78 // if you want to access the metadata fields (like batchSize and randomness). 79 // 80 // Each type of interpreter (e.g. Vmap) has a convenience struct 81 // (e.g. VmapInterpreterPtr) associated with it. 82 // 83 // Construct the convenience struct with VmapInterpreterPtr(Interpreter*), 84 // and then one can access methods on VmapInterpreterPtr like so: 85 // >>> VmapInterpreterPtr(&interpreter).batchSize() 86 // 87 // Finally, Interpreter::process switches on the type of the interpreter 88 // and calls one of {Transform}Intepreter::processImpl under the hood. 89 // Same for Interpreter::sendToNextInterpreter :) 90 91 struct VmapInterpreterMeta { VmapInterpreterMetaVmapInterpreterMeta92 explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) : 93 batchSize_(std::move(batchSize)), randomness_(randomness) {} 94 c10::SymInt batchSize_; 95 RandomnessType randomness_; 96 }; 97 98 struct GradInterpreterMeta { GradInterpreterMetaGradInterpreterMeta99 explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} 100 bool prevGradMode_; 101 }; 102 103 struct JvpInterpreterMeta { JvpInterpreterMetaJvpInterpreterMeta104 explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} 105 bool prevFwdGradMode_; 106 }; 107 108 struct FunctionalizeInterpreterMeta { FunctionalizeInterpreterMetaFunctionalizeInterpreterMeta109 explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : 110 functionalizeAddBackViews_(functionalizeAddBackViews) {} 111 bool functionalizeAddBackViews_; 112 }; 113 114 typedef std::variant< 115 int64_t, 116 GradInterpreterMeta, 117 JvpInterpreterMeta, 118 VmapInterpreterMeta, 119 FunctionalizeInterpreterMeta 120 > InterpreterMeta; 121 122 123 struct Interpreter { 124 // factory functions VmapInterpreter125 static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) { 126 return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness)); 127 } GradInterpreter128 static Interpreter Grad(int64_t level, bool prevGradMode) { 129 return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); 130 } JvpInterpreter131 static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { 132 return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); 133 } FunctionalizeInterpreter134 static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { 135 return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); 136 } 137 138 // methods keyInterpreter139 TransformType key() const { return type_; } levelInterpreter140 int64_t level() const { return level_; } metaInterpreter141 const InterpreterMeta& meta() const { return meta_; } 142 143 void process(const c10::OperatorHandle& op, torch::jit::Stack* stack); 144 void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); 145 saveLocalDispatchKeySetInterpreter146 void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) { 147 TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value()); 148 savedLocalDispatchKeySet_ = keyset; 149 } clearSavedLocalDispatchKeySetInterpreter150 void clearSavedLocalDispatchKeySet() { 151 TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); 152 savedLocalDispatchKeySet_ = std::nullopt; 153 } getSavedLocalDispatchKeySetInterpreter154 c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const { 155 TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); 156 return *savedLocalDispatchKeySet_; 157 } 158 159 // An Interpreter is alive if we are currently inside the ongoing transform 160 // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's 161 // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack. is_aliveInterpreter162 bool is_alive() const { 163 return *is_alive_; 164 } is_alive_ptrInterpreter165 const std::shared_ptr<bool>& is_alive_ptr() const { 166 return is_alive_; 167 } set_is_aliveInterpreter168 void set_is_alive(bool alive) { 169 *is_alive_ = alive; 170 } 171 172 // Please don't use this 173 explicit Interpreter() = default; 174 175 private: InterpreterInterpreter176 explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): 177 type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {} 178 179 // fields 180 TransformType type_{}; 181 int64_t level_{}; 182 std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_; 183 std::shared_ptr<bool> is_alive_; 184 InterpreterMeta meta_; 185 }; 186 187 // Applies the following for-loop: 188 // for i in range(begin, end): 189 // args[i] = func(args[i]) 190 void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end, 191 std::function<Tensor(const Tensor&)> func); 192 193 // Applies the following for-loop: 194 // for i in range(begin, end): 195 // if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset 196 // args[i] = func(args[i], i - begin, true) 197 // args[i] = func(args[i], i - begin) 198 void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end, 199 const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func); 200 201 std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end); 202 203 DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key); 204 205 void setup_dispatch_key_tls(TransformType key, DispatchKeySet include); 206 207 void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack); 208 209 } // namespace at::functorch 210