xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <type_traits>
18 #include <vector>
19 
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/AsmParser/AsmParser.h"  // from @llvm-project
30 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
31 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/Dialect.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
40 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
43 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
44 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
45 #include "tensorflow/compiler/tf2xla/xla_argument.h"
46 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
47 #include "tensorflow/compiler/xla/service/hlo.pb.h"
48 #include "tensorflow/compiler/xla/service/hlo_module.h"
49 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/status.h"
55 
56 namespace {
57 
58 // NOLINTNEXTLINE
59 llvm::cl::opt<std::string> input_types(
60     "tf-xla-input-types",
61     llvm::cl::desc("XLA input argument types (kinds), separated by ','. "
62                    "Supported types include ['parameter', 'resource']. If "
63                    "empty, all arguments are assumed to be parameters."),
64     llvm::cl::init(""));
65 // NOLINTNEXTLINE
66 llvm::cl::opt<bool> emit_use_tuple_arg(
67     "tf-xla-emit-use-tuple-args",
68     llvm::cl::desc(
69         "Emit HLO modules using tuples as args for the entry computation"),
70     llvm::cl::init(false));
71 // NOLINTNEXTLINE
72 llvm::cl::opt<bool> emit_return_tuple(
73     "tf-xla-emit-return-tuple",
74     llvm::cl::desc("Emit HLO modules with entry computations returning tuple"),
75     llvm::cl::init(false));
76 }  // namespace
77 
78 namespace tensorflow {
79 
80 namespace {
81 
PrintHloModuleText(const XlaCompilationResult & compilation_result,llvm::raw_ostream & output)82 mlir::LogicalResult PrintHloModuleText(
83     const XlaCompilationResult& compilation_result, llvm::raw_ostream& output) {
84   const xla::HloModuleConfig module_config(
85       compilation_result.computation->GetProgramShape().ValueOrDie());
86   auto status_or_hlo_module = xla::HloModule::CreateFromProto(
87       compilation_result.computation->proto(), module_config);
88   if (!status_or_hlo_module.ok()) {
89     LOG(ERROR) << "Conversion to HLO module failed: "
90                << status_or_hlo_module.status().ToString();
91     return mlir::failure();
92   }
93 
94   xla::HloModule* hlo_module = status_or_hlo_module.ValueOrDie().get();
95 
96   output << hlo_module->ToString();
97 
98   if (!compilation_result.input_mapping.empty())
99     output << "// InputMapping {"
100            << absl::StrJoin(compilation_result.input_mapping, ", ") << "}\n";
101 
102   for (const auto& xla_input_shape : compilation_result.xla_input_shapes)
103     output << "// XlaInputShape " << xla_input_shape.ToString() << '\n';
104 
105   output << "// XlaOutputShape "
106          << compilation_result.xla_output_shape.ToString() << '\n';
107 
108   for (const auto& xla_output_description : compilation_result.outputs) {
109     output << "// XlaOutputDescription type="
110            << DataTypeString(xla_output_description.type) << " shape=("
111            << absl::StrJoin(xla_output_description.shape.dim_sizes(), ", ")
112            << ')';
113     if (xla_output_description.input_index >= 0)
114       output << " input_index=" << xla_output_description.input_index;
115     if (xla_output_description.is_constant) output << " constant";
116     if (xla_output_description.is_tensor_list) output << " tensor_list";
117     output << '\n';
118   }
119 
120   for (const auto& resource_update : compilation_result.resource_updates) {
121     output << "// ResourceUpdate input_index=" << resource_update.input_index
122            << " type=" << DataTypeString(resource_update.type) << " shape=("
123            << absl::StrJoin(resource_update.shape.dim_sizes(), " ") << ')';
124     if (resource_update.modified) output << " modified";
125     output << '\n';
126   }
127 
128   return mlir::success();
129 }
130 
ParseArgumentShapes(absl::string_view input_shapes_str,llvm::SmallVectorImpl<TensorOrResourceShape> & arg_shapes)131 Status ParseArgumentShapes(
132     absl::string_view input_shapes_str,
133     llvm::SmallVectorImpl<TensorOrResourceShape>& arg_shapes) {
134   arg_shapes.clear();
135   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
136   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector));
137   arg_shapes.resize(input_shapes_vector.size());
138   for (const auto& shape : llvm::enumerate(input_shapes_vector)) {
139     if (!shape.value().has_value()) {
140       TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
141           static_cast<int*>(nullptr), 0, &arg_shapes[shape.index()].shape));
142       continue;
143     }
144     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
145         shape.value().getValue(), &arg_shapes[shape.index()].shape));
146   }
147 
148   return OkStatus();
149 }
150 
ParseDataTypes(absl::string_view data_types_str,llvm::SmallVectorImpl<DataType> & data_types)151 Status ParseDataTypes(absl::string_view data_types_str,
152                       llvm::SmallVectorImpl<DataType>& data_types) {
153   data_types.clear();
154   std::vector<std::string> input_dtypes_vector;
155   TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types_str, input_dtypes_vector));
156   data_types.resize(input_dtypes_vector.size(), DT_INVALID);
157   for (auto data_type : llvm::enumerate(input_dtypes_vector)) {
158     if (!DataType_Parse(data_type.value(), &data_types[data_type.index()]))
159       return errors::InvalidArgument("Invalid dtype at index ",
160                                      data_type.index(), ": ",
161                                      data_type.value());
162     const auto& resolved_dtype = data_types[data_type.index()];
163     if (resolved_dtype == DT_INVALID || resolved_dtype == DT_STRING ||
164         resolved_dtype == DT_RESOURCE || resolved_dtype == DT_VARIANT ||
165         IsRefType(resolved_dtype))
166       return errors::InvalidArgument("Unsupported dtype at index ",
167                                      data_type.index(), ": ",
168                                      data_type.value());
169   }
170 
171   return OkStatus();
172 }
173 
ParseArgumentKinds(absl::string_view input_types_str,llvm::SmallVectorImpl<XlaArgument::Kind> & argument_kinds)174 Status ParseArgumentKinds(
175     absl::string_view input_types_str,
176     llvm::SmallVectorImpl<XlaArgument::Kind>& argument_kinds) {
177   argument_kinds.clear();
178   if (input_types_str.empty()) return OkStatus();
179 
180   std::vector<absl::string_view> argument_kind_strs =
181       absl::StrSplit(input_types_str, ',');
182   argument_kinds.reserve(argument_kind_strs.size());
183   for (const auto& argument_kind_str : llvm::enumerate(argument_kind_strs)) {
184     const auto& value = argument_kind_str.value();
185     if (value == "parameter") {
186       argument_kinds.push_back(XlaArgument::Kind::kParameter);
187     } else if (value == "resource") {
188       argument_kinds.push_back(XlaArgument::Kind::kResource);
189     } else {
190       return errors::InvalidArgument(
191           "Unsupported TF/XLA argument kind at index ",
192           argument_kind_str.index(), ": ", value);
193     }
194   }
195 
196   return OkStatus();
197 }
198 
ParseXlaArguments(absl::string_view input_shapes_str,absl::string_view input_dtypes_str,absl::string_view arg_kinds_str,llvm::SmallVectorImpl<XlaArgument> & xla_arguments)199 Status ParseXlaArguments(absl::string_view input_shapes_str,
200                          absl::string_view input_dtypes_str,
201                          absl::string_view arg_kinds_str,
202                          llvm::SmallVectorImpl<XlaArgument>& xla_arguments) {
203   xla_arguments.clear();
204   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
205   TF_RETURN_IF_ERROR(
206       tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector));
207   llvm::SmallVector<DataType, 4> dtypes_vector;
208   TF_RETURN_IF_ERROR(ParseDataTypes(input_dtypes_str, dtypes_vector));
209   llvm::SmallVector<XlaArgument::Kind, 4> arg_kinds_vector;
210   TF_RETURN_IF_ERROR(ParseArgumentKinds(arg_kinds_str, arg_kinds_vector));
211 
212   if (input_shapes_vector.empty())
213     input_shapes_vector.resize(dtypes_vector.size());
214 
215   if (arg_kinds_vector.empty())
216     arg_kinds_vector.resize(input_shapes_vector.size(),
217                             XlaArgument::Kind::kParameter);
218 
219   if (input_shapes_vector.size() != dtypes_vector.size() ||
220       input_shapes_vector.size() != arg_kinds_vector.size())
221     return errors::InvalidArgument(
222         "Input shapes, dtypes, and types/kinds must be of the same "
223         "length, but got ",
224         input_shapes_vector.size(), ", ", dtypes_vector.size(), ", and ",
225         arg_kinds_vector.size(), " respectively");
226 
227   xla_arguments.resize(input_shapes_vector.size());
228   for (const auto& arg_components :
229        llvm::zip(xla_arguments, input_shapes_vector, dtypes_vector,
230                  arg_kinds_vector)) {
231     XlaArgument& arg = std::get<0>(arg_components);
232     TensorShape shape;
233     auto input_shapes = std::get<1>(arg_components);
234     if (input_shapes.has_value()) {
235       TF_RETURN_IF_ERROR(
236           TensorShapeUtils::MakeShape(input_shapes.getValue(), &shape));
237     } else {
238       TF_RETURN_IF_ERROR(
239           TensorShapeUtils::MakeShape(static_cast<int*>(nullptr), 0, &shape));
240     }
241     arg.shape = std::move(shape);
242     arg.type = std::get<2>(arg_components);
243     arg.kind = std::get<3>(arg_components);
244   }
245 
246   return OkStatus();
247 }
248 
249 }  // anonymous namespace
250 
251 // Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
252 // to make this test comparable to other compile tests, the test implements
253 // the remaining parts of the conversion.
CompileMlirToXlaHloViaBuilder(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)254 Status CompileMlirToXlaHloViaBuilder(
255     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
256     llvm::StringRef device_type, XlaCompilationResult* compilation_result,
257     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
258         custom_legalization_passes) {
259   // This call to RefineShapes is redundant with the call in BuildHloFromTf.
260   // It's here so xla::Parameters that are created form block.getArguments will
261   // have the proper shapes.
262   TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
263 
264   mlir::func::FuncOp main = module_op.lookupSymbol<mlir::func::FuncOp>("main");
265   mlir::Block& block = main.getRegion().front();
266   xla::XlaBuilder builder("main");
267 
268   // Create xla_params.
269   std::vector<xla::XlaOp> xla_params;
270   for (mlir::BlockArgument& arg : block.getArguments()) {
271     auto num = arg.getArgNumber();
272     xla::Shape shape = xla::TypeToShape(arg.getType());
273     xla::XlaOp argop =
274         xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
275     xla_params.push_back(argop);
276   }
277 
278   std::vector<xla::XlaOp> returns(1);
279   TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
280                                     arg_shapes, device_type,
281                                     custom_legalization_passes));
282 
283   xla::XlaOp return_value;
284   if (returns.size() == 1)
285     return_value = returns[0];
286   else
287     return_value = xla::Tuple(&builder, returns);
288 
289   TF_ASSIGN_OR_RETURN(
290       xla::XlaComputation computation,
291       return_value.valid() ? builder.Build(return_value) : builder.Build());
292   auto hlo_module = computation.proto();
293   xla::HloProto hlo_proto;
294   hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
295 
296   compilation_result->computation = std::make_shared<xla::XlaComputation>();
297   xla::XlaComputation* xla_computation = compilation_result->computation.get();
298   *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
299 
300   XlaHelpers::ShapeRepresentationFn shape_representation_fn =
301       IdentityShapeRepresentationFn();
302   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
303       UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
304   return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
305                               /*use_resource_updates_for_aliases=*/false,
306                               shape_determination_fns, compilation_result);
307 }
308 
MlirTfToHloTextTranslateFunctionImpl(mlir::ModuleOp module_op,llvm::raw_ostream & output,bool via_builder)309 static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
310     mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
311   if (!module_op) return mlir::failure();
312 
313   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
314   auto args_status =
315       ParseArgumentShapes(mlir::StringRefToView(input_shapes), arg_shapes);
316   if (!args_status.ok()) {
317     LOG(ERROR) << args_status.ToString();
318     return mlir::failure();
319   }
320 
321   auto device_type = "XLA_CPU_JIT";
322   llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
323       custom_legalization_passes{};
324   XlaCompilationResult compilation_result;
325   auto compilation_status =
326       via_builder ? CompileMlirToXlaHloViaBuilder(
327                         module_op, arg_shapes, device_type, &compilation_result,
328                         custom_legalization_passes)
329                   : CompileMlirToXlaHlo(
330                         module_op, arg_shapes, device_type, emit_use_tuple_arg,
331                         /*analyse_graph=*/false, emit_return_tuple,
332                         /*use_resource_updates_for_aliases=*/true,
333                         /*shape_determination_fns=*/{}, &compilation_result,
334                         custom_legalization_passes);
335   if (!compilation_status.ok()) {
336     LOG(ERROR) << "TF/XLA compilation failed: "
337                << compilation_status.ToString();
338     return mlir::failure();
339   }
340 
341   return PrintHloModuleText(compilation_result, output);
342 }
343 
MlirTfGraphToHloTextTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)344 static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction(
345     mlir::ModuleOp module_op, llvm::raw_ostream& output) {
346   if (!module_op) return mlir::failure();
347 
348   llvm::SmallVector<XlaArgument, 4> xla_arguments;
349   auto args_status = ParseXlaArguments(
350       mlir::StringRefToView(input_shapes), mlir::StringRefToView(input_dtypes),
351       mlir::StringRefToView(input_types), xla_arguments);
352   if (!args_status.ok()) {
353     LOG(ERROR) << args_status.ToString();
354     return mlir::failure();
355   }
356 
357   XlaCompilationResult compilation_result;
358   auto compilation_status =
359       CompileGraphToXlaHlo(module_op, xla_arguments,
360                            /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
361                            /*analyse_graph=*/false, emit_return_tuple,
362                            /*shape_determination_fns=*/{}, &compilation_result,
363                            /*custom_legalization_passes=*/{});
364   if (!compilation_status.ok()) {
365     LOG(ERROR) << "TF/XLA compilation failed: "
366                << compilation_status.ToString();
367     return mlir::failure();
368   }
369 
370   return PrintHloModuleText(compilation_result, output);
371 }
372 
RegisterMlirInputDialects(mlir::DialectRegistry & registry)373 static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) {
374   registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
375                   mlir::TF::TensorFlowDialect>();
376 }
377 
RegisterGraphInputDialects(mlir::DialectRegistry & registry)378 static void RegisterGraphInputDialects(mlir::DialectRegistry& registry) {
379   RegisterMlirInputDialects(registry);
380   registry.insert<mlir::tf_executor::TensorFlowExecutorDialect>();
381 }
382 
383 static mlir::OwningOpRef<mlir::ModuleOp>
SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input,mlir::MLIRContext * context)384 SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input,
385                                               mlir::MLIRContext* context) {
386   mlir::Attribute attr = mlir::parseAttribute(input, context);
387   if (!attr || !attr.isa<mlir::StringAttr>()) {
388     LOG(ERROR) << "Input is not parsable as a MLIR StringAttr.";
389     return nullptr;
390   }
391   auto str_attr = attr.cast<mlir::StringAttr>();
392 
393   mlir::DialectRegistry registry;
394   RegisterMlirInputDialects(registry);
395   context->appendDialectRegistry(registry);
396   mlir::OwningOpRef<mlir::ModuleOp> module_ref;
397   auto status =
398       DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref);
399   if (!status.ok()) {
400     LOG(ERROR) << status.ToString();
401     return nullptr;
402   }
403 
404   return module_ref;
405 }
406 
MlirModuleToSerializedMlirStringAttrTranslate(mlir::ModuleOp module_op,llvm::raw_ostream & output)407 static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate(
408     mlir::ModuleOp module_op, llvm::raw_ostream& output) {
409   output << "\"";
410   std::string serialized_module = SerializeMlirModule(module_op);
411   llvm::printEscapedString(serialized_module, output);
412   output << "\"";
413   return mlir::success();
414 }
415 
MlirTfToHloTextTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)416 static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
417     mlir::ModuleOp module_op, llvm::raw_ostream& output) {
418   return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
419 }
420 
MlirTfToHloTextViaBuilderTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)421 static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
422     mlir::ModuleOp module_op, llvm::raw_ostream& output) {
423   return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
424 }
425 
426 }  // namespace tensorflow
427 
428 static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
429     "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
430     tensorflow::RegisterMlirInputDialects);
431 
432 static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
433     "mlir-tf-to-hlo-text-via-builder",
434     tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
435     tensorflow::RegisterMlirInputDialects);
436 
437 static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
438     "mlir-tf-graph-to-hlo-text",
439     tensorflow::MlirTfGraphToHloTextTranslateFunction,
440     tensorflow::RegisterGraphInputDialects);
441 
442 static mlir::TranslateToMLIRRegistration SerializedMlirStringAttrToMlirModule(
443     "mlir-tf-str-attr-to-mlir",
444     tensorflow::SerializedMlirStringAttrToMlirModuleTranslate);
445 
446 static mlir::TranslateFromMLIRRegistration MlirModuleToSerializedMlirStringAttr(
447     "mlir-tf-mlir-to-str-attr",
448     tensorflow::MlirModuleToSerializedMlirStringAttrTranslate,
449     tensorflow::RegisterMlirInputDialects);
450