1 #pragma once 2 3 #include <ATen/CachedTensorUtils.h> 4 #include <ATen/LegacyBatchedTensorImpl.h> 5 #include <ATen/TensorOperators.h> 6 #include <torch/csrc/Export.h> 7 #include <torch/csrc/autograd/function.h> 8 #include <torch/csrc/autograd/utils/grad_layout_contract.h> 9 #include <torch/csrc/autograd/variable.h> 10 11 #ifndef AT_PER_OPERATOR_HEADERS 12 #include <ATen/Functions.h> 13 #else 14 #include <ATen/ops/_sparse_coo_tensor_unsafe.h> 15 #endif 16 17 #include <mutex> 18 19 namespace torch::autograd { 20 21 #define CHECK_RESULT(RESULT, VAR) \ 22 if (!(RESULT.is_sparse() || VAR.is_sparse() || RESULT.is_sparse_csr() || \ 23 VAR.is_sparse_csr())) { \ 24 if (!utils::obeys_layout_contract(RESULT, VAR)) { \ 25 TORCH_WARN_ONCE( \ 26 "grad and param do not obey the gradient layout contract. " \ 27 "This is not an error, but may impair performance.\n" \ 28 "grad.sizes() = ", \ 29 RESULT.sizes(), \ 30 ", strides() = ", \ 31 RESULT.strides(), \ 32 "\n", \ 33 "param.sizes() = ", \ 34 VAR.sizes(), \ 35 ", strides() = ", \ 36 VAR.strides()); \ 37 } \ 38 } 39 40 struct TORCH_API AccumulateGrad : public Node { 41 explicit AccumulateGrad(Variable variable_); 42 43 variable_list apply(variable_list&& grads) override; 44 tensor_pre_hooksAccumulateGrad45 std::vector<std::unique_ptr<FunctionPreHook>>& tensor_pre_hooks() noexcept 46 override { 47 // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor, 48 // it can be destroyed even though the Tensor is still alive (contrary 49 // to all other Nodes). So we must lazily read the Tensor hooks here. 50 return impl::hooks(variable); 51 } 52 tensor_post_acc_grad_hooksAccumulateGrad53 std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept 54 override { 55 // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor, 56 // it can be destroyed even though the Tensor is still alive (contrary 57 // to all other Nodes). So we must lazily read the Tensor hooks here. 58 return impl::post_acc_grad_hooks(variable); 59 } 60 61 // Given a variable with its current grad as variable_grad, accumulates 62 // new_grad into variable_grad if in place accumulation is possible. 63 // Otherwise, uses 'update_grad' to update the grad for the variable. 64 65 // "Gradient Layout Contract" 66 // 67 // AccumulateGrad tries to stash strided (non-sparse) grads with memory layout 68 // (strides) such that variables and grads interact efficiently in later 69 // optimizer kernels, and grads interact efficiently with c10d::Reducer.cpp. 70 // 71 // Specifically, AccumulateGrad tries to ensure the following 72 // (cf torch/csrc/autograd/utils/grad_layout_contract.h): 73 // (1) if variable.is_non_overlapping_and_dense(), the stashed grad's 74 // strides match variable. 75 // (2) else, stashed grad is rowmajor contiguous. 76 // If variable's grad does not exist (!variable_grad.defined()) 77 // AccumulateGrad steals new_grad if it's stealable and obeys the contract 78 // already, otherwise it deep copies new_grad into an obedient clone. 79 // 80 // If variable's grad already exists (variable_grad.defined()), new_grad must 81 // be added to variable_grad. If we aren't setting up for double backward 82 // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad += 83 // new_grad" in-place, which keeps variable_grad's layout. We assume (hope) 84 // variable_grad was created obeying (1) or (2) at some point in the past. 85 // 86 // If we are setting up for double backward, AccumulateGrad updates the grad 87 // out-of-place via "variable_grad + new_grad." TensorIterator operator+ 88 // decides result's layout. Typically TensorIterator matches strides of the 89 // first arg, so we once again assume (hope) variable_grad was originally 90 // created obeying (1) or (2). 91 // 92 // AccumulateGrad does not enforce the contract with 100% certainty. Examples: 93 // - If a user manually permutes a param or its grad, then runs a fwd+bwd, 94 // variable_grad += new_grad keeps variable_grad's layout without 95 // rechecking the contract. 96 // - If TensorIterator changes its corner cases about operator+'s result 97 // (for example, giving more or less priority to channels_last inputs, see 98 // https://github.com/pytorch/pytorch/pull/37968) the result may not obey. 99 // 100 // Fortunately, if a given grad doesn't satisfy (1) or (2), the penalty is 101 // degraded performance in Reducer.cpp or optimizer kernels, not death by 102 // assert or silently bad numerics. 103 104 // variable: the variable whose grad we're accumulating. 105 // variable_grad: the current grad for the variable. 106 // new_grad: new grad we want to accumulate for the variable. 107 // num_expected_refs: the number of refs we expect to hold internally 108 // such that it is safe to avoid cloning the grad 109 // if use_count() of the grad is less than or equal 110 // to this value (in addition to post_hooks). 111 // update_grad: Function that is used to update grad for the variable. 112 // The argument to the function is a Tensor which 113 // is used to set a new value for the grad. 114 template <typename T> accumulateGradAccumulateGrad115 static void accumulateGrad( 116 const Variable& variable, 117 at::Tensor& variable_grad, 118 const at::Tensor& new_grad, 119 size_t num_expected_refs, 120 const T& update_grad) { 121 if (!variable_grad.defined()) { 122 if (!GradMode::is_enabled() && !new_grad.is_sparse() && 123 !new_grad.is_sparse_csr() && 124 !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && 125 at::caching::adjusted_use_count(new_grad) <= num_expected_refs && 126 (new_grad.is_mkldnn() || 127 utils::obeys_layout_contract(new_grad, variable))) { 128 // we aren't setting up for double-backward 129 // not sparse 130 // no other user-visible tensor references new_grad 131 // new_grad obeys the "Gradient Layout Contract", there has a special 132 // case, For MKLDNN tensor, which is a opaque tensor, assuming it obeys 133 // layout_contract. Under these conditions, we can steal new_grad 134 // without a deep copy. 135 update_grad(new_grad.detach()); 136 } else if ( 137 !GradMode::is_enabled() && new_grad.is_sparse() && 138 new_grad._indices().is_contiguous() && 139 new_grad._values().is_contiguous() && 140 // Use count for indices and values should always be <=1 since the 141 // SparseTensor should be the only one holding a reference to these. 142 new_grad._indices().use_count() <= 1 && 143 new_grad._values().use_count() <= 1 && 144 new_grad.use_count() <= num_expected_refs) { 145 // Can't detach sparse tensor (since metadata changes are not allowed 146 // after detach), so just create a new one for the grad which is a 147 // shallow copy. We need a shallow copy so that modifying the original 148 // grad tensor doesn't modify the grad we accumulate. 149 // We only skip clone if indices and values themselves are contiguous 150 // for backward compatibility reasons. Since without this optimization, 151 // earlier we would clone the entire SparseTensor which cloned indices 152 // and values. 153 // For details see https://github.com/pytorch/pytorch/issues/34375. 154 155 // No scenario where we expect this to be true currently 156 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 157 !at::caching::is_cached_tensor(new_grad._indices()) && 158 !at::caching::is_cached_tensor(new_grad._values()) && 159 !at::caching::is_cached_tensor(new_grad)); 160 161 update_grad(at::_sparse_coo_tensor_unsafe( 162 new_grad._indices(), 163 new_grad._values(), 164 new_grad.sizes(), 165 new_grad.options())); 166 } else { 167 if (new_grad.is_sparse() || new_grad.is_sparse_csr() || 168 new_grad.is_nested()) { 169 update_grad(new_grad.clone()); 170 } else { 171 if (new_grad.is_mkldnn()) { 172 update_grad(new_grad.clone()); 173 } else { 174 // Deep copies new_grad according to the "Gradient Layout Contract." 175 update_grad(utils::clone_obey_contract(new_grad, variable)); 176 } 177 } 178 } 179 } else if (!GradMode::is_enabled()) { 180 // This case is not strictly necessary, but it makes the first-order only 181 // case slightly more efficient. 182 if (variable_grad.is_sparse() && !new_grad.is_sparse()) { 183 // If `variable_grad` is sparse and `new_grad` is not sparse, their 184 // sum is not sparse, and we must change the TensorImpl type of 185 // `variable_grad` for it to store the result. However, changing the 186 // TensorImpl type of a tensor requires changing the tensor itself, and 187 // thus in this case we have to change the grad tensor. 188 auto result = new_grad + variable_grad; 189 CHECK_RESULT(result, variable); 190 update_grad(std::move(result)); 191 } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) { 192 // Ideally we'd perform an in-place operation to avoid changing 193 // the grad tensor. However, if that's impossible because the grads 194 // are vmap-incompatible (See NOTE: [vmap-incompatible in-place 195 // operations]), then we just add them out-of-place. 196 auto result = variable_grad + new_grad; 197 CHECK_RESULT(result, variable); 198 update_grad(std::move(result)); 199 } else { 200 // In this case we can avoid changing the grad tensor. There are three 201 // scenarios when we'll hit this case: 202 // 203 // 1. `variable_grad` is sparse, and `new_grad` is sparse. 204 // 2. `variable_grad` is dense, and `new_grad` is sparse. 205 // 3. `variable_grad` is dense, and `new_grad` is dense. 206 // 4. `variable_grad` is mkldnn, and `new_grad` is mkldnn. 207 // 208 // In all of these four cases, `variable_grad += new_grad` is a 209 // valid operation which adds `new_grad` to `variable_grad` in 210 // place. `variable_grad` is thus still referring to the same tensor 211 // after the operation. 212 // Also DistributedDataParallel(DDP) package relies on grad being 213 // mutated in place for saving peak memory usage. DDP will still 214 // work correctly if it is mutated out of place here, but DDP will 215 // maintain one extra copy of grad tensors in buffer and thus 216 // increase peak memory usage. 217 variable_grad += new_grad; 218 CHECK_RESULT(variable_grad, variable); 219 // ^ We could enforce the contract more aggressively here by writing: 220 // if (variable_grad.is_sparse() || new_grad.is_sparse()) { 221 // variable_grad += new_grad; 222 // } else if (obeys_layout_contract(variable_grad, variable)) { 223 // variable_grad += new_grad; 224 // } else { 225 // result = at::empty_strided(variable.sizes(), variable.strides(), 226 // variable.options().memory_format(std::nullopt)); 227 // update_grad(at::native::add_out(result, variable_grad, 228 // new_grad, 1.0); 229 // } 230 // However, that accumulation is sometimes in place and sometimes not, 231 // which may break user code. 232 } 233 } else { 234 at::Tensor result; 235 if (variable_grad.is_sparse() && !new_grad.is_sparse()) { 236 // CPU backend throws an error on sparse + dense, so prefer dense + 237 // sparse here. 238 result = new_grad + variable_grad; 239 } else { 240 // Assumes operator+ result typically matches strides of first arg, 241 // and hopes variable_grad was originally created obeying layout 242 // contract. 243 result = variable_grad + new_grad; 244 } 245 CHECK_RESULT(result, variable); 246 update_grad(std::move(result)); 247 // ^ We could enforce the contract more aggressively here by saying 248 // if (obeys_layout_contract(new_grad, variable)) { 249 // update_grad(new_grad + variable_grad); 250 // } else { 251 // update_grad(variable_grad + new_grad); 252 // } 253 // such that the stashed grad is likely to have the right strides if 254 // either variable_grad or new_grad already has the right strides. 255 // We could enforce the contract with certainty by saying 256 // auto result = variable_grad + new_grad (or vice versa), checking 257 // result's layout, and copying to an obedient clone if necessary before 258 // update_grad. The copy would require another gmem pass. We can't create 259 // empty result with the right layout then add_out into it with a single 260 // kernel, because GradMode is enabled in this branch, and add_out isn't 261 // differentiable. Maybe more trouble than it's worth. 262 } 263 } 264 265 void compiled_args(CompiledNodeArgs& args) override; 266 variable_list apply_with_saved( 267 const variable_list& inputs, 268 SwapSavedVariables& saved) override; 269 270 Variable variable; 271 }; 272 273 #undef CHECK_RESULT 274 275 } // namespace torch::autograd 276