1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch { 6 namespace jit { 7 8 // Introduction 9 // 10 // The encapsulation part will find the nodes of patterns, like how other 11 // pre-onnx passes are written. But instead of converting the nodes, it will 12 // encapsulate them into a sub-block of a new placeholder node. This part is 13 // called before onnx pass, so it runs before calling symbolic functions. 14 // 15 // Note: Why separate the function into two parts 16 // 17 // The purpose is to support conversions that depend on shape and type 18 // information. Shape and type information is only available after 19 // _jit_pass_onnx, which converts aten nodes to onnx nodes. So there is a 20 // interdependent issue. _jit_pass_onnx depends on preprocess passes to convert 21 // aten nodes into convertable condition, and preprocess passes depend on 22 // _jit_pass_onnx to convert upstream nodes and apply onnx shape inference. 23 // Separating the pass into two parts breaks the interdependency. 24 // 25 // Note: Edit Pattern Encapsulation 26 // 27 // Encapsulation step identifies the pattern, and copies the nodes into 28 // the subblock of a new placeholder node. The outputs of the new placeholder 29 // node are used in place of the original nodes instead. The category of the 30 // pattern is stored as attr::name. 31 TORCH_API std::optional<Node*> EncapsulatePatternIntoSubblock(Node* n); 32 33 } // namespace jit 34 } // namespace torch 35