xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 // Introduction
9 //
10 // The encapsulation part will find the nodes of patterns, like how other
11 // pre-onnx passes are written. But instead of converting the nodes, it will
12 // encapsulate them into a sub-block of a new placeholder node. This part is
13 // called before onnx pass, so it runs before calling symbolic functions.
14 //
15 // Note: Why separate the function into two parts
16 //
17 // The purpose is to support conversions that depend on shape and type
18 // information. Shape and type information is only available after
19 // _jit_pass_onnx, which converts aten nodes to onnx nodes. So there is a
20 // interdependent issue. _jit_pass_onnx depends on preprocess passes to convert
21 // aten nodes into convertable condition, and preprocess passes depend on
22 // _jit_pass_onnx to convert upstream nodes and apply onnx shape inference.
23 // Separating the pass into two parts breaks the interdependency.
24 //
25 // Note: Edit Pattern Encapsulation
26 //
27 // Encapsulation step identifies the pattern, and copies the nodes into
28 // the subblock of a new placeholder node. The outputs of the new placeholder
29 // node are used in place of the original nodes instead. The category of the
30 // pattern is stored as attr::name.
31 TORCH_API std::optional<Node*> EncapsulatePatternIntoSubblock(Node* n);
32 
33 } // namespace jit
34 } // namespace torch
35