xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/function_substitution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/function_substitution.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/naming.h>
6 
7 namespace torch::jit {
8 
9 namespace {
10 
11 const std::string kTopModuleVariableName = "";
12 
TidyClassNameFromTorchScript(const std::optional<c10::QualifiedName> & class_name)13 std::string TidyClassNameFromTorchScript(
14     const std::optional<c10::QualifiedName>& class_name) {
15   if (!class_name) {
16     return "UNKNOWN_CLASS";
17   }
18   std::string out = "";
19   for (const auto& atom : class_name->atoms()) {
20     bool is_internal_torch_atom = (atom == "__torch__");
21     bool is_mangle_atom = (atom.find("__torch_mangle") != std::string::npos);
22     if (!is_internal_torch_atom && !is_mangle_atom) {
23       if (!out.empty()) {
24         out += ".";
25       }
26       out += atom;
27     }
28   }
29   return out;
30 }
31 
GetCallNodeVariableName(const Node * call_node)32 std::string GetCallNodeVariableName(const Node* call_node) {
33   TORCH_INTERNAL_ASSERT(
34       call_node->kind() == prim::CallFunction ||
35       call_node->kind() == prim::CallMethod);
36   auto module_node = call_node->input(0)->node();
37 
38   if (!module_node->hasAttribute(attr::name)) {
39     return "";
40   }
41   std::string module_name = module_node->s(attr::name);
42   if (module_node->inputs().empty()) {
43     return module_name;
44   }
45   // If module is from container, attr::name in module node only carries
46   // index info. Need to check parent node (container) for variable name.
47   auto parent_module_value = module_node->input(0);
48   while (parent_module_value) {
49     auto parent_module_type = parent_module_value->type()->cast<ClassType>();
50     if (parent_module_type &&
51         parent_module_type->name() ==
52             "__torch__.torch.nn.modules.container.ModuleList") {
53       auto parent_module_node = parent_module_value->node();
54       module_name = parent_module_node->s(attr::name) + "." + module_name;
55       parent_module_value = !parent_module_node->inputs().empty()
56           ? parent_module_node->input(0)
57           : nullptr;
58     } else {
59       break;
60     }
61   }
62 
63   return module_name;
64 }
65 
ForwardCallScope(Graph & graph,Node * call_node)66 ScopePtr ForwardCallScope(Graph& graph, Node* call_node) {
67   TORCH_INTERNAL_ASSERT(call_node->kind() == prim::CallMethod);
68   const std::string& method_name = call_node->s(attr::name);
69   if (method_name == "forward") {
70     const auto type = call_node->input(0)->type()->expect<c10::NamedType>();
71     const std::string class_name = TidyClassNameFromTorchScript(type->name());
72     const std::string variable_name = GetCallNodeVariableName(call_node);
73     const std::string scope_name =
74         onnx::ONNXScopeName::createFullScopeName(class_name, variable_name);
75     return graph.current_scope()->push(Symbol::scope(scope_name));
76   }
77   return graph.current_scope();
78 }
79 
functionCallSubstitution(Block * block)80 void functionCallSubstitution(Block* block) {
81   auto graph = block->owningGraph();
82   for (auto it = block->nodes().begin(), end = block->nodes().end();
83        it != end;) {
84     Node* cur = *it++;
85     switch (cur->kind()) {
86       case prim::CallFunction: {
87         TORCH_INTERNAL_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
88         auto function_constant = cur->input(0)->node();
89         auto fun_type =
90             function_constant->output()->type()->expect<FunctionType>();
91 
92         if ((fun_type->function()->qualname().qualifiedName().find(
93                  "torch.nn.functional") != std::string::npos) &&
94             (fun_type->function()->qualname().qualifiedName().find(
95                  "interpolate") != std::string::npos)) {
96           // Remove input[0] and the node that feeds into it
97           auto input_node_0 = cur->input(0)->node();
98           cur->removeInput(0);
99           if (!input_node_0->hasUses()) {
100             input_node_0->destroy();
101           }
102           Node* interpolate_node = block->owningGraph()->create(
103               Symbol::fromQualString("aten::__interpolate"),
104               {cur->inputs()},
105               cur->outputs().size());
106           interpolate_node->output()->copyMetadata(cur->output());
107           interpolate_node->insertAfter(cur);
108           interpolate_node->copyMetadata(cur);
109           cur->replaceAllUsesWith(interpolate_node);
110           cur->removeAllInputs();
111           cur->destroy();
112           GRAPH_UPDATE(
113               "ONNX function call substitution function: '",
114               fun_type->function()->name(),
115               "' to aten::__interpolate");
116           GRAPH_UPDATE(
117               "Function in ONNX function call substitution body: ",
118               toGraphFunction(*fun_type->function()).optimized_graph());
119         } else {
120           // Remove input[0] and the node that feeds into it
121           auto input_node_0 = cur->input(0)->node();
122           cur->removeInput(0);
123           if (!input_node_0->hasUses()) {
124             input_node_0->destroy();
125           }
126           auto& graphFunction = toGraphFunction(*fun_type->function());
127           functionCallSubstitution(graphFunction.graph()->block());
128           inlineCallTo(cur, &graphFunction, false);
129         }
130       } break;
131       case prim::CallMethod: {
132         const std::string& name = cur->s(attr::name);
133         if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
134           Function& function = class_type->getMethod(name);
135           ScopePtr call_scope = ForwardCallScope(*graph, cur);
136           WithCurrentScope scope_guard(*graph, call_scope);
137           GRAPH_DEBUG(
138               "Setting scope guard for forward call: ",
139               graph->current_scope()->namesFromRoot());
140           if (auto graphFunction = tryToGraphFunction(function)) {
141             GRAPH_DEBUG(
142                 "Inner graph for method call ",
143                 name,
144                 ": ",
145                 *graphFunction->graph());
146             WithCurrentScope inner_graph_scope_guard(
147                 *graphFunction->graph(), call_scope);
148             functionCallSubstitution(graphFunction->graph()->block());
149             inlineCallTo(cur, graphFunction, false);
150           }
151         }
152       } break;
153       default: {
154         if (!graph->current_scope()->isBlank()) {
155           cur->setScope(graph->current_scope());
156         }
157         for (auto b : cur->blocks()) {
158           functionCallSubstitution(b);
159         }
160       } break;
161     }
162     GRAPH_DEBUG(
163         "Graph current scope after node process: ",
164         graph->current_scope()->namesFromRoot());
165   }
166 }
167 
ONNXGraphTopLevelScope(Graph & graph)168 ScopePtr ONNXGraphTopLevelScope(Graph& graph) {
169   if (graph.inputs().empty()) {
170     return graph.current_scope();
171   }
172   if (auto top_module_type = graph.inputs().at(0)->type()->cast<ClassType>()) {
173     auto scope_name = ::torch::jit::onnx::ONNXScopeName::createFullScopeName(
174         TidyClassNameFromTorchScript(top_module_type->name()),
175         kTopModuleVariableName);
176     return graph.current_scope()->push(Symbol::scope(scope_name));
177   }
178   return graph.current_scope();
179 }
180 
181 } // namespace
182 
183 // This pass is to be used for ONNX conversion only. The ONNX converter depends
184 // on a number of deprecated aten operators. These operators are removed from IR
185 // and replaced by the compiled python function code. However, in-order to
186 // maintain the behavior for ONNX conversion, we replace these function calls
187 // with the aten symbolic which can still be used by the ONNX converter.
ONNXFunctionCallSubstitution(Graph & graph)188 void ONNXFunctionCallSubstitution(Graph& graph) {
189   GRAPH_DUMP("Before function call substitution calls: ", &graph);
190   WithCurrentScope top_level_scope_guard(graph, ONNXGraphTopLevelScope(graph));
191   functionCallSubstitution(graph.block());
192   GRAPH_DUMP("After function call substitution calls: ", &graph);
193 }
194 
195 } // namespace torch::jit
196