1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/graph_compactor/pass.h"
17
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "llvm/ADT/BitVector.h"
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/OperationSupport.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/LLVM.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/ir/dialect.h"
33 #include "tensorflow/core/ir/importexport/convert_attributes.h"
34 #include "tensorflow/core/ir/interfaces.h"
35 #include "tensorflow/core/ir/ops.h"
36 #include "tensorflow/core/ir/tf_op_registry.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/transforms/pass_detail.h"
39
40 namespace mlir {
41 namespace tfg {
42
43 // Encode an unsigned integer in as few characters as possible to a string that
44 // is still a valid TensorFlow node name. The regex for valid names, according
45 // to `NodeDef`, is "[A-Za-z0-9.][A-Za-z0-9_>./]*"
46 //
47 // The valid characters are provided in the two arrays `first_valid_chars` and
48 // `trailing_valid_chars`.
EncodeName(unsigned counter,std::string & output,ArrayRef<char> first_valid_chars,ArrayRef<char> trailing_valid_chars)49 static void EncodeName(unsigned counter, std::string &output,
50 ArrayRef<char> first_valid_chars,
51 ArrayRef<char> trailing_valid_chars) {
52 assert(!first_valid_chars.empty() && !trailing_valid_chars.empty());
53 unsigned rem = counter % first_valid_chars.size();
54 counter /= first_valid_chars.size();
55 output.push_back(first_valid_chars[rem]);
56 while (counter > 0) {
57 --counter;
58 rem = counter % trailing_valid_chars.size();
59 counter /= trailing_valid_chars.size();
60 output.push_back(trailing_valid_chars[rem]);
61 }
62 }
63
64 // Encode an unsigned integer to a valid TensorFlow node name.
EncodeName(unsigned counter,std::string & output)65 static void EncodeName(unsigned counter, std::string &output) {
66 // The alphabet of valid characters, but the last 3 are only valid in trailing
67 // characters.
68 static constexpr char valid_chars[] =
69 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._>/";
70 // Sanity check: all alphanumeric characters, four special characters, and a
71 // null terminator.
72 constexpr unsigned valid_first_chars = 26 * 2 + 10 + 1;
73 constexpr unsigned valid_trailing_chars = valid_first_chars + 3;
74 static_assert(sizeof(valid_chars) == valid_trailing_chars + 1,
75 "alphabet sanity check");
76 EncodeName(counter, output,
77 llvm::makeArrayRef(valid_chars, valid_first_chars),
78 llvm::makeArrayRef(valid_chars, valid_trailing_chars));
79 }
80
81 namespace {
82 class NameCompressPass : public NameCompressBase<NameCompressPass> {
83 public:
initialize(MLIRContext * context)84 LogicalResult initialize(MLIRContext *context) override {
85 dialect_ = context->getOrLoadDialect<TFGraphDialect>();
86 empty_dict_ = DictionaryAttr::get(context);
87 return success();
88 }
89
runOnOperation()90 void runOnOperation() override {
91 GraphFuncOp func = getOperation();
92
93 Builder b(&getContext());
94 unsigned counter = 0;
95 std::string name;
96 const auto encode_new_name = [&name, &b, &counter] {
97 name.clear();
98 EncodeName(counter++, name);
99 return b.getStringAttr(name);
100 };
101
102 // Rename the arguments and results.
103 NamedAttrList attrs = func->getAttrDictionary();
104 if (func.getNumArguments()) {
105 assert(func.arg_attrs().has_value() && "expected argument attributes");
106 SmallVector<Attribute> arg_attrs;
107 arg_attrs.reserve(func.getNumArguments());
108 // Iterate over the function arguments, skipping the control tokens.
109 for (int i = 0, e = func.getNumArguments(); i != e; i += 2) {
110 NamedAttrList attrs = func.arg_attrsAttr()[i].cast<DictionaryAttr>();
111 attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
112 arg_attrs.append({attrs.getDictionary(&getContext()), empty_dict_});
113 }
114 attrs.set(func.arg_attrsAttrName(), b.getArrayAttr(arg_attrs));
115 }
116 if (func.getNumResults()) {
117 assert(func.res_attrs().has_value() && "expected result attributes");
118 SmallVector<Attribute> res_attrs;
119 res_attrs.reserve(func.getNumResults());
120 for (NamedAttrList attrs :
121 func.res_attrsAttr().getAsRange<DictionaryAttr>()) {
122 attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
123 res_attrs.push_back(attrs.getDictionary(&getContext()));
124 }
125 attrs.set(func.res_attrsAttrName(), b.getArrayAttr(res_attrs));
126 }
127 if (func.getNumArguments() || func.getNumResults()) {
128 func->setAttrs(attrs.getDictionary(&getContext()));
129 }
130
131 // Rename the control results.
132 ReturnOp terminator = cast<ReturnOp>(func.getBody()->getTerminator());
133 ArrayAttr control_attrs = terminator.control_ret_attrs();
134 if (!attrs.empty()) {
135 SmallVector<Attribute> control_ret_attrs;
136 control_ret_attrs.reserve(control_attrs.size());
137 for (NamedAttrList attrs : control_attrs.getAsRange<DictionaryAttr>()) {
138 attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
139 control_ret_attrs.push_back(attrs.getDictionary(&getContext()));
140 }
141 terminator.control_ret_attrsAttr(b.getArrayAttr(control_ret_attrs));
142 }
143
144 // Rename all non-intrisic operations.
145 func.walk([this, &encode_new_name](Operation *op) {
146 if (op->hasTrait<OpTrait::IntrinsicOperation>()) return;
147 op->setAttr(dialect_->getNameAttrIdentifier(), encode_new_name());
148 });
149 }
150
151 private:
152 // An instance of the TFG dialect for accessing cached identifiers.
153 TFGraphDialect *dialect_;
154 // An instance of the empty dictionary attribute.
155 DictionaryAttr empty_dict_;
156 };
157 } // namespace
158
CreateNameCompressPass()159 std::unique_ptr<Pass> CreateNameCompressPass() {
160 return std::make_unique<NameCompressPass>();
161 }
162
163 namespace {
164 class StripDefaultAttrsPass
165 : public StripDefaultAttrsBase<StripDefaultAttrsPass> {
166 public:
initialize(MLIRContext * context)167 LogicalResult initialize(MLIRContext *context) override {
168 // Initialize the pass by getting a registered instance of the TensorFlow
169 // operation registry. If no instance was registered, this pass will fail.
170 dialect_ = context->getOrLoadDialect<TFGraphDialect>();
171 registry_ = nullptr;
172 if (auto registry_interface =
173 dialect_->getRegisteredInterface<TensorFlowOpRegistryInterface>()) {
174 registry_ = registry_interface->GetRegistry();
175 }
176 return success(registry_);
177 }
178
runOnOperation()179 void runOnOperation() override {
180 WalkResult result = getOperation()->walk([&](Operation *op) {
181 // Ignore intrinsic operations.
182 if (op->hasTrait<OpTrait::IntrinsicOperation>())
183 return WalkResult::advance();
184
185 // If removing default-valued attributes failed (attribute conversion
186 // error), bail out.
187 if (failed(removeDefaultValuedAttrs(op))) return WalkResult::interrupt();
188
189 return WalkResult::advance();
190 });
191
192 // If the pass failed on any operation, signal failure.
193 if (result.wasInterrupted()) return signalPassFailure();
194 }
195
196 private:
197 // Remove attributes from the operation equal to their default values
198 // according to the TensorFlow op registry.
199 LogicalResult removeDefaultValuedAttrs(Operation *op);
200
201 // The TFG dialect instance.
202 TFGraphDialect *dialect_;
203 // The TensorFlow op registry to query for default-valued attributes.
204 const tensorflow::OpRegistry *registry_;
205 };
206 } // namespace
207
removeDefaultValuedAttrs(Operation * op)208 LogicalResult StripDefaultAttrsPass::removeDefaultValuedAttrs(Operation *op) {
209 const tensorflow::OpRegistrationData *op_reg_data =
210 registry_->LookUp(op->getName().stripDialect().str());
211 // Ignore unregistered ops.
212 if (!op_reg_data) return success();
213
214 // Find the attributes to remove.
215 ArrayRef<NamedAttribute> attrs = op->getAttrs();
216 llvm::BitVector indices_to_remove(attrs.size());
217 Builder b(&getContext());
218 for (const tensorflow::OpDef::AttrDef &attr : op_reg_data->op_def.attr()) {
219 // Ignore attributes without default values.
220 if (!attr.has_default_value()) continue;
221 auto it = impl::findAttrSorted(attrs.begin(), attrs.end(), attr.name());
222 // Ignore default-valued attributes that are already missing.
223 if (!it.second) continue;
224 // Convert the TensorFlow attribute value and compare it to the MLIR
225 // attribute.
226 tensorflow::StatusOr<Attribute> maybe_attr =
227 ConvertAttributeValue(attr.default_value(), b);
228 if (!maybe_attr.ok())
229 return op->emitError(maybe_attr.status().error_message());
230 if (maybe_attr.ValueOrDie() == it.first->getValue())
231 indices_to_remove.set(std::distance(attrs.begin(), it.first));
232 }
233 if (indices_to_remove.none()) return success();
234
235 // Construct and set the new attributes.
236 SmallVector<NamedAttribute> new_attrs;
237 new_attrs.reserve(attrs.size());
238 for (auto &it : llvm::enumerate(attrs)) {
239 if (indices_to_remove.test(it.index())) continue;
240 new_attrs.push_back(it.value());
241 }
242 op->setAttrs(DictionaryAttr::getWithSorted(&getContext(), new_attrs));
243
244 return success();
245 }
246
CreateStripDefaultAttrsPass()247 std::unique_ptr<Pass> CreateStripDefaultAttrsPass() {
248 return std::make_unique<StripDefaultAttrsPass>();
249 }
250
251 namespace {
252 class AddDefaultAttrsPass : public AddDefaultAttrsBase<AddDefaultAttrsPass> {
253 public:
initialize(MLIRContext * context)254 LogicalResult initialize(MLIRContext *context) override {
255 // Initialize the pass by getting a registered instance of the TensorFlow
256 // operation registry. If no instance was registered, this pass will fail.
257 dialect_ = context->getOrLoadDialect<TFGraphDialect>();
258 registry_ = nullptr;
259 if (auto registry_interface =
260 dialect_->getRegisteredInterface<TensorFlowOpRegistryInterface>()) {
261 registry_ = registry_interface->GetRegistry();
262 }
263 return success(registry_);
264 }
265
runOnOperation()266 void runOnOperation() override {
267 WalkResult result = getOperation()->walk([&](Operation *op) {
268 // Ignore intrinsic operations.
269 if (op->hasTrait<OpTrait::IntrinsicOperation>())
270 return WalkResult::advance();
271
272 // If removing default-valued attributes failed (attribute conversion
273 // error), bail out.
274 if (failed(addDefaultValuedAttrs(op))) return WalkResult::interrupt();
275
276 return WalkResult::advance();
277 });
278
279 // If the pass failed on any operation, signal failure.
280 if (result.wasInterrupted()) return signalPassFailure();
281 }
282
283 private:
284 // Remove attributes from the operation equal to their default values
285 // according to the TensorFlow op registry.
286 LogicalResult addDefaultValuedAttrs(Operation *op);
287
288 // The TFG dialect instance.
289 TFGraphDialect *dialect_;
290 // The TensorFlow op registry to query for default-valued attributes.
291 const tensorflow::OpRegistry *registry_;
292 };
293 } // namespace
294
addDefaultValuedAttrs(Operation * op)295 LogicalResult AddDefaultAttrsPass::addDefaultValuedAttrs(Operation *op) {
296 const tensorflow::OpRegistrationData *op_reg_data =
297 registry_->LookUp(op->getName().stripDialect().str());
298 // Ignore unregistered ops.
299 if (!op_reg_data) return success();
300
301 // Ignore operations with no default-valued attributes.
302 if (llvm::all_of(op_reg_data->op_def.attr(),
303 [](const auto &attr) { return !attr.has_default_value(); }))
304 return success();
305
306 // Add missing default-valued attributes
307 Builder b(&getContext());
308 NamedAttrList attrs = op->getAttrDictionary();
309 for (const auto &attr : op_reg_data->op_def.attr()) {
310 // Ignore attributes without default values.
311 if (!attr.has_default_value()) continue;
312 // Ignore default-valued attributes that are present.
313 if (attrs.get(attr.name())) continue;
314 // Convert the TensorFlow attribute value and set it.
315 tensorflow::StatusOr<Attribute> maybe_attr =
316 ConvertAttributeValue(attr.default_value(), b);
317 if (!maybe_attr.ok())
318 return op->emitError(maybe_attr.status().error_message());
319 attrs.set(attr.name(), maybe_attr.ValueOrDie());
320 }
321 op->setAttrs(attrs.getDictionary(&getContext()));
322
323 return success();
324 }
325
CreateAddDefaultAttrsPass()326 std::unique_ptr<Pass> CreateAddDefaultAttrsPass() {
327 return std::make_unique<AddDefaultAttrsPass>();
328 }
329
330 } // namespace tfg
331 } // namespace mlir
332