1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/OpImplementation.h" // from @llvm-project
31 #include "mlir/IR/PatternMatch.h" // from @llvm-project
32 #include "mlir/IR/SymbolTable.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37
38 namespace mlir {
39 namespace tf_saved_model {
40
41 //===----------------------------------------------------------------------===//
42 // Utilities
43 //===----------------------------------------------------------------------===//
44
IsStrArrayAttr(Attribute attr)45 static bool IsStrArrayAttr(Attribute attr) {
46 auto array = attr.dyn_cast<ArrayAttr>();
47 if (!array) return false;
48
49 return llvm::all_of(array,
50 [](Attribute attr) { return attr.isa<StringAttr>(); });
51 }
52
53 //===----------------------------------------------------------------------===//
54 // TensorFlowSavedModelDialect Op's
55 //===----------------------------------------------------------------------===//
56
VerifyTensorTypesCompatible(Type t1,Type t2)57 LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
58 if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
59 return failure();
60 }
61 return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
62 }
63
verify()64 LogicalResult GlobalTensorOp::verify() {
65 GlobalTensorOp global_tensor = *this;
66 if (failed(VerifyTensorTypesCompatible(global_tensor.type(),
67 global_tensor.value().getType()))) {
68 return global_tensor.emitError() << "'type' and 'value' attributes should "
69 "have compatible tensor types";
70 }
71 if (!global_tensor.is_mutable()) {
72 if (!global_tensor.type().cast<TensorType>().hasStaticShape()) {
73 return global_tensor.emitError()
74 << "'type' attribute for immutable 'tf_saved_model.global_tensor' "
75 "should have a static shape";
76 }
77 }
78 return success();
79 }
80
verify()81 LogicalResult SessionInitializerOp::verify() {
82 SessionInitializerOp session_initializer = *this;
83 mlir::SymbolTable symbol_table(
84 session_initializer->getParentOfType<ModuleOp>());
85
86 for (auto sym_ref : session_initializer.initializers()) {
87 auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
88 sym_ref.cast<FlatSymbolRefAttr>().getValue());
89
90 if (!init_func_op)
91 return session_initializer.emitOpError()
92 << "the initializer function does not exist";
93
94 if (!init_func_op.getFunctionType().getResults().empty())
95 return session_initializer.emitOpError()
96 << "the initializer function should have no output";
97
98 auto exported_names = GetExportedNames(init_func_op);
99
100 if (exported_names.empty())
101 return session_initializer.emitOpError()
102 << "the initializer function should be exported";
103
104 if (exported_names.size() != 1)
105 return session_initializer.emitOpError()
106 << "the initializer function should have only one exported names";
107 }
108
109 return success();
110 }
111
112 } // namespace tf_saved_model
113 } // namespace mlir
114
115 #define GET_OP_CLASSES
116 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
117
118 namespace mlir {
119 namespace tf_saved_model {
120
121 //===----------------------------------------------------------------------===//
122 // TensorFlowSavedModelDialect Dialect
123 //===----------------------------------------------------------------------===//
124
TensorFlowSavedModelDialect(MLIRContext * context)125 TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
126 : Dialect(/*name=*/"tf_saved_model", context,
127 TypeID::get<TensorFlowSavedModelDialect>()) {
128 // The TensorFlow Dialect is needed in the verifier and other routines
129 // associated to this dialect. It makes little sense anyway to use the
130 // SavedModel dialect without the TensorFlow Dialect.
131 context->loadDialect<TF::TensorFlowDialect>();
132
133 addOperations<
134 #define GET_OP_LIST
135 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
136 >();
137 }
138
VerifyIndexPath(Operation * op,NamedAttribute named_attr)139 static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
140 auto attr = named_attr.getValue().dyn_cast<ArrayAttr>();
141 if (!attr) {
142 return op->emitError()
143 << "'tf_saved_model.index_path' attribute should be an ArrayAttr";
144 }
145 for (auto element : attr) {
146 if (element.isa<StringAttr>()) {
147 continue;
148 }
149 if (auto integer = element.dyn_cast<IntegerAttr>()) {
150 if (integer.getValue().getBitWidth() == 64) {
151 continue;
152 }
153 }
154 return op->emitError() << "'tf_saved_model.index_path' elements should "
155 "be strings or 64-bit integers";
156 }
157 return mlir::success();
158 }
159
GetBoundInputArgTypeFor(mlir::Operation * op)160 Type GetBoundInputArgTypeFor(mlir::Operation *op) {
161 if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
162 auto type = global_tensor.type().cast<TensorType>();
163 return RankedTensorType::get(
164 {}, TF::ResourceType::get({type}, type.getContext()));
165 }
166
167 if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
168 return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
169 }
170
171 op->emitError() << "unknown symbol operation";
172 return {};
173 }
174
VerifyBoundInputArgType(Operation * op_for_diagnostics,Type arg_type,mlir::Operation * symbol_op)175 static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
176 Type arg_type,
177 mlir::Operation *symbol_op) {
178 auto expected_type = GetBoundInputArgTypeFor(symbol_op);
179 if (!expected_type) return failure();
180
181 if (arg_type != expected_type) {
182 return op_for_diagnostics->emitError()
183 << "bound input with type " << arg_type << " expected to have type "
184 << expected_type;
185 }
186 return success();
187 }
188
verifyRegionArgAttribute(Operation * op,unsigned region_index,unsigned arg_index,NamedAttribute named_attr)189 LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
190 Operation *op, unsigned region_index, unsigned arg_index,
191 NamedAttribute named_attr) {
192 if (named_attr.getName() == "tf_saved_model.bound_input") {
193 if (!named_attr.getValue().isa<FlatSymbolRefAttr>()) {
194 return op->emitError() << "'tf_saved_model.bound_input' attribute should "
195 "be a FlatSymbolRefAttr";
196 }
197 auto symbol_name =
198 named_attr.getValue().cast<FlatSymbolRefAttr>().getValue();
199 auto module = op->getParentOfType<ModuleOp>();
200 mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
201 if (!symbol_op) {
202 return op->emitError() << "'tf_saved_model.bound_input' attribute must "
203 "reference a valid symbol, got invalid symbol '"
204 << symbol_name << "'";
205 }
206 auto arg_type = cast<func::FuncOp>(op).getArgument(arg_index).getType();
207 return VerifyBoundInputArgType(op, arg_type, symbol_op);
208 }
209 if (named_attr.getName() == "tf_saved_model.index_path") {
210 return VerifyIndexPath(op, named_attr);
211 }
212
213 return op->emitError() << "unknown tf_saved_model dialect arg attribute '"
214 << named_attr.getName().getValue() << "'";
215 }
216
verifyRegionResultAttribute(Operation * op,unsigned region_index,unsigned result_index,NamedAttribute named_attr)217 LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute(
218 Operation *op, unsigned region_index, unsigned result_index,
219 NamedAttribute named_attr) {
220 if (named_attr.getName() == "tf_saved_model.index_path") {
221 return VerifyIndexPath(op, named_attr);
222 }
223
224 return op->emitError() << "unknown tf_saved_model dialect result attribute '"
225 << named_attr.getName().getValue() << "'";
226 }
227
HasAnyTfSavedModelArgAttr(func::FuncOp func)228 static bool HasAnyTfSavedModelArgAttr(func::FuncOp func) {
229 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
230 if (func.getArgAttr(i, "tf_saved_model.index_path") ||
231 func.getArgAttr(i, "tf_saved_model.bound_input")) {
232 return true;
233 }
234 }
235 for (int i = 0, e = func.getNumResults(); i < e; i++) {
236 if (func.getResultAttr(i, "tf_saved_model.index_path") ||
237 func.getResultAttr(i, "tf_saved_model.bound_input")) {
238 return true;
239 }
240 }
241 return false;
242 }
243
VerifySavedModelModule(ModuleOp module,TensorFlowSavedModelDialect * dialect)244 static LogicalResult VerifySavedModelModule(
245 ModuleOp module, TensorFlowSavedModelDialect *dialect) {
246 auto exported_names_ident =
247 StringAttr::get(dialect->getContext(), "tf_saved_model.exported_names");
248 // Check that there are no duplicated exported_names.
249 DenseMap<StringRef, Operation *> exported_name_to_op;
250 for (auto &op : module) {
251 auto attr = op.getAttr(exported_names_ident);
252 if (!attr) continue;
253 // If this verifier is called before we verify the
254 // 'tf_saved_model.exported_names' attribute, then it might be invalid.
255 // Forward to the dialect's verification to establish that precondition.
256 if (failed(dialect->verifyOperationAttribute(
257 &op, {exported_names_ident, attr}))) {
258 return failure();
259 }
260 for (auto str : attr.cast<ArrayAttr>()) {
261 auto exported_name = str.cast<StringAttr>().getValue();
262 auto p = exported_name_to_op.insert({exported_name, &op});
263 if (!p.second) {
264 return op.emitError()
265 .append("duplicate exported name '", exported_name, "'")
266 .attachNote(p.first->getSecond()->getLoc())
267 .append("previously seen here");
268 }
269 }
270 }
271 for (auto func : module.getOps<func::FuncOp>()) {
272 const bool is_exported = IsExported(func);
273
274 if (is_exported && !func.isPublic()) {
275 return func.emitError()
276 << "exported function @" << func.getName() << " should be public";
277 }
278
279 if (!is_exported && func.isPublic()) {
280 return func.emitError() << "non-exported function @" << func.getName()
281 << " should be private";
282 }
283 if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
284 return func.emitError() << "can only apply 'tf_saved_model' argument "
285 "attributes to exported functions";
286 }
287 }
288
289 auto session_initializers = module.getOps<SessionInitializerOp>();
290 if (!session_initializers.empty() &&
291 !llvm::hasSingleElement(session_initializers)) {
292 return (*++session_initializers.begin()).emitError()
293 << "there must be no more than one session_initializer op";
294 }
295
296 auto is_init = [&session_initializers](mlir::func::FuncOp func) {
297 if (session_initializers.empty()) return false;
298 auto init_syms = (*session_initializers.begin()).initializers();
299 return std::any_of(
300 init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
301 return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
302 });
303 };
304
305 SymbolTable symbol_table(module);
306 auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
307 if (!symbol_uses.has_value()) {
308 return module.emitError() << "modules with 'tf_saved_model.semantics' must "
309 "have analyzable symbol uses";
310 }
311 for (auto symbol_use : *symbol_uses) {
312 auto func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>(
313 symbol_use.getUser(), symbol_use.getSymbolRef());
314 if (func && IsExported(func)) {
315 // If it is an init function, then it can be used by the unique
316 // session_initializer op.
317 if (is_init(func) &&
318 llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
319 continue;
320
321 return symbol_use.getUser()
322 ->emitError("exported function cannot be internally referenced")
323 .attachNote(func.getLoc())
324 .append("references this exported function");
325 }
326 }
327 return success();
328 }
329
VerifyExportedFunc(func::FuncOp func)330 LogicalResult VerifyExportedFunc(func::FuncOp func) {
331 bool reached_bound_inputs = false;
332 auto module = func->getParentOfType<ModuleOp>();
333 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
334 if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
335 reached_bound_inputs = true;
336 continue;
337 }
338 if (func.getArgAttr(i, "tf_saved_model.index_path")) {
339 if (reached_bound_inputs) {
340 return func.emitError()
341 << "all 'tf_saved_model.index_path' arg attributes should "
342 "precede all 'tf_saved_model.bound_input' arg attributes";
343 }
344 continue;
345 }
346 if (func.getArgAttr(i, "tf.resource_name")) {
347 if (module->getAttr("tf_saved_model.under_construction")) continue;
348 return func.emitError() << "'tf.resource_name' attribute is not allowed "
349 "unless it is being under construction";
350 }
351 return func.emitError()
352 << "all arguments should have 'tf_saved_model.index_path', "
353 "'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
354 }
355 llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
356 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
357 if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
358 i, "tf_saved_model.bound_input")) {
359 if (!unique_bound_inputs.insert(attr.getValue()).second) {
360 if (module->getAttr("tf_saved_model.under_construction")) continue;
361 return func.emitError()
362 << "duplicate 'tf_saved_model.bound_input' binding";
363 }
364 }
365 }
366
367 for (int i = 0, e = func.getNumResults(); i < e; i++) {
368 if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
369 return func.emitError() << "all results should have "
370 "'tf_saved_model.index_path' attributes";
371 }
372 }
373
374 return success();
375 }
376
verifyOperationAttribute(Operation * op,NamedAttribute named_attr)377 LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
378 Operation *op, NamedAttribute named_attr) {
379 if (named_attr.getName() == "tf_saved_model.exported_names") {
380 if (!isa<func::FuncOp, GlobalTensorOp>(op)) {
381 return op->emitError() << "'tf_saved_model.exported_names' must be on a "
382 "'func' or 'tf_saved_model.global_tensor' op";
383 }
384 if (!IsStrArrayAttr(named_attr.getValue())) {
385 return op->emitError()
386 << "'tf_saved_model.exported_names' must be an array of strings";
387 }
388 if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) {
389 return op->emitError()
390 << "'tf_saved_model.exported_names' must be on an op "
391 "whose immediate parent has attribute "
392 "'tf_saved_model.semantics'";
393 }
394 if (auto func = dyn_cast<func::FuncOp>(op)) {
395 if (failed(VerifyExportedFunc(func))) {
396 return failure();
397 }
398 }
399 return success();
400 }
401 if (named_attr.getName() == "tf_saved_model.semantics") {
402 auto module = dyn_cast<ModuleOp>(op);
403 if (!module) {
404 return op->emitError() << "'tf_saved_model.semantics' must "
405 "be on a module op";
406 }
407 return VerifySavedModelModule(module, this);
408 }
409 if (named_attr.getName() == "tf_saved_model.under_construction") {
410 return success();
411 }
412
413 return op->emitError() << "unknown tf_saved_model dialect attribute '"
414 << named_attr.getName().getValue() << "'";
415 }
416
GetExportedNames(Operation * op)417 SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
418 SmallVector<StringRef, 2> ret;
419 auto exported_names =
420 op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
421 if (exported_names) {
422 for (auto name : exported_names) {
423 ret.push_back(name.cast<StringAttr>().getValue());
424 }
425 }
426 return ret;
427 }
428
IsExported(Operation * op)429 bool IsExported(Operation *op) {
430 auto exported_names =
431 op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
432 return exported_names && !exported_names.empty();
433 }
434
HasTfSavedModelSemantics(ModuleOp module)435 bool HasTfSavedModelSemantics(ModuleOp module) {
436 return module->getAttr("tf_saved_model.semantics") != nullptr;
437 }
438
LookupBoundInput(func::FuncOp func,int arg_index,const SymbolTable & symbol_table)439 Operation *LookupBoundInput(func::FuncOp func, int arg_index,
440 const SymbolTable &symbol_table) {
441 auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
442 arg_index, "tf_saved_model.bound_input");
443 if (!attr) return nullptr;
444 return symbol_table.lookup(attr.getValue());
445 }
446
GetSessionInitializerOp(mlir::ModuleOp op)447 SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
448 auto initializers = op.getOps<SessionInitializerOp>();
449 if (initializers.empty()) return {};
450 return *initializers.begin();
451 }
452
453 class OptimizeSessionInitializerPattern
454 : public OpRewritePattern<SessionInitializerOp> {
455 public:
456 using OpRewritePattern::OpRewritePattern;
457
matchAndRewrite(SessionInitializerOp op,PatternRewriter & rewriter) const458 LogicalResult matchAndRewrite(SessionInitializerOp op,
459 PatternRewriter &rewriter) const override {
460 SymbolTable symbol_table(op->getParentOfType<ModuleOp>());
461
462 SmallVector<func::FuncOp, 2> to_remove;
463 SmallVector<mlir::Attribute, 2> to_keep;
464 for (auto sym_ref : op.initializers()) {
465 auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
466 sym_ref.cast<FlatSymbolRefAttr>().getValue());
467
468 // The init function can only be referenced from the SessionInitializerOp.
469 // And there is at most one SessionInitializerOp in the module. So if both
470 // ops have no other uses or have one NoOp only, they can be simply
471 // erased.
472 auto &operations = init_func_op.front().getOperations();
473 if ((operations.size() == 1 &&
474 operations.front().hasTrait<OpTrait::IsTerminator>()) ||
475 (operations.size() == 2 &&
476 dyn_cast<mlir::TF::NoOp>(operations.front()) &&
477 operations.back().hasTrait<OpTrait::IsTerminator>())) {
478 to_remove.push_back(init_func_op);
479 } else {
480 to_keep.push_back(sym_ref);
481 }
482 }
483
484 for (auto func_op : to_remove) rewriter.eraseOp(func_op);
485
486 if (to_keep.empty())
487 rewriter.eraseOp(op);
488 else
489 op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
490
491 return success();
492 }
493 };
494
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)495 void SessionInitializerOp::getCanonicalizationPatterns(
496 RewritePatternSet &results, MLIRContext *context) {
497 results.add<OptimizeSessionInitializerPattern>(context);
498 }
499
GetSessionInitializerExportedName(ModuleOp op)500 SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
501 auto session_initializer_op = GetSessionInitializerOp(op);
502 if (!session_initializer_op) return {};
503
504 SymbolTable symbol_table(op);
505
506 SmallVector<StringRef, 2> results;
507 for (auto sym_ref : session_initializer_op.initializers()) {
508 auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
509 sym_ref.cast<FlatSymbolRefAttr>().getValue());
510 auto exported_names = GetExportedNames(init_func_op);
511 assert(exported_names.size() == 1);
512 results.push_back(exported_names[0]);
513 }
514
515 return results;
516 }
517
518 } // namespace tf_saved_model
519 } // namespace mlir
520