xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/autodiff.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 #include <memory>
7 #include <vector>
8 
9 namespace torch::jit {
10 
11 using value_list = std::vector<Value*>;
12 // clang-format off
13 // Example showcasing how Gradient is constructed:
14 //
15 // Let's assume we have a function f, `m` and `n` do not require grad
16 // (`n` can depend only on `m`):
17 //   y, n = f(x, m)
18 //
19 // Now, let's assume that the reverse of f (called f') needs to use values of `x`, `t` and `y`.
20 // `t` is an intermediate value produced in the body of f, and let's assume that it requires
21 // grad too.
22 //
23 // In this case differentiate(f) will return this:
24 //   y, n, t = f(x, m)        // `t` is appended to the output list
25 //   dx = f'(dy, dt, x, t, y) // No `dm` or `dn` because they do not require gradient
26 //                            // All needed values from f are prepended to the input list
27 //
28 //   f_real_outputs = 2       // Only first two outputs were present in f originally
29 //   df_input_vjps = {0, 2}   // i.e. connect grad_fn of y and t variables produced by f,
30 //                    y  t    // with y's output_nr = 0 and t's output_nr = 1
31 //   df_input_captures = {I0, O2, O0} // Order matches the prefix of inputs to df
32 //                        x   t   y
33 //   df_output_vjps = {0}     // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
34 //
35 // Terminology: vjp = vector-jacobian product
36 // clang-format on
37 
38 struct Gradient {
39   explicit operator bool() const {
40     return df != nullptr;
41   }
42   std::shared_ptr<Graph> f;
43   std::shared_ptr<Graph> df;
44 
45   // Describes how to construct outputs of f from what its graph will return.
46   // This is necessary because some trailing outputs are intermediates produced
47   // only to be saved for df (and should be ignored).
48   size_t f_real_outputs = 0; // initialized for safety.
49 
50   // df inputs are split into two sections: vjps (aka grad_outputs) and
51   // captures. VJPs are "seeds" for the gradient computation given for each
52   // input capture of an Output kind. Captures are values the need to be saved
53   // when f is run. We handle inputs specially, because this allows us to avoid
54   // adding extra vjps as df inputs.
55 
56   std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
57   // capture can come from inputs or outputs
58   std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
59   std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs
60 
61   // df will produce vjps for a subset of inputs of f that required grad.
62   // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a
63   // vjp for inp_idx-th input of f.
64   std::vector<size_t> df_output_vjps; // Offsets into f's inputs.
65 
66   // How to use gradient to implement a differentiable autograd function:
67   // When running f:
68   //   - Unwrap input Variables
69   //   - Run f's graph
70   //   - Create grad_fn
71   //   - Wrap outputs in Variables (assume we have a tensor_outputs array):
72   //       outputs = map(Variable, tensor_output)
73   //       for i, offset in enumerate(df_input_vjps):
74   //         outputs[offset].set_grad_fn(grad_fn, output_nr=i)
75   //   - Use df_output_vjps to connect next_edges of grad_fn:
76   //       for idx in df_output_vjps:
77   //         grad_fn.add_next_edge(inputs[idx].gradient_edge())
78   //   - Save captures for df (care needs to be taken to use SavedVariables for
79   //                           inputs and outputs that we will actually return)
80   //   - Return outputs[:f_real_outputs]
81   //
82   // When running df:
83   //   - Concatenate received vjps and captured Variables
84   //   - Interpret df
85   //   - Wrap outputs of df into Variables (that don't require grad)
86 };
87 TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph);
88 
89 // can we take a derivative of this node symbolically?
90 TORCH_API bool isDifferentiable(const Node* n);
91 TORCH_API bool isDifferentiable(Graph& g);
92 TORCH_API bool isZero(Value* v);
93 
94 } // namespace torch::jit
95