1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 #include <memory> 7 #include <vector> 8 9 namespace torch::jit { 10 11 using value_list = std::vector<Value*>; 12 // clang-format off 13 // Example showcasing how Gradient is constructed: 14 // 15 // Let's assume we have a function f, `m` and `n` do not require grad 16 // (`n` can depend only on `m`): 17 // y, n = f(x, m) 18 // 19 // Now, let's assume that the reverse of f (called f') needs to use values of `x`, `t` and `y`. 20 // `t` is an intermediate value produced in the body of f, and let's assume that it requires 21 // grad too. 22 // 23 // In this case differentiate(f) will return this: 24 // y, n, t = f(x, m) // `t` is appended to the output list 25 // dx = f'(dy, dt, x, t, y) // No `dm` or `dn` because they do not require gradient 26 // // All needed values from f are prepended to the input list 27 // 28 // f_real_outputs = 2 // Only first two outputs were present in f originally 29 // df_input_vjps = {0, 2} // i.e. connect grad_fn of y and t variables produced by f, 30 // y t // with y's output_nr = 0 and t's output_nr = 1 31 // df_input_captures = {I0, O2, O0} // Order matches the prefix of inputs to df 32 // x t y 33 // df_output_vjps = {0} // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr). 34 // 35 // Terminology: vjp = vector-jacobian product 36 // clang-format on 37 38 struct Gradient { 39 explicit operator bool() const { 40 return df != nullptr; 41 } 42 std::shared_ptr<Graph> f; 43 std::shared_ptr<Graph> df; 44 45 // Describes how to construct outputs of f from what its graph will return. 46 // This is necessary because some trailing outputs are intermediates produced 47 // only to be saved for df (and should be ignored). 48 size_t f_real_outputs = 0; // initialized for safety. 49 50 // df inputs are split into two sections: vjps (aka grad_outputs) and 51 // captures. VJPs are "seeds" for the gradient computation given for each 52 // input capture of an Output kind. Captures are values the need to be saved 53 // when f is run. We handle inputs specially, because this allows us to avoid 54 // adding extra vjps as df inputs. 55 56 std::vector<size_t> df_input_vjps; // Offsets into f's outputs. 57 // capture can come from inputs or outputs 58 std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs 59 std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs 60 61 // df will produce vjps for a subset of inputs of f that required grad. 62 // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a 63 // vjp for inp_idx-th input of f. 64 std::vector<size_t> df_output_vjps; // Offsets into f's inputs. 65 66 // How to use gradient to implement a differentiable autograd function: 67 // When running f: 68 // - Unwrap input Variables 69 // - Run f's graph 70 // - Create grad_fn 71 // - Wrap outputs in Variables (assume we have a tensor_outputs array): 72 // outputs = map(Variable, tensor_output) 73 // for i, offset in enumerate(df_input_vjps): 74 // outputs[offset].set_grad_fn(grad_fn, output_nr=i) 75 // - Use df_output_vjps to connect next_edges of grad_fn: 76 // for idx in df_output_vjps: 77 // grad_fn.add_next_edge(inputs[idx].gradient_edge()) 78 // - Save captures for df (care needs to be taken to use SavedVariables for 79 // inputs and outputs that we will actually return) 80 // - Return outputs[:f_real_outputs] 81 // 82 // When running df: 83 // - Concatenate received vjps and captured Variables 84 // - Interpret df 85 // - Wrap outputs of df into Variables (that don't require grad) 86 }; 87 TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph); 88 89 // can we take a derivative of this node symbolically? 90 TORCH_API bool isDifferentiable(const Node* n); 91 TORCH_API bool isDifferentiable(Graph& g); 92 TORCH_API bool isZero(Value* v); 93 94 } // namespace torch::jit 95