1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/autograd/function.h> 5 #include <torch/csrc/autograd/variable.h> 6 7 #include <ATen/TensorGeometry.h> 8 #include <ATen/core/DeprecatedTypeProperties.h> 9 #include <optional> 10 11 #include <cstdint> 12 #include <memory> 13 14 namespace torch::autograd { 15 16 struct TORCH_API CopyBackwards : public Node { 17 variable_list apply(variable_list&& grads) override; 18 void compiled_args(CompiledNodeArgs& args) override; 19 variable_list apply_with_saved( 20 const variable_list& inputs, 21 SwapSavedVariables& saved) override; 22 23 at::TensorOptions src_options; 24 }; 25 26 // Note [View + Inplace update for base tensor] 27 // 28 // This note covers a few important topics related to view + inplace handling. 29 // - It explains what is the CopySlices Node and why we need it. 30 // - It explains the considerations on what is saved for backward in 31 // CopySlices. 32 // - It explains why we need to sometimes change the exec_info of the current 33 // backward 34 // 35 // What is CopySlices? 36 // ~~~~~~~~~~~~~~~~~~~ 37 // 38 // We support autograd with inplace mutation; e.g., if you write x.mul_(2) 39 // the autograd will work as if you now had multiple Tensors under the hood and 40 // you did 41 // x = t.clone() 42 // x0 = x 43 // x1 = x0 * 2 44 // x = x1 45 // As you can see here, after this operation, x.grad_fn now points to x1.grad_fn 46 // (the MulBackward node) and this node points to x's original grad_fn (which is 47 // also x0.grad_fn). It is important to keep in mind that after the inplace, 48 // there is no Tensor object that represents the x0 state anymore. But the graph 49 // for it is still around in autograd (in case x was used before being modified 50 // inplace). See Example 1 in 51 // https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE 52 // We call this rebasing the history of the Tensor. 53 // 54 // Now, a difficult situation is what happens if x is a differentiable view 55 // of a base b. 56 // b = t.clone() 57 // x = b.select(0, 0) 58 // x *= 2 59 // With the same approach as above, this will become 60 // b = t.clone() 61 // x = b.select(0, 0) 62 // b0 = b 63 // x0 = x 64 // x1 = x0 * 2 65 // b1 = b0.select_scatter(x1, 0, 0) 66 // x2 = b1.select(0, 0) 67 // x = x2 68 // b = b1 69 // As you can see here, not only we need to modify x's grad_fn, we also need to 70 // modify the one from b. We also need to ensure that the new grad_fn on x is 71 // linked to b's new grad_fn. The chain the select_scatter, multiplication and 72 // select is what CopySlices does, all wrapped into a single Node. 73 // 74 // See Example 1 in 75 // https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE 76 // 77 // What do we need to save in CopySlices to run backward? 78 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 79 // 80 // We need to perform grad_view = fn(grad_view), but out-of-place. 81 // view_fn_ is an optional function saved in DifferentiableViewMeta 82 // from forward pass, so that we can recover we when as_strided is not 83 // supported. It preserves the invariants: 84 // view = view_fn_(base) 85 // grad_view = view_fn_(grad_base) 86 // 87 // When as_strided is supported (e.g. strided CPU/CUDA Tensors), view_fn_ 88 // is empty and we save TensorGeometry(view) instead. 89 // With the TensorGeometry information we can use `as_strided` call which 90 // is more efficient to recover views in backward. 91 // 92 // For example: 93 // view_1 = view_op_1(base) 94 // view_2 = view_op_2(view_1) 95 // ... 96 // view_n = view_op_n(view_n-1) 97 // view_n = inplace_op(view_n) 98 // 99 // In CPU/CUDA case where we support efficient as_strided implementation, 100 // grad_view_n can be calculated through 1 step. 101 // 102 // grad_view_n = grad_base.as_strided(view_sizes, view_strides, view_offset); 103 // 104 // But in XLA backend where we don't have full support of as_strided, 105 // it has to save a chained lambda function view_fn_, to exactly 106 // replay how the view was done in forward. 107 // 108 // view_fn_ = view_op_n(...(view_op_2(view_op_1()))) 109 // grad_view_n = view_fn_(grad_base) 110 // 111 // This chain view_fn_ works as long as forward view ops are implemented, 112 // e.g XLA simulates view without a real Storage behind Tensor, but it's less 113 // efficient than the as_strided one so we should be careful to only use it when 114 // necessary. 115 // 116 // - For CPU/CUDA we save TensorGeometry of both base and view tensors, 117 // That's all we need to pass into as_strided. 118 // E.g. int[] sizes, int[] strides, and int storage_offset. 119 // - For XLA we use view_fn_, which captures all forward view op arguments 120 // by **value**. 121 // E.g for at::narrow, int dim, int start, in length are saved. 122 // 123 // Theoretically we could also save Tensor `view` in CopySlices Node, but 124 // it's far more expensive than what we currently save. 125 // 1. We cannot afford keeping large tensors alive to recover views only. 126 // 2. There are inplace checks when Tensors are loaded back to make sure 127 // they haven't been changed (including size metadata). 128 // So saving metadata like TensorGeometry/view arguments is much better 129 // because it is minimal information needed to recover views, as well as it 130 // allows the user to modify the original Tensor without preventing the 131 // backward pass from running. 132 // 133 // Why do we manually change exec_info in the apply? 134 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 135 // 136 // Using the same example as before, 137 // b = t.clone() 138 // x = b.select(0, 0) 139 // x *= y 140 // 141 // You can see the visualization at 142 // https://docs.google.com/drawings/d/1Bx-Hcz-zlIv7PabQqnPhUIVIs9F8WWi48svqMsAUMFs 143 // which contains the wrapped MulBackward Node and show what it links to. 144 // Since a backward can happen between any subset of the inputs (t and y) and 145 // outputs (o, x, b). It is possible to get into a state where CopySlices's 0th 146 // next function (CloneBackward) needs gradient but MulBackward's 0th next 147 // function (SelectBackward) is not. This happens if you do autograd.grad 148 // between x and t for example. 149 // In such a case, we do need to mark SelectBackward as requiring gradient such 150 // that, during the execution of MulBackward, we will actually compute gradient 151 // for the 0th input. 152 // 153 // All the other next functions are always shared (this is asserted in the apply 154 // code) and so nothing needs to be done for them. 155 156 // See Note [View + Inplace update for view tensor] for what we do to view 157 // tensor when an in-place operation happens. 158 struct TORCH_API CopySlices : public Node { 159 CopySlices( 160 const Variable& base_var, 161 at::TensorGeometry view_, 162 std::unique_ptr<ViewFunc> view_fn_, 163 std::shared_ptr<Node> fn_); 164 165 // common code between apply/apply_with_saved 166 template <typename T> 167 variable_list apply_impl(variable_list&& inputs, const T& call_fn); 168 169 variable_list apply(variable_list&& inputs) override; 170 void release_variables() override; 171 void compiled_args(CompiledNodeArgs& args) override; 172 variable_list apply_with_saved( 173 const variable_list& inputs, 174 SwapSavedVariables& saved) override; 175 176 at::TensorGeometry base; 177 // view and view_fn are redundant and view_fn will be used if available. 178 // See Note [View + Inplace update for base tensor] for details. 179 at::TensorGeometry view; 180 std::unique_ptr<ViewFunc> view_fn; 181 std::shared_ptr<Node> fn; 182 }; 183 184 } // namespace torch::autograd 185