1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
25 #include "mlir/IR/Dialect.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
36 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
37 #include "tensorflow/compiler/tf2xla/tf2xla.h"
38 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
39 #include "tensorflow/compiler/xla/client/xla_computation.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
41
42 namespace tensorflow {
43
44 namespace {
45
46 // A fake device to simulate the presence of a CPU.
47 class FakeDevice : public Device {
48 public:
FakeDevice(const DeviceAttributes & device_attributes)49 explicit FakeDevice(const DeviceAttributes& device_attributes)
50 : Device(nullptr, device_attributes) {}
51
Sync()52 Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
53 };
54
55 // Translates the graph input information from tf2xla:::Config to
56 // GraphImportConfig.
ConvertInputInfo(const tf2xla::Config & config,const std::unordered_map<std::string,std::string> & feed_name_remap,GraphImportConfig * specs)57 Status ConvertInputInfo(
58 const tf2xla::Config& config,
59 const std::unordered_map<std::string, std::string>& feed_name_remap,
60 GraphImportConfig* specs) {
61 std::vector<std::string> array_names;
62 std::vector<std::string> data_types;
63 std::vector<llvm::Optional<std::vector<int>>> shapes;
64 for (const tf2xla::Feed& feed : config.feed()) {
65 std::string place_holder_name =
66 feed_name_remap.at(TensorIdToString(feed.id()));
67 array_names.push_back(place_holder_name);
68 data_types.push_back(
69 feed.type() == DT_INVALID ? "" : DataType_Name(feed.type()));
70 if (feed.shape().unknown_rank()) {
71 shapes.push_back(llvm::None);
72 continue;
73 }
74 std::vector<int> dims;
75 dims.reserve(feed.shape().dim_size());
76 absl::c_for_each(feed.shape().dim(), [&](const TensorShapeProto::Dim d) {
77 dims.push_back(d.size());
78 });
79 shapes.push_back(dims);
80 }
81
82 return ParseInputArrayInfo(array_names, data_types, shapes, &specs->inputs);
83 }
84
85 // Translates the graph output information from tf2xla:::Config to
86 // GraphImportConfig.
ConvertOutputInfo(const tf2xla::Config & config,GraphImportConfig * specs)87 Status ConvertOutputInfo(const tf2xla::Config& config,
88 GraphImportConfig* specs) {
89 std::vector<std::string> array_names;
90 for (const tf2xla::Fetch& fetch : config.fetch()) {
91 array_names.push_back(fetch.id().node_name());
92 }
93
94 return ParseOutputArrayInfo(array_names, &specs->outputs);
95 }
96
97 } // namespace
98
ConvertGraphDefToXlaViaMlir(GraphDef graph_def,const tf2xla::Config & config,xla::XlaComputation * computation,absl::string_view debug_info_filename,absl::string_view debug_info_path_begin_marker)99 Status ConvertGraphDefToXlaViaMlir(
100 GraphDef graph_def, const tf2xla::Config& config,
101 xla::XlaComputation* computation, absl::string_view debug_info_filename,
102 absl::string_view debug_info_path_begin_marker) {
103 // AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
104 // purposes: (1) It creates a placeholder node for each feed, so that
105 // PruneGraphDefInfo can prune away the node containing the feed. (2) It
106 // is also a workaround for b/149029125. It replaces a feed representation
107 // with a placeholder node that contains a single output.
108 FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
109 std::unique_ptr<Graph> graph(new Graph(flib_def));
110 std::unordered_map<string, string> feed_name_remap;
111 TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, graph->op_registry(),
112 &feed_name_remap, &graph_def));
113
114 // TODO(b/149024678): remove this workaround after the ticket is fixed.
115 // Prune the GraphDef because MLIR importer doesn't allow unknown ops in
116 // graph nodes even the nodes are not needed for computing the outputs.
117 GraphDef pruned_graph_def;
118 TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, &pruned_graph_def));
119
120 GraphImportConfig specs;
121 specs.prune_unused_nodes = false;
122 specs.convert_legacy_fed_inputs = false;
123 specs.graph_as_function = false;
124 specs.upgrade_legacy = true;
125 TF_RETURN_IF_ERROR(ConvertInputInfo(config, feed_name_remap, &specs));
126 TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
127
128 GraphDebugInfo debug_info;
129 if (!debug_info_filename.empty()) {
130 TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info));
131
132 if (!debug_info_path_begin_marker.empty()) {
133 for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) {
134 std::string* file_name = debug_info.mutable_files(i);
135 size_t location =
136 file_name->rfind(std::string(debug_info_path_begin_marker));
137 if (location != std::string::npos) {
138 *file_name = file_name->substr(location +
139 debug_info_path_begin_marker.length());
140 }
141 }
142 }
143 }
144
145 mlir::MLIRContext context;
146 TF_ASSIGN_OR_RETURN(
147 mlir::OwningOpRef<mlir::ModuleOp> module,
148 ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
149
150 // Construct a CPU device and add the device to the operations.
151 DeviceSet device_set;
152 DeviceAttributes attr;
153 attr.set_name("/job:localhost/replica:0/task:0/device:CPU:0");
154 attr.set_device_type(DeviceType("CPU").type());
155 FakeDevice device(attr);
156 device_set.AddDevice(&device);
157 AddDevicesToOp(*module, &device_set);
158
159 TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
160 *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
161
162 // Convert the MLIR module to XLA computation. If the input graph can't be
163 // lowered down to a single graph node with a single island by the previous
164 // step, this step will return an error.
165 return ConvertMLIRToXlaComputation(
166 *module, /*device_type=*/"XLA_CPU_JIT", computation,
167 /*use_tuple_args=*/false, /*prefer_tf2xla=*/false,
168 /*return_tuple=*/true);
169 }
170
171 } // namespace tensorflow
172