1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/SymbolTable.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
31 #include "tensorflow/core/framework/resource_var.h"
32 #include "tensorflow/core/public/session.h"
33
34 namespace mlir {
35 namespace tf_saved_model {
36 namespace {
37
InitializeVariable(TF::VarHandleOp var_handle_op,tensorflow::Tensor * tensor,func::FuncOp session_init_func,OpBuilder builder)38 void InitializeVariable(TF::VarHandleOp var_handle_op,
39 tensorflow::Tensor* tensor,
40 func::FuncOp session_init_func, OpBuilder builder) {
41 tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
42 tensorflow::ConvertTensor(*tensor, &builder);
43 assert(tensor_attr_or.ok() && "Expect valid tensor");
44 ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
45
46 builder.setInsertionPointToStart(&session_init_func.getBlocks().front());
47 auto var_handle_op_in_init = var_handle_op->clone();
48 builder.insert(var_handle_op_in_init);
49 auto const_op = builder.create<mlir::arith::ConstantOp>(
50 session_init_func.getLoc(), tensor_attr.getType(), tensor_attr);
51
52 builder.create<TF::AssignVariableOp>(
53 session_init_func.getLoc(), llvm::ArrayRef<mlir::Type>{},
54 llvm::ArrayRef<mlir::Value>{var_handle_op_in_init->getResult(0),
55 const_op.getResult()});
56 }
57
58 constexpr char kTfSavedModelExportedNameAttr[] =
59 "tf_saved_model.exported_names";
60
CreateSessionInitFunc(ModuleOp module)61 func::FuncOp CreateSessionInitFunc(ModuleOp module) {
62 constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
63
64 mlir::OpBuilder builder(module.getBodyRegion());
65 auto func_type =
66 FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{});
67 auto func = builder.create<func::FuncOp>(module->getLoc(),
68 kSessionInitFuncName, func_type);
69 func->setAttr(kTfSavedModelExportedNameAttr,
70 builder.getStrArrayAttr({kSessionInitFuncName}));
71 func.setVisibility(mlir::func::FuncOp::Visibility::Public);
72 auto func_builder = OpBuilder::atBlockBegin(func.addEntryBlock());
73 func_builder.create<mlir::func::ReturnOp>(func.getLoc());
74 // In cases where there is a session initializer op with empty initializer,
75 // replace the session initializer with the new one that points to the session
76 // initializer func.
77 SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
78 auto new_session_init_op =
79 builder.create<tf_saved_model::SessionInitializerOp>(
80 module->getLoc(), builder.getArrayAttr(SymbolRefAttr::get(
81 builder.getContext(), kSessionInitFuncName)));
82 if (session_init_op) {
83 session_init_op->replaceAllUsesWith(new_session_init_op);
84 session_init_op->erase();
85 }
86 return func;
87 }
88
GetOrCreateSessionInitFunc(ModuleOp module)89 func::FuncOp GetOrCreateSessionInitFunc(ModuleOp module) {
90 SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
91 if (!session_init_op) return CreateSessionInitFunc(module);
92
93 SymbolTable symbol_table(module);
94 if (!session_init_op.initializers().empty()) {
95 func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
96 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
97 return init_func_op;
98 }
99 return CreateSessionInitFunc(module);
100 }
101
102 } // namespace
103
InitializeVariablesInSessionInitializer(ModuleOp module,tensorflow::Session * session)104 LogicalResult InitializeVariablesInSessionInitializer(
105 ModuleOp module, tensorflow::Session* session) {
106 const tensorflow::DeviceMgr* mgr = nullptr;
107 auto status = session->LocalDeviceManager(&mgr);
108 if (!status.ok()) {
109 module->emitError("failed to fetch device manager: " +
110 status.error_message());
111 return failure();
112 }
113
114 // Fetch all VarHandleOp.
115 llvm::StringSet<> variable_names;
116 llvm::SmallVector<TF::VarHandleOp, 4> var_ops;
117 for (auto func_op : module.getOps<func::FuncOp>()) {
118 for (auto var_handle_op : func_op.getOps<TF::VarHandleOp>()) {
119 auto variable_name = GetVariableName(var_handle_op);
120 if (variable_names.count(variable_name)) continue;
121 var_ops.emplace_back(var_handle_op);
122 variable_names.insert(variable_name);
123 }
124 }
125
126 // Get resources from Session.
127 auto resource_tensors_or = GetResourcesFromSession(var_ops, session);
128 if (!resource_tensors_or.ok()) {
129 module->emitError(resource_tensors_or.status().message().data());
130 return failure();
131 }
132
133 auto session_init_func = GetOrCreateSessionInitFunc(module);
134 OpBuilder builder(session_init_func.getContext());
135
136 for (auto var_and_tensor : llvm::zip(var_ops, resource_tensors_or.value())) {
137 auto& var_op = std::get<0>(var_and_tensor);
138 auto& resource_tensor = std::get<1>(var_and_tensor);
139 if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
140 InitializeVariable(var_op, &resource_tensor, session_init_func, builder);
141 continue;
142 }
143
144 auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
145 auto* var_ptr = GetVariableFromSession(var_op, handle.device(), mgr);
146 if (!var_ptr) {
147 // If no value in session, then just skip this variable.
148 // This can happen if the variable is not saved in checkpoint.
149 // For example, when the variable is created on every call.
150 continue;
151 }
152 tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
153 auto* tensor = var_ptr->tensor();
154
155 InitializeVariable(var_op, tensor, session_init_func, builder);
156 }
157 return success();
158 }
159
160 } // namespace tf_saved_model
161 } // namespace mlir
162