1 #include <torch/csrc/jit/passes/onnx/function_substitution.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/naming.h>
6
7 namespace torch::jit {
8
9 namespace {
10
11 const std::string kTopModuleVariableName = "";
12
TidyClassNameFromTorchScript(const std::optional<c10::QualifiedName> & class_name)13 std::string TidyClassNameFromTorchScript(
14 const std::optional<c10::QualifiedName>& class_name) {
15 if (!class_name) {
16 return "UNKNOWN_CLASS";
17 }
18 std::string out = "";
19 for (const auto& atom : class_name->atoms()) {
20 bool is_internal_torch_atom = (atom == "__torch__");
21 bool is_mangle_atom = (atom.find("__torch_mangle") != std::string::npos);
22 if (!is_internal_torch_atom && !is_mangle_atom) {
23 if (!out.empty()) {
24 out += ".";
25 }
26 out += atom;
27 }
28 }
29 return out;
30 }
31
GetCallNodeVariableName(const Node * call_node)32 std::string GetCallNodeVariableName(const Node* call_node) {
33 TORCH_INTERNAL_ASSERT(
34 call_node->kind() == prim::CallFunction ||
35 call_node->kind() == prim::CallMethod);
36 auto module_node = call_node->input(0)->node();
37
38 if (!module_node->hasAttribute(attr::name)) {
39 return "";
40 }
41 std::string module_name = module_node->s(attr::name);
42 if (module_node->inputs().empty()) {
43 return module_name;
44 }
45 // If module is from container, attr::name in module node only carries
46 // index info. Need to check parent node (container) for variable name.
47 auto parent_module_value = module_node->input(0);
48 while (parent_module_value) {
49 auto parent_module_type = parent_module_value->type()->cast<ClassType>();
50 if (parent_module_type &&
51 parent_module_type->name() ==
52 "__torch__.torch.nn.modules.container.ModuleList") {
53 auto parent_module_node = parent_module_value->node();
54 module_name = parent_module_node->s(attr::name) + "." + module_name;
55 parent_module_value = !parent_module_node->inputs().empty()
56 ? parent_module_node->input(0)
57 : nullptr;
58 } else {
59 break;
60 }
61 }
62
63 return module_name;
64 }
65
ForwardCallScope(Graph & graph,Node * call_node)66 ScopePtr ForwardCallScope(Graph& graph, Node* call_node) {
67 TORCH_INTERNAL_ASSERT(call_node->kind() == prim::CallMethod);
68 const std::string& method_name = call_node->s(attr::name);
69 if (method_name == "forward") {
70 const auto type = call_node->input(0)->type()->expect<c10::NamedType>();
71 const std::string class_name = TidyClassNameFromTorchScript(type->name());
72 const std::string variable_name = GetCallNodeVariableName(call_node);
73 const std::string scope_name =
74 onnx::ONNXScopeName::createFullScopeName(class_name, variable_name);
75 return graph.current_scope()->push(Symbol::scope(scope_name));
76 }
77 return graph.current_scope();
78 }
79
functionCallSubstitution(Block * block)80 void functionCallSubstitution(Block* block) {
81 auto graph = block->owningGraph();
82 for (auto it = block->nodes().begin(), end = block->nodes().end();
83 it != end;) {
84 Node* cur = *it++;
85 switch (cur->kind()) {
86 case prim::CallFunction: {
87 TORCH_INTERNAL_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
88 auto function_constant = cur->input(0)->node();
89 auto fun_type =
90 function_constant->output()->type()->expect<FunctionType>();
91
92 if ((fun_type->function()->qualname().qualifiedName().find(
93 "torch.nn.functional") != std::string::npos) &&
94 (fun_type->function()->qualname().qualifiedName().find(
95 "interpolate") != std::string::npos)) {
96 // Remove input[0] and the node that feeds into it
97 auto input_node_0 = cur->input(0)->node();
98 cur->removeInput(0);
99 if (!input_node_0->hasUses()) {
100 input_node_0->destroy();
101 }
102 Node* interpolate_node = block->owningGraph()->create(
103 Symbol::fromQualString("aten::__interpolate"),
104 {cur->inputs()},
105 cur->outputs().size());
106 interpolate_node->output()->copyMetadata(cur->output());
107 interpolate_node->insertAfter(cur);
108 interpolate_node->copyMetadata(cur);
109 cur->replaceAllUsesWith(interpolate_node);
110 cur->removeAllInputs();
111 cur->destroy();
112 GRAPH_UPDATE(
113 "ONNX function call substitution function: '",
114 fun_type->function()->name(),
115 "' to aten::__interpolate");
116 GRAPH_UPDATE(
117 "Function in ONNX function call substitution body: ",
118 toGraphFunction(*fun_type->function()).optimized_graph());
119 } else {
120 // Remove input[0] and the node that feeds into it
121 auto input_node_0 = cur->input(0)->node();
122 cur->removeInput(0);
123 if (!input_node_0->hasUses()) {
124 input_node_0->destroy();
125 }
126 auto& graphFunction = toGraphFunction(*fun_type->function());
127 functionCallSubstitution(graphFunction.graph()->block());
128 inlineCallTo(cur, &graphFunction, false);
129 }
130 } break;
131 case prim::CallMethod: {
132 const std::string& name = cur->s(attr::name);
133 if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
134 Function& function = class_type->getMethod(name);
135 ScopePtr call_scope = ForwardCallScope(*graph, cur);
136 WithCurrentScope scope_guard(*graph, call_scope);
137 GRAPH_DEBUG(
138 "Setting scope guard for forward call: ",
139 graph->current_scope()->namesFromRoot());
140 if (auto graphFunction = tryToGraphFunction(function)) {
141 GRAPH_DEBUG(
142 "Inner graph for method call ",
143 name,
144 ": ",
145 *graphFunction->graph());
146 WithCurrentScope inner_graph_scope_guard(
147 *graphFunction->graph(), call_scope);
148 functionCallSubstitution(graphFunction->graph()->block());
149 inlineCallTo(cur, graphFunction, false);
150 }
151 }
152 } break;
153 default: {
154 if (!graph->current_scope()->isBlank()) {
155 cur->setScope(graph->current_scope());
156 }
157 for (auto b : cur->blocks()) {
158 functionCallSubstitution(b);
159 }
160 } break;
161 }
162 GRAPH_DEBUG(
163 "Graph current scope after node process: ",
164 graph->current_scope()->namesFromRoot());
165 }
166 }
167
ONNXGraphTopLevelScope(Graph & graph)168 ScopePtr ONNXGraphTopLevelScope(Graph& graph) {
169 if (graph.inputs().empty()) {
170 return graph.current_scope();
171 }
172 if (auto top_module_type = graph.inputs().at(0)->type()->cast<ClassType>()) {
173 auto scope_name = ::torch::jit::onnx::ONNXScopeName::createFullScopeName(
174 TidyClassNameFromTorchScript(top_module_type->name()),
175 kTopModuleVariableName);
176 return graph.current_scope()->push(Symbol::scope(scope_name));
177 }
178 return graph.current_scope();
179 }
180
181 } // namespace
182
183 // This pass is to be used for ONNX conversion only. The ONNX converter depends
184 // on a number of deprecated aten operators. These operators are removed from IR
185 // and replaced by the compiled python function code. However, in-order to
186 // maintain the behavior for ONNX conversion, we replace these function calls
187 // with the aten symbolic which can still be used by the ONNX converter.
ONNXFunctionCallSubstitution(Graph & graph)188 void ONNXFunctionCallSubstitution(Graph& graph) {
189 GRAPH_DUMP("Before function call substitution calls: ", &graph);
190 WithCurrentScope top_level_scope_guard(graph, ONNXGraphTopLevelScope(graph));
191 functionCallSubstitution(graph.block());
192 GRAPH_DUMP("After function call substitution calls: ", &graph);
193 }
194
195 } // namespace torch::jit
196