1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 // This api will be used by serialization/export.cpp to extract function 6 // information. It should do conversion on graph to 7 // 1. Extract subgraph pattern of functions and define as local function 8 // node. 9 // 2. Replace subgraph pattern of functions with a single node reflecting 10 // that local function node type. 11 // Function attribute map information is also returned, as Torch IR cannot 12 // represent these info inside Graph object. 13 // export.cpp will serialize the ONNX model with function_proto with 14 // above information. 15 namespace torch::jit::onnx { 16 17 // The following return types are used to track information regarding function 18 // attributes, that are unable to be traced through Torch IR. 19 // NodeAttrNameMap tracks mapping from attribute name of IR Node inside function 20 // subgraph, to function attribute name. Here's an example of exporting CELU and 21 // LayerNorm. 22 // 23 // clang-format off 24 // class M(torch.nn.Module): 25 // def __init__(self) -> None: 26 // super().__init__() 27 // self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)]) 28 // self.celu1 = torch.nn.CELU(1.0) 29 // self.celu2 = torch.nn.CELU(2.0) 30 31 // def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: 32 // res1 = self.celu1(x) 33 // res2 = self.celu2(y) 34 // for ln in self.lns: 35 // z = ln(z) 36 // return res1 + res2 + z 37 // clang-format on 38 // 39 // Returning 40 // 41 // NodeAttrNameMap: 42 // { 43 // %1 : Float(2, 3) = onnx::Celu[alpha=2.](%y) : { 44 // 'alpha' : 'Celu_alpha' 45 // } 46 // } 47 // 48 // The info here helps graph._export_onnx to construct function attributes for 49 // onnx local FunctionProto. 50 using NodeAttrNameMap = std:: 51 unordered_map<const Node*, std::unordered_map<std::string, std::string>>; 52 53 TORCH_API NodeAttrNameMap ONNXFunctionExtraction( 54 std::shared_ptr<Graph>& graph, 55 const std::unordered_set<std::string>& module_names, 56 const std::vector<std::string>& param_names); 57 58 TORCH_API void ONNXClearScopeRecords(); 59 60 TORCH_API void ONNXTrackScopeAttributes( 61 std::shared_ptr<Graph>& graph, 62 std::map<std::string, IValue>& attributes); 63 64 } // namespace torch::jit::onnx 65