xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7 
8 namespace torch::jit {
9 
10 // aten and prim nodes (except FusionGroup) are guaranteed to work
11 // with Autograd, other nodes (e.g. user-defined nodes) are not necessarily
12 // Autograd-aware
canRunWithAutograd(Node * node)13 bool canRunWithAutograd(Node* node) {
14   auto kind = node->kind();
15   for (Block* block : node->blocks()) {
16     if (!std::all_of(
17             block->nodes().begin(), block->nodes().end(), canRunWithAutograd)) {
18       return false;
19     }
20   }
21   return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
22       kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
23       kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
24       kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
25 }
26 
27 namespace {
28 
29 void InlineAutodiffSubgraphs(Block* block, size_t threshold);
30 
blockSize(Block * block)31 size_t blockSize(Block* block) {
32   size_t num = 0;
33   for (Node* n : block->nodes()) {
34     for (Block* b : n->blocks()) {
35       num += blockSize(b);
36     }
37     num++;
38   }
39   return num;
40 }
41 
scanNode(Node * node,size_t threshold)42 graph_node_list::iterator scanNode(Node* node, size_t threshold) {
43   auto next_node = ++node->iterator();
44 
45   for (Block* block : node->blocks()) {
46     InlineAutodiffSubgraphs(block, threshold);
47   }
48 
49   if (node->kind() != prim::DifferentiableGraph) {
50     return next_node;
51   }
52 
53   auto subgraph = node->g(attr::Subgraph);
54   size_t subgraph_size = blockSize(subgraph->block());
55   if (subgraph_size >= threshold) {
56     return next_node;
57   }
58 
59   if (!std::all_of(
60           subgraph->nodes().begin(),
61           subgraph->nodes().end(),
62           canRunWithAutograd)) {
63     return next_node;
64   }
65 
66   // now that we inline the graph, we are no longer detaching input tensors,
67   // so the profiles will have outdated requires_grad=False.
68   // conservatively update them to maybe requiring grad, bc we might create
69   // autodiff graphs when the tensors maybe require grad
70   UpdateDifferentiableGraphRequiresGrad(subgraph, std::nullopt);
71   SubgraphUtils::unmergeSubgraph(node);
72   return next_node;
73 }
74 
InlineAutodiffSubgraphs(Block * block,size_t threshold)75 void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
76   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
77     it = scanNode(*it, threshold);
78   }
79 }
80 
81 } // anonymous namespace
82 
InlineAutodiffSubgraphs(std::shared_ptr<Graph> & graph,size_t threshold)83 void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
84   InlineAutodiffSubgraphs(graph->block(), threshold);
85   EliminateDeadCode(graph);
86 }
87 
88 } // namespace torch::jit
89