xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/mlir_tf2xla.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
25 #include "mlir/IR/Dialect.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
36 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
37 #include "tensorflow/compiler/tf2xla/tf2xla.h"
38 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
39 #include "tensorflow/compiler/xla/client/xla_computation.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
41 
42 namespace tensorflow {
43 
44 namespace {
45 
46 // A fake device to simulate the presence of a CPU.
47 class FakeDevice : public Device {
48  public:
FakeDevice(const DeviceAttributes & device_attributes)49   explicit FakeDevice(const DeviceAttributes& device_attributes)
50       : Device(nullptr, device_attributes) {}
51 
Sync()52   Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
53 };
54 
55 // Translates the graph input information from tf2xla:::Config to
56 // GraphImportConfig.
ConvertInputInfo(const tf2xla::Config & config,const std::unordered_map<std::string,std::string> & feed_name_remap,GraphImportConfig * specs)57 Status ConvertInputInfo(
58     const tf2xla::Config& config,
59     const std::unordered_map<std::string, std::string>& feed_name_remap,
60     GraphImportConfig* specs) {
61   std::vector<std::string> array_names;
62   std::vector<std::string> data_types;
63   std::vector<llvm::Optional<std::vector<int>>> shapes;
64   for (const tf2xla::Feed& feed : config.feed()) {
65     std::string place_holder_name =
66         feed_name_remap.at(TensorIdToString(feed.id()));
67     array_names.push_back(place_holder_name);
68     data_types.push_back(
69         feed.type() == DT_INVALID ? "" : DataType_Name(feed.type()));
70     if (feed.shape().unknown_rank()) {
71       shapes.push_back(llvm::None);
72       continue;
73     }
74     std::vector<int> dims;
75     dims.reserve(feed.shape().dim_size());
76     absl::c_for_each(feed.shape().dim(), [&](const TensorShapeProto::Dim d) {
77       dims.push_back(d.size());
78     });
79     shapes.push_back(dims);
80   }
81 
82   return ParseInputArrayInfo(array_names, data_types, shapes, &specs->inputs);
83 }
84 
85 // Translates the graph output information from tf2xla:::Config to
86 // GraphImportConfig.
ConvertOutputInfo(const tf2xla::Config & config,GraphImportConfig * specs)87 Status ConvertOutputInfo(const tf2xla::Config& config,
88                          GraphImportConfig* specs) {
89   std::vector<std::string> array_names;
90   for (const tf2xla::Fetch& fetch : config.fetch()) {
91     array_names.push_back(fetch.id().node_name());
92   }
93 
94   return ParseOutputArrayInfo(array_names, &specs->outputs);
95 }
96 
97 }  // namespace
98 
ConvertGraphDefToXlaViaMlir(GraphDef graph_def,const tf2xla::Config & config,xla::XlaComputation * computation,absl::string_view debug_info_filename,absl::string_view debug_info_path_begin_marker)99 Status ConvertGraphDefToXlaViaMlir(
100     GraphDef graph_def, const tf2xla::Config& config,
101     xla::XlaComputation* computation, absl::string_view debug_info_filename,
102     absl::string_view debug_info_path_begin_marker) {
103   // AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
104   // purposes: (1) It creates a placeholder node for each feed, so that
105   // PruneGraphDefInfo can prune away the node containing the feed. (2) It
106   // is also a workaround for b/149029125. It replaces a feed representation
107   // with a placeholder node that contains a single output.
108   FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
109   std::unique_ptr<Graph> graph(new Graph(flib_def));
110   std::unordered_map<string, string> feed_name_remap;
111   TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, graph->op_registry(),
112                                              &feed_name_remap, &graph_def));
113 
114   // TODO(b/149024678): remove this workaround after the ticket is fixed.
115   //   Prune the GraphDef because MLIR importer doesn't allow unknown ops in
116   //   graph nodes even the nodes are not needed for computing the outputs.
117   GraphDef pruned_graph_def;
118   TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, &pruned_graph_def));
119 
120   GraphImportConfig specs;
121   specs.prune_unused_nodes = false;
122   specs.convert_legacy_fed_inputs = false;
123   specs.graph_as_function = false;
124   specs.upgrade_legacy = true;
125   TF_RETURN_IF_ERROR(ConvertInputInfo(config, feed_name_remap, &specs));
126   TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
127 
128   GraphDebugInfo debug_info;
129   if (!debug_info_filename.empty()) {
130     TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info));
131 
132     if (!debug_info_path_begin_marker.empty()) {
133       for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) {
134         std::string* file_name = debug_info.mutable_files(i);
135         size_t location =
136             file_name->rfind(std::string(debug_info_path_begin_marker));
137         if (location != std::string::npos) {
138           *file_name = file_name->substr(location +
139                                          debug_info_path_begin_marker.length());
140         }
141       }
142     }
143   }
144 
145   mlir::MLIRContext context;
146   TF_ASSIGN_OR_RETURN(
147       mlir::OwningOpRef<mlir::ModuleOp> module,
148       ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
149 
150   // Construct a CPU device and add the device to the operations.
151   DeviceSet device_set;
152   DeviceAttributes attr;
153   attr.set_name("/job:localhost/replica:0/task:0/device:CPU:0");
154   attr.set_device_type(DeviceType("CPU").type());
155   FakeDevice device(attr);
156   device_set.AddDevice(&device);
157   AddDevicesToOp(*module, &device_set);
158 
159   TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
160       *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
161 
162   // Convert the MLIR module to XLA computation. If the input graph can't be
163   // lowered down to a single graph node with a single island by the previous
164   // step, this step will return an error.
165   return ConvertMLIRToXlaComputation(
166       *module, /*device_type=*/"XLA_CPU_JIT", computation,
167       /*use_tuple_args=*/false, /*prefer_tf2xla=*/false,
168       /*return_tuple=*/true);
169 }
170 
171 }  // namespace tensorflow
172