xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/variable.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/utils/python_stub.h>
4 
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/autograd/cpp_hook.h>
7 #include <torch/csrc/autograd/edge.h>
8 #include <torch/csrc/autograd/forward_grad.h>
9 #include <torch/csrc/autograd/function_hook.h>
10 
11 #include <ATen/NamedTensorUtils.h>
12 #include <ATen/core/Tensor.h>
13 #include <ATen/core/VariableHooksInterface.h>
14 #include <c10/util/Exception.h>
15 
16 #include <cstdint>
17 #include <memory>
18 #include <mutex>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 namespace torch::autograd {
24 
25 /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =
26 /// at::Tensor`). This means you can perform all the usual mathematical and
27 /// other operations you can perform on `Tensor`s also on `Variable`s.
28 ///
29 /// The only reason we are keeping the `Variable` class is backward
30 /// compatibility with external user's legacy C++ frontend code. Our intention
31 /// is to eliminate the `Variable` class in the near future.
32 using Variable = at::Tensor;
33 
34 } // namespace torch::autograd
35 
36 // The following are all internal APIs and should not be shown in libtorch docs.
37 // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
38 // ... #endif`
39 
40 #ifndef DOXYGEN_SHOULD_SKIP_THIS
41 
42 namespace torch::autograd {
43 
44 /// Check if this type is supported by the autograd engine.
45 /// If you change this, update the doc at the top of the
46 /// torch/autograd/__init__.py file and
47 /// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py
isDifferentiableType(at::ScalarType t)48 static inline bool isDifferentiableType(at::ScalarType t) {
49   return isFloatingType(t) || isComplexType(t);
50 }
51 
52 struct Node;
53 
54 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55 ///                                Variable
56 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 /// A `Variable` augments a `Tensor` with the ability to interact in our
58 /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
59 /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
60 /// weight in a neural network, or an interior variable, when it is the result
61 /// of an operation between variables. Every `Variable` also stores another
62 /// `Variable` called its `grad` (gradient). If the variable is a leaf, its
63 /// gradient will be accumulated into this variable.
64 ///
65 /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables
66 /// that don't require gradients as Tensors (since none of the autograd
67 /// machinery for Variables applies).  Historically, Variables and Tensors
68 /// were separate concepts, but now they are exactly the same (i.e. we have
69 /// `using Variable = at::Tensor`).
70 ///
71 ///                              Gradient Edges
72 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
73 /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
74 /// edge in the autograd graph that connects the variable to a particular input
75 /// of the gradient function that will be invoked with the variable during the
76 /// backward pass. More precisely, this gradient function can be one of two
77 /// things:
78 /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
79 ///    gradient of the function that produced the variable.
80 /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
81 ///    scalar gradient value into its `grad` variable.
82 ///
83 ///                               Versioning
84 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
85 /// Another major feature of `Variable`s are *versions*. Versions are
86 /// incremented when an in-place mutation of a variable occurs. Versions are
87 /// useful when constructing `SavedVariable`s, which take a snapshot of a
88 /// `Variable` at a certain version. You can retrieve a `Variable`'s version
89 /// through its `current_version()` method.
90 ///
91 ///                                 Views
92 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
93 /// It is possible for a  `Variable` to be a *view* of another `Variable`, in
94 /// which case it tracks that `Variable`'s data and autograd history. Beyond
95 /// construction, the interface of a view is identical to that of a regular
96 /// `Variable`. You can determine whether `Variable` is in fact a view by
97 /// probing its `is_view()` method. Note that the *view* semantics are only
98 /// meaningful for `Variable` relations that are relevant to autograd.
99 /// See NOTE [ Autograd View Variables ] for more details.
100 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101 
102 struct AutogradMeta;
103 struct DifferentiableViewMeta;
104 
105 // Private-ish functions for manipulating variables; we don't want to put them
106 // on Tensor proper
107 namespace impl {
108 
109 // WARNING: This may return a nullptr.  If you require AutogradMeta to return
110 // a materialized structure, use materialize_autograd_meta instead.
111 TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
112 
113 // WARNING: This will return a nullptr if the Tensor is not a view.
114 TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
115 
116 // Returns the current autograd meta, materializing it if it was previously
117 // none.  This counts as a *mutating* operation, so do not call it on
118 // "read-only" operators; in particular, this is NOT thread safe
119 TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
120 
121 /// Set the gradient accumulator of the `Variable`. This is only applicable to
122 /// leaf variables. Interior variables should call `set_gradient_edge()`.
123 TORCH_API void set_grad_accumulator(
124     const Variable&,
125     std::weak_ptr<Node> grad_accumulator);
126 
127 /// Attempts to get a pointer to the gradient accumulator of the `Variable`,
128 /// if it still exists. If the gradient accumulator function has been
129 /// destroyed, returns a `nullptr`.
130 TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
131 
132 /// Gets the gradient accumulator of the `Variable` if it has one, or else
133 /// create one on the fly and return it.
134 TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);
135 
136 /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
137 /// gradient function if this is an interior `Variable`, or the gradient
138 /// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
139 /// will store the input index of the `Node` to which this variable is
140 /// connected in its `input_nr` field. For leaves, the `input_nr` is always
141 /// zero. Note that `set_gradient_edge` and `gradient_edge` are not
142 /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
143 /// `set_grad_accumulator` to set the accumulator.
144 TORCH_API Edge gradient_edge(const Variable&);
145 
146 /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
147 /// `Variable`.
148 /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
149 /// and never the `grad_accumulator`. For the latter, use
150 /// `set_grad_accumulator`. This allows late construction of an interior
151 /// `Variable`.
152 TORCH_API void set_gradient_edge(const Variable&, Edge edge);
153 
154 // Autograd Graph Interaction
155 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156 
157 /// Update the `grad_fn` of an existing Variable. Called after in-place
158 /// modifications.
159 ///
160 /// For View Variables:
161 /// Called after in-place modifications. Modifies the grad_fn of the base
162 /// Variable.
163 TORCH_API void rebase_history(const Variable&, Edge gradient_edge);
164 
165 /// Gets the raw gradient function pointer, whatever it currently is.
166 TORCH_API Node* grad_fn_unsafe(const Variable&);
167 
168 /// Increments the version count of this `Variable`.
169 TORCH_API void bump_version(const Variable&);
170 TORCH_API void set_version_counter(
171     const Variable&,
172     const c10::VariableVersion& version_counter);
173 
174 /// Retrieves this `Variable`s version counter.
175 TORCH_API const c10::VariableVersion& version_counter(const Variable&);
176 
177 TORCH_API void set_name(const Variable&, const std::string& name);
178 
179 TORCH_API void add_hook(
180     const at::TensorBase&,
181     std::unique_ptr<FunctionPreHook> hook);
182 TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
183 TORCH_API void clear_hooks(const at::TensorBase&);
184 
185 TORCH_API void set_post_acc_grad_hooks(
186     const at::TensorBase&,
187     std::unique_ptr<PostAccumulateGradHook> dict);
188 TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
189     const Variable&);
190 
191 TORCH_API void create_cpp_hook(
192     const at::TensorBase&,
193     bool is_retains_grad_hooks = false);
194 } // namespace impl
195 
196 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
197 //                            AutogradMeta
198 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
199 
200 /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
201 /// metadata fields that are necessary for tracking the Variable's autograd
202 /// history. As an optimization, a Variable may store a nullptr, in lieu of a
203 /// default constructed AutogradMeta.
204 
205 struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
206   std::string name_;
207 
208   Variable grad_;
209   std::shared_ptr<Node> grad_fn_;
210   std::weak_ptr<Node> grad_accumulator_;
211 
212   // This field is used to store all the forward AD gradients
213   // associated with this AutogradMeta (and the Tensor it corresponds to)
214   // There is a semantic 1:1 correspondence between AutogradMeta and
215   // ForwardGrad but:
216   //   - This field is lazily populated.
217   //   - This field is a shared_ptr but it must never be
218   //     shared by multiple Tensors. See Note [ Using ForwardGrad ]
219   // Any transition from not_initialized to initialized
220   // must be protected by mutex_
221   mutable std::shared_ptr<ForwardGrad> fw_grad_;
222 
223   // The hooks_ field is actually reused by both python and cpp logic
224   // For both cases, we have a data structure, cpp_hooks_list_ (cpp)
225   // or dict (python) which is the canonical copy.
226   // Then, for both cases, we always register a single hook to
227   // hooks_ which wraps all the hooks in the list/dict.
228   // And, again in both cases, if the grad_fn exists on that tensor
229   // we will additionally register a single hook to the grad_fn.
230   //
231   // Note that the cpp and python use cases aren't actually aware of
232   // each other, so using both is not defined behavior.
233   std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
234   std::shared_ptr<hooks_list> cpp_hooks_list_;
235 
236   // The post_acc_grad_hooks_ field stores only Python hooks
237   // (PyFunctionTensorPostAccGradHooks) that are called after the
238   // .grad field has been accumulated into. This is less complicated
239   // than the hooks_ field, which encapsulates a lot more.
240   std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
241 
242   // Only meaningful on leaf variables (must be false otherwise)
243   bool requires_grad_{false};
244 
245   // Only meaningful on non-leaf variables (must be false otherwise)
246   bool retains_grad_{false};
247 
248   bool is_view_{false};
249 
250   // The "output number" of this variable; e.g., if this variable
251   // was the second output of a function, then output_nr == 1.
252   // We use this to make sure we can setup the backwards trace
253   // correctly when this variable is passed to another function.
254   uint32_t output_nr_;
255 
256   // Mutex to ensure that concurrent read operations that modify internal
257   // state are still thread-safe. Used by grad_fn(), grad_accumulator(),
258   // fw_grad() and set_fw_grad()
259   // This is mutable because we need to be able to acquire this from const
260   // version of this class for the functions above
261   mutable std::mutex mutex_;
262 
263   /// Sets the `requires_grad` property of `Variable`. This should be true for
264   /// leaf variables that want to accumulate gradients, and false for all other
265   /// variables.
set_requires_gradAutogradMeta266   void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {
267     TORCH_CHECK(
268         !requires_grad ||
269             isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
270         "Only Tensors of floating point and complex dtype can require gradients");
271     requires_grad_ = requires_grad;
272   }
273 
requires_gradAutogradMeta274   bool requires_grad() const override {
275     return requires_grad_ || grad_fn_;
276   }
277 
278   /// Accesses the gradient `Variable` of this `Variable`.
mutable_gradAutogradMeta279   Variable& mutable_grad() override {
280     return grad_;
281   }
282 
gradAutogradMeta283   const Variable& grad() const override {
284     return grad_;
285   }
286 
287   const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
288       const override;
289 
290   void set_fw_grad(
291       const at::TensorBase& new_grad,
292       const at::TensorBase& self,
293       uint64_t level,
294       bool is_inplace_op) override;
295 
296   AutogradMeta(
297       at::TensorImpl* self_impl = nullptr,
298       bool requires_grad = false,
299       Edge gradient_edge = Edge())
300       : grad_fn_(std::move(gradient_edge.function)),
301 
302         output_nr_(gradient_edge.input_nr) {
303     // set_requires_grad also checks error conditions.
304     if (requires_grad) {
305       TORCH_INTERNAL_ASSERT(self_impl);
306       set_requires_grad(requires_grad, self_impl);
307     }
308     TORCH_CHECK(
309         !grad_fn_ || !requires_grad_,
310         "requires_grad should be false if grad_fn is set");
311   }
312 
~AutogradMetaAutogradMeta313   ~AutogradMeta() override {
314     // If AutogradMeta is being destroyed, it means that there is no other
315     // reference to its corresponding Tensor. It implies that no other thread
316     // can be using this object and so there is no need to lock mutex_ here to
317     // guard the check if fw_grad_ is populated.
318     if (fw_grad_) {
319       // See note [ Using ForwardGrad ]
320       fw_grad_->clear();
321     }
322   }
323 };
324 
325 /// Base class for view functions, providing reapplication of a view on a new
326 /// base. Each view op should get a codegenerated subclass of this class
327 /// containing any state needed to reconstruct the view. The class also provides
328 /// convenience accessors for saved SymInts / tensor state. This is useful for
329 /// e.g. fake-ification, where we want to use symbolic values or fake tensors
330 /// instead.
331 struct TORCH_API ViewFunc {
332   virtual ~ViewFunc() = default;
333   /// Returns any SymInts in the saved state.
get_symintsViewFunc334   virtual std::vector<c10::SymInt> get_symints() const {
335     return {};
336   }
337   /// Returns the number of SymInts in the saved state.
num_symintsViewFunc338   virtual size_t num_symints() const {
339     return 0;
340   }
341   /// Returns any tensors in the saved state.
get_tensorsViewFunc342   virtual std::vector<at::Tensor> get_tensors() const {
343     return {};
344   }
345   /// Returns the number of tensors in the saved state.
num_tensorsViewFunc346   virtual size_t num_tensors() const {
347     return 0;
348   }
349   /// Reapplies the view on the given base using the saved state.
350   virtual at::Tensor operator()(const at::Tensor&) const = 0;
351   /// Returns a clone of this ViewFunc, optionally with the specified saved
352   /// state.
353   virtual std::unique_ptr<ViewFunc> clone_and_set(
354       std::optional<std::vector<c10::SymInt>> = std::nullopt,
355       std::optional<std::vector<at::Tensor>> = std::nullopt) const = 0;
356 
357  protected:
358   /// Sets the values of any SymInts in the saved state. The input vector size
359   /// must match the number of SymInts in the saved state (i.e. the size of the
360   /// list returned by get_symints()).
set_symintsViewFunc361   virtual void set_symints(std::vector<c10::SymInt>) {}
362   /// Sets the values of any Tensors in the saved state. The input vector size
363   /// must match the number of Tensors in the saved state (i.e. the size of the
364   /// list returned by get_tensors()).
set_tensorsViewFunc365   virtual void set_tensors(std::vector<at::Tensor>) {}
366 };
367 
368 /// ViewFunc that represents a chain of two ViewFuncs.
369 struct ChainedViewFunc : public ViewFunc {
ChainedViewFuncChainedViewFunc370   ChainedViewFunc(
371       std::unique_ptr<ViewFunc> first,
372       std::unique_ptr<ViewFunc> second)
373       : first(std::move(first)), second(std::move(second)) {}
374   ~ChainedViewFunc() override = default;
375   std::vector<c10::SymInt> get_symints() const override;
num_symintsChainedViewFunc376   size_t num_symints() const override {
377     return first->num_symints() + second->num_symints();
378   }
379   std::vector<at::Tensor> get_tensors() const override;
num_tensorsChainedViewFunc380   size_t num_tensors() const override {
381     return first->num_tensors() + second->num_tensors();
382   }
383   at::Tensor operator()(const at::Tensor&) const override;
384   std::unique_ptr<ViewFunc> clone_and_set(
385       std::optional<std::vector<c10::SymInt>> = std::nullopt,
386       std::optional<std::vector<at::Tensor>> = std::nullopt) const override;
387 
388  private:
389   std::unique_ptr<ViewFunc> first;
390   std::unique_ptr<ViewFunc> second;
391 };
392 
393 /// ViewFunc that errors with a specified error message when called.
394 struct ErroringViewFunc : public ViewFunc {
ErroringViewFuncErroringViewFunc395   ErroringViewFunc(std::string error_msg) : error_msg(std::move(error_msg)) {}
396   ~ErroringViewFunc() override = default;
operatorErroringViewFunc397   at::Tensor operator()(const at::Tensor&) const override {
398     TORCH_CHECK(false, error_msg);
399   }
400   std::unique_ptr<ViewFunc> clone_and_set(
401       std::optional<std::vector<c10::SymInt>> = std::nullopt,
402       std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
403     return std::make_unique<ErroringViewFunc>(error_msg);
404   }
405 
406  private:
407   std::string error_msg;
408 };
409 
410 struct TORCH_API ViewInfo {
411   /// The base `Variable`
412   /// If this ViewInfo represents a forward (respectively backward) AD gradient,
413   /// then this Tensor cannot be a forward (respectively backward) view.
414   Variable base_;
415 
416   /// By default we use as_strided to recover views which is more efficient.
417   /// view_fn is only saved when as_strided is not supported.
418   /// If view_fn has value, we use it to recover views in backward.
419   std::unique_ptr<ViewFunc> view_fn_;
420 
421   /// Analogue of view_fn but in reverse: given a view -> produce the base by
422   /// applying the inverse view.
423   std::function<Variable(const Variable&)> rev_view_fn_;
424 
425   /// Accessors for the view function
has_view_fnViewInfo426   bool has_view_fn() const {
427     // assume either BOTH or NEITHER of view_fn_ and rev_view_fn_ exist
428     return view_fn_ != nullptr;
429   }
430 
view_fnViewInfo431   const ViewFunc& view_fn() const {
432     TORCH_CHECK(
433         has_view_fn(), "Can only access the view function if it exists.");
434     return *view_fn_;
435   }
436 
rev_view_fnViewInfo437   std::function<Variable(const Variable&)> rev_view_fn() const {
438     TORCH_CHECK(
439         has_view_fn(),
440         "Can only access the reverse view function if it exists.");
441     return rev_view_fn_;
442   }
443 
444   /// The chain function can be used to build a new ViewInfo for a
445   /// differentiable view function. It will return a new view info that
446   /// accurately represents how "tensor" is a view of this instance's "base_".
447   /// The "base" and "tensor" are respectively the input and output of the
448   /// differentiable view function that happened. They are required to properly
449   /// set the optional view_fn_ when it is not provided. The "view_func", if
450   /// provided, should be a function that allows to re-do the view between
451   /// "base" and "tensor".
452   ViewInfo chain(
453       const Variable& base,
454       const Variable& tensor,
455       std::unique_ptr<ViewFunc> view_func = nullptr,
456       std::function<Variable(const Variable&)> rev_view_func = nullptr) const;
457 
ViewInfoViewInfo458   ViewInfo(
459       Variable base,
460       std::unique_ptr<ViewFunc> view_fn,
461       std::function<Variable(const Variable&)> rev_view_fn)
462       : base_(std::move(base)),
463         view_fn_(std::move(view_fn)),
464         rev_view_fn_(std::move(rev_view_fn)) {
465     TORCH_CHECK(base_.defined(), "base is undefined");
466   }
467 };
468 
469 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
470 //                     DifferentiableViewMeta
471 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
472 
473 /// NOTE [ Autograd View Variables ]
474 ///
475 /// Many operations return Variable that shares storage with an input Variable.
476 /// The returned Variable is called a **view** Variable on the input **base**
477 /// Variable.
478 ///
479 /// In PyTorch, we have two types of views: differentiable views, and
480 /// non-differentiable views. In either type, to support proper version
481 /// checking, the base and view Variables must always share the same
482 /// version_counter.
483 ///
484 ///
485 /// Differentiable Views
486 /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
487 /// This class allows to track both forward and backward AD differentiable
488 /// views. These views can have different base as non-differentiable view for
489 /// forward and backward mode AD are not the same.
490 ///
491 /// Most function are either both forward and backward differentiable views (for
492 /// example: view, select, narrow, transpose, etc) or both not forward and not
493 /// backward differentiable views (for example: indices, values, eq, lt, etc).
494 /// But there are also functions that are forward but not backward
495 /// differentiable views (only detach for now) or functions that are backward
496 /// but not forward differentiable view (only make_dual and unpack dual for
497 /// now).
498 ///
499 /// A concrete example of two views with different bases is as follow:
500 ///
501 ///     # Have:
502 ///     #   dual is a dual Tensor that is neither a forward or backward view
503 ///     detached_dual = dual.detach()
504 ///     view = detached_dual.view_as(dual)
505 ///     # The forward base of view is dual
506 ///     # The backward base of view is detached_dual
507 ///
508 /// - Backward Mode View
509 /// Differentiable views are the view variables where you want gradients to flow
510 /// back to the base variables. Out-of-place operations on views are quite
511 /// straightforward, but in-place ones are very tricky. Even if the base
512 /// variable may not require grad when we create the view, we still need to
513 /// track the view relation because future in-place ops may require back-proping
514 /// through it. For example, we need to support
515 ///
516 ///   (1) in-place operation on view, e.g.,
517 ///
518 ///     # Have:
519 ///     #   base.requires_grad = False
520 ///     #   var.requires_grad = True
521 ///     base[1] = var  # i.e., base[1].copy_(var)
522 ///     torch.autograd.grad(base.sum(), var)  <- should return an all ones
523 ///     tensor
524 ///
525 ///   (2) in-place operation on base after view is created, e.g.,
526 ///
527 ///     # Have:
528 ///     #   base.requires_grad = False
529 ///     #   var.requires_grad = True
530 ///     view = base[1]
531 ///     base.copy_(var)
532 ///     torch.autograd.grad(view.sum(), var)  <- should return a tensor with
533 ///                                              var[1] filled with all ones and
534 ///                                              zeros everywhere else
535 ///
536 /// - Forward Mode View
537 /// Forward differentiable views follow the same semantic as backward ones but
538 /// show up differently as they are computed along with the forward evaluation.
539 /// The hard examples above are thus very similar
540 ///
541 ///   (1) in-place operation on view, e.g.,
542 ///
543 ///     # Have:
544 ///     #   base is a regular Tensor
545 ///     #   var is a dual Tensor whose tangent is all ones
546 ///     base[1] = var  # i.e., base[1].copy_(var)
547 ///     # Now, base is a dual Tensor
548 ///     _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
549 ///                                              fw_grad[1] filled with all ones
550 ///                                              and zeros everywhere else
551 ///
552 ///   (2) in-place operation on base after view is created, e.g.,
553 ///
554 ///     # Have:
555 ///     #   base is a regular Tensor
556 ///     #   var is a dual Tensor whose tangent is all ones
557 ///     view = base[1]
558 ///     base.copy_(var)
559 ///     _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones
560 ///     tensor
561 ///
562 /// See Note [Forward Grad View/inplace] for more details on how we handle these
563 /// hard cases.
564 ///
565 ///
566 /// DifferentiableViewMeta is created to support gradient tracking of
567 /// such **in-place** operations. In particular,
568 ///   + if an in-place op is done on base, the grad_fn field of the view may
569 ///     become stale. So accesses should always go through grad_fn(), which
570 ///     reconstructs an updated grad_fn if the version_counter has incremented.
571 ///     All other fields are always valid.
572 ///   + if an in-place op is done on view, in rebase_history() of view, which is
573 ///     called after every in-place op in VariableType.cpp, the grad_fn of base
574 ///     is updated.
575 ///   + if a single autograd Node returns multiple differentiable views, if any
576 ///     output is modified by an inplace operation, the autograd engine will
577 ///     make an equivalent graph (corresponding to the view operations) without
578 ///     using equivalent graph, where each output is treated as if it were
579 ///     produced by a distinct view operation. This discards the original (e.g.,
580 ///     user provided) grad_fn. If the provided grad_fn does more than the
581 ///     backward of the view, then the DifferentiableViewMeta must be created
582 ///     with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the
583 ///     engine from ignoring the provided grad_fn.
584 ///
585 /// Interaction with GradMode:
586 /// The particular case that we consider here is:
587 ///
588 ///     # Have:
589 ///     #   base.requires_grad = True or False
590 ///     with torch.no_grad():
591 ///         view = base[1]
592 ///     base.requires_grad_()
593 ///     view.copy_(var)
594 ///     torch.autograd.grad(base.sum(), var)  <- what should it return?
595 ///
596 /// Given that this particular code example is ambiguous and can easily be
597 /// replace by either moving both inside the no_grad block or both outside, we
598 /// explicitly forbid it. For now, it is deprecated by a warning. This is
599 /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all
600 /// differentiable views created in no_grad mode.
601 ///
602 /// See Note [View + Inplace update for base tensor]
603 /// and Note [View + Inplace update for view tensor] for the details how
604 /// autograd handles inplace update with view ops.
605 ///
606 /// Non-Differentiable Views
607 /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
608 /// In certain cases, although function outputs share storage with inputs, they
609 /// will **never** require gradient history tracking. Instead of registering the
610 /// view relation via DifferentiableViewMeta in autograd, the views will be
611 /// using usual AutogradMeta and just share the version counters with the base
612 /// Variables.
613 /// Such views include:
614 ///   1. Views created from .detach()
615 ///   2. Views that are non-differentiable by its nature.
616 ///      E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
617 ///      floating point tensor.
618 ///      See top of `derivatives.yaml` on how to specify that outputs of a
619 ///      function are non-differentiable.
620 /// These are called non-differentiable views as the gradients do not flow
621 /// through the view relation.
622 ///
623 /// Relevant logic for both differentiable and non-differentiable views is
624 /// implemented in make_variable_(non_)differentiable_view below, and
625 /// wrap_output of gen_variable_type.py.
626 
627 /// NOTE [ View + Inplace detection ]
628 ///
629 /// We want to detect views followed by inplace as they are often forbidden to
630 /// ensure correctness of the computed gradients. But since we want to only
631 /// notify the user when both happen, we tag the DifferentiableViewMeta when the
632 /// view is created via the `make_variable_*_view()` functions. This tag is then
633 /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that
634 /// should be called before every inplace operation and to detect cases where
635 /// other views are modified and this one is rebased by side effect, we also
636 /// check in the `VariableHooks::grad_fn()`.
637 
638 /// Flag that gives more information about when this view was created:
639 /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom
640 ///   autograd Function is returned.
641 /// - NO_GRAD_MODE should be set when a view in created when GradMode is
642 /// disabled
643 /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code
644 /// returns
645 ///   multiple differentiable views
646 /// - Inference_MODE should be set when a view of normal tensor is created in
647 /// InferenceMode.
648 /// - DEFAULT is for all other cases
649 enum class CreationMeta : uint8_t {
650   DEFAULT,
651   IN_CUSTOM_FUNCTION,
652   MULTI_OUTPUT_NODE,
653   NO_GRAD_MODE,
654   INFERENCE_MODE
655 };
656 
657 /// Handles correctly propagating CreationMeta when a new view is created from a
658 /// previous view. In general, we don't want the new view to be _less_
659 /// restrictive than the previous view (it's okay to be _more_ restrictive). A
660 /// CreationMeta value of DEFAULT is currently the least restrictive, as the
661 /// behavior for all other CreationMeta values is to error out for in-place ops.
662 /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so
663 /// it takes precedence in propagation. If this changes, the logic here will
664 /// need to be updated to properly handle the new semantics.
propagate_creation_meta(CreationMeta prev_view_creation_meta,CreationMeta new_view_creation_meta)665 inline CreationMeta propagate_creation_meta(
666     CreationMeta prev_view_creation_meta,
667     CreationMeta new_view_creation_meta) {
668   return (new_view_creation_meta == CreationMeta::DEFAULT)
669       ? prev_view_creation_meta
670       : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
671              ? prev_view_creation_meta
672              : new_view_creation_meta);
673 }
674 
675 /// Unified function to handle error checking when rebase happens
676 /// indirect=true means that the caller is not doing the inplace, but the
677 /// inplace happened somewhere else.
678 TORCH_API void handle_view_on_rebase(
679     DifferentiableViewMeta* diff_view_meta,
680     bool indirect = false);
681 
682 struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
683  private:
684   /// Information about the views
685   std::optional<ViewInfo> backward_info_;
686   std::optional<ViewInfo> forward_info_;
687 
688   // Optimization to reduce the number of ViewInfo we create.
689   // In the (very common) case where backward_info_ == forward_info_, we only
690   // populate backward_info_ (that should be used as both the forward and
691   // backward view information) and set shared_view_info_ = true. Invariants:
692   //   - If shared_view_info_ is false, there is no special constraints on
693   //     backward_info_ and forward_info_
694   //   - If shared_view_info_ is true, we must have:
695   //      - backward_info_.has_value() == true
696   //      - forward_info_.has_value() == false
697   bool shared_view_info_;
698 
699   /// The two following fields are extra information that we track to ensure
700   /// that any operation on this backward view is valid.
701 
702   /// The value of the version_counter at the time grad_fn was created. The
703   /// grad_fn field is stale if attr_version_ !=
704   /// version_counter.current_version().
705   uint32_t attr_version_;
706   CreationMeta creation_meta_;
707 
708  public:
709   /// requires_grad is a backward AD field so we only use the view specific
710   /// logic for backward differentiable views
requires_gradDifferentiableViewMeta711   bool requires_grad() const override {
712     return requires_grad_ || grad_fn_ ||
713         (has_bw_view() && get_backward_view().base_.requires_grad());
714   }
715 
shared_view_infoDifferentiableViewMeta716   bool shared_view_info() const {
717     return shared_view_info_;
718   }
719 
has_bw_viewDifferentiableViewMeta720   bool has_bw_view() const {
721     return backward_info_.has_value();
722   }
723 
get_backward_viewDifferentiableViewMeta724   const ViewInfo& get_backward_view() const {
725     TORCH_CHECK(
726         has_bw_view(), "backward view info can only exist for backward views.");
727     return backward_info_.value();
728   }
729 
get_attr_versionDifferentiableViewMeta730   uint32_t get_attr_version() const {
731     TORCH_CHECK(
732         has_bw_view(), "attr_version can only exist for backward views.");
733     return attr_version_;
734   }
735 
set_attr_versionDifferentiableViewMeta736   void set_attr_version(uint32_t new_attr_version) {
737     TORCH_CHECK(
738         has_bw_view(), "attr_version can only exist for backward views.");
739     attr_version_ = new_attr_version;
740   }
741 
get_creation_metaDifferentiableViewMeta742   CreationMeta get_creation_meta() const {
743     TORCH_CHECK(
744         has_bw_view(), "creation_meta can only exist for backward views.");
745     return creation_meta_;
746   }
747 
set_creation_metaDifferentiableViewMeta748   void set_creation_meta(CreationMeta new_creation_meta) {
749     TORCH_CHECK(
750         has_bw_view(), "creation_meta can only exist for backward views.");
751     creation_meta_ = new_creation_meta;
752   }
753 
has_fw_viewDifferentiableViewMeta754   bool has_fw_view() const {
755     return shared_view_info_ || forward_info_.has_value();
756   }
757 
get_forward_viewDifferentiableViewMeta758   const ViewInfo& get_forward_view() const {
759     TORCH_CHECK(
760         has_fw_view(), "forward view info can only exist for forward views.");
761     TORCH_CHECK(
762         !shared_view_info_ || has_bw_view(),
763         "forward view info can only exist for forward views.");
764     return shared_view_info_ ? backward_info_.value() : forward_info_.value();
765   }
766 
767   DifferentiableViewMeta(
768       at::TensorImpl* self_impl,
769       std::optional<ViewInfo> backward_info,
770       std::optional<ViewInfo> forward_info,
771       bool shared_view_info,
772       CreationMeta creation_meta = CreationMeta::DEFAULT);
773 };
774 
775 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
776 //                        Variable Implementation
777 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
778 
779 // Factory Functions
780 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
781 
782 /// Creates a `Variable` that is a *view* of another (*base*) variable.
783 /// The `gradient_edge` is an optional (gradient_function, input_number) pair.
784 /// `is_differentiable` is a bool that specifies whether this view is
785 /// differentiable, i.e., whether the relation should be tracked by autograd.
786 /// See NOTE [ Autograd View Variables ] for details.
787 
788 /// NOTE: `allow_tensor_metadata_change` is set to true by default, because
789 /// there are a lot of call sites to these factory functions that need to change
790 /// the variable's size or storage afterwards, and they don't expect the
791 /// original tensor (where the variable is created from) to be updated. Setting
792 /// `allow_tensor_metadata_change_` to false by default would unnecessarily
793 /// prevent those changes from happening and is undesirable.
794 
795 // See NOTE [ Autograd View Variables ] for details.
796 // Differentiable view. Track history with DifferentiableViewMeta.
797 inline Variable make_variable_differentiable_view(
798     const at::Tensor& data,
799     std::optional<ViewInfo> backward_info,
800     std::optional<ViewInfo> forward_info,
801     bool shared_view_info,
802     CreationMeta creation_meta,
803     bool allow_tensor_metadata_change = true) {
804   if (data.defined()) {
805     TORCH_CHECK(
806         data.getIntrusivePtr()->autograd_meta() == nullptr,
807         "Attempted to make a tensor into a differentiable view, but the "
808         "tensor already had autograd metadata associated with it.  If you are "
809         "using a __torch_dispatch__ mode, the most common cause for this "
810         "problem is that you used torch.overrides.enable_reentrant_dispatch() "
811         "improperly; tensors created within the extent of reentrant dispatch "
812         "MUST NOT be directly returned from __torch_dispatch__; instead, they "
813         "must be wrapped into fresh tensors that serve as the output.  If you "
814         "are not using wrappers, you probably don't need reentrant dispatch.  "
815         "If this doesn't seem applicable, please file a bug to PyTorch.");
816     at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
817     data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
818     data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
819         data_impl,
820         std::move(backward_info),
821         std::move(forward_info),
822         shared_view_info,
823         creation_meta));
824     return data;
825   }
826   return Variable();
827 }
828 
829 // See NOTE [ Autograd View Variables ] for details.
830 // Non-differentiable view. Just share version counter.
831 inline Variable make_variable_non_differentiable_view(
832     const Variable& base,
833     const at::Tensor& data,
834     bool allow_tensor_metadata_change = true) {
835   if (data.defined()) {
836     // Currently all of non-differentiable view ops(detach/_indices/_values)
837     // share the same TensorImpl as their base Tensor. Thus a new TensorImpl
838     // allocation here is required.
839     auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
840         /*version_counter=*/impl::version_counter(base),
841         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
842     data_impl_copy->set_autograd_meta(nullptr);
843     return Variable(data_impl_copy);
844   }
845   return Variable();
846 }
847 
848 /// Creates a `Variable` from the given `Tensor`, copying its underlying
849 /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines
850 /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be
851 /// a `Variable` already. Its dynamic type *must* be `Tensor`.
852 ///
853 /// TODO: Eliminate this function as much as possible, as it can be expressed
854 /// more clearly as detach() or a no-op in most call sites (especially when
855 /// there is only one use of the variable).
856 inline Variable make_variable(
857     at::Tensor data,
858     bool requires_grad = false,
859     bool allow_tensor_metadata_change = true) {
860   if (data.defined()) {
861     if (data.getIntrusivePtr().use_count() == 1 &&
862         data.getIntrusivePtr()->unique_version()) {
863       auto data_impl = data.unsafeReleaseIntrusivePtr();
864       data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
865       if (requires_grad) {
866         data_impl->set_autograd_meta(
867             std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
868       } else {
869         data_impl->set_autograd_meta(nullptr);
870       }
871       return Variable(std::move(data_impl));
872     } else {
873       auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
874           /*version_counter=*/0,
875           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
876       if (requires_grad) {
877         data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
878             data_impl_copy.get(), requires_grad));
879       } else {
880         data_impl_copy->set_autograd_meta(nullptr);
881       }
882       return Variable(data_impl_copy);
883     }
884   }
885   return Variable();
886 }
887 
888 /// Creates a `Variable` from the given `Tensor`, copying its underlying
889 /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair
890 /// specifying the function in the autograd graph, and what particular input of
891 /// that function, this variable is connected to.
892 inline Variable make_variable(
893     const at::Tensor& data,
894     Edge gradient_edge,
895     bool allow_tensor_metadata_change = true) {
896   if (data.defined()) {
897     auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
898         /*version_counter=*/0,
899         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
900     data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
901         data_impl_copy.get(), false, std::move(gradient_edge)));
902     return Variable(data_impl_copy);
903   }
904   return Variable();
905 }
906 
907 struct VariableHooks final : at::impl::VariableHooksInterface {
908   at::TensorBase tensor_data(const at::TensorBase&) const override;
909   at::TensorBase variable_data(const at::TensorBase&) const override;
910   const std::shared_ptr<torch::autograd::Node>& grad_fn(
911       const at::TensorBase&) const override;
912   unsigned _register_hook(
913       const at::TensorBase&,
914       std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
915   void remove_hook(const at::TensorBase&, unsigned pos) const override;
916   bool is_view(const at::TensorBase&) const override;
917   const at::TensorBase& base(const at::TensorBase&) const override;
918   const std::string& name(const at::TensorBase&) const override;
919   bool is_leaf(const at::TensorBase&) const override;
920   int64_t output_nr(const at::TensorBase&) const override;
921   void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
922       const override;
923   at::TensorBase data(const at::TensorBase& self) const override;
924   int64_t _version(const at::TensorBase& self) const override;
925   void retain_grad(const at::TensorBase& self) const override;
926   bool retains_grad(const at::TensorBase& self) const override;
927   void _backward(
928       const at::Tensor& self,
929       at::TensorList inputs,
930       const std::optional<at::Tensor>& gradient,
931       std::optional<bool> keep_graph,
932       bool create_graph) const override;
933   void requires_grad_(const at::TensorBase& self, bool _requires_grad)
934       const override;
935   void basic_autograd_not_implemented_fallback(
936       const c10::OperatorHandle& op,
937       c10::DispatchKeySet dispatch_keys,
938       torch::jit::Stack* stack) const override;
939 };
940 
941 namespace utils {
942 
943 TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
944 
945 } // namespace utils
946 } // namespace torch::autograd
947 
948 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
949