1 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h> 2 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 5 6 namespace torch::jit { 7 UpdateDifferentiableGraphRequiresGrad(Block * block,std::optional<bool> new_requires_grad)8static void UpdateDifferentiableGraphRequiresGrad( 9 Block* block, 10 std::optional<bool> new_requires_grad) { 11 for (Node* n : block->nodes()) { 12 for (Value* v : n->inputs()) { 13 auto ty = v->type()->cast<TensorType>(); 14 if (ty) { 15 v->setType(ty->withRequiresGrad(new_requires_grad)); 16 } 17 } 18 if (n->kind() == prim::profile) { 19 n->ty_( 20 attr::profiled_type, 21 n->ty(attr::profiled_type) 22 ->expectRef<TensorType>() 23 .withRequiresGrad(new_requires_grad)); 24 } 25 for (Block* b : n->blocks()) { 26 UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad); 27 } 28 } 29 } 30 UpdateDifferentiableGraphRequiresGrad(std::shared_ptr<Graph> & diff_forward_graph,std::optional<bool> new_requires_grad)31void UpdateDifferentiableGraphRequiresGrad( 32 std::shared_ptr<Graph>& diff_forward_graph, 33 std::optional<bool> new_requires_grad) { 34 UpdateDifferentiableGraphRequiresGrad( 35 diff_forward_graph->block(), new_requires_grad); 36 } 37 38 } // namespace torch::jit 39