1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 17 #include "mlir/IR/OperationSupport.h" // from @llvm-project 18 #include "mlir/Pass/Pass.h" // from @llvm-project 19 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 20 #include "tensorflow/compiler/mlir/lite/transforms/passes.h" 21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" 22 23 namespace mlir { 24 namespace TFL { 25 namespace { 26 #define GEN_PASS_CLASSES 27 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" 28 29 // This pass inserts a TFL::CallOnce op when tf_saved_model's session 30 // initializer is given. 31 class InsertCallOnceOpFromSessionInitializerPass 32 : public InsertCallOnceOpFromSessionInitializerPassBase< 33 InsertCallOnceOpFromSessionInitializerPass> { 34 public: 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 36 InsertCallOnceOpFromSessionInitializerPass) 37 38 private: 39 void runOnOperation() override; 40 }; 41 runOnOperation()42void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() { 43 ModuleOp module = getOperation(); 44 tf_saved_model::SessionInitializerOp session_init_op = 45 tf_saved_model::GetSessionInitializerOp(module); 46 47 if (!session_init_op) return; 48 49 SymbolTable symbol_table(module); 50 51 for (auto sym_ref : session_init_op.initializers()) { 52 func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>( 53 sym_ref.cast<FlatSymbolRefAttr>().getValue()); 54 55 if (!init_func_op) { 56 module.emitError("no session initializer function found"); 57 return signalPassFailure(); 58 } 59 60 for (auto func : module.getOps<func::FuncOp>()) { 61 auto dict_attr = 62 func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function"); 63 if (!dict_attr) continue; 64 65 OpBuilder builder(func.getContext()); 66 builder.setInsertionPointToStart(&func.getBlocks().front()); 67 builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName()); 68 } 69 } 70 } 71 72 } // namespace 73 74 // Inserts a TFL::CallOnce op when tf_saved_model's session initializer is 75 // given. 76 std::unique_ptr<OperationPass<ModuleOp>> CreateInsertCallOnceOpFromSessionInitializerPass()77CreateInsertCallOnceOpFromSessionInitializerPass() { 78 return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>(); 79 } 80 81 } // namespace TFL 82 } // namespace mlir 83