1 #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h> 2 3 #include <torch/csrc/jit/jit_log.h> 4 #include <torch/csrc/jit/passes/onnx/helper.h> 5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 6 7 namespace torch { 8 namespace jit { 9 convertSubgraphToSubBlock(Block * block)10void convertSubgraphToSubBlock(Block* block) { 11 for (auto it = block->nodes().begin(), end = block->nodes().end(); 12 it != end;) { 13 Node* node = *it++; 14 if (node->kind() == prim::PythonOp) { 15 // Construct subblock 16 auto subblock = node->addBlock(); 17 auto graph = subblock->owningGraph(); 18 19 std::unordered_map<Value*, Value*> env; 20 // Populate subblock with subgraph nodes 21 auto subgraph = node->g(attr::Subgraph); 22 for (const auto i : c10::irange(subgraph->inputs().size())) { 23 subblock->addInput()->copyMetadata(subgraph->inputs()[i]); 24 env[subgraph->inputs()[i]] = subblock->inputs()[i]; 25 } 26 for (auto* n : subgraph->nodes()) { 27 auto cloned_n = 28 subblock->appendNode(graph->createClone(n, [&](Value* v) { 29 return env.find(v) != env.end() ? env[v] : v; 30 })); 31 for (size_t i = 0; i < n->outputs().size(); ++i) { 32 env[n->outputs().at(i)] = cloned_n->outputs().at(i); 33 auto it = std::find( 34 subgraph->outputs().begin(), 35 subgraph->outputs().end(), 36 n->outputs()[i]); 37 if (it != subgraph->outputs().end()) { 38 subblock->registerOutput(cloned_n->outputs()[i]); 39 } 40 } 41 } 42 // Remove subgraph attribute from the pythonOp node and recurse through 43 // sub-blocks 44 node->removeAttribute(attr::Subgraph); 45 } 46 for (auto block : node->blocks()) { 47 convertSubgraphToSubBlock(block); 48 } 49 } 50 } 51 52 // This pass is to be used for ONNX conversion only. ONNXAutogradFunctionProcess(std::shared_ptr<Graph> & graph)53void ONNXAutogradFunctionProcess(std::shared_ptr<Graph>& graph) { 54 convertSubgraphToSubBlock(graph->block()); 55 } 56 57 } // namespace jit 58 } // namespace torch 59