1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/utils/name_utils.h"
17
18 #include <cctype>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24
25 namespace mlir {
26
27 namespace {
28 // Checks if a character is legal for a TensorFlow node name, with special
29 // handling if a character is at the beginning.
IsLegalChar(char c,bool first_char)30 bool IsLegalChar(char c, bool first_char) {
31 if (isalpha(c)) return true;
32 if (isdigit(c)) return true;
33 if (c == '.') return true;
34 if (c == '_') return true;
35
36 // First character of a node name can only be a letter, digit, dot or
37 // underscore.
38 if (first_char) return false;
39
40 if (c == '/') return true;
41 if (c == '-') return true;
42
43 return false;
44 }
45 } // anonymous namespace
46
LegalizeNodeName(std::string & name)47 void LegalizeNodeName(std::string& name) {
48 if (name.empty()) return;
49
50 if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
51
52 for (char& c : llvm::drop_begin(name, 1))
53 if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
54 }
55
GetNameFromLoc(Location loc)56 std::string GetNameFromLoc(Location loc) {
57 llvm::SmallVector<llvm::StringRef, 8> loc_names;
58 llvm::SmallVector<Location, 8> locs;
59 locs.push_back(loc);
60 bool names_is_nonempty = false;
61
62 while (!locs.empty()) {
63 Location curr_loc = locs.pop_back_val();
64
65 if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
66 // Add name in NameLoc. For NameLoc we also account for names due to ops
67 // in functions where the op's name is first.
68 auto name = name_loc.getName().strref().split('@').first;
69 // Skip if the name is for op type.
70 if (!name.endswith(":")) {
71 loc_names.push_back(name);
72 if (!name.empty()) names_is_nonempty = true;
73 }
74 continue;
75 } else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
76 // Use location of the Callee to generate the name.
77 locs.push_back(call_loc.getCallee());
78 continue;
79 } else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
80 // Push all locations in FusedLoc in reverse order, so locations are
81 // visited based on order in FusedLoc.
82 auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
83 locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
84 continue;
85 }
86
87 // Location is not a supported, so an empty StringRef is added.
88 loc_names.push_back(llvm::StringRef());
89 }
90
91 if (names_is_nonempty)
92 return llvm::join(loc_names.begin(), loc_names.end(), ";");
93
94 return "";
95 }
96
GetOpTypeFromLoc(Location loc)97 std::string GetOpTypeFromLoc(Location loc) {
98 llvm::SmallVector<llvm::StringRef, 1> loc_op_types;
99 llvm::SmallVector<Location, 8> locs;
100 locs.push_back(loc);
101 bool op_types_is_nonempty = false;
102
103 while (!locs.empty()) {
104 Location curr_loc = locs.pop_back_val();
105
106 if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
107 // Add name in NameLoc. For NameLoc we also account for names due to ops
108 // in functions where the op's name is first.
109 auto op_type = name_loc.getName().strref().split('@').first;
110 if (op_type.endswith(":")) {
111 op_type = op_type.substr(0, op_type.size() - 1);
112 loc_op_types.push_back(op_type);
113 if (!op_type.empty()) op_types_is_nonempty = true;
114 }
115 continue;
116 } else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
117 // Use location of the Callee to generate the name.
118 locs.push_back(call_loc.getCallee());
119 continue;
120 } else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
121 // The first location is reserved for op_type.
122 if (!fused_loc.getLocations().empty())
123 locs.push_back(fused_loc.getLocations()[0]);
124 continue;
125 }
126
127 // Location is not a supported, so an empty StringRef is added.
128 loc_op_types.push_back(llvm::StringRef());
129 }
130
131 if (op_types_is_nonempty)
132 return llvm::join(loc_op_types.begin(), loc_op_types.end(), ";");
133
134 return "";
135 }
136
137 } // namespace mlir
138