xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/DynamicLayer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 #include <ATen/functorch/Macros.h>
9 #include <c10/core/DispatchKey.h>
10 #include <ATen/core/function_schema.h>
11 #include <optional>
12 #include <c10/core/impl/LocalDispatchKeySet.h>
13 #include <ATen/functorch/Interpreter.h>
14 #include <ATen/functorch/VmapInterpreter.h>
15 #include <ATen/functorch/ADInterpreters.h>
16 #include <ATen/functorch/FunctionalizeInterpreter.h>
17 
18 // Forward declared
19 namespace c10 { struct AutogradMetaInterface; }
20 
21 namespace at::functorch  {
22 
23 // This file contains the implementation of functorch's interpreter stack.
24 // See NOTE: [functorch interpreter stack] first before reading on.
25 //
26 // NB: the functorch interpreter stack is also referred to as:
27 // - the "dynamic layer stack" -- an older name for "interpreter" was
28 //   "dynamic layer".
29 // - the "functorch mode stack". You can think of each functorch transform as a
30 //   "mode" (in the same sense as torch_dispatch mode or torch_function mode),
31 //   and functorch being an implementation of a "mode stack" where the modes
32 //   may be arbitrary composed.
33 
34 // DynamicLayer is basically the same thing as an Interpreter.
35 // It represents a functorch transform and it holds an Interpreter,
36 // which contains metadata related to the transform and instructions on
37 // how to perform the transform.
38 //
39 // TODO: we can excise DynamicLayer in favor of Interpreter,
40 // But I am going to leave it for now as a compatiblity shim to avoid
41 // needing to refactor a lot of callsites...
42 struct TORCH_API DynamicLayer {
43   explicit DynamicLayer(
44       TransformType transform_type,
45       int64_t layerId,
46       std::optional<c10::SymInt> batchSize = std::nullopt,
47       std::optional<RandomnessType> randomness = std::nullopt,
48       std::optional<bool> prev_grad_mode = std::nullopt,
49       std::optional<bool> pre_fwd_grad_mode = std::nullopt,
50       std::optional<bool> functionalize_add_back_views = std::nullopt);
51 
52   TransformType key() const;
53   int64_t layerId() const;
54 
interpreterDynamicLayer55   const Interpreter& interpreter() const { return interpreter_; }
interpreterDynamicLayer56   Interpreter& interpreter() { return interpreter_; }
57 
58   // Only valid for vmap
59   c10::SymInt batchSize() const;
60   RandomnessType randomness() const;
61 
62  private:
63   Interpreter interpreter_;
64 };
65 
66 TORCH_API int64_t initAndPushDynamicLayer(
67     TransformType transform_type,
68     std::optional<c10::SymInt> batch_size = std::nullopt,
69     std::optional<RandomnessType> randomness = std::nullopt,
70     std::optional<bool> prev_grad_mode = std::nullopt,
71     std::optional<bool> prev_fwd_grad_mode = std::nullopt,
72     std::optional<bool> functionalize_add_back_views = std::nullopt);
73 TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
74 TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
75 TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
76 TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
77 TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
78 
79 // NOTE: [Life handles and lexically scoped transforms]
80 // functorch transforms are lexically scoped.
81 // Given a level, we store a "life handle" that is a boolean that tells us if the
82 // transform with that level is active or not.
83 //
84 // functorch's TensorWrapper (for grad transforms) stores a life handle.
85 // If a TensorWrapper escapes from the scope of the transform, then somehow
86 // it must know it escaped; it can tell by querying the life handle.
87 TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
88 
89 // Returns if an operator is in-place. An operator is inplace if:
90 // 1. The first argument is a Tensor and it is being written to
91 // 2. The first argument is being returned
92 // 3. No other arguments are aliased
93 // Here is an example of an in-place operator:
94 // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
95 TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
96 
97 // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
98 TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
99 
100 TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
101 TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
102 
103 // Pretty printers
104 TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
105 TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
106 
107 // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
108 // is disabled by default. The following two APIs are APIs for enabling
109 // it. These are not user-facing APIs. We can delete this in the future, but
110 // it is useful for debugging when something goes wrong with the
111 // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
112 // because it leads to loud errors if something is incorrect.
113 TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
114 TORCH_API bool getSingleLevelAutogradFunctionAllowed();
115 
116 // While a functorch grad transform is active, Tensor.requires_grad_() gets
117 // disabled. These two functions are the mechanism to controlling that.
118 TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
119 TORCH_API bool getInplaceRequiresGradAllowed();
120 
121 TORCH_API DynamicLayer popDynamicLayer();
122 TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
123 
124 } // namespace at::functorch
125