1 #include <torch/csrc/jit/passes/onnx/naming.h>
2 #include <torch/csrc/onnx/onnx.h>
3
4 #include <utility>
5
6 namespace torch::jit::onnx {
7
8 namespace ONNXScopeName {
9
10 using NameFunc = std::string (*)(torch::jit::ScopePtr scope);
11
12 const std::string name_separator = "::";
13
14 namespace {
15
nameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator,NameFunc name_func)16 std::string nameFromRoot(
17 const torch::jit::ScopePtr& scope,
18 const std::string& layer_separator,
19 NameFunc name_func) {
20 std::string out = (*name_func)(scope);
21 if (scope->isRoot()) {
22 return out;
23 }
24 auto parent = scope->parent();
25 while (isCompatibleScope(parent)) {
26 out = std::string((*name_func)(parent)).append(layer_separator).append(out);
27 parent = parent->parent();
28 }
29 return out;
30 }
31
parseNameFromScope(const torch::jit::ScopePtr & scope)32 std::pair<std::string, std::string> parseNameFromScope(
33 const torch::jit::ScopePtr& scope) {
34 std::string full_name = scope->name().toUnqualString();
35 auto pos = full_name.find(name_separator);
36 TORCH_CHECK(
37 pos != std::string::npos,
38 "Scope name (" + full_name + ") does not contain '" + name_separator +
39 "'");
40 return std::make_pair(full_name.substr(0, pos), full_name.substr(pos + 2));
41 }
42
43 } // namespace
44
createFullScopeName(const std::string & class_name,const std::string & variable_name)45 std::string createFullScopeName(
46 const std::string& class_name,
47 const std::string& variable_name) {
48 return std::string(class_name).append(name_separator).append(variable_name);
49 }
50
variableName(torch::jit::ScopePtr scope)51 std::string variableName(torch::jit::ScopePtr scope) {
52 return parseNameFromScope(scope).second;
53 }
54
variableNameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator)55 std::string variableNameFromRoot(
56 const torch::jit::ScopePtr& scope,
57 const std::string& layer_separator) {
58 return nameFromRoot(scope, layer_separator, &variableName);
59 }
60
className(torch::jit::ScopePtr scope)61 std::string className(torch::jit::ScopePtr scope) {
62 return parseNameFromScope(scope).first;
63 }
64
classNameFromRoot(const torch::jit::ScopePtr & scope,const std::string & layer_separator)65 std::string classNameFromRoot(
66 const torch::jit::ScopePtr& scope,
67 const std::string& layer_separator) {
68 return nameFromRoot(scope, layer_separator, &className);
69 }
70
isCompatibleScope(const torch::jit::ScopePtr & scope)71 bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
72 return !scope->isRoot() && !scope->isBlank() &&
73 (std::string(scope->name().toUnqualString()).find(name_separator) !=
74 std::string::npos);
75 }
76 } // namespace ONNXScopeName
77
78 namespace {
79
80 class NodeNameGenerator {
81 public:
NodeNameGenerator(std::shared_ptr<Graph> g)82 NodeNameGenerator(std::shared_ptr<Graph> g) : graph_(std::move(g)){};
83 virtual ~NodeNameGenerator() = 0;
84 void PopulateNodeNames();
85
86 protected:
87 virtual void CreateNodeName(Node* n) = 0;
88 void PopulateNodeNames(Block*);
89 void UpdateOutputsNames(Node* n);
90 bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
91
92 protected:
93 std::string CreateUniqueName(
94 std::unordered_map<std::string, size_t>& base_name_count,
95 std::string base_name);
96
97 std::unordered_map<const Node*, std::string> node_names_;
98 std::unordered_map<std::string, size_t> base_node_name_counts_;
99 std::shared_ptr<Graph> graph_;
100 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
101 const std::string layer_separator_ = "/";
102 };
103 NodeNameGenerator::~NodeNameGenerator() = default;
104
105 class ScopedNodeNameGenerator : public NodeNameGenerator {
106 public:
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)107 ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
108 : NodeNameGenerator(std::move(g)){};
109
110 protected:
111 void CreateNodeName(Node* n) override;
112
113 private:
114 std::string GetFullScopeName(const ScopePtr& scope);
115 std::unordered_map<ScopePtr, std::string> full_scope_names_;
116 std::unordered_map<std::string, size_t> base_scope_name_counts_;
117 };
118
CreateUniqueName(std::unordered_map<std::string,size_t> & base_name_count,std::string base_name)119 std::string NodeNameGenerator::CreateUniqueName(
120 std::unordered_map<std::string, size_t>& base_name_count,
121 std::string base_name) {
122 if (base_name_count.find(base_name) == base_name_count.end()) {
123 base_name_count[base_name] = 0;
124 } else {
125 auto count = ++base_name_count[base_name];
126 base_name += "_";
127 base_name += std::to_string(count);
128 }
129 return base_name;
130 }
131
IsGraphOutput(const Value * v,const std::shared_ptr<Graph> & graph) const132 bool NodeNameGenerator::IsGraphOutput(
133 const Value* v,
134 const std::shared_ptr<Graph>& graph) const {
135 for (const auto* graph_output : graph->outputs()) {
136 if (v == graph_output) {
137 return true;
138 }
139 }
140 return false;
141 }
142
UpdateOutputsNames(Node * n)143 void NodeNameGenerator::UpdateOutputsNames(Node* n) {
144 if (node_names_.find(n) != node_names_.end()) {
145 auto node_name = node_names_[n];
146 for (auto i : c10::irange(n->outputs().size())) {
147 auto output = n->output(i);
148 if (!IsGraphOutput(output, graph_)) {
149 auto output_name = node_name;
150 output_name.append("_output_").append(std::to_string(i));
151 output->setDebugName(output_name);
152 }
153 }
154 }
155 }
156
PopulateNodeNames()157 void NodeNameGenerator::PopulateNodeNames() {
158 PopulateNodeNames(graph_->block());
159 }
160
PopulateNodeNames(Block * b)161 void NodeNameGenerator::PopulateNodeNames(Block* b) {
162 for (auto* n : b->nodes()) {
163 for (auto* sub_block : n->blocks()) {
164 PopulateNodeNames(sub_block);
165 }
166 CreateNodeName(n);
167 UpdateOutputsNames(n);
168 }
169 }
170
CreateNodeName(Node * n)171 void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
172 if (node_names_.find(n) == node_names_.end()) {
173 if (!ONNXScopeName::isCompatibleScope(n->scope())) {
174 return;
175 }
176 if (n->mustBeNone()) {
177 // JIT IR does not allow attribute for None node.
178 return;
179 }
180 auto name = GetFullScopeName(n->scope());
181 name += layer_separator_;
182 name += n->kind().toUnqualString();
183 node_names_[n] = CreateUniqueName(base_node_name_counts_, name);
184 }
185 n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
186 }
187
GetFullScopeName(const ScopePtr & scope)188 std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
189 if (full_scope_names_.find(scope) == full_scope_names_.end()) {
190 auto full_scope_name =
191 ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
192 full_scope_names_[scope] =
193 CreateUniqueName(base_scope_name_counts_, full_scope_name);
194 }
195 return full_scope_names_[scope];
196 }
197
198 } // namespace
199
AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph> & graph)200 void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
201 auto node_name_generator = std::make_unique<ScopedNodeNameGenerator>(graph);
202 node_name_generator->PopulateNodeNames();
203 }
204
205 } // namespace torch::jit::onnx
206