1 #pragma once
2
3 #include <torch/csrc/utils/python_stub.h>
4
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/autograd/cpp_hook.h>
7 #include <torch/csrc/autograd/edge.h>
8 #include <torch/csrc/autograd/forward_grad.h>
9 #include <torch/csrc/autograd/function_hook.h>
10
11 #include <ATen/NamedTensorUtils.h>
12 #include <ATen/core/Tensor.h>
13 #include <ATen/core/VariableHooksInterface.h>
14 #include <c10/util/Exception.h>
15
16 #include <cstdint>
17 #include <memory>
18 #include <mutex>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 namespace torch::autograd {
24
25 /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =
26 /// at::Tensor`). This means you can perform all the usual mathematical and
27 /// other operations you can perform on `Tensor`s also on `Variable`s.
28 ///
29 /// The only reason we are keeping the `Variable` class is backward
30 /// compatibility with external user's legacy C++ frontend code. Our intention
31 /// is to eliminate the `Variable` class in the near future.
32 using Variable = at::Tensor;
33
34 } // namespace torch::autograd
35
36 // The following are all internal APIs and should not be shown in libtorch docs.
37 // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
38 // ... #endif`
39
40 #ifndef DOXYGEN_SHOULD_SKIP_THIS
41
42 namespace torch::autograd {
43
44 /// Check if this type is supported by the autograd engine.
45 /// If you change this, update the doc at the top of the
46 /// torch/autograd/__init__.py file and
47 /// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py
isDifferentiableType(at::ScalarType t)48 static inline bool isDifferentiableType(at::ScalarType t) {
49 return isFloatingType(t) || isComplexType(t);
50 }
51
52 struct Node;
53
54 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55 /// Variable
56 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 /// A `Variable` augments a `Tensor` with the ability to interact in our
58 /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
59 /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
60 /// weight in a neural network, or an interior variable, when it is the result
61 /// of an operation between variables. Every `Variable` also stores another
62 /// `Variable` called its `grad` (gradient). If the variable is a leaf, its
63 /// gradient will be accumulated into this variable.
64 ///
65 /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables
66 /// that don't require gradients as Tensors (since none of the autograd
67 /// machinery for Variables applies). Historically, Variables and Tensors
68 /// were separate concepts, but now they are exactly the same (i.e. we have
69 /// `using Variable = at::Tensor`).
70 ///
71 /// Gradient Edges
72 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
73 /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
74 /// edge in the autograd graph that connects the variable to a particular input
75 /// of the gradient function that will be invoked with the variable during the
76 /// backward pass. More precisely, this gradient function can be one of two
77 /// things:
78 /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
79 /// gradient of the function that produced the variable.
80 /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
81 /// scalar gradient value into its `grad` variable.
82 ///
83 /// Versioning
84 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
85 /// Another major feature of `Variable`s are *versions*. Versions are
86 /// incremented when an in-place mutation of a variable occurs. Versions are
87 /// useful when constructing `SavedVariable`s, which take a snapshot of a
88 /// `Variable` at a certain version. You can retrieve a `Variable`'s version
89 /// through its `current_version()` method.
90 ///
91 /// Views
92 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
93 /// It is possible for a `Variable` to be a *view* of another `Variable`, in
94 /// which case it tracks that `Variable`'s data and autograd history. Beyond
95 /// construction, the interface of a view is identical to that of a regular
96 /// `Variable`. You can determine whether `Variable` is in fact a view by
97 /// probing its `is_view()` method. Note that the *view* semantics are only
98 /// meaningful for `Variable` relations that are relevant to autograd.
99 /// See NOTE [ Autograd View Variables ] for more details.
100 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101
102 struct AutogradMeta;
103 struct DifferentiableViewMeta;
104
105 // Private-ish functions for manipulating variables; we don't want to put them
106 // on Tensor proper
107 namespace impl {
108
109 // WARNING: This may return a nullptr. If you require AutogradMeta to return
110 // a materialized structure, use materialize_autograd_meta instead.
111 TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
112
113 // WARNING: This will return a nullptr if the Tensor is not a view.
114 TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
115
116 // Returns the current autograd meta, materializing it if it was previously
117 // none. This counts as a *mutating* operation, so do not call it on
118 // "read-only" operators; in particular, this is NOT thread safe
119 TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
120
121 /// Set the gradient accumulator of the `Variable`. This is only applicable to
122 /// leaf variables. Interior variables should call `set_gradient_edge()`.
123 TORCH_API void set_grad_accumulator(
124 const Variable&,
125 std::weak_ptr<Node> grad_accumulator);
126
127 /// Attempts to get a pointer to the gradient accumulator of the `Variable`,
128 /// if it still exists. If the gradient accumulator function has been
129 /// destroyed, returns a `nullptr`.
130 TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
131
132 /// Gets the gradient accumulator of the `Variable` if it has one, or else
133 /// create one on the fly and return it.
134 TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);
135
136 /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
137 /// gradient function if this is an interior `Variable`, or the gradient
138 /// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
139 /// will store the input index of the `Node` to which this variable is
140 /// connected in its `input_nr` field. For leaves, the `input_nr` is always
141 /// zero. Note that `set_gradient_edge` and `gradient_edge` are not
142 /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
143 /// `set_grad_accumulator` to set the accumulator.
144 TORCH_API Edge gradient_edge(const Variable&);
145
146 /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
147 /// `Variable`.
148 /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
149 /// and never the `grad_accumulator`. For the latter, use
150 /// `set_grad_accumulator`. This allows late construction of an interior
151 /// `Variable`.
152 TORCH_API void set_gradient_edge(const Variable&, Edge edge);
153
154 // Autograd Graph Interaction
155 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156
157 /// Update the `grad_fn` of an existing Variable. Called after in-place
158 /// modifications.
159 ///
160 /// For View Variables:
161 /// Called after in-place modifications. Modifies the grad_fn of the base
162 /// Variable.
163 TORCH_API void rebase_history(const Variable&, Edge gradient_edge);
164
165 /// Gets the raw gradient function pointer, whatever it currently is.
166 TORCH_API Node* grad_fn_unsafe(const Variable&);
167
168 /// Increments the version count of this `Variable`.
169 TORCH_API void bump_version(const Variable&);
170 TORCH_API void set_version_counter(
171 const Variable&,
172 const c10::VariableVersion& version_counter);
173
174 /// Retrieves this `Variable`s version counter.
175 TORCH_API const c10::VariableVersion& version_counter(const Variable&);
176
177 TORCH_API void set_name(const Variable&, const std::string& name);
178
179 TORCH_API void add_hook(
180 const at::TensorBase&,
181 std::unique_ptr<FunctionPreHook> hook);
182 TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
183 TORCH_API void clear_hooks(const at::TensorBase&);
184
185 TORCH_API void set_post_acc_grad_hooks(
186 const at::TensorBase&,
187 std::unique_ptr<PostAccumulateGradHook> dict);
188 TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
189 const Variable&);
190
191 TORCH_API void create_cpp_hook(
192 const at::TensorBase&,
193 bool is_retains_grad_hooks = false);
194 } // namespace impl
195
196 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
197 // AutogradMeta
198 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
199
200 /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
201 /// metadata fields that are necessary for tracking the Variable's autograd
202 /// history. As an optimization, a Variable may store a nullptr, in lieu of a
203 /// default constructed AutogradMeta.
204
205 struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
206 std::string name_;
207
208 Variable grad_;
209 std::shared_ptr<Node> grad_fn_;
210 std::weak_ptr<Node> grad_accumulator_;
211
212 // This field is used to store all the forward AD gradients
213 // associated with this AutogradMeta (and the Tensor it corresponds to)
214 // There is a semantic 1:1 correspondence between AutogradMeta and
215 // ForwardGrad but:
216 // - This field is lazily populated.
217 // - This field is a shared_ptr but it must never be
218 // shared by multiple Tensors. See Note [ Using ForwardGrad ]
219 // Any transition from not_initialized to initialized
220 // must be protected by mutex_
221 mutable std::shared_ptr<ForwardGrad> fw_grad_;
222
223 // The hooks_ field is actually reused by both python and cpp logic
224 // For both cases, we have a data structure, cpp_hooks_list_ (cpp)
225 // or dict (python) which is the canonical copy.
226 // Then, for both cases, we always register a single hook to
227 // hooks_ which wraps all the hooks in the list/dict.
228 // And, again in both cases, if the grad_fn exists on that tensor
229 // we will additionally register a single hook to the grad_fn.
230 //
231 // Note that the cpp and python use cases aren't actually aware of
232 // each other, so using both is not defined behavior.
233 std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
234 std::shared_ptr<hooks_list> cpp_hooks_list_;
235
236 // The post_acc_grad_hooks_ field stores only Python hooks
237 // (PyFunctionTensorPostAccGradHooks) that are called after the
238 // .grad field has been accumulated into. This is less complicated
239 // than the hooks_ field, which encapsulates a lot more.
240 std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
241
242 // Only meaningful on leaf variables (must be false otherwise)
243 bool requires_grad_{false};
244
245 // Only meaningful on non-leaf variables (must be false otherwise)
246 bool retains_grad_{false};
247
248 bool is_view_{false};
249
250 // The "output number" of this variable; e.g., if this variable
251 // was the second output of a function, then output_nr == 1.
252 // We use this to make sure we can setup the backwards trace
253 // correctly when this variable is passed to another function.
254 uint32_t output_nr_;
255
256 // Mutex to ensure that concurrent read operations that modify internal
257 // state are still thread-safe. Used by grad_fn(), grad_accumulator(),
258 // fw_grad() and set_fw_grad()
259 // This is mutable because we need to be able to acquire this from const
260 // version of this class for the functions above
261 mutable std::mutex mutex_;
262
263 /// Sets the `requires_grad` property of `Variable`. This should be true for
264 /// leaf variables that want to accumulate gradients, and false for all other
265 /// variables.
set_requires_gradAutogradMeta266 void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {
267 TORCH_CHECK(
268 !requires_grad ||
269 isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
270 "Only Tensors of floating point and complex dtype can require gradients");
271 requires_grad_ = requires_grad;
272 }
273
requires_gradAutogradMeta274 bool requires_grad() const override {
275 return requires_grad_ || grad_fn_;
276 }
277
278 /// Accesses the gradient `Variable` of this `Variable`.
mutable_gradAutogradMeta279 Variable& mutable_grad() override {
280 return grad_;
281 }
282
gradAutogradMeta283 const Variable& grad() const override {
284 return grad_;
285 }
286
287 const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
288 const override;
289
290 void set_fw_grad(
291 const at::TensorBase& new_grad,
292 const at::TensorBase& self,
293 uint64_t level,
294 bool is_inplace_op) override;
295
296 AutogradMeta(
297 at::TensorImpl* self_impl = nullptr,
298 bool requires_grad = false,
299 Edge gradient_edge = Edge())
300 : grad_fn_(std::move(gradient_edge.function)),
301
302 output_nr_(gradient_edge.input_nr) {
303 // set_requires_grad also checks error conditions.
304 if (requires_grad) {
305 TORCH_INTERNAL_ASSERT(self_impl);
306 set_requires_grad(requires_grad, self_impl);
307 }
308 TORCH_CHECK(
309 !grad_fn_ || !requires_grad_,
310 "requires_grad should be false if grad_fn is set");
311 }
312
~AutogradMetaAutogradMeta313 ~AutogradMeta() override {
314 // If AutogradMeta is being destroyed, it means that there is no other
315 // reference to its corresponding Tensor. It implies that no other thread
316 // can be using this object and so there is no need to lock mutex_ here to
317 // guard the check if fw_grad_ is populated.
318 if (fw_grad_) {
319 // See note [ Using ForwardGrad ]
320 fw_grad_->clear();
321 }
322 }
323 };
324
325 /// Base class for view functions, providing reapplication of a view on a new
326 /// base. Each view op should get a codegenerated subclass of this class
327 /// containing any state needed to reconstruct the view. The class also provides
328 /// convenience accessors for saved SymInts / tensor state. This is useful for
329 /// e.g. fake-ification, where we want to use symbolic values or fake tensors
330 /// instead.
331 struct TORCH_API ViewFunc {
332 virtual ~ViewFunc() = default;
333 /// Returns any SymInts in the saved state.
get_symintsViewFunc334 virtual std::vector<c10::SymInt> get_symints() const {
335 return {};
336 }
337 /// Returns the number of SymInts in the saved state.
num_symintsViewFunc338 virtual size_t num_symints() const {
339 return 0;
340 }
341 /// Returns any tensors in the saved state.
get_tensorsViewFunc342 virtual std::vector<at::Tensor> get_tensors() const {
343 return {};
344 }
345 /// Returns the number of tensors in the saved state.
num_tensorsViewFunc346 virtual size_t num_tensors() const {
347 return 0;
348 }
349 /// Reapplies the view on the given base using the saved state.
350 virtual at::Tensor operator()(const at::Tensor&) const = 0;
351 /// Returns a clone of this ViewFunc, optionally with the specified saved
352 /// state.
353 virtual std::unique_ptr<ViewFunc> clone_and_set(
354 std::optional<std::vector<c10::SymInt>> = std::nullopt,
355 std::optional<std::vector<at::Tensor>> = std::nullopt) const = 0;
356
357 protected:
358 /// Sets the values of any SymInts in the saved state. The input vector size
359 /// must match the number of SymInts in the saved state (i.e. the size of the
360 /// list returned by get_symints()).
set_symintsViewFunc361 virtual void set_symints(std::vector<c10::SymInt>) {}
362 /// Sets the values of any Tensors in the saved state. The input vector size
363 /// must match the number of Tensors in the saved state (i.e. the size of the
364 /// list returned by get_tensors()).
set_tensorsViewFunc365 virtual void set_tensors(std::vector<at::Tensor>) {}
366 };
367
368 /// ViewFunc that represents a chain of two ViewFuncs.
369 struct ChainedViewFunc : public ViewFunc {
ChainedViewFuncChainedViewFunc370 ChainedViewFunc(
371 std::unique_ptr<ViewFunc> first,
372 std::unique_ptr<ViewFunc> second)
373 : first(std::move(first)), second(std::move(second)) {}
374 ~ChainedViewFunc() override = default;
375 std::vector<c10::SymInt> get_symints() const override;
num_symintsChainedViewFunc376 size_t num_symints() const override {
377 return first->num_symints() + second->num_symints();
378 }
379 std::vector<at::Tensor> get_tensors() const override;
num_tensorsChainedViewFunc380 size_t num_tensors() const override {
381 return first->num_tensors() + second->num_tensors();
382 }
383 at::Tensor operator()(const at::Tensor&) const override;
384 std::unique_ptr<ViewFunc> clone_and_set(
385 std::optional<std::vector<c10::SymInt>> = std::nullopt,
386 std::optional<std::vector<at::Tensor>> = std::nullopt) const override;
387
388 private:
389 std::unique_ptr<ViewFunc> first;
390 std::unique_ptr<ViewFunc> second;
391 };
392
393 /// ViewFunc that errors with a specified error message when called.
394 struct ErroringViewFunc : public ViewFunc {
ErroringViewFuncErroringViewFunc395 ErroringViewFunc(std::string error_msg) : error_msg(std::move(error_msg)) {}
396 ~ErroringViewFunc() override = default;
operatorErroringViewFunc397 at::Tensor operator()(const at::Tensor&) const override {
398 TORCH_CHECK(false, error_msg);
399 }
400 std::unique_ptr<ViewFunc> clone_and_set(
401 std::optional<std::vector<c10::SymInt>> = std::nullopt,
402 std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
403 return std::make_unique<ErroringViewFunc>(error_msg);
404 }
405
406 private:
407 std::string error_msg;
408 };
409
410 struct TORCH_API ViewInfo {
411 /// The base `Variable`
412 /// If this ViewInfo represents a forward (respectively backward) AD gradient,
413 /// then this Tensor cannot be a forward (respectively backward) view.
414 Variable base_;
415
416 /// By default we use as_strided to recover views which is more efficient.
417 /// view_fn is only saved when as_strided is not supported.
418 /// If view_fn has value, we use it to recover views in backward.
419 std::unique_ptr<ViewFunc> view_fn_;
420
421 /// Analogue of view_fn but in reverse: given a view -> produce the base by
422 /// applying the inverse view.
423 std::function<Variable(const Variable&)> rev_view_fn_;
424
425 /// Accessors for the view function
has_view_fnViewInfo426 bool has_view_fn() const {
427 // assume either BOTH or NEITHER of view_fn_ and rev_view_fn_ exist
428 return view_fn_ != nullptr;
429 }
430
view_fnViewInfo431 const ViewFunc& view_fn() const {
432 TORCH_CHECK(
433 has_view_fn(), "Can only access the view function if it exists.");
434 return *view_fn_;
435 }
436
rev_view_fnViewInfo437 std::function<Variable(const Variable&)> rev_view_fn() const {
438 TORCH_CHECK(
439 has_view_fn(),
440 "Can only access the reverse view function if it exists.");
441 return rev_view_fn_;
442 }
443
444 /// The chain function can be used to build a new ViewInfo for a
445 /// differentiable view function. It will return a new view info that
446 /// accurately represents how "tensor" is a view of this instance's "base_".
447 /// The "base" and "tensor" are respectively the input and output of the
448 /// differentiable view function that happened. They are required to properly
449 /// set the optional view_fn_ when it is not provided. The "view_func", if
450 /// provided, should be a function that allows to re-do the view between
451 /// "base" and "tensor".
452 ViewInfo chain(
453 const Variable& base,
454 const Variable& tensor,
455 std::unique_ptr<ViewFunc> view_func = nullptr,
456 std::function<Variable(const Variable&)> rev_view_func = nullptr) const;
457
ViewInfoViewInfo458 ViewInfo(
459 Variable base,
460 std::unique_ptr<ViewFunc> view_fn,
461 std::function<Variable(const Variable&)> rev_view_fn)
462 : base_(std::move(base)),
463 view_fn_(std::move(view_fn)),
464 rev_view_fn_(std::move(rev_view_fn)) {
465 TORCH_CHECK(base_.defined(), "base is undefined");
466 }
467 };
468
469 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
470 // DifferentiableViewMeta
471 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
472
473 /// NOTE [ Autograd View Variables ]
474 ///
475 /// Many operations return Variable that shares storage with an input Variable.
476 /// The returned Variable is called a **view** Variable on the input **base**
477 /// Variable.
478 ///
479 /// In PyTorch, we have two types of views: differentiable views, and
480 /// non-differentiable views. In either type, to support proper version
481 /// checking, the base and view Variables must always share the same
482 /// version_counter.
483 ///
484 ///
485 /// Differentiable Views
486 /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
487 /// This class allows to track both forward and backward AD differentiable
488 /// views. These views can have different base as non-differentiable view for
489 /// forward and backward mode AD are not the same.
490 ///
491 /// Most function are either both forward and backward differentiable views (for
492 /// example: view, select, narrow, transpose, etc) or both not forward and not
493 /// backward differentiable views (for example: indices, values, eq, lt, etc).
494 /// But there are also functions that are forward but not backward
495 /// differentiable views (only detach for now) or functions that are backward
496 /// but not forward differentiable view (only make_dual and unpack dual for
497 /// now).
498 ///
499 /// A concrete example of two views with different bases is as follow:
500 ///
501 /// # Have:
502 /// # dual is a dual Tensor that is neither a forward or backward view
503 /// detached_dual = dual.detach()
504 /// view = detached_dual.view_as(dual)
505 /// # The forward base of view is dual
506 /// # The backward base of view is detached_dual
507 ///
508 /// - Backward Mode View
509 /// Differentiable views are the view variables where you want gradients to flow
510 /// back to the base variables. Out-of-place operations on views are quite
511 /// straightforward, but in-place ones are very tricky. Even if the base
512 /// variable may not require grad when we create the view, we still need to
513 /// track the view relation because future in-place ops may require back-proping
514 /// through it. For example, we need to support
515 ///
516 /// (1) in-place operation on view, e.g.,
517 ///
518 /// # Have:
519 /// # base.requires_grad = False
520 /// # var.requires_grad = True
521 /// base[1] = var # i.e., base[1].copy_(var)
522 /// torch.autograd.grad(base.sum(), var) <- should return an all ones
523 /// tensor
524 ///
525 /// (2) in-place operation on base after view is created, e.g.,
526 ///
527 /// # Have:
528 /// # base.requires_grad = False
529 /// # var.requires_grad = True
530 /// view = base[1]
531 /// base.copy_(var)
532 /// torch.autograd.grad(view.sum(), var) <- should return a tensor with
533 /// var[1] filled with all ones and
534 /// zeros everywhere else
535 ///
536 /// - Forward Mode View
537 /// Forward differentiable views follow the same semantic as backward ones but
538 /// show up differently as they are computed along with the forward evaluation.
539 /// The hard examples above are thus very similar
540 ///
541 /// (1) in-place operation on view, e.g.,
542 ///
543 /// # Have:
544 /// # base is a regular Tensor
545 /// # var is a dual Tensor whose tangent is all ones
546 /// base[1] = var # i.e., base[1].copy_(var)
547 /// # Now, base is a dual Tensor
548 /// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
549 /// fw_grad[1] filled with all ones
550 /// and zeros everywhere else
551 ///
552 /// (2) in-place operation on base after view is created, e.g.,
553 ///
554 /// # Have:
555 /// # base is a regular Tensor
556 /// # var is a dual Tensor whose tangent is all ones
557 /// view = base[1]
558 /// base.copy_(var)
559 /// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones
560 /// tensor
561 ///
562 /// See Note [Forward Grad View/inplace] for more details on how we handle these
563 /// hard cases.
564 ///
565 ///
566 /// DifferentiableViewMeta is created to support gradient tracking of
567 /// such **in-place** operations. In particular,
568 /// + if an in-place op is done on base, the grad_fn field of the view may
569 /// become stale. So accesses should always go through grad_fn(), which
570 /// reconstructs an updated grad_fn if the version_counter has incremented.
571 /// All other fields are always valid.
572 /// + if an in-place op is done on view, in rebase_history() of view, which is
573 /// called after every in-place op in VariableType.cpp, the grad_fn of base
574 /// is updated.
575 /// + if a single autograd Node returns multiple differentiable views, if any
576 /// output is modified by an inplace operation, the autograd engine will
577 /// make an equivalent graph (corresponding to the view operations) without
578 /// using equivalent graph, where each output is treated as if it were
579 /// produced by a distinct view operation. This discards the original (e.g.,
580 /// user provided) grad_fn. If the provided grad_fn does more than the
581 /// backward of the view, then the DifferentiableViewMeta must be created
582 /// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the
583 /// engine from ignoring the provided grad_fn.
584 ///
585 /// Interaction with GradMode:
586 /// The particular case that we consider here is:
587 ///
588 /// # Have:
589 /// # base.requires_grad = True or False
590 /// with torch.no_grad():
591 /// view = base[1]
592 /// base.requires_grad_()
593 /// view.copy_(var)
594 /// torch.autograd.grad(base.sum(), var) <- what should it return?
595 ///
596 /// Given that this particular code example is ambiguous and can easily be
597 /// replace by either moving both inside the no_grad block or both outside, we
598 /// explicitly forbid it. For now, it is deprecated by a warning. This is
599 /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all
600 /// differentiable views created in no_grad mode.
601 ///
602 /// See Note [View + Inplace update for base tensor]
603 /// and Note [View + Inplace update for view tensor] for the details how
604 /// autograd handles inplace update with view ops.
605 ///
606 /// Non-Differentiable Views
607 /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
608 /// In certain cases, although function outputs share storage with inputs, they
609 /// will **never** require gradient history tracking. Instead of registering the
610 /// view relation via DifferentiableViewMeta in autograd, the views will be
611 /// using usual AutogradMeta and just share the version counters with the base
612 /// Variables.
613 /// Such views include:
614 /// 1. Views created from .detach()
615 /// 2. Views that are non-differentiable by its nature.
616 /// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
617 /// floating point tensor.
618 /// See top of `derivatives.yaml` on how to specify that outputs of a
619 /// function are non-differentiable.
620 /// These are called non-differentiable views as the gradients do not flow
621 /// through the view relation.
622 ///
623 /// Relevant logic for both differentiable and non-differentiable views is
624 /// implemented in make_variable_(non_)differentiable_view below, and
625 /// wrap_output of gen_variable_type.py.
626
627 /// NOTE [ View + Inplace detection ]
628 ///
629 /// We want to detect views followed by inplace as they are often forbidden to
630 /// ensure correctness of the computed gradients. But since we want to only
631 /// notify the user when both happen, we tag the DifferentiableViewMeta when the
632 /// view is created via the `make_variable_*_view()` functions. This tag is then
633 /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that
634 /// should be called before every inplace operation and to detect cases where
635 /// other views are modified and this one is rebased by side effect, we also
636 /// check in the `VariableHooks::grad_fn()`.
637
638 /// Flag that gives more information about when this view was created:
639 /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom
640 /// autograd Function is returned.
641 /// - NO_GRAD_MODE should be set when a view in created when GradMode is
642 /// disabled
643 /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code
644 /// returns
645 /// multiple differentiable views
646 /// - Inference_MODE should be set when a view of normal tensor is created in
647 /// InferenceMode.
648 /// - DEFAULT is for all other cases
649 enum class CreationMeta : uint8_t {
650 DEFAULT,
651 IN_CUSTOM_FUNCTION,
652 MULTI_OUTPUT_NODE,
653 NO_GRAD_MODE,
654 INFERENCE_MODE
655 };
656
657 /// Handles correctly propagating CreationMeta when a new view is created from a
658 /// previous view. In general, we don't want the new view to be _less_
659 /// restrictive than the previous view (it's okay to be _more_ restrictive). A
660 /// CreationMeta value of DEFAULT is currently the least restrictive, as the
661 /// behavior for all other CreationMeta values is to error out for in-place ops.
662 /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so
663 /// it takes precedence in propagation. If this changes, the logic here will
664 /// need to be updated to properly handle the new semantics.
propagate_creation_meta(CreationMeta prev_view_creation_meta,CreationMeta new_view_creation_meta)665 inline CreationMeta propagate_creation_meta(
666 CreationMeta prev_view_creation_meta,
667 CreationMeta new_view_creation_meta) {
668 return (new_view_creation_meta == CreationMeta::DEFAULT)
669 ? prev_view_creation_meta
670 : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
671 ? prev_view_creation_meta
672 : new_view_creation_meta);
673 }
674
675 /// Unified function to handle error checking when rebase happens
676 /// indirect=true means that the caller is not doing the inplace, but the
677 /// inplace happened somewhere else.
678 TORCH_API void handle_view_on_rebase(
679 DifferentiableViewMeta* diff_view_meta,
680 bool indirect = false);
681
682 struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
683 private:
684 /// Information about the views
685 std::optional<ViewInfo> backward_info_;
686 std::optional<ViewInfo> forward_info_;
687
688 // Optimization to reduce the number of ViewInfo we create.
689 // In the (very common) case where backward_info_ == forward_info_, we only
690 // populate backward_info_ (that should be used as both the forward and
691 // backward view information) and set shared_view_info_ = true. Invariants:
692 // - If shared_view_info_ is false, there is no special constraints on
693 // backward_info_ and forward_info_
694 // - If shared_view_info_ is true, we must have:
695 // - backward_info_.has_value() == true
696 // - forward_info_.has_value() == false
697 bool shared_view_info_;
698
699 /// The two following fields are extra information that we track to ensure
700 /// that any operation on this backward view is valid.
701
702 /// The value of the version_counter at the time grad_fn was created. The
703 /// grad_fn field is stale if attr_version_ !=
704 /// version_counter.current_version().
705 uint32_t attr_version_;
706 CreationMeta creation_meta_;
707
708 public:
709 /// requires_grad is a backward AD field so we only use the view specific
710 /// logic for backward differentiable views
requires_gradDifferentiableViewMeta711 bool requires_grad() const override {
712 return requires_grad_ || grad_fn_ ||
713 (has_bw_view() && get_backward_view().base_.requires_grad());
714 }
715
shared_view_infoDifferentiableViewMeta716 bool shared_view_info() const {
717 return shared_view_info_;
718 }
719
has_bw_viewDifferentiableViewMeta720 bool has_bw_view() const {
721 return backward_info_.has_value();
722 }
723
get_backward_viewDifferentiableViewMeta724 const ViewInfo& get_backward_view() const {
725 TORCH_CHECK(
726 has_bw_view(), "backward view info can only exist for backward views.");
727 return backward_info_.value();
728 }
729
get_attr_versionDifferentiableViewMeta730 uint32_t get_attr_version() const {
731 TORCH_CHECK(
732 has_bw_view(), "attr_version can only exist for backward views.");
733 return attr_version_;
734 }
735
set_attr_versionDifferentiableViewMeta736 void set_attr_version(uint32_t new_attr_version) {
737 TORCH_CHECK(
738 has_bw_view(), "attr_version can only exist for backward views.");
739 attr_version_ = new_attr_version;
740 }
741
get_creation_metaDifferentiableViewMeta742 CreationMeta get_creation_meta() const {
743 TORCH_CHECK(
744 has_bw_view(), "creation_meta can only exist for backward views.");
745 return creation_meta_;
746 }
747
set_creation_metaDifferentiableViewMeta748 void set_creation_meta(CreationMeta new_creation_meta) {
749 TORCH_CHECK(
750 has_bw_view(), "creation_meta can only exist for backward views.");
751 creation_meta_ = new_creation_meta;
752 }
753
has_fw_viewDifferentiableViewMeta754 bool has_fw_view() const {
755 return shared_view_info_ || forward_info_.has_value();
756 }
757
get_forward_viewDifferentiableViewMeta758 const ViewInfo& get_forward_view() const {
759 TORCH_CHECK(
760 has_fw_view(), "forward view info can only exist for forward views.");
761 TORCH_CHECK(
762 !shared_view_info_ || has_bw_view(),
763 "forward view info can only exist for forward views.");
764 return shared_view_info_ ? backward_info_.value() : forward_info_.value();
765 }
766
767 DifferentiableViewMeta(
768 at::TensorImpl* self_impl,
769 std::optional<ViewInfo> backward_info,
770 std::optional<ViewInfo> forward_info,
771 bool shared_view_info,
772 CreationMeta creation_meta = CreationMeta::DEFAULT);
773 };
774
775 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
776 // Variable Implementation
777 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
778
779 // Factory Functions
780 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
781
782 /// Creates a `Variable` that is a *view* of another (*base*) variable.
783 /// The `gradient_edge` is an optional (gradient_function, input_number) pair.
784 /// `is_differentiable` is a bool that specifies whether this view is
785 /// differentiable, i.e., whether the relation should be tracked by autograd.
786 /// See NOTE [ Autograd View Variables ] for details.
787
788 /// NOTE: `allow_tensor_metadata_change` is set to true by default, because
789 /// there are a lot of call sites to these factory functions that need to change
790 /// the variable's size or storage afterwards, and they don't expect the
791 /// original tensor (where the variable is created from) to be updated. Setting
792 /// `allow_tensor_metadata_change_` to false by default would unnecessarily
793 /// prevent those changes from happening and is undesirable.
794
795 // See NOTE [ Autograd View Variables ] for details.
796 // Differentiable view. Track history with DifferentiableViewMeta.
797 inline Variable make_variable_differentiable_view(
798 const at::Tensor& data,
799 std::optional<ViewInfo> backward_info,
800 std::optional<ViewInfo> forward_info,
801 bool shared_view_info,
802 CreationMeta creation_meta,
803 bool allow_tensor_metadata_change = true) {
804 if (data.defined()) {
805 TORCH_CHECK(
806 data.getIntrusivePtr()->autograd_meta() == nullptr,
807 "Attempted to make a tensor into a differentiable view, but the "
808 "tensor already had autograd metadata associated with it. If you are "
809 "using a __torch_dispatch__ mode, the most common cause for this "
810 "problem is that you used torch.overrides.enable_reentrant_dispatch() "
811 "improperly; tensors created within the extent of reentrant dispatch "
812 "MUST NOT be directly returned from __torch_dispatch__; instead, they "
813 "must be wrapped into fresh tensors that serve as the output. If you "
814 "are not using wrappers, you probably don't need reentrant dispatch. "
815 "If this doesn't seem applicable, please file a bug to PyTorch.");
816 at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
817 data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
818 data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
819 data_impl,
820 std::move(backward_info),
821 std::move(forward_info),
822 shared_view_info,
823 creation_meta));
824 return data;
825 }
826 return Variable();
827 }
828
829 // See NOTE [ Autograd View Variables ] for details.
830 // Non-differentiable view. Just share version counter.
831 inline Variable make_variable_non_differentiable_view(
832 const Variable& base,
833 const at::Tensor& data,
834 bool allow_tensor_metadata_change = true) {
835 if (data.defined()) {
836 // Currently all of non-differentiable view ops(detach/_indices/_values)
837 // share the same TensorImpl as their base Tensor. Thus a new TensorImpl
838 // allocation here is required.
839 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
840 /*version_counter=*/impl::version_counter(base),
841 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
842 data_impl_copy->set_autograd_meta(nullptr);
843 return Variable(data_impl_copy);
844 }
845 return Variable();
846 }
847
848 /// Creates a `Variable` from the given `Tensor`, copying its underlying
849 /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines
850 /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be
851 /// a `Variable` already. Its dynamic type *must* be `Tensor`.
852 ///
853 /// TODO: Eliminate this function as much as possible, as it can be expressed
854 /// more clearly as detach() or a no-op in most call sites (especially when
855 /// there is only one use of the variable).
856 inline Variable make_variable(
857 at::Tensor data,
858 bool requires_grad = false,
859 bool allow_tensor_metadata_change = true) {
860 if (data.defined()) {
861 if (data.getIntrusivePtr().use_count() == 1 &&
862 data.getIntrusivePtr()->unique_version()) {
863 auto data_impl = data.unsafeReleaseIntrusivePtr();
864 data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
865 if (requires_grad) {
866 data_impl->set_autograd_meta(
867 std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
868 } else {
869 data_impl->set_autograd_meta(nullptr);
870 }
871 return Variable(std::move(data_impl));
872 } else {
873 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
874 /*version_counter=*/0,
875 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
876 if (requires_grad) {
877 data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
878 data_impl_copy.get(), requires_grad));
879 } else {
880 data_impl_copy->set_autograd_meta(nullptr);
881 }
882 return Variable(data_impl_copy);
883 }
884 }
885 return Variable();
886 }
887
888 /// Creates a `Variable` from the given `Tensor`, copying its underlying
889 /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair
890 /// specifying the function in the autograd graph, and what particular input of
891 /// that function, this variable is connected to.
892 inline Variable make_variable(
893 const at::Tensor& data,
894 Edge gradient_edge,
895 bool allow_tensor_metadata_change = true) {
896 if (data.defined()) {
897 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
898 /*version_counter=*/0,
899 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
900 data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
901 data_impl_copy.get(), false, std::move(gradient_edge)));
902 return Variable(data_impl_copy);
903 }
904 return Variable();
905 }
906
907 struct VariableHooks final : at::impl::VariableHooksInterface {
908 at::TensorBase tensor_data(const at::TensorBase&) const override;
909 at::TensorBase variable_data(const at::TensorBase&) const override;
910 const std::shared_ptr<torch::autograd::Node>& grad_fn(
911 const at::TensorBase&) const override;
912 unsigned _register_hook(
913 const at::TensorBase&,
914 std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
915 void remove_hook(const at::TensorBase&, unsigned pos) const override;
916 bool is_view(const at::TensorBase&) const override;
917 const at::TensorBase& base(const at::TensorBase&) const override;
918 const std::string& name(const at::TensorBase&) const override;
919 bool is_leaf(const at::TensorBase&) const override;
920 int64_t output_nr(const at::TensorBase&) const override;
921 void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
922 const override;
923 at::TensorBase data(const at::TensorBase& self) const override;
924 int64_t _version(const at::TensorBase& self) const override;
925 void retain_grad(const at::TensorBase& self) const override;
926 bool retains_grad(const at::TensorBase& self) const override;
927 void _backward(
928 const at::Tensor& self,
929 at::TensorList inputs,
930 const std::optional<at::Tensor>& gradient,
931 std::optional<bool> keep_graph,
932 bool create_graph) const override;
933 void requires_grad_(const at::TensorBase& self, bool _requires_grad)
934 const override;
935 void basic_autograd_not_implemented_fallback(
936 const c10::OperatorHandle& op,
937 c10::DispatchKeySet dispatch_keys,
938 torch::jit::Stack* stack) const override;
939 };
940
941 namespace utils {
942
943 TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
944
945 } // namespace utils
946 } // namespace torch::autograd
947
948 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
949