1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ADT/StringSet.h"
21 #include "mlir/Support/LogicalResult.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
23
24 namespace mlir {
25 namespace TFR {
26 namespace {
27
28 // TODO(b/174692018): Use the official allowlist of the unregistered attrs.
GetAllowedAttributes()29 const llvm::StringSet<>& GetAllowedAttributes() {
30 static auto* const ops = new llvm::StringSet<>({"device", "_tpu_replicate"});
31 return *ops;
32 }
33
34 // Some TFL optional attributes may not appear in their corresponding TF op
35 // attributes.
GetOptionalAttributes()36 const llvm::StringSet<>& GetOptionalAttributes() {
37 static auto* const ops =
38 new llvm::StringSet<>({"asymmetric_quantize_inputs"});
39 return *ops;
40 }
41
CollectAllowedAttrs(CallOp src,NamedAttrList * attrs)42 void CollectAllowedAttrs(CallOp src, NamedAttrList* attrs) {
43 for (auto& attr : src->getAttrs()) {
44 if (GetAllowedAttributes().contains(attr.getName().strref())) {
45 attrs->append(attr);
46 }
47 }
48 }
49
50 // Adds `attrs` to all the operations between `begin` and `end` in the same
51 // block. Does not include `end`.
AddAttributesInSameBlock(Block::iterator begin,Block::iterator end,const NamedAttrList & attrs)52 void AddAttributesInSameBlock(Block::iterator begin, Block::iterator end,
53 const NamedAttrList& attrs) {
54 for (Block::iterator it = begin; it != end; ++it) {
55 for (auto& attr : attrs) {
56 it->setAttr(attr.getName(), attr.getValue());
57 }
58 }
59 }
60
61 // Adds `attrs` to all the operations between `begin` and `end`. Does not
62 // include `end`. The operations might be across multiple blocks.
AddAttributes(Block::iterator begin,Block::iterator end,const NamedAttrList & attrs)63 void AddAttributes(Block::iterator begin, Block::iterator end,
64 const NamedAttrList& attrs) {
65 if (begin->getBlock() == end->getBlock()) {
66 AddAttributesInSameBlock(begin, end, attrs);
67 } else {
68 Region::iterator begin_block = Region::iterator(begin->getBlock());
69 Region::iterator end_block = Region::iterator(end->getBlock());
70 AddAttributesInSameBlock(begin, begin_block->end(), attrs);
71 for (Region::iterator it = ++begin_block; it != end_block; ++it) {
72 AddAttributesInSameBlock(it->begin(), it->end(), attrs);
73 }
74 }
75 }
76
77 } // namespace
78
GetComposeFuncName(StringRef tf_op_name)79 std::string GetComposeFuncName(StringRef tf_op_name) {
80 std::string compose_func_name;
81 for (int i = 0; i < tf_op_name.size(); ++i) {
82 if (tf_op_name[i] == '_') {
83 // The field name must not contain "_"s. "_Arg" and "_RetVal" are special
84 // op names and we can return empty string to skip the decomposition.
85 return {};
86 }
87 if (tf_op_name[i] == '.') {
88 compose_func_name.push_back('_');
89 } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') {
90 compose_func_name.push_back('_');
91 compose_func_name.push_back(tf_op_name[i] + 'a' - 'A');
92 } else {
93 compose_func_name.push_back(tf_op_name[i]);
94 }
95 }
96 return compose_func_name;
97 }
98
GetTFOpName(StringRef compose_func_name)99 std::string GetTFOpName(StringRef compose_func_name) {
100 std::string tf_op_name;
101 bool after_underscore = false;
102 for (int i = 0; i < compose_func_name.size(); ++i) {
103 if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') {
104 // The field name must not contain uppercase letters.
105 return {};
106 }
107 if (after_underscore) {
108 if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') {
109 tf_op_name.push_back(compose_func_name[i] + 'A' - 'a');
110 after_underscore = false;
111 } else {
112 // The character after a "_" must be a lowercase letter.
113 return {};
114 }
115 } else if (compose_func_name[i] == '_') { // first time visit '_'
116 if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') {
117 tf_op_name.push_back('.');
118 i++;
119 }
120 after_underscore = true;
121 } else {
122 tf_op_name.push_back(compose_func_name[i]);
123 }
124 }
125 if (after_underscore) {
126 // Trailing "_".
127 return {};
128 }
129 return tf_op_name;
130 }
131
ValidateAttrs(Operation * src,const StringSet<> & registered)132 LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered) {
133 for (auto& attr : src->getAttrs()) {
134 StringRef attr_name = attr.getName().strref();
135
136 if (!registered.contains(attr_name) &&
137 !(GetAllowedAttributes().contains(attr_name) ||
138 GetOptionalAttributes().contains(attr_name))) {
139 src->emitError("Denied unregistered attribute was found: " + attr_name);
140 return failure();
141 }
142 }
143 return success();
144 }
145
CopyAllowedUnregisteredAttrs(Operation * src,CallOp dst,const StringSet<> & registered)146 LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
147 const StringSet<>& registered) {
148 for (auto& attr : src->getAttrs()) {
149 StringRef attr_name = attr.getName().strref();
150 // Skip the registered or optional attribute.
151 if (registered.contains(attr_name) ||
152 GetOptionalAttributes().contains(attr_name))
153 continue;
154
155 // Unregistered attribute.
156 if (GetAllowedAttributes().contains(attr_name)) {
157 dst->setAttr(attr.getName(), attr.getValue());
158 } else {
159 src->emitError("Denied unregistered attribute was found: " + attr_name);
160 return failure();
161 }
162 }
163 return success();
164 }
165
CopyNonSymbolRefAttrs(CallOp src,Operation * dst)166 LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst) {
167 NamedAttrList attrs;
168 CollectAllowedAttrs(src, &attrs);
169
170 for (auto& attr : attrs) {
171 dst->setAttr(attr.getName(), attr.getValue());
172 }
173
174 return success();
175 }
176
PropagateAttrsToOperations(CallOp src,Block::iterator begin,Block::iterator end)177 void PropagateAttrsToOperations(CallOp src, Block::iterator begin,
178 Block::iterator end) {
179 // Find all the attributes in the call op. These attributes are not in the
180 // op definition, so needs to be propagated to all the target ops.
181 NamedAttrList attrs;
182 CollectAllowedAttrs(src, &attrs);
183
184 // Add all the attributes to the operations in the range.
185 if (!attrs.empty()) {
186 AddAttributes(begin, end, attrs);
187 }
188 }
189
190 } // namespace TFR
191 } // namespace mlir
192