xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/function_extraction.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 // This api will be used by serialization/export.cpp to extract function
6 // information. It should do conversion on graph to
7 //    1. Extract subgraph pattern of functions and define as local function
8 //    node.
9 //    2. Replace subgraph pattern of functions with a single node reflecting
10 //    that local function node type.
11 // Function attribute map information is also returned, as Torch IR cannot
12 // represent these info inside Graph object.
13 // export.cpp will serialize the ONNX model with function_proto with
14 // above information.
15 namespace torch::jit::onnx {
16 
17 // The following return types are used to track information regarding function
18 // attributes, that are unable to be traced through Torch IR.
19 // NodeAttrNameMap tracks mapping from attribute name of IR Node inside function
20 // subgraph, to function attribute name. Here's an example of exporting CELU and
21 // LayerNorm.
22 //
23 // clang-format off
24 // class M(torch.nn.Module):
25 //     def __init__(self) -> None:
26 //         super().__init__()
27 //         self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
28 //         self.celu1 = torch.nn.CELU(1.0)
29 //         self.celu2 = torch.nn.CELU(2.0)
30 
31 //     def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
32 //         res1 = self.celu1(x)
33 //         res2 = self.celu2(y)
34 //         for ln in self.lns:
35 //             z = ln(z)
36 //         return res1 + res2 + z
37 // clang-format on
38 //
39 // Returning
40 //
41 // NodeAttrNameMap:
42 // {
43 //    %1 : Float(2, 3) = onnx::Celu[alpha=2.](%y) : {
44 //      'alpha' : 'Celu_alpha'
45 //    }
46 // }
47 //
48 // The info here helps graph._export_onnx to construct function attributes for
49 // onnx local FunctionProto.
50 using NodeAttrNameMap = std::
51     unordered_map<const Node*, std::unordered_map<std::string, std::string>>;
52 
53 TORCH_API NodeAttrNameMap ONNXFunctionExtraction(
54     std::shared_ptr<Graph>& graph,
55     const std::unordered_set<std::string>& module_names,
56     const std::vector<std::string>& param_names);
57 
58 TORCH_API void ONNXClearScopeRecords();
59 
60 TORCH_API void ONNXTrackScopeAttributes(
61     std::shared_ptr<Graph>& graph,
62     std::map<std::string, IValue>& attributes);
63 
64 } // namespace torch::jit::onnx
65