1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 // Because differentiable graphs detach the gradients of input Tensors, 8 // creating and inlining differentiable graphs changes the requires_grad 9 // property of tensors in the graph. This pass updates prim::profiles 10 // requires_grad to keep profiled properties up to date, it does not update 11 // grad properties of other nodes like graph inputs bc the only downstream 12 // user of the grad property is the profiling executor, which just uses 13 // the types of prim::profiles 14 TORCH_API void UpdateDifferentiableGraphRequiresGrad( 15 std::shared_ptr<Graph>& diff_forward_graph, 16 std::optional<bool> new_requires_grad); 17 18 } // namespace torch::jit 19