xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/variable.h>
6 
7 #include <ATen/TensorGeometry.h>
8 #include <ATen/core/DeprecatedTypeProperties.h>
9 #include <optional>
10 
11 #include <cstdint>
12 #include <memory>
13 
14 namespace torch::autograd {
15 
16 struct TORCH_API CopyBackwards : public Node {
17   variable_list apply(variable_list&& grads) override;
18   void compiled_args(CompiledNodeArgs& args) override;
19   variable_list apply_with_saved(
20       const variable_list& inputs,
21       SwapSavedVariables& saved) override;
22 
23   at::TensorOptions src_options;
24 };
25 
26 // Note [View + Inplace update for base tensor]
27 //
28 // This note covers a few important topics related to view + inplace handling.
29 //   - It explains what is the CopySlices Node and why we need it.
30 //   - It explains the considerations on what is saved for backward in
31 //   CopySlices.
32 //   - It explains why we need to sometimes change the exec_info of the current
33 //   backward
34 //
35 // What is CopySlices?
36 // ~~~~~~~~~~~~~~~~~~~
37 //
38 // We support autograd with inplace mutation; e.g., if you write x.mul_(2)
39 // the autograd will work as if you now had multiple Tensors under the hood and
40 // you did
41 //   x = t.clone()
42 //   x0 = x
43 //   x1 = x0 * 2
44 //   x = x1
45 // As you can see here, after this operation, x.grad_fn now points to x1.grad_fn
46 // (the MulBackward node) and this node points to x's original grad_fn (which is
47 // also x0.grad_fn). It is important to keep in mind that after the inplace,
48 // there is no Tensor object that represents the x0 state anymore. But the graph
49 // for it is still around in autograd (in case x was used before being modified
50 // inplace). See Example 1 in
51 // https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE
52 // We call this rebasing the history of the Tensor.
53 //
54 // Now, a difficult situation is what happens if x is a differentiable view
55 // of a base b.
56 //   b = t.clone()
57 //   x = b.select(0, 0)
58 //   x *= 2
59 // With the same approach as above, this will become
60 //   b = t.clone()
61 //   x = b.select(0, 0)
62 //   b0 = b
63 //   x0 = x
64 //   x1 = x0 * 2
65 //   b1 = b0.select_scatter(x1, 0, 0)
66 //   x2 = b1.select(0, 0)
67 //   x = x2
68 //   b = b1
69 // As you can see here, not only we need to modify x's grad_fn, we also need to
70 // modify the one from b. We also need to ensure that the new grad_fn on x is
71 // linked to b's new grad_fn. The chain the select_scatter, multiplication and
72 // select is what CopySlices does, all wrapped into a single Node.
73 //
74 // See Example 1 in
75 // https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE
76 //
77 // What do we need to save in CopySlices to run backward?
78 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79 //
80 // We need to perform grad_view = fn(grad_view), but out-of-place.
81 // view_fn_ is an optional function saved in DifferentiableViewMeta
82 // from forward pass, so that we can recover we when as_strided is not
83 // supported. It preserves the invariants:
84 //   view = view_fn_(base)
85 //   grad_view = view_fn_(grad_base)
86 //
87 // When as_strided is supported (e.g. strided CPU/CUDA Tensors), view_fn_
88 // is empty and we save TensorGeometry(view) instead.
89 // With the TensorGeometry information we can use `as_strided` call which
90 // is more efficient to recover views in backward.
91 //
92 // For example:
93 //   view_1 = view_op_1(base)
94 //   view_2 = view_op_2(view_1)
95 //   ...
96 //   view_n = view_op_n(view_n-1)
97 //   view_n = inplace_op(view_n)
98 //
99 // In CPU/CUDA case where we support efficient as_strided implementation,
100 // grad_view_n can be calculated through 1 step.
101 //
102 //   grad_view_n = grad_base.as_strided(view_sizes, view_strides, view_offset);
103 //
104 // But in XLA backend where we don't have full support of as_strided,
105 // it has to save a chained lambda function view_fn_, to exactly
106 // replay how the view was done in forward.
107 //
108 //   view_fn_ = view_op_n(...(view_op_2(view_op_1())))
109 //   grad_view_n = view_fn_(grad_base)
110 //
111 // This chain view_fn_ works as long as forward view ops are implemented,
112 // e.g XLA simulates view without a real Storage behind Tensor, but it's less
113 // efficient than the as_strided one so we should be careful to only use it when
114 // necessary.
115 //
116 //   - For CPU/CUDA we save TensorGeometry of both base and view tensors,
117 //     That's all we need to pass into as_strided.
118 //     E.g. int[] sizes, int[] strides, and int storage_offset.
119 //   - For XLA we use view_fn_, which captures all forward view op arguments
120 //     by **value**.
121 //     E.g for at::narrow, int dim, int start, in length are saved.
122 //
123 // Theoretically we could also save Tensor `view` in CopySlices Node, but
124 // it's far more expensive than what we currently save.
125 //   1. We cannot afford keeping large tensors alive to recover views only.
126 //   2. There are inplace checks when Tensors are loaded back to make sure
127 //      they haven't been changed (including size metadata).
128 // So saving metadata like TensorGeometry/view arguments is much better
129 // because it is minimal information needed to recover views, as well as it
130 // allows the user to modify the original Tensor without preventing the
131 // backward pass from running.
132 //
133 // Why do we manually change exec_info in the apply?
134 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135 //
136 // Using the same example as before,
137 //   b = t.clone()
138 //   x = b.select(0, 0)
139 //   x *= y
140 //
141 // You can see the visualization at
142 // https://docs.google.com/drawings/d/1Bx-Hcz-zlIv7PabQqnPhUIVIs9F8WWi48svqMsAUMFs
143 // which contains the wrapped MulBackward Node and show what it links to.
144 // Since a backward can happen between any subset of the inputs (t and y) and
145 // outputs (o, x, b). It is possible to get into a state where CopySlices's 0th
146 // next function (CloneBackward) needs gradient but MulBackward's 0th next
147 // function (SelectBackward) is not. This happens if you do autograd.grad
148 // between x and t for example.
149 // In such a case, we do need to mark SelectBackward as requiring gradient such
150 // that, during the execution of MulBackward, we will actually compute gradient
151 // for the 0th input.
152 //
153 // All the other next functions are always shared (this is asserted in the apply
154 // code) and so nothing needs to be done for them.
155 
156 // See Note [View + Inplace update for view tensor] for what we do to view
157 // tensor when an in-place operation happens.
158 struct TORCH_API CopySlices : public Node {
159   CopySlices(
160       const Variable& base_var,
161       at::TensorGeometry view_,
162       std::unique_ptr<ViewFunc> view_fn_,
163       std::shared_ptr<Node> fn_);
164 
165   // common code between apply/apply_with_saved
166   template <typename T>
167   variable_list apply_impl(variable_list&& inputs, const T& call_fn);
168 
169   variable_list apply(variable_list&& inputs) override;
170   void release_variables() override;
171   void compiled_args(CompiledNodeArgs& args) override;
172   variable_list apply_with_saved(
173       const variable_list& inputs,
174       SwapSavedVariables& saved) override;
175 
176   at::TensorGeometry base;
177   // view and view_fn are redundant and view_fn will be used if available.
178   // See Note [View + Inplace update for base tensor] for details.
179   at::TensorGeometry view;
180   std::unique_ptr<ViewFunc> view_fn;
181   std::shared_ptr<Node> fn;
182 };
183 
184 } // namespace torch::autograd
185