xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
5 
6 namespace torch::jit {
7 
UpdateDifferentiableGraphRequiresGrad(Block * block,std::optional<bool> new_requires_grad)8 static void UpdateDifferentiableGraphRequiresGrad(
9     Block* block,
10     std::optional<bool> new_requires_grad) {
11   for (Node* n : block->nodes()) {
12     for (Value* v : n->inputs()) {
13       auto ty = v->type()->cast<TensorType>();
14       if (ty) {
15         v->setType(ty->withRequiresGrad(new_requires_grad));
16       }
17     }
18     if (n->kind() == prim::profile) {
19       n->ty_(
20           attr::profiled_type,
21           n->ty(attr::profiled_type)
22               ->expectRef<TensorType>()
23               .withRequiresGrad(new_requires_grad));
24     }
25     for (Block* b : n->blocks()) {
26       UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);
27     }
28   }
29 }
30 
UpdateDifferentiableGraphRequiresGrad(std::shared_ptr<Graph> & diff_forward_graph,std::optional<bool> new_requires_grad)31 void UpdateDifferentiableGraphRequiresGrad(
32     std::shared_ptr<Graph>& diff_forward_graph,
33     std::optional<bool> new_requires_grad) {
34   UpdateDifferentiableGraphRequiresGrad(
35       diff_forward_graph->block(), new_requires_grad);
36 }
37 
38 } // namespace torch::jit
39