xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/naming.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/naming.h>
2 #include <torch/csrc/onnx/onnx.h>
3 
4 #include <utility>
5 
6 namespace torch::jit::onnx {
7 
8 namespace ONNXScopeName {
9 
10 using NameFunc = std::string (*)(torch::jit::ScopePtr scope);
11 
12 const std::string name_separator = "::";
13 
14 namespace {
15 
nameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator,NameFunc name_func)16 std::string nameFromRoot(
17     const torch::jit::ScopePtr& scope,
18     const std::string& layer_separator,
19     NameFunc name_func) {
20   std::string out = (*name_func)(scope);
21   if (scope->isRoot()) {
22     return out;
23   }
24   auto parent = scope->parent();
25   while (isCompatibleScope(parent)) {
26     out = std::string((*name_func)(parent)).append(layer_separator).append(out);
27     parent = parent->parent();
28   }
29   return out;
30 }
31 
parseNameFromScope(const torch::jit::ScopePtr & scope)32 std::pair<std::string, std::string> parseNameFromScope(
33     const torch::jit::ScopePtr& scope) {
34   std::string full_name = scope->name().toUnqualString();
35   auto pos = full_name.find(name_separator);
36   TORCH_CHECK(
37       pos != std::string::npos,
38       "Scope name (" + full_name + ") does not contain '" + name_separator +
39           "'");
40   return std::make_pair(full_name.substr(0, pos), full_name.substr(pos + 2));
41 }
42 
43 } // namespace
44 
createFullScopeName(const std::string & class_name,const std::string & variable_name)45 std::string createFullScopeName(
46     const std::string& class_name,
47     const std::string& variable_name) {
48   return std::string(class_name).append(name_separator).append(variable_name);
49 }
50 
variableName(torch::jit::ScopePtr scope)51 std::string variableName(torch::jit::ScopePtr scope) {
52   return parseNameFromScope(scope).second;
53 }
54 
variableNameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator)55 std::string variableNameFromRoot(
56     const torch::jit::ScopePtr& scope,
57     const std::string& layer_separator) {
58   return nameFromRoot(scope, layer_separator, &variableName);
59 }
60 
className(torch::jit::ScopePtr scope)61 std::string className(torch::jit::ScopePtr scope) {
62   return parseNameFromScope(scope).first;
63 }
64 
classNameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator)65 std::string classNameFromRoot(
66     const torch::jit::ScopePtr& scope,
67     const std::string& layer_separator) {
68   return nameFromRoot(scope, layer_separator, &className);
69 }
70 
isCompatibleScope(const torch::jit::ScopePtr & scope)71 bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
72   return !scope->isRoot() && !scope->isBlank() &&
73       (std::string(scope->name().toUnqualString()).find(name_separator) !=
74        std::string::npos);
75 }
76 } // namespace ONNXScopeName
77 
78 namespace {
79 
80 class NodeNameGenerator {
81  public:
NodeNameGenerator(std::shared_ptr<Graph> g)82   NodeNameGenerator(std::shared_ptr<Graph> g) : graph_(std::move(g)){};
83   virtual ~NodeNameGenerator() = 0;
84   void PopulateNodeNames();
85 
86  protected:
87   virtual void CreateNodeName(Node* n) = 0;
88   void PopulateNodeNames(Block*);
89   void UpdateOutputsNames(Node* n);
90   bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
91 
92  protected:
93   std::string CreateUniqueName(
94       std::unordered_map<std::string, size_t>& base_name_count,
95       std::string base_name);
96 
97   std::unordered_map<const Node*, std::string> node_names_;
98   std::unordered_map<std::string, size_t> base_node_name_counts_;
99   std::shared_ptr<Graph> graph_;
100   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
101   const std::string layer_separator_ = "/";
102 };
103 NodeNameGenerator::~NodeNameGenerator() = default;
104 
105 class ScopedNodeNameGenerator : public NodeNameGenerator {
106  public:
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)107   ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
108       : NodeNameGenerator(std::move(g)){};
109 
110  protected:
111   void CreateNodeName(Node* n) override;
112 
113  private:
114   std::string GetFullScopeName(const ScopePtr& scope);
115   std::unordered_map<ScopePtr, std::string> full_scope_names_;
116   std::unordered_map<std::string, size_t> base_scope_name_counts_;
117 };
118 
CreateUniqueName(std::unordered_map<std::string,size_t> & base_name_count,std::string base_name)119 std::string NodeNameGenerator::CreateUniqueName(
120     std::unordered_map<std::string, size_t>& base_name_count,
121     std::string base_name) {
122   if (base_name_count.find(base_name) == base_name_count.end()) {
123     base_name_count[base_name] = 0;
124   } else {
125     auto count = ++base_name_count[base_name];
126     base_name += "_";
127     base_name += std::to_string(count);
128   }
129   return base_name;
130 }
131 
IsGraphOutput(const Value * v,const std::shared_ptr<Graph> & graph) const132 bool NodeNameGenerator::IsGraphOutput(
133     const Value* v,
134     const std::shared_ptr<Graph>& graph) const {
135   for (const auto* graph_output : graph->outputs()) {
136     if (v == graph_output) {
137       return true;
138     }
139   }
140   return false;
141 }
142 
UpdateOutputsNames(Node * n)143 void NodeNameGenerator::UpdateOutputsNames(Node* n) {
144   if (node_names_.find(n) != node_names_.end()) {
145     auto node_name = node_names_[n];
146     for (auto i : c10::irange(n->outputs().size())) {
147       auto output = n->output(i);
148       if (!IsGraphOutput(output, graph_)) {
149         auto output_name = node_name;
150         output_name.append("_output_").append(std::to_string(i));
151         output->setDebugName(output_name);
152       }
153     }
154   }
155 }
156 
PopulateNodeNames()157 void NodeNameGenerator::PopulateNodeNames() {
158   PopulateNodeNames(graph_->block());
159 }
160 
PopulateNodeNames(Block * b)161 void NodeNameGenerator::PopulateNodeNames(Block* b) {
162   for (auto* n : b->nodes()) {
163     for (auto* sub_block : n->blocks()) {
164       PopulateNodeNames(sub_block);
165     }
166     CreateNodeName(n);
167     UpdateOutputsNames(n);
168   }
169 }
170 
CreateNodeName(Node * n)171 void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
172   if (node_names_.find(n) == node_names_.end()) {
173     if (!ONNXScopeName::isCompatibleScope(n->scope())) {
174       return;
175     }
176     if (n->mustBeNone()) {
177       // JIT IR does not allow attribute for None node.
178       return;
179     }
180     auto name = GetFullScopeName(n->scope());
181     name += layer_separator_;
182     name += n->kind().toUnqualString();
183     node_names_[n] = CreateUniqueName(base_node_name_counts_, name);
184   }
185   n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
186 }
187 
GetFullScopeName(const ScopePtr & scope)188 std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
189   if (full_scope_names_.find(scope) == full_scope_names_.end()) {
190     auto full_scope_name =
191         ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
192     full_scope_names_[scope] =
193         CreateUniqueName(base_scope_name_counts_, full_scope_name);
194   }
195   return full_scope_names_[scope];
196 }
197 
198 } // namespace
199 
AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph> & graph)200 void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
201   auto node_name_generator = std::make_unique<ScopedNodeNameGenerator>(graph);
202   node_name_generator->PopulateNodeNames();
203 }
204 
205 } // namespace torch::jit::onnx
206