1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h"
17
18 #include "absl/strings/str_split.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/MLIRContext.h" // from @llvm-project
22 #include "mlir/Pass/PassManager.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
26 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
31 #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime
32 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
33 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
34 #include "tfrt/host_context/host_context.h" // from @tf_runtime
35 #include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime
36
37 namespace tensorflow {
38 namespace {
39
ProcessIndexPath(mlir::ArrayAttr index_path)40 llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) {
41 if (index_path.size() == 1 && index_path[0].isa<mlir::StringAttr>()) {
42 // TODO(chky): Support cases where index_path is not a single string.
43 return index_path[0].cast<mlir::StringAttr>().getValue();
44 }
45 return "";
46 }
47
48 StatusOr<std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>>
ProcessTensorSpec(mlir::TensorType type)49 ProcessTensorSpec(mlir::TensorType type) {
50 tensorflow::DataType dtype;
51 TF_RETURN_IF_ERROR(
52 ConvertScalarTypeToDataType(type.getElementType(), &dtype));
53
54 if (!type.hasRank())
55 return std::make_pair(dtype, tensorflow::PartialTensorShape());
56
57 auto shape = type.getShape();
58 llvm::SmallVector<int64_t, 4> dims;
59 dims.assign(shape.begin(), shape.end());
60 return std::make_pair(dtype, tensorflow::PartialTensorShape(dims));
61 }
62
63 } // namespace
64
MapFunctionSignaturesFromTFSavedModelMLIR(mlir::ModuleOp module,llvm::function_ref<void (const TFRTSavedModelSignatureInfo &)> map_fn)65 Status MapFunctionSignaturesFromTFSavedModelMLIR(
66 mlir::ModuleOp module,
67 llvm::function_ref<void(const TFRTSavedModelSignatureInfo&)> map_fn) {
68 // Create bound inputs for each functions.
69 mlir::SymbolTable symbol_table(module);
70 tensorflow::Status status = OkStatus();
71 module.walk([&symbol_table, map_fn, &status](mlir::func::FuncOp func) {
72 // Use the exported name as the function name, and skip non-exported
73 // functions.
74 auto func_names = mlir::tf_saved_model::GetExportedNames(func);
75 if (func_names.empty()) return mlir::WalkResult::advance();
76
77 auto func_type = func.getFunctionType();
78
79 // Here we walk through each arguments and find out the input/output names,
80 // and input devices, variables used by this function.
81 llvm::SmallVector<llvm::StringRef, 4> input_names;
82 llvm::SmallVector<
83 std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>, 4>
84 input_specs;
85 llvm::SmallVector<llvm::StringRef, 4> input_devices;
86 llvm::SmallVector<mlir::Operation*, 4> bound_inputs;
87 for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
88 if (auto input_index_path = func.getArgAttrOfType<mlir::ArrayAttr>(
89 i, "tf_saved_model.index_path")) {
90 input_names.push_back(ProcessIndexPath(input_index_path));
91 auto statusor_spec =
92 ProcessTensorSpec(func_type.getInput(i).cast<mlir::TensorType>());
93 if (!statusor_spec.ok()) {
94 status = std::move(statusor_spec).status();
95 return mlir::WalkResult::interrupt();
96 }
97 input_specs.push_back(std::move(statusor_spec).ValueOrDie());
98 if (auto input_device =
99 func.getArgAttrOfType<mlir::StringAttr>(i, "tf.device")) {
100 input_devices.push_back(input_device.getValue());
101 } else {
102 input_devices.push_back("");
103 }
104 }
105 if (auto* bound_input =
106 mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
107 bound_inputs.push_back(bound_input);
108 }
109 }
110
111 llvm::SmallVector<llvm::StringRef, 4> output_names;
112 llvm::SmallVector<
113 std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>, 4>
114 output_specs;
115 for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) {
116 if (auto output_index_path = func.getResultAttrOfType<mlir::ArrayAttr>(
117 i, "tf_saved_model.index_path")) {
118 output_names.push_back(ProcessIndexPath(output_index_path));
119 auto statusor_spec =
120 ProcessTensorSpec(func_type.getResult(i).cast<mlir::TensorType>());
121 if (!statusor_spec.ok()) {
122 status = std::move(statusor_spec).status();
123 return mlir::WalkResult::interrupt();
124 }
125 output_specs.push_back(std::move(statusor_spec).ValueOrDie());
126 }
127 }
128
129 for (auto func_name : func_names) {
130 TFRTSavedModelSignatureInfo sig_info;
131 sig_info.func_name = func_name;
132 sig_info.input_names = input_names;
133 sig_info.input_specs = input_specs;
134 sig_info.input_devices = input_devices;
135 sig_info.output_names = output_names;
136 sig_info.output_specs = output_specs;
137 sig_info.bound_inputs = bound_inputs;
138 map_fn(sig_info);
139 }
140
141 return mlir::WalkResult::advance();
142 });
143
144 return status;
145 }
146
147 } // namespace tensorflow
148