1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <type_traits>
18 #include <vector>
19
20 #include "absl/strings/str_join.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/AsmParser/AsmParser.h" // from @llvm-project
30 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/Dialect.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
40 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
43 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
44 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
45 #include "tensorflow/compiler/tf2xla/xla_argument.h"
46 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
47 #include "tensorflow/compiler/xla/service/hlo.pb.h"
48 #include "tensorflow/compiler/xla/service/hlo_module.h"
49 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/status.h"
55
56 namespace {
57
58 // NOLINTNEXTLINE
59 llvm::cl::opt<std::string> input_types(
60 "tf-xla-input-types",
61 llvm::cl::desc("XLA input argument types (kinds), separated by ','. "
62 "Supported types include ['parameter', 'resource']. If "
63 "empty, all arguments are assumed to be parameters."),
64 llvm::cl::init(""));
65 // NOLINTNEXTLINE
66 llvm::cl::opt<bool> emit_use_tuple_arg(
67 "tf-xla-emit-use-tuple-args",
68 llvm::cl::desc(
69 "Emit HLO modules using tuples as args for the entry computation"),
70 llvm::cl::init(false));
71 // NOLINTNEXTLINE
72 llvm::cl::opt<bool> emit_return_tuple(
73 "tf-xla-emit-return-tuple",
74 llvm::cl::desc("Emit HLO modules with entry computations returning tuple"),
75 llvm::cl::init(false));
76 } // namespace
77
78 namespace tensorflow {
79
80 namespace {
81
PrintHloModuleText(const XlaCompilationResult & compilation_result,llvm::raw_ostream & output)82 mlir::LogicalResult PrintHloModuleText(
83 const XlaCompilationResult& compilation_result, llvm::raw_ostream& output) {
84 const xla::HloModuleConfig module_config(
85 compilation_result.computation->GetProgramShape().ValueOrDie());
86 auto status_or_hlo_module = xla::HloModule::CreateFromProto(
87 compilation_result.computation->proto(), module_config);
88 if (!status_or_hlo_module.ok()) {
89 LOG(ERROR) << "Conversion to HLO module failed: "
90 << status_or_hlo_module.status().ToString();
91 return mlir::failure();
92 }
93
94 xla::HloModule* hlo_module = status_or_hlo_module.ValueOrDie().get();
95
96 output << hlo_module->ToString();
97
98 if (!compilation_result.input_mapping.empty())
99 output << "// InputMapping {"
100 << absl::StrJoin(compilation_result.input_mapping, ", ") << "}\n";
101
102 for (const auto& xla_input_shape : compilation_result.xla_input_shapes)
103 output << "// XlaInputShape " << xla_input_shape.ToString() << '\n';
104
105 output << "// XlaOutputShape "
106 << compilation_result.xla_output_shape.ToString() << '\n';
107
108 for (const auto& xla_output_description : compilation_result.outputs) {
109 output << "// XlaOutputDescription type="
110 << DataTypeString(xla_output_description.type) << " shape=("
111 << absl::StrJoin(xla_output_description.shape.dim_sizes(), ", ")
112 << ')';
113 if (xla_output_description.input_index >= 0)
114 output << " input_index=" << xla_output_description.input_index;
115 if (xla_output_description.is_constant) output << " constant";
116 if (xla_output_description.is_tensor_list) output << " tensor_list";
117 output << '\n';
118 }
119
120 for (const auto& resource_update : compilation_result.resource_updates) {
121 output << "// ResourceUpdate input_index=" << resource_update.input_index
122 << " type=" << DataTypeString(resource_update.type) << " shape=("
123 << absl::StrJoin(resource_update.shape.dim_sizes(), " ") << ')';
124 if (resource_update.modified) output << " modified";
125 output << '\n';
126 }
127
128 return mlir::success();
129 }
130
ParseArgumentShapes(absl::string_view input_shapes_str,llvm::SmallVectorImpl<TensorOrResourceShape> & arg_shapes)131 Status ParseArgumentShapes(
132 absl::string_view input_shapes_str,
133 llvm::SmallVectorImpl<TensorOrResourceShape>& arg_shapes) {
134 arg_shapes.clear();
135 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
136 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector));
137 arg_shapes.resize(input_shapes_vector.size());
138 for (const auto& shape : llvm::enumerate(input_shapes_vector)) {
139 if (!shape.value().has_value()) {
140 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
141 static_cast<int*>(nullptr), 0, &arg_shapes[shape.index()].shape));
142 continue;
143 }
144 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
145 shape.value().getValue(), &arg_shapes[shape.index()].shape));
146 }
147
148 return OkStatus();
149 }
150
ParseDataTypes(absl::string_view data_types_str,llvm::SmallVectorImpl<DataType> & data_types)151 Status ParseDataTypes(absl::string_view data_types_str,
152 llvm::SmallVectorImpl<DataType>& data_types) {
153 data_types.clear();
154 std::vector<std::string> input_dtypes_vector;
155 TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types_str, input_dtypes_vector));
156 data_types.resize(input_dtypes_vector.size(), DT_INVALID);
157 for (auto data_type : llvm::enumerate(input_dtypes_vector)) {
158 if (!DataType_Parse(data_type.value(), &data_types[data_type.index()]))
159 return errors::InvalidArgument("Invalid dtype at index ",
160 data_type.index(), ": ",
161 data_type.value());
162 const auto& resolved_dtype = data_types[data_type.index()];
163 if (resolved_dtype == DT_INVALID || resolved_dtype == DT_STRING ||
164 resolved_dtype == DT_RESOURCE || resolved_dtype == DT_VARIANT ||
165 IsRefType(resolved_dtype))
166 return errors::InvalidArgument("Unsupported dtype at index ",
167 data_type.index(), ": ",
168 data_type.value());
169 }
170
171 return OkStatus();
172 }
173
ParseArgumentKinds(absl::string_view input_types_str,llvm::SmallVectorImpl<XlaArgument::Kind> & argument_kinds)174 Status ParseArgumentKinds(
175 absl::string_view input_types_str,
176 llvm::SmallVectorImpl<XlaArgument::Kind>& argument_kinds) {
177 argument_kinds.clear();
178 if (input_types_str.empty()) return OkStatus();
179
180 std::vector<absl::string_view> argument_kind_strs =
181 absl::StrSplit(input_types_str, ',');
182 argument_kinds.reserve(argument_kind_strs.size());
183 for (const auto& argument_kind_str : llvm::enumerate(argument_kind_strs)) {
184 const auto& value = argument_kind_str.value();
185 if (value == "parameter") {
186 argument_kinds.push_back(XlaArgument::Kind::kParameter);
187 } else if (value == "resource") {
188 argument_kinds.push_back(XlaArgument::Kind::kResource);
189 } else {
190 return errors::InvalidArgument(
191 "Unsupported TF/XLA argument kind at index ",
192 argument_kind_str.index(), ": ", value);
193 }
194 }
195
196 return OkStatus();
197 }
198
ParseXlaArguments(absl::string_view input_shapes_str,absl::string_view input_dtypes_str,absl::string_view arg_kinds_str,llvm::SmallVectorImpl<XlaArgument> & xla_arguments)199 Status ParseXlaArguments(absl::string_view input_shapes_str,
200 absl::string_view input_dtypes_str,
201 absl::string_view arg_kinds_str,
202 llvm::SmallVectorImpl<XlaArgument>& xla_arguments) {
203 xla_arguments.clear();
204 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
205 TF_RETURN_IF_ERROR(
206 tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector));
207 llvm::SmallVector<DataType, 4> dtypes_vector;
208 TF_RETURN_IF_ERROR(ParseDataTypes(input_dtypes_str, dtypes_vector));
209 llvm::SmallVector<XlaArgument::Kind, 4> arg_kinds_vector;
210 TF_RETURN_IF_ERROR(ParseArgumentKinds(arg_kinds_str, arg_kinds_vector));
211
212 if (input_shapes_vector.empty())
213 input_shapes_vector.resize(dtypes_vector.size());
214
215 if (arg_kinds_vector.empty())
216 arg_kinds_vector.resize(input_shapes_vector.size(),
217 XlaArgument::Kind::kParameter);
218
219 if (input_shapes_vector.size() != dtypes_vector.size() ||
220 input_shapes_vector.size() != arg_kinds_vector.size())
221 return errors::InvalidArgument(
222 "Input shapes, dtypes, and types/kinds must be of the same "
223 "length, but got ",
224 input_shapes_vector.size(), ", ", dtypes_vector.size(), ", and ",
225 arg_kinds_vector.size(), " respectively");
226
227 xla_arguments.resize(input_shapes_vector.size());
228 for (const auto& arg_components :
229 llvm::zip(xla_arguments, input_shapes_vector, dtypes_vector,
230 arg_kinds_vector)) {
231 XlaArgument& arg = std::get<0>(arg_components);
232 TensorShape shape;
233 auto input_shapes = std::get<1>(arg_components);
234 if (input_shapes.has_value()) {
235 TF_RETURN_IF_ERROR(
236 TensorShapeUtils::MakeShape(input_shapes.getValue(), &shape));
237 } else {
238 TF_RETURN_IF_ERROR(
239 TensorShapeUtils::MakeShape(static_cast<int*>(nullptr), 0, &shape));
240 }
241 arg.shape = std::move(shape);
242 arg.type = std::get<2>(arg_components);
243 arg.kind = std::get<3>(arg_components);
244 }
245
246 return OkStatus();
247 }
248
249 } // anonymous namespace
250
251 // Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
252 // to make this test comparable to other compile tests, the test implements
253 // the remaining parts of the conversion.
CompileMlirToXlaHloViaBuilder(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)254 Status CompileMlirToXlaHloViaBuilder(
255 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
256 llvm::StringRef device_type, XlaCompilationResult* compilation_result,
257 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
258 custom_legalization_passes) {
259 // This call to RefineShapes is redundant with the call in BuildHloFromTf.
260 // It's here so xla::Parameters that are created form block.getArguments will
261 // have the proper shapes.
262 TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
263
264 mlir::func::FuncOp main = module_op.lookupSymbol<mlir::func::FuncOp>("main");
265 mlir::Block& block = main.getRegion().front();
266 xla::XlaBuilder builder("main");
267
268 // Create xla_params.
269 std::vector<xla::XlaOp> xla_params;
270 for (mlir::BlockArgument& arg : block.getArguments()) {
271 auto num = arg.getArgNumber();
272 xla::Shape shape = xla::TypeToShape(arg.getType());
273 xla::XlaOp argop =
274 xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
275 xla_params.push_back(argop);
276 }
277
278 std::vector<xla::XlaOp> returns(1);
279 TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
280 arg_shapes, device_type,
281 custom_legalization_passes));
282
283 xla::XlaOp return_value;
284 if (returns.size() == 1)
285 return_value = returns[0];
286 else
287 return_value = xla::Tuple(&builder, returns);
288
289 TF_ASSIGN_OR_RETURN(
290 xla::XlaComputation computation,
291 return_value.valid() ? builder.Build(return_value) : builder.Build());
292 auto hlo_module = computation.proto();
293 xla::HloProto hlo_proto;
294 hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
295
296 compilation_result->computation = std::make_shared<xla::XlaComputation>();
297 xla::XlaComputation* xla_computation = compilation_result->computation.get();
298 *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
299
300 XlaHelpers::ShapeRepresentationFn shape_representation_fn =
301 IdentityShapeRepresentationFn();
302 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
303 UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
304 return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
305 /*use_resource_updates_for_aliases=*/false,
306 shape_determination_fns, compilation_result);
307 }
308
MlirTfToHloTextTranslateFunctionImpl(mlir::ModuleOp module_op,llvm::raw_ostream & output,bool via_builder)309 static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
310 mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
311 if (!module_op) return mlir::failure();
312
313 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
314 auto args_status =
315 ParseArgumentShapes(mlir::StringRefToView(input_shapes), arg_shapes);
316 if (!args_status.ok()) {
317 LOG(ERROR) << args_status.ToString();
318 return mlir::failure();
319 }
320
321 auto device_type = "XLA_CPU_JIT";
322 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
323 custom_legalization_passes{};
324 XlaCompilationResult compilation_result;
325 auto compilation_status =
326 via_builder ? CompileMlirToXlaHloViaBuilder(
327 module_op, arg_shapes, device_type, &compilation_result,
328 custom_legalization_passes)
329 : CompileMlirToXlaHlo(
330 module_op, arg_shapes, device_type, emit_use_tuple_arg,
331 /*analyse_graph=*/false, emit_return_tuple,
332 /*use_resource_updates_for_aliases=*/true,
333 /*shape_determination_fns=*/{}, &compilation_result,
334 custom_legalization_passes);
335 if (!compilation_status.ok()) {
336 LOG(ERROR) << "TF/XLA compilation failed: "
337 << compilation_status.ToString();
338 return mlir::failure();
339 }
340
341 return PrintHloModuleText(compilation_result, output);
342 }
343
MlirTfGraphToHloTextTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)344 static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction(
345 mlir::ModuleOp module_op, llvm::raw_ostream& output) {
346 if (!module_op) return mlir::failure();
347
348 llvm::SmallVector<XlaArgument, 4> xla_arguments;
349 auto args_status = ParseXlaArguments(
350 mlir::StringRefToView(input_shapes), mlir::StringRefToView(input_dtypes),
351 mlir::StringRefToView(input_types), xla_arguments);
352 if (!args_status.ok()) {
353 LOG(ERROR) << args_status.ToString();
354 return mlir::failure();
355 }
356
357 XlaCompilationResult compilation_result;
358 auto compilation_status =
359 CompileGraphToXlaHlo(module_op, xla_arguments,
360 /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
361 /*analyse_graph=*/false, emit_return_tuple,
362 /*shape_determination_fns=*/{}, &compilation_result,
363 /*custom_legalization_passes=*/{});
364 if (!compilation_status.ok()) {
365 LOG(ERROR) << "TF/XLA compilation failed: "
366 << compilation_status.ToString();
367 return mlir::failure();
368 }
369
370 return PrintHloModuleText(compilation_result, output);
371 }
372
RegisterMlirInputDialects(mlir::DialectRegistry & registry)373 static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) {
374 registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
375 mlir::TF::TensorFlowDialect>();
376 }
377
RegisterGraphInputDialects(mlir::DialectRegistry & registry)378 static void RegisterGraphInputDialects(mlir::DialectRegistry& registry) {
379 RegisterMlirInputDialects(registry);
380 registry.insert<mlir::tf_executor::TensorFlowExecutorDialect>();
381 }
382
383 static mlir::OwningOpRef<mlir::ModuleOp>
SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input,mlir::MLIRContext * context)384 SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input,
385 mlir::MLIRContext* context) {
386 mlir::Attribute attr = mlir::parseAttribute(input, context);
387 if (!attr || !attr.isa<mlir::StringAttr>()) {
388 LOG(ERROR) << "Input is not parsable as a MLIR StringAttr.";
389 return nullptr;
390 }
391 auto str_attr = attr.cast<mlir::StringAttr>();
392
393 mlir::DialectRegistry registry;
394 RegisterMlirInputDialects(registry);
395 context->appendDialectRegistry(registry);
396 mlir::OwningOpRef<mlir::ModuleOp> module_ref;
397 auto status =
398 DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref);
399 if (!status.ok()) {
400 LOG(ERROR) << status.ToString();
401 return nullptr;
402 }
403
404 return module_ref;
405 }
406
MlirModuleToSerializedMlirStringAttrTranslate(mlir::ModuleOp module_op,llvm::raw_ostream & output)407 static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate(
408 mlir::ModuleOp module_op, llvm::raw_ostream& output) {
409 output << "\"";
410 std::string serialized_module = SerializeMlirModule(module_op);
411 llvm::printEscapedString(serialized_module, output);
412 output << "\"";
413 return mlir::success();
414 }
415
MlirTfToHloTextTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)416 static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
417 mlir::ModuleOp module_op, llvm::raw_ostream& output) {
418 return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
419 }
420
MlirTfToHloTextViaBuilderTranslateFunction(mlir::ModuleOp module_op,llvm::raw_ostream & output)421 static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
422 mlir::ModuleOp module_op, llvm::raw_ostream& output) {
423 return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
424 }
425
426 } // namespace tensorflow
427
428 static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
429 "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
430 tensorflow::RegisterMlirInputDialects);
431
432 static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
433 "mlir-tf-to-hlo-text-via-builder",
434 tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
435 tensorflow::RegisterMlirInputDialects);
436
437 static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
438 "mlir-tf-graph-to-hlo-text",
439 tensorflow::MlirTfGraphToHloTextTranslateFunction,
440 tensorflow::RegisterGraphInputDialects);
441
442 static mlir::TranslateToMLIRRegistration SerializedMlirStringAttrToMlirModule(
443 "mlir-tf-str-attr-to-mlir",
444 tensorflow::SerializedMlirStringAttrToMlirModuleTranslate);
445
446 static mlir::TranslateFromMLIRRegistration MlirModuleToSerializedMlirStringAttr(
447 "mlir-tf-mlir-to-str-attr",
448 tensorflow::MlirModuleToSerializedMlirStringAttrTranslate,
449 tensorflow::RegisterMlirInputDialects);
450