xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/grad_mode.h>
6 #include <torch/csrc/autograd/graph_task.h>
7 #include <torch/csrc/autograd/input_metadata.h>
8 #include <torch/csrc/autograd/saved_variable.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/variadic.h>
12 
13 #include <ATen/SequenceNumber.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/record_function.h>
16 #include <c10/util/Exception.h>
17 #include <c10/util/irange.h>
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <initializer_list>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 namespace torch::autograd {
28 
29 struct Edge;
30 struct FunctionPostHook;
31 struct FunctionPreHook;
32 
33 using tensor_list = std::vector<at::Tensor>;
34 using variable_list = std::vector<Variable>;
35 using edge_list = std::vector<Edge>;
36 using saved_variable_list = std::vector<SavedVariable>;
37 using IndexRange = std::pair<size_t, size_t>;
38 using torch::dynamo::autograd::CompiledNodeArgs;
39 using torch::dynamo::autograd::SwapSavedVariables;
40 
41 // Custom deleter to prevent stack overflows.
42 TORCH_API void deleteNode(Node* function);
43 
44 // Guard that sets and restores the evaluating node
45 class NodeGuard {
46  public:
47   explicit NodeGuard(std::shared_ptr<Node> node);
48   ~NodeGuard();
49 
50  private:
51   std::shared_ptr<Node> last_evaluating_node_;
52 };
53 
54 // Return the Node currently being evaluated (if any)
55 // This is only set during the backward pass while a Node is being
56 // executed.
57 TORCH_API std::shared_ptr<Node> get_current_node();
58 
59 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 //                               Node
61 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62 // A `Node` is an abstract class that represents an operation taking zero
63 // or more input `Variable`s and producing zero or more output `Variable`s. All
64 // functions in PyTorch's autograd machinery derive from this class and
65 // override its `apply` method. Instances of such subclasses will then be
66 // invokable via the call operator.
67 //
68 //                    Nodes in the Autograd Graph
69 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70 // When viewing the autograd system as a graph, `Node`s are the vertices or
71 // nodes, connected to each other via (directed) `Edge`s, which themselves are
72 // represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to
73 // and inputs of `Node`s, and travel between these edges during execution
74 // of the graph. When two or more `Edge`s (from different sources) point at the
75 // same input to a `Node`, the values produced along all of these edges are
76 // implicitly summed prior to being forwarded to the target `Node`.
77 //
78 //                              Hierarchy
79 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80 // Subclasses usually represent differentiable functions as well as their
81 // gradient operators. Note, however, that due to the very general definition
82 // of a `Node` taking *zero* or more inputs and producing *zero* or more
83 // outputs, uses of `Node`s are flexible and extend beyond purely
84 // mathematical operations. For example, the `AccumulateGrad` function is a
85 // *sink*: it takes one input, but produces no outputs, instead accumulating
86 // the input as a side effect. At the other extreme, the `GraphRoot` function
87 // receives no inputs from other functions, but produces multiple outputs.
88 //
89 //                              Interface
90 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91 // The most important method on `Node` is the call operator, which takes in
92 // a list of variables and produces a list of variables. The precise size of
93 // these lists can be determined with `num_inputs()` and `num_outputs()`.
94 // `Node`s are stitched together via their `next_edge` interface, which let
95 // you manipulate the set of outgoing edges of a `Node`. You can add an
96 // edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and
97 // iterate over them via the `next_edges()` method. Other methods exist for
98 // integration with the JIT and other parts of PyTorch. Every `Node` has a
99 // *sequence number* that increases monotonically in the order of `Node`
100 // construction. It can be retrieved via the `sequence_nr()` method. Note that
101 // this sequence number is *thread local*. This means that when `Node`s
102 // `A`, `B` and `C` are created consecutively in the same thread, their
103 // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
104 // are created in one thread and `C` is created in a new thread, there are *no
105 // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
106 // See NOTE [ Sequence Number] for more details on the usages of sequence
107 // number.
108 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
109 struct TORCH_API Node : std::enable_shared_from_this<Node> {
110  public:
111   /// Construct a new `Node` with the given `next_edges`
112   explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list())
sequence_nr_Node113       : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) {
114     for (const Edge& edge : next_edges_) {
115       update_topological_nr(edge);
116     }
117 
118     if (AnomalyMode::is_enabled()) {
119       metadata()->store_stack();
120 
121       // If anomaly mode is enabled and graph is constructed, then assign the
122       // currently evaluating node as the parent of this node.
123       // A parent is a Node where this Node is created.
124       // We are tracking the parents to track multiple backward operations.
125       assign_parent();
126     }
127 
128     // Store the thread_id of the forward operator.
129     // See NOTE [ Sequence Numbers ]
130     thread_id_ = at::RecordFunction::currentThreadId();
131   }
132 
133   explicit Node(edge_list&& next_edges = edge_list())
NodeNode134       : Node(
135             /*sequence_nr=*/at::sequence_number::get_and_increment(),
136             std::move(next_edges)) {}
137 
138   /// Nodes are neither copyable nor moveable.
139   Node(const Node& other) = delete;
140   Node(Node&& other) = delete;
141   Node& operator=(const Node& other) = delete;
142   Node& operator=(Node&& other) = delete;
143   virtual ~Node() = default;
144 
getptrNode145   std::shared_ptr<Node> getptr() {
146     return shared_from_this();
147   }
148   /// Evaluates the function on the given inputs and returns the result of the
149   /// function call.
operatorNode150   variable_list operator()(variable_list&& inputs) {
151     // In the first iteration of named tensors, autograd ignores names and
152     // operates on unnamed tensors. In the long term, autograd should
153     // probably operate with names.
154     at::NoNamesGuard no_names_guard;
155 
156 #ifdef USE_ROCM
157     // Keep track of backward pass for rocblas.
158     at::ROCmBackwardPassGuard in_backward;
159 #endif
160 
161     auto step_callbacks =
162         at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
163     if (C10_UNLIKELY(step_callbacks.has_value())) {
164       at::RecordFunction guard(std::move(*step_callbacks));
165       // Using sequence number and thread id to correlate with
166       // the forward pass function
167       guard.setForwardThreadId(thread_id_);
168       if (guard.needsInputs()) {
169         std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end());
170         guard.before(
171             name(),
172             c10::ArrayRef<const c10::IValue>(
173                 inputs_vec.data(), inputs_vec.size()),
174             static_cast<int64_t>(sequence_nr()));
175       } else {
176         guard.before(name(), static_cast<int64_t>(sequence_nr()));
177       }
178       return apply(std::move(inputs));
179     } else {
180       return apply(std::move(inputs));
181     }
182   }
183 
184   // Graph Connectivity API
185   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
186 
187   // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
188   // forward function.
189 
190   // Marker for expected undefined input
191   struct undefined_input {};
192 
193   /// Adds the type and shape metadata for a new input. Returns the index of
194   /// of the new input.
add_input_metadataNode195   uint32_t add_input_metadata(
196       const at::TensorOptions& options,
197       c10::SymIntArrayRef shape,
198       bool is_tensor_subclass,
199       bool is_nested) noexcept {
200     uint32_t input_nr = input_metadata_.size();
201     auto meta_shape = MetadataShape{std::in_place_type<SymIntSmallVec>, shape};
202     input_metadata_.emplace_back(
203         options, meta_shape, is_tensor_subclass, is_nested);
204     return input_nr;
205   }
206 
add_input_metadataNode207   uint32_t add_input_metadata(const at::Tensor& t) noexcept {
208     uint32_t input_nr = input_metadata_.size();
209     input_metadata_.emplace_back(t);
210     return input_nr;
211   }
212 
213   /// Adds a placeholder for an input that will not be used.
add_input_metadataNode214   uint32_t add_input_metadata(undefined_input u) noexcept {
215     uint32_t input_nr = input_metadata_.size();
216     input_metadata_.emplace_back();
217     return input_nr;
218   }
219 
num_inputsNode220   uint32_t num_inputs() const noexcept {
221     return input_metadata_.size();
222   }
223 
input_metadataNode224   const InputMetadata& input_metadata(size_t index) const {
225     return input_metadata_[index];
226   }
227 
228   // Danger: not thread safe, caller must protect with lock
mutable_input_metadataNode229   InputMetadata& mutable_input_metadata(size_t index) {
230     return input_metadata_[index];
231   }
232 
233   /**
234    * Note: Function Streams
235    * A function's stream (for a given device type) is the stream of the first
236    * element of its input buffer on a device of that type.
237    *
238    * If all elements are on the same device they MUST share a stream. If
239    * elements are on different devices (across multiple GPUs, for example)
240    * they may have different streams.
241    */
streamNode242   std::optional<c10::Stream> stream() {
243     auto opt_device_type = at::getAccelerator();
244     if (!opt_device_type.has_value()) {
245       return std::nullopt;
246     }
247     for (const auto& metadata : input_metadata_) {
248       if (metadata.device().type() == opt_device_type.value())
249         return metadata.stream();
250     }
251 
252     return std::nullopt;
253   }
254 
clear_input_metadataNode255   void clear_input_metadata() {
256     input_metadata_.clear();
257   }
258 
259   // Outputs ("Next Edges")
260 
update_topological_nrNode261   void update_topological_nr(const Edge& edge) {
262     TORCH_INTERNAL_ASSERT(
263         !has_parent_,
264         "Cannot update a node's topological_nr after it already has a parent."
265         " If we allow this, we can no longer guarantee that a parent's"
266         " topo_nr is always greater than those of all its children")
267     Node* node = edge.function.get();
268     if (node) {
269       auto topo_nr = node->topological_nr();
270       if (topological_nr_ <= topo_nr) {
271         topological_nr_ = topo_nr + 1;
272       }
273     }
274   }
275 
set_next_edgeNode276   void set_next_edge(size_t index, Edge edge) {
277     update_topological_nr(edge);
278     next_edges_[index] = std::move(edge);
279   }
280 
add_next_edgeNode281   void add_next_edge(Edge edge) {
282     update_topological_nr(edge);
283     next_edges_.emplace_back(std::move(edge));
284   }
285 
set_next_edgesNode286   void set_next_edges(edge_list&& next_edges) {
287     next_edges_ = std::move(next_edges);
288     for (const auto& next_edge : next_edges_) {
289       update_topological_nr(next_edge);
290     }
291   }
292 
next_edgeNode293   const Edge& next_edge(size_t index) const noexcept {
294     return next_edges_[index];
295   }
296 
next_edgesNode297   const edge_list& next_edges() const noexcept {
298     return next_edges_;
299   }
300 
next_edgesNode301   edge_list& next_edges() noexcept {
302     return next_edges_;
303   }
304 
num_outputsNode305   uint32_t num_outputs() const noexcept {
306     return next_edges_.size();
307   }
308 
309   // Miscellaneous Methods
310   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
311 
312   /// NOTE [ Sequence Number]
313   ///
314   /// The sequence_nr has two main usages in autograd:
315   ///
316   /// 1) Helps determine the node's execution priority in the engine.
317   ///    All else being equal, nodes with higher priority numbers are executed
318   ///    first. Thus, nodes corresponding to ops executed later are the first to
319   ///    be executed in the backward pass. One caveat is that we prioritize
320   ///    AccumulateGrad nodes by explicitly setting its sequence_nr to be
321   ///    UINT64_MAX.
322   /// 2) The sequence number of this `Node` is paired with with thread_id it was
323   /// created in
324   ///    as a unique identifier by the profiler to annotate recorded events.
325   ///    The purpose of this is to help users (and possibly programs)
326   ///    interpreting the profiler's output to correlate backward nodes with its
327   ///    forward ops. We need both sequence_nr and thread_id to identify a node
328   ///    because sequence_nr is thread_local, i.e., starts counting up from zero
329   ///    in a new thread
sequence_nrNode330   uint64_t sequence_nr() const noexcept {
331     return sequence_nr_;
332   }
333 
set_sequence_nrNode334   void set_sequence_nr(uint64_t sequence_nr) {
335     sequence_nr_ = sequence_nr;
336   }
337 
338   // NOTE [ Topological Number ]
339   //
340   // topological_nr is used to prune branches in the DAG during autograd
341   // discovery as maintaining topological_nr helps us check in O(1) if there
342   // does NOT exist a directed path between two nodes.
343   //
344   // The topological order number of this `Node` representing the length of the
345   // longest possible path from this Node to any leaf node. If you are leaf
346   // node, aka AccumulateGrad, this will be zero. This value has the property
347   // that For every pair of nodes X, Y in G, existence of a directed path from X
348   // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so
349   // we cannot prove existence of a path from X to Y, only non-existence.
350   //
351   // One assumption we make when using topo_nr is that once a node
352   // has been used, i.e., has a parent node, its own topo_nr does not change
353   // we have added some checks with the `has_parent_` field to enforce this.
354   //
355   // What NOT to do:
356   //
357   //   1) 2 -> 1 -> 0               In this diagram we label nodes with their
358   //   topo_nr.
359   //      2 -> 1 -> 0               We have two simple graphs that can each
360   //      arise from
361   //                                `t.exp().exp()`, for example.
362   //   2)        2 -> 1 -> 0
363   //            /
364   //      2 -> 1 -> 0               We add 2 as a next edge to 1 even though 1
365   //      already
366   //                                has a parent.
367   //   3)        2 -> 1 -> 0
368   //            /
369   //      2 -> 3 -> 0               2 < 3, yet there exists a path from 2 to 3!
370   //
topological_nrNode371   uint64_t topological_nr() const noexcept {
372     has_parent_ = true;
373     return topological_nr_;
374   }
375 
376   // assigning a node as a parent to this node
377   void assign_parent();
378 
379   /// Id of the thread that created Node
thread_idNode380   uint64_t thread_id() const noexcept {
381     return thread_id_;
382   }
383 
384   /// Returns the name of the dynamic type of the function, for debugging.
385   virtual std::string name() const;
386 
387   /// The difference between functions `should_compute_output` and
388   /// `task_should_compute_output`:
389   /// - `should_compute_output` should only be used during graph construction
390   /// and takes into account only requires_grad information
391   /// - `task_should_compute_output` should only be called during the backward
392   /// pass (unless called directly through grad_fn) and takes into account the
393   /// current graph task.  Specifically, the autograd engine trims unnecessary
394   /// edges when `inputs` are specified, and during backward untrimmed nodes
395   /// left on the graph can/should check `task_should_compute_output` to see if
396   /// any outgoing edges have been trimmed by the engine. If that is the case,
397   /// gradient computation wrt those edges can be omitted.
398   ///
399   /// Returns true if the particular output edge is active, and that particular
400   /// output of this function should be computed.
should_compute_outputNode401   bool should_compute_output(size_t output_edge_index) const {
402     TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
403     return next_edges_[output_edge_index].is_valid();
404   }
405 
406   /// Returns true if any of the output edges in any of the ranges are active.
should_compute_outputNode407   bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
408     return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
409       for (const auto i : c10::irange(range.first, range.second)) {
410         if (should_compute_output(i))
411           return true;
412       }
413       return false;
414     });
415   }
416 
417   /// Same as the above `should_compute_output` function but will also
418   /// check whether this edge is needed within the current graph task.
task_should_compute_outputNode419   bool task_should_compute_output(size_t output_edge_index) const {
420     TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
421     const auto& next = next_edges_[output_edge_index];
422     if (next.is_valid()) {
423       const auto exec_info = get_current_graph_task_exec_info();
424       if (exec_info && !exec_info->empty()) {
425         auto it = exec_info->find(next.function.get());
426         if (it == exec_info->end() || !it->second.should_execute()) {
427           return false; // this edge is not needed for the current graph_task
428         }
429       }
430       return true;
431     }
432     return false;
433   }
434 
435   /// Returns true if any of the output edges in any of the ranges are active
436   /// and should be computed in the current graph task.
task_should_compute_outputNode437   bool task_should_compute_output(
438       std::initializer_list<IndexRange> idxs) const {
439     return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
440       for (const auto i : c10::irange(range.first, range.second)) {
441         if (task_should_compute_output(i))
442           return true;
443       }
444       return false;
445     });
446   }
447 
448   /// Returns the `PyObject` stored for this `Node` (for Python
449   /// interaction).
pyobjNode450   PyObject* pyobj() const noexcept {
451     return pyobj_;
452   }
453 
454   /// Sets the `PyObject` stored for this `Node` (for Python interaction).
set_pyobjNode455   void set_pyobj(PyObject* pyobj) noexcept {
456     pyobj_ = pyobj;
457   }
458 
459   /// Returns the anomaly metadata stored for this `Node`.
460   /// If none exist, creates a new empty one.
461   AnomalyMetadata* metadata() noexcept;
462 
463   // Hook API
464   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
465 
add_post_hookNode466   uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
467     post_hooks_.emplace_back(std::move(post_hook));
468     // Use the raw pointer as the unique key to identify this hook. This key
469     // can then be used in del_post_hook(key) to remove this hook.
470     return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
471   }
472 
post_hooksNode473   const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks()
474       const noexcept {
475     return post_hooks_;
476   }
477 
478   // delete a post hook matching the key
del_post_hookNode479   bool del_post_hook(const uintptr_t& key) {
480     for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) {
481       if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
482         post_hooks_.erase(it);
483         return true;
484       }
485     }
486     return false;
487   }
488 
post_hooksNode489   std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
490     return post_hooks_;
491   }
492 
add_pre_hookNode493   void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
494     pre_hooks_.emplace_back(std::move(pre_hook));
495   }
496 
add_tensor_pre_hookNode497   void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
498     tensor_pre_hooks_.emplace_back(std::move(pre_hook));
499   }
500 
add_retains_grad_hookNode501   void add_retains_grad_hook(
502       std::unique_ptr<FunctionPreHook>&& pre_hook,
503       size_t output_idx) {
504     retains_grad_hooks_[output_idx] = std::move(pre_hook);
505   }
506 
pop_retains_grad_hookNode507   std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(size_t output_idx) {
508     auto ret = std::move(retains_grad_hooks_[output_idx]);
509     retains_grad_hooks_.erase(output_idx);
510     return ret;
511   }
512 
pre_hooksNode513   const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
514       const noexcept {
515     return pre_hooks_;
516   }
517 
pre_hooksNode518   std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
519     return pre_hooks_;
520   }
521 
522   virtual std::vector<std::unique_ptr<FunctionPreHook>>&
tensor_pre_hooksNode523   tensor_pre_hooks() noexcept {
524     return tensor_pre_hooks_;
525   }
526 
527   virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooksNode528   tensor_post_acc_grad_hooks() noexcept {
529     static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
530     return empty;
531   }
532 
533   std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooksNode534   retains_grad_hooks() noexcept {
535     return retains_grad_hooks_;
536   }
537 
538   // Customization Points for Subclasses
539   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
540 
541   /// Releases saved variables if the operation won't be reused.
release_variablesNode542   virtual void release_variables() {}
543 
544   /// Called before an apply if `release_variables()` is going to be called.
545   /// Allows larger ops like `InterpreterAutogradFunction` to incrementally
546   /// release variables as they run.
will_release_variablesNode547   virtual void will_release_variables() {}
548 
549   /// Returns true if this function is traceable. An op is traceable if all
550   /// operations happening within `apply()` are performed on autograd
551   /// `Variables` (i.e. apply mostly instantiates and applies other functions).
is_traceableNode552   virtual bool is_traceable() {
553     return false;
554   }
555 
556   /// A `Node` is said to pass state transparently to backward, if the
557   /// state consists only of (Saved)Variables and only non-variable objects
558   /// that parameterize the operation in some way that defines the graph
559   /// structure AND the backward function is traceable. In particular,
560   /// parametrization MUST NOT depend on the data of any `Variable`.
561   /// TODO: it might be possible to handle cases where backward is
562   /// non-traceable but state passing could be considered transparent. This
563   /// will probably depend on saved_variable_list being mutable.
564   /// NOTE: this value matters only if is_traceable() returns false.
passes_state_transparentlyNode565   virtual bool passes_state_transparently() {
566     return false;
567   }
568 
569   // see [Note: Compiled Autograd]
570   // Used by compiled autograd to
571   //   1) Extract tensors/symint args
572   //   2) Collect node information for specialization and caching
573   // Implementations in subclasses should call args.collect() with all node
574   // attrs. These functions are only called durring backward.
compiled_argsNode575   virtual void compiled_args(CompiledNodeArgs& args) {
576     throw std::runtime_error(
577         std::string("compiled_args not implemented: ") + name());
578   }
579 
580   // Used by compiled autograd to call apply() with different saved tensors
581   // Implementations should call saved.before() on all attrs, then apply(), then
582   // saved.after() on all attrs in the same order.
apply_with_savedNode583   virtual variable_list apply_with_saved(
584       const variable_list& inputs,
585       SwapSavedVariables& saved) {
586     throw std::runtime_error(
587         std::string("apply_with_saved not implemented: ") + name());
588   }
589 
590  protected:
591   /// Performs the `Node`'s actual operation.
592   virtual variable_list apply(variable_list&& inputs) = 0;
593 
594   /// Calls `apply()`, but instruments it with tracing machinery.
595   variable_list traced_apply(variable_list inputs);
596 
597   // Sequence number used to correlate backward nodes with forward ops in the
598   // profiler and provide determinism in the engine.
599   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
600   uint64_t sequence_nr_;
601 
602   // See NOTE [ Topological Number ]
603   uint64_t topological_nr_ = 0;
604 
605   // Tracks whether this node has been added as the next_edge of another node
606   // via set_next_edge(s), which always calls topological_nr() of all its
607   // children See NOTE [ Topological Number ] for why we need this.
608   mutable bool has_parent_ = false;
609 
610   // Id of the thread that created the instance
611   uint64_t thread_id_ = 0;
612 
613   // Note [Thread Safety on Autograd Node]
614   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
615   // Autograd Engine let the owning thread which calls Engine::execute to drive
616   // the GraphTask execution, there might be cases that part of the GraphTask is
617   // shared across different `backward()` or `grad()` calls, i.e. fork new
618   // threads in the middle of the forward and call `backward()` separately from
619   // different threads. We need to protect the thread safety on NodeTask to
620   // prevent data racing on shared variables read/write.
621   //
622   // NB: This is only needed for Autograd Nodes that runs on CPU, technically
623   // "CUDA", "XLA" nodes don't need locking because device threads are always
624   // single threaded.
625   //
626   // Here we add a thread mutex to help protect the Node's thread safety, so
627   // that different threads cannot race the shared data when executing the same
628   // NodeTask from multiple CPU threads. It IS the user/developer responsibility
629   // to take advantage of this mutex to protect the thread safety of their
630   // autograd Node. The general strategy of thread safety on autograd Node:
631   //
632   // 1. User should lock the mutex during Node::release_variables() if the Node
633   // needs
634   //    to release the variables on the fly, this serve the purpose that when we
635   //    release saved_variables from one thread, no other threads can release
636   //    the saved variables concurrently. call the Node::apply(),
637   // 2. User should lock the mutex during Node::apply(), this is to ensure Node
638   // that
639   //    writing to the shared variable are not racing across threads (i.e.
640   //    AccumulateGrad and custom C++ Autograd Node if writing to shared
641   //    variables )
642   // 3. item 2 and item 3 should work together so that when we release saved
643   // variables
644   //    from one thread, no other threads can call Node::apply(), this ensures
645   //    the variable references from other threads aren't dangling.
646   // 4. if the Node don't release any variables and no shared data read/write in
647   // the Node
648   //    i.e. purely functional, user don't need to lock the mutex
649   //
650   // This way we could protect the thread safety on Autograd Node, but we could
651   // still not protect the thread safety on Node pre/post C++ hooks (python
652   // hooks are automatically thread safe), we rely on the user to write thread
653   // safe C++ hooks if they want the hook to be correctly applied in
654   // multithreading environment.
655   std::mutex mutex_;
656 
657   edge_list next_edges_;
658   PyObject* pyobj_ = nullptr; // weak reference
659   std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
660 
661   // NOTE [Hooks ordering]
662   // We have 3 separate fields for pre hooks registered to the autograd nodes
663   // because the conditions under which they execute are different, and we
664   // want more fine-grained control over the order in which different types
665   // of hooks are executed.
666   // - pre_hooks  are only executed when the node itself is executed
667   // - tensor_pre_hook is executed as long as the engine traverses over it
668   //   even if that node won't be executed.
669   // - retains_grad_hook are like tensor_pre_hooks except they are always
670   //   ordered after all other tensor pre hooks
671   std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
672   std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
673   std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>>
674       retains_grad_hooks_;
675   std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
676   at::SmallVector<InputMetadata, 2> input_metadata_;
677 };
678 
679 /// See Node::is_traceable() for definition.
680 struct TraceableFunction : public Node {
681   using Node::Node;
is_traceableTraceableFunction682   bool is_traceable() final {
683     return true;
684   }
685 };
686 
687 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
688 //                       Associated Free Nodes
689 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
690 
691 namespace detail {
692 // Implementation of `collect_next_edges` (see below).
693 struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
694   edge_list next_edges;
695   using IterArgs<MakeNextFunctionList>::operator();
operatorMakeNextFunctionList696   void operator()(const Variable& variable) {
697     if (variable.defined()) {
698       next_edges.emplace_back(impl::gradient_edge(variable));
699     } else {
700       next_edges.emplace_back();
701     }
702   }
operatorMakeNextFunctionList703   void operator()(const Variable* variable) {
704     operator()(*variable);
705   }
operatorMakeNextFunctionList706   void operator()(const std::optional<Variable>& variable) {
707     if (variable.has_value()) {
708       operator()(*variable);
709     } else {
710       next_edges.emplace_back();
711     }
712   }
713 };
714 } // namespace detail
715 
716 /// Create an `Edge` between the given `variable` and the `function`, which is
717 /// assumed to be the gradient function of this variable (i.e. the function
718 /// through which this variable is backpropagated during the backward pass).
719 /// This sets the `grad_fn` property of the `variable`. This function assumes
720 /// that the `Variable` is a new input to the gradient function and its
721 /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
722 /// increments the `Node`'s number of inputs by one. Approximately
723 /// equivalent to `variable.set_gradient_edge(function,
724 /// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`.
725 /// If you don't want the `Node`'s `num_inputs` to be incremented, use
726 /// `set_gradient_edge` directly.
create_gradient_edge(Variable & variable,std::shared_ptr<Node> function)727 inline void create_gradient_edge(
728     Variable& variable,
729     std::shared_ptr<Node> function) {
730   // Copy before move.
731   const auto input_nr = function->add_input_metadata(variable);
732   impl::set_gradient_edge(variable, {std::move(function), input_nr});
733 }
734 
735 /// Return true if any of the variables in the list require a gradient.
any_variable_requires_grad(const variable_list & variables)736 inline bool any_variable_requires_grad(const variable_list& variables) {
737   return std::any_of(
738       variables.begin(), variables.end(), [](const Variable& variable) {
739         return variable.defined() && variable.requires_grad();
740       });
741 }
742 
743 /// Return the next edges of all the given variables, or tuples of variables.
744 template <typename... Variables>
collect_next_edges(Variables &&...variables)745 edge_list collect_next_edges(Variables&&... variables) {
746   detail::MakeNextFunctionList make;
747   make.apply(std::forward<Variables>(variables)...);
748   return std::move(make.next_edges);
749 }
750 
751 struct TypeAndSize {
TypeAndSizeTypeAndSize752   TypeAndSize() : options(at::TensorOptions()) {}
753   /* implicit */
TypeAndSizeTypeAndSize754   TypeAndSize(const at::Tensor& t)
755       : sym_sizes(t.sym_sizes().vec()), options(t.options()) {}
756 
757   at::Tensor zeros();
758 
759   std::vector<c10::SymInt> sym_sizes;
760   at::TensorOptions options;
761 };
762 
763 } // namespace torch::autograd
764