1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
31 #include "tensorflow/core/framework/resource_var.h"
32 #include "tensorflow/core/public/session.h"
33 
34 namespace mlir {
35 namespace tf_saved_model {
36 namespace {
37 
InitializeVariable(TF::VarHandleOp var_handle_op,tensorflow::Tensor * tensor,func::FuncOp session_init_func,OpBuilder builder)38 void InitializeVariable(TF::VarHandleOp var_handle_op,
39                         tensorflow::Tensor* tensor,
40                         func::FuncOp session_init_func, OpBuilder builder) {
41   tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
42       tensorflow::ConvertTensor(*tensor, &builder);
43   assert(tensor_attr_or.ok() && "Expect valid tensor");
44   ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
45 
46   builder.setInsertionPointToStart(&session_init_func.getBlocks().front());
47   auto var_handle_op_in_init = var_handle_op->clone();
48   builder.insert(var_handle_op_in_init);
49   auto const_op = builder.create<mlir::arith::ConstantOp>(
50       session_init_func.getLoc(), tensor_attr.getType(), tensor_attr);
51 
52   builder.create<TF::AssignVariableOp>(
53       session_init_func.getLoc(), llvm::ArrayRef<mlir::Type>{},
54       llvm::ArrayRef<mlir::Value>{var_handle_op_in_init->getResult(0),
55                                   const_op.getResult()});
56 }
57 
58 constexpr char kTfSavedModelExportedNameAttr[] =
59     "tf_saved_model.exported_names";
60 
CreateSessionInitFunc(ModuleOp module)61 func::FuncOp CreateSessionInitFunc(ModuleOp module) {
62   constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
63 
64   mlir::OpBuilder builder(module.getBodyRegion());
65   auto func_type =
66       FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{});
67   auto func = builder.create<func::FuncOp>(module->getLoc(),
68                                            kSessionInitFuncName, func_type);
69   func->setAttr(kTfSavedModelExportedNameAttr,
70                 builder.getStrArrayAttr({kSessionInitFuncName}));
71   func.setVisibility(mlir::func::FuncOp::Visibility::Public);
72   auto func_builder = OpBuilder::atBlockBegin(func.addEntryBlock());
73   func_builder.create<mlir::func::ReturnOp>(func.getLoc());
74   // In cases where there is a session initializer op with empty initializer,
75   // replace the session initializer with the new one that points to the session
76   // initializer func.
77   SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
78   auto new_session_init_op =
79       builder.create<tf_saved_model::SessionInitializerOp>(
80           module->getLoc(), builder.getArrayAttr(SymbolRefAttr::get(
81                                 builder.getContext(), kSessionInitFuncName)));
82   if (session_init_op) {
83     session_init_op->replaceAllUsesWith(new_session_init_op);
84     session_init_op->erase();
85   }
86   return func;
87 }
88 
GetOrCreateSessionInitFunc(ModuleOp module)89 func::FuncOp GetOrCreateSessionInitFunc(ModuleOp module) {
90   SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
91   if (!session_init_op) return CreateSessionInitFunc(module);
92 
93   SymbolTable symbol_table(module);
94   if (!session_init_op.initializers().empty()) {
95     func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
96         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
97     return init_func_op;
98   }
99   return CreateSessionInitFunc(module);
100 }
101 
102 }  // namespace
103 
InitializeVariablesInSessionInitializer(ModuleOp module,tensorflow::Session * session)104 LogicalResult InitializeVariablesInSessionInitializer(
105     ModuleOp module, tensorflow::Session* session) {
106   const tensorflow::DeviceMgr* mgr = nullptr;
107   auto status = session->LocalDeviceManager(&mgr);
108   if (!status.ok()) {
109     module->emitError("failed to fetch device manager: " +
110                       status.error_message());
111     return failure();
112   }
113 
114   // Fetch all VarHandleOp.
115   llvm::StringSet<> variable_names;
116   llvm::SmallVector<TF::VarHandleOp, 4> var_ops;
117   for (auto func_op : module.getOps<func::FuncOp>()) {
118     for (auto var_handle_op : func_op.getOps<TF::VarHandleOp>()) {
119       auto variable_name = GetVariableName(var_handle_op);
120       if (variable_names.count(variable_name)) continue;
121       var_ops.emplace_back(var_handle_op);
122       variable_names.insert(variable_name);
123     }
124   }
125 
126   // Get resources from Session.
127   auto resource_tensors_or = GetResourcesFromSession(var_ops, session);
128   if (!resource_tensors_or.ok()) {
129     module->emitError(resource_tensors_or.status().message().data());
130     return failure();
131   }
132 
133   auto session_init_func = GetOrCreateSessionInitFunc(module);
134   OpBuilder builder(session_init_func.getContext());
135 
136   for (auto var_and_tensor : llvm::zip(var_ops, resource_tensors_or.value())) {
137     auto& var_op = std::get<0>(var_and_tensor);
138     auto& resource_tensor = std::get<1>(var_and_tensor);
139     if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
140       InitializeVariable(var_op, &resource_tensor, session_init_func, builder);
141       continue;
142     }
143 
144     auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
145     auto* var_ptr = GetVariableFromSession(var_op, handle.device(), mgr);
146     if (!var_ptr) {
147       // If no value in session, then just skip this variable.
148       // This can happen if the variable is not saved in checkpoint.
149       // For example, when the variable is created on every call.
150       continue;
151     }
152     tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
153     auto* tensor = var_ptr->tensor();
154 
155     InitializeVariable(var_op, tensor, session_init_func, builder);
156   }
157   return success();
158 }
159 
160 }  // namespace tf_saved_model
161 }  // namespace mlir
162