xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
6 
7 namespace torch {
8 namespace jit {
9 
convertSubgraphToSubBlock(Block * block)10 void convertSubgraphToSubBlock(Block* block) {
11   for (auto it = block->nodes().begin(), end = block->nodes().end();
12        it != end;) {
13     Node* node = *it++;
14     if (node->kind() == prim::PythonOp) {
15       // Construct subblock
16       auto subblock = node->addBlock();
17       auto graph = subblock->owningGraph();
18 
19       std::unordered_map<Value*, Value*> env;
20       // Populate subblock with subgraph nodes
21       auto subgraph = node->g(attr::Subgraph);
22       for (const auto i : c10::irange(subgraph->inputs().size())) {
23         subblock->addInput()->copyMetadata(subgraph->inputs()[i]);
24         env[subgraph->inputs()[i]] = subblock->inputs()[i];
25       }
26       for (auto* n : subgraph->nodes()) {
27         auto cloned_n =
28             subblock->appendNode(graph->createClone(n, [&](Value* v) {
29               return env.find(v) != env.end() ? env[v] : v;
30             }));
31         for (size_t i = 0; i < n->outputs().size(); ++i) {
32           env[n->outputs().at(i)] = cloned_n->outputs().at(i);
33           auto it = std::find(
34               subgraph->outputs().begin(),
35               subgraph->outputs().end(),
36               n->outputs()[i]);
37           if (it != subgraph->outputs().end()) {
38             subblock->registerOutput(cloned_n->outputs()[i]);
39           }
40         }
41       }
42       // Remove subgraph attribute from the pythonOp node and recurse through
43       // sub-blocks
44       node->removeAttribute(attr::Subgraph);
45     }
46     for (auto block : node->blocks()) {
47       convertSubgraphToSubBlock(block);
48     }
49   }
50 }
51 
52 // This pass is to be used for ONNX conversion only.
ONNXAutogradFunctionProcess(std::shared_ptr<Graph> & graph)53 void ONNXAutogradFunctionProcess(std::shared_ptr<Graph>& graph) {
54   convertSubgraphToSubBlock(graph->block());
55 }
56 
57 } // namespace jit
58 } // namespace torch
59