xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
18 
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
22 #include "tensorflow/compiler/mlir/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 
26 namespace mlir {
27 
28 struct MlirToHloConversionOptions {
29   // Best-effort propagation of the layouts. These layouts serve as performance
30   // hints to the backend.
31   //
32   // Note that non-array shapes are not carrying layouts, and users have to
33   // figure out the proper layouts of them through context. This is one of the
34   // reasons why the attribute-based solution is temporary.
35   //
36   // TODO(timshen): Investigate the necessity of having layouts in MHLO.
37   bool propagate_layouts = false;
38 
39   // Propagate the source and result layouts from mhlo bitcast op into the
40   // backend config for the bitcast. This is required for XLA:GPU backend to
41   // use elemental IR emitters for fused bitcasts without propagating layouts.
42   bool propagate_bitcast_layouts_to_backend_config = false;
43 
44   // Legalize names to be compatible with TensorFlow.
45   bool legalize_node_names = true;
46 
47   LayoutPreferenceFn layout_preference_fn;
48   ShapeRepresentationFn shape_representation_fn;
49 };
50 
51 // Converts a MLIR module in HLO dialect into a HloModuleProto. If
52 // use_tuple_args is set, then the entry computations's arguments are converted
53 // to a tuple and passed as a single parameter.
54 // Similarly, if return tuple is true, then the entry function's return values
55 // are converted to a tuple even when there is only a single return value.
56 // Multiple return values are always converted to a tuple and returned as a
57 // single value.
58 Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
59                            bool use_tuple_args, bool return_tuple,
60                            MlirToHloConversionOptions options = {});
61 
62 // Transforms a Block into HLO, where the HLO is represented as calls into an
63 // XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
64 // xla_params are inputs to block. returns are the returned XlaOps.
65 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
66                            llvm::ArrayRef<xla::XlaOp> xla_params,
67                            std::vector<xla::XlaOp>& returns,
68                            MlirToHloConversionOptions options = {});
69 
70 // Converts a region to a computation. It returns a standalone module that
71 // contains the converted region as the entry computation.
72 Status ConvertRegionToComputation(mlir::Region* region,
73                                   ::xla::XlaComputation* func,
74                                   MlirToHloConversionOptions options = {});
75 
76 // Creates XlaOp equivalent of a given MLIR operation using the operand info
77 // from `value_lowering` map.
78 llvm::Optional<::xla::XlaOp> CreateXlaOperator(
79     mlir::Operation* op,
80     llvm::DenseMap<mlir::Value, ::xla::XlaOp>* value_lowering);
81 
82 }  // namespace mlir
83 
84 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
85