xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
16 #define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Location.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36 
37 namespace xla {
38 
39 // Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder
40 // interface.
41 //
42 // Requires that all XlaOp arguments are either returned by any of the builder
43 // method or constructed using MakeXlaOp method in this builder.
44 //
45 // TODO(hinsu): Support more ops and utility functions to set special attributes
46 // like OpMetadata and Sharding.
47 class MlirHloBuilder : public XlaBuilder {
48  public:
49   // Constructs builder for the given function. New operations are added to the
50   // beginning of the function, if it is non empty and has a block.
MlirHloBuilder(mlir::func::FuncOp func)51   explicit MlirHloBuilder(mlir::func::FuncOp func)
52       : XlaBuilder(func.getName().str()),
53         builder_(&func.getBody()),
54         loc_(builder_.getUnknownLoc()),
55         build_functions_(false) {}
56 
57   // TODO(hinsu): Add a constructor to build a new MLIR function from scratch
58   // and override Build methods.
59 
MlirHloBuilder(std::string name,mlir::OpBuilder builder,mlir::Location loc,bool build_functions)60   MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc,
61                  bool build_functions)
62       : XlaBuilder(name),
63         builder_(builder),
64         loc_(loc),
65         build_functions_(build_functions) {}
66 
67   MlirHloBuilder(const MlirHloBuilder&) = delete;
68   MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
69 
70   ~MlirHloBuilder() override;
71 
72   // Wraps the given MLIR value under an XlaOp instance. Note that all HLO
73   // operations returns exactly one result therefore each op has an XlaOp
74   // wrapping result of the op.
75   //
76   // Returns an error if the HLO dialect doesn't support type of the given
77   // value.
78   StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
79 
80   // Returns value corresponding to the given op.
81   //
82   // Requires that the op was created by this builder.
GetValue(XlaOp op)83   mlir::Value GetValue(XlaOp op) {
84     void* ptr = reinterpret_cast<void*>(op.handle());
85     return mlir::Value::getFromOpaquePointer(ptr);
86   }
87 
88   // Returns MLIR values corresponding to the given XLA ops.
89   //
90   // Requires that the ops were created by this builder.
GetValues(absl::Span<const XlaOp> ops)91   std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
92     std::vector<mlir::Value> values;
93     for (auto xla_op : ops) {
94       values.push_back(GetValue(xla_op));
95     }
96     return values;
97   }
98 
99   // Sets location for newly built ops, until reset.
SetLocation(mlir::Location loc)100   void SetLocation(mlir::Location loc) { loc_ = loc; }
101 
102   // Update insertion point so that newly built ops are inserted before the
103   // given op in order, until reset.
setInsertionPoint(mlir::Operation * op)104   void setInsertionPoint(mlir::Operation* op) {
105     builder_.setInsertionPoint(op);
106   }
107 
108   // Returns the shape of the given op.
109   StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
110 
111   // Creates the given op at the current location.
112   template <typename OpTy, typename... Args>
create(Args &&...args)113   OpTy create(Args&&... args) {
114     return builder_.create<OpTy>(loc_, std::forward<Args>(args)...);
115   }
116 
117  private:
118   XlaOp ConstantLiteral(const LiteralSlice& literal) override;
119 
120   StatusOr<XlaOp> ConvGeneralDilatedInternal(
121       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
122       absl::Span<const int64_t> window_strides,
123       absl::Span<const std::pair<int64_t, int64_t>> padding,
124       absl::Span<const int64_t> lhs_dilation,
125       absl::Span<const int64_t> rhs_dilation,
126       const ConvolutionDimensionNumbers& dimension_numbers,
127       int64_t feature_group_count, int64_t batch_group_count,
128       const PrecisionConfig* precision_config) override;
129 
130   StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
131                               FftType fft_type,
132                               absl::Span<const int64_t> fft_length) override;
133 
134   StatusOr<XlaOp> TriangularSolveInternal(
135       const Shape& shape, XlaOp a, XlaOp b,
136       TriangularSolveOptions options) override;
137 
138   StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
139                                    bool lower) override;
140 
141   StatusOr<XlaOp> CustomCallInternal(
142       const std::string& call_target_name, absl::Span<const XlaOp> operands,
143       const XlaComputation* computation, const Shape& shape,
144       const std::string& opaque,
145       std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
146       bool has_side_effect,
147       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
148           output_operand_aliasing,
149       const Literal* literal, std::optional<Window> window,
150       std::optional<ConvolutionDimensionNumbers> dnums,
151       CustomCallSchedule schedule, CustomCallApiVersion api_version) override;
152 
153   StatusOr<XlaOp> ReduceInternal(
154       const Shape& shape, absl::Span<const XlaOp> all_operands,
155       const XlaComputation& computation,
156       absl::Span<const int64_t> dimensions_to_reduce) override;
157 
158   StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
159                                        XlaOp init_value,
160                                        const XlaComputation& computation,
161                                        Window window) override;
162 
163   XlaOp Iota(const Shape& shape, int64_t iota_dimension) override;
164 
165   StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
166                                              XlaOp operand) override;
167 
168   StatusOr<XlaOp> TransposeInternal(
169       const Shape& shape, XlaOp operand,
170       absl::Span<const int64_t> permutation) override;
171 
172   StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
173                               absl::Span<const int64_t> dimensions) override;
174 
175   StatusOr<XlaOp> SortInternal(const Shape& shape,
176                                absl::Span<const XlaOp> operands,
177                                const XlaComputation& comparator,
178                                int64_t dimension, bool is_stable) override;
179 
180   StatusOr<XlaOp> WhileInternal(const Shape& shape,
181                                 const XlaComputation& condition,
182                                 const XlaComputation& body,
183                                 XlaOp init) override;
184 
185   StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand,
186                                           const int exponent_bits,
187                                           const int mantissa_bits) override;
188 
189   StatusOr<XlaOp> GatherInternal(
190       const Shape& shape, XlaOp input, XlaOp start_indices,
191       const GatherDimensionNumbers& dimension_numbers,
192       absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) override;
193 
194   StatusOr<XlaOp> ScatterInternal(
195       const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
196       absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
197       const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
198       bool unique_indices) override;
199 
200   StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape, XlaOp operand,
201                                            XlaOp val,
202                                            int64_t dimension) override;
203 
204   StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
205                                 absl::Span<const XlaOp> parameters,
206                                 const Shape& shape) override;
207   StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape,
208                                           RandomAlgorithm algorithm,
209                                           XlaOp initial_state) override;
210 
211   StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
212                                   int64_t inferred_dimension) override;
213 
214   StatusOr<XlaOp> DotGeneralInternal(
215       const Shape& shape, XlaOp lhs, XlaOp rhs,
216       const DotDimensionNumbers& dimension_number,
217       const PrecisionConfig* precision_config) override;
218 
219   StatusOr<XlaOp> InDimBroadcast(
220       const Shape& shape, XlaOp operand,
221       absl::Span<const int64_t> broadcast_dimensions) override;
222 
223   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
224                                  absl::Span<const XlaOp> operands) override;
225 
226   StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
227                           ComparisonDirection direction,
228                           Comparison::Type type) override;
229 
230   XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
231                             XlaOp rhs) override;
232 
233   StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
234                                  absl::Span<const XlaOp> operands) override;
235 
236   XlaOp CreateToken() override;
237 
238   StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
239                                           XlaOp token,
240                                           const std::string& config) override;
241   StatusOr<XlaOp> OutfeedWithTokenInternal(
242       XlaOp operand, XlaOp token, const Shape& shape_with_layout,
243       const std::string& outfeed_config) override;
244 
245   StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
246                                       absl::Span<const XlaOp> operands,
247                                       int64_t dimension) override;
248 
249   StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
250                                           int64_t index) override;
251 
252   StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
253                                 absl::Span<const int64_t> start_indices,
254                                 absl::Span<const int64_t> limit_indices,
255                                 absl::Span<const int64_t> strides) override;
256 
257   StatusOr<XlaOp> DynamicSliceInternal(
258       const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
259       absl::Span<const int64_t> slice_sizes) override;
260 
261   StatusOr<XlaOp> DynamicUpdateSliceInternal(
262       const Shape& shape, XlaOp operand, XlaOp update,
263       absl::Span<const XlaOp> start_indices) override;
264 
265   StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
266                               XlaOp padding_value,
267                               const PaddingConfig& padding_config) override;
268 
269   StatusOr<XlaOp> TupleInternal(const Shape& shape,
270                                 absl::Span<const XlaOp> elements) override;
271 
272   // Creates HLO dialect op and returns the result as an XlaOp.
273   StatusOr<XlaOp> CreateOp(
274       const std::string& op_name, const Shape& shape,
275       llvm::ArrayRef<XlaOp> operands,
276       llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
277 
278   Status ImportComputation(const HloModuleProto& computation,
279                            mlir::Region* region,
280                            bool flatten_region_arg_tuple = false);
281 
282   Status ImportComputation(const HloModuleProto& computation,
283                            mlir::ModuleOp module);
284 
285   mlir::OpBuilder builder_;
286   mlir::Location loc_;
287   bool build_functions_;
288 
289   absl::flat_hash_map<int64_t, std::unique_ptr<Shape>> handle_to_shape_;
290 };
291 
292 }  // namespace xla
293 
294 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
295