1 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
2
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7
8 namespace torch::jit {
9
10 // aten and prim nodes (except FusionGroup) are guaranteed to work
11 // with Autograd, other nodes (e.g. user-defined nodes) are not necessarily
12 // Autograd-aware
canRunWithAutograd(Node * node)13 bool canRunWithAutograd(Node* node) {
14 auto kind = node->kind();
15 for (Block* block : node->blocks()) {
16 if (!std::all_of(
17 block->nodes().begin(), block->nodes().end(), canRunWithAutograd)) {
18 return false;
19 }
20 }
21 return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
22 kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
23 kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
24 kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
25 }
26
27 namespace {
28
29 void InlineAutodiffSubgraphs(Block* block, size_t threshold);
30
blockSize(Block * block)31 size_t blockSize(Block* block) {
32 size_t num = 0;
33 for (Node* n : block->nodes()) {
34 for (Block* b : n->blocks()) {
35 num += blockSize(b);
36 }
37 num++;
38 }
39 return num;
40 }
41
scanNode(Node * node,size_t threshold)42 graph_node_list::iterator scanNode(Node* node, size_t threshold) {
43 auto next_node = ++node->iterator();
44
45 for (Block* block : node->blocks()) {
46 InlineAutodiffSubgraphs(block, threshold);
47 }
48
49 if (node->kind() != prim::DifferentiableGraph) {
50 return next_node;
51 }
52
53 auto subgraph = node->g(attr::Subgraph);
54 size_t subgraph_size = blockSize(subgraph->block());
55 if (subgraph_size >= threshold) {
56 return next_node;
57 }
58
59 if (!std::all_of(
60 subgraph->nodes().begin(),
61 subgraph->nodes().end(),
62 canRunWithAutograd)) {
63 return next_node;
64 }
65
66 // now that we inline the graph, we are no longer detaching input tensors,
67 // so the profiles will have outdated requires_grad=False.
68 // conservatively update them to maybe requiring grad, bc we might create
69 // autodiff graphs when the tensors maybe require grad
70 UpdateDifferentiableGraphRequiresGrad(subgraph, std::nullopt);
71 SubgraphUtils::unmergeSubgraph(node);
72 return next_node;
73 }
74
InlineAutodiffSubgraphs(Block * block,size_t threshold)75 void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
76 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
77 it = scanNode(*it, threshold);
78 }
79 }
80
81 } // anonymous namespace
82
InlineAutodiffSubgraphs(std::shared_ptr<Graph> & graph,size_t threshold)83 void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
84 InlineAutodiffSubgraphs(graph->block(), threshold);
85 EliminateDeadCode(graph);
86 }
87
88 } // namespace torch::jit
89