1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 #include <ATen/functorch/Macros.h> 9 #include <c10/core/DispatchKey.h> 10 #include <ATen/core/function_schema.h> 11 #include <optional> 12 #include <c10/core/impl/LocalDispatchKeySet.h> 13 #include <ATen/functorch/Interpreter.h> 14 #include <ATen/functorch/VmapInterpreter.h> 15 #include <ATen/functorch/ADInterpreters.h> 16 #include <ATen/functorch/FunctionalizeInterpreter.h> 17 18 // Forward declared 19 namespace c10 { struct AutogradMetaInterface; } 20 21 namespace at::functorch { 22 23 // This file contains the implementation of functorch's interpreter stack. 24 // See NOTE: [functorch interpreter stack] first before reading on. 25 // 26 // NB: the functorch interpreter stack is also referred to as: 27 // - the "dynamic layer stack" -- an older name for "interpreter" was 28 // "dynamic layer". 29 // - the "functorch mode stack". You can think of each functorch transform as a 30 // "mode" (in the same sense as torch_dispatch mode or torch_function mode), 31 // and functorch being an implementation of a "mode stack" where the modes 32 // may be arbitrary composed. 33 34 // DynamicLayer is basically the same thing as an Interpreter. 35 // It represents a functorch transform and it holds an Interpreter, 36 // which contains metadata related to the transform and instructions on 37 // how to perform the transform. 38 // 39 // TODO: we can excise DynamicLayer in favor of Interpreter, 40 // But I am going to leave it for now as a compatiblity shim to avoid 41 // needing to refactor a lot of callsites... 42 struct TORCH_API DynamicLayer { 43 explicit DynamicLayer( 44 TransformType transform_type, 45 int64_t layerId, 46 std::optional<c10::SymInt> batchSize = std::nullopt, 47 std::optional<RandomnessType> randomness = std::nullopt, 48 std::optional<bool> prev_grad_mode = std::nullopt, 49 std::optional<bool> pre_fwd_grad_mode = std::nullopt, 50 std::optional<bool> functionalize_add_back_views = std::nullopt); 51 52 TransformType key() const; 53 int64_t layerId() const; 54 interpreterDynamicLayer55 const Interpreter& interpreter() const { return interpreter_; } interpreterDynamicLayer56 Interpreter& interpreter() { return interpreter_; } 57 58 // Only valid for vmap 59 c10::SymInt batchSize() const; 60 RandomnessType randomness() const; 61 62 private: 63 Interpreter interpreter_; 64 }; 65 66 TORCH_API int64_t initAndPushDynamicLayer( 67 TransformType transform_type, 68 std::optional<c10::SymInt> batch_size = std::nullopt, 69 std::optional<RandomnessType> randomness = std::nullopt, 70 std::optional<bool> prev_grad_mode = std::nullopt, 71 std::optional<bool> prev_fwd_grad_mode = std::nullopt, 72 std::optional<bool> functionalize_add_back_views = std::nullopt); 73 TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); 74 TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer(); 75 TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack(); 76 TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack); 77 TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); 78 79 // NOTE: [Life handles and lexically scoped transforms] 80 // functorch transforms are lexically scoped. 81 // Given a level, we store a "life handle" that is a boolean that tells us if the 82 // transform with that level is active or not. 83 // 84 // functorch's TensorWrapper (for grad transforms) stores a life handle. 85 // If a TensorWrapper escapes from the scope of the transform, then somehow 86 // it must know it escaped; it can tell by querying the life handle. 87 TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level); 88 89 // Returns if an operator is in-place. An operator is inplace if: 90 // 1. The first argument is a Tensor and it is being written to 91 // 2. The first argument is being returned 92 // 3. No other arguments are aliased 93 // Here is an example of an in-place operator: 94 // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) 95 TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); 96 97 // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped 98 TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); 99 100 TORCH_API Tensor unwrapIfDead(const Tensor& tensor); 101 TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); 102 103 // Pretty printers 104 TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); 105 TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack); 106 107 // While a functorch transform is active, torch.autograd.function._SingleLevelFunction 108 // is disabled by default. The following two APIs are APIs for enabling 109 // it. These are not user-facing APIs. We can delete this in the future, but 110 // it is useful for debugging when something goes wrong with the 111 // autograd.Function <> functorch interaction, which uses _SingleLevelFunction, 112 // because it leads to loud errors if something is incorrect. 113 TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed); 114 TORCH_API bool getSingleLevelAutogradFunctionAllowed(); 115 116 // While a functorch grad transform is active, Tensor.requires_grad_() gets 117 // disabled. These two functions are the mechanism to controlling that. 118 TORCH_API void setInplaceRequiresGradAllowed(bool allowed); 119 TORCH_API bool getInplaceRequiresGradAllowed(); 120 121 TORCH_API DynamicLayer popDynamicLayer(); 122 TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); 123 124 } // namespace at::functorch 125