xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/elemental_ir_emitter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace xla {
35 
36 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
37  public:
38   using HloToElementGeneratorMap =
39       absl::flat_hash_map<const HloInstruction*, llvm_ir::ElementGenerator>;
40 
ElementalIrEmitter(llvm::Module * module,llvm::IRBuilder<> * b)41   ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
42       : b_(b), module_(module) {}
43 
44   virtual ~ElementalIrEmitter() = default;
45 
46   // Returns a function to generate an element of the output of `hlo`, given a
47   // map of functions to generate elements of its operands.
48   llvm_ir::ElementGenerator MakeElementGenerator(
49       const HloInstruction* hlo,
50       const HloToElementGeneratorMap& operand_to_generator);
51 
b()52   llvm::IRBuilder<>* b() { return b_; }
53 
54   // builder() is for IrBuilderMixin.
builder()55   llvm::IRBuilder<>* builder() { return b_; }
56 
module()57   llvm::Module* module() { return module_; }
58 
59  protected:
GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)60   virtual llvm_ir::IrArray::Index GetSourceIndexOfBitcast(
61       const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) {
62     return index.SourceIndexOfBitcast(hlo->shape(), hlo->operand(0)->shape(),
63                                       b_);
64   }
65 
66   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
67                                                    llvm::Value* lhs_value,
68                                                    llvm::Value* rhs_value);
69 
70   virtual llvm::Value* EmitExtractReal(llvm::Value* value);
71   virtual llvm::Value* EmitExtractImag(llvm::Value* value);
72 
73  private:
74   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
75                                              llvm::Value* operand_value);
76 
77   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
78                                               llvm::Value* lhs_value,
79                                               llvm::Value* rhs_value);
80 
81   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
82                                                     llvm::Value* operand_value);
83 
84   virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
85                                                   llvm::Value* operand_value);
86 
87   virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
88                                                     llvm::Value* operand_value);
89 
90   llvm::Value* IsZero(llvm::Value* v);
91   llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
92   llvm::Value* GetZero(llvm::Type* type);
93   llvm::Value* GetOne(llvm::Type* type);
94   llvm::Value* GetIntSMin(llvm::Type* type);
95   llvm::Value* GetMinusOne(llvm::Type* type);
96 
97   llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
98                                  bool is_signed);
99   llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
100                                     bool is_signed);
101   llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs,
102                               bool is_signed);
103 
104   virtual StatusOr<llvm::Value*> EmitPredBinaryOp(const HloInstruction* op,
105                                                   llvm::Value* lhs_value,
106                                                   llvm::Value* rhs_value);
107 
108   virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
109                                                      llvm::Value* lhs_value,
110                                                      llvm::Value* rhs_value,
111                                                      bool is_signed);
112 
113   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
114                                                      llvm::Value* lhs_value,
115                                                      llvm::Value* rhs_value);
116 
117   virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
118                                     llvm::Value* rhs_value,
119                                     absl::string_view name);
120 
121   virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
122                                     llvm::Value* rhs_value,
123                                     absl::string_view name);
124 
125   llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
126                                bool is_signed);
127 
128   llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
129                                bool is_signed);
130 
131   virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
132                                            llvm::Value* lhs, llvm::Value* rhs,
133                                            absl::string_view name);
134 
135   virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
136                                          llvm::Value* value);
137 
138   virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type,
139                                           llvm::Value* value);
140 
141   virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type,
142                                           llvm::Value* value);
143 
144   virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type,
145                                            llvm::Value* value);
146 
147   virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
148                                            llvm::Value* value);
149 
150   virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
151                                          llvm::Value* value);
152 
153   virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
154                                          llvm::Value* value);
155 
156   virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
157                                          llvm::Value* value,
158                                          absl::string_view name);
159 
160   virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
161                                            llvm::Value* value);
162 
163   virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
164                                          llvm::Value* lhs, llvm::Value* rhs,
165                                          absl::string_view name);
166 
167   virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
168                                           llvm::Value* value);
169 
170   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
171                                                      llvm::Value* x);
172 
173   virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
174   EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value,
175                        bool return_sqrt);
176 
177   virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
178                                                 llvm::Value* operand_value);
179 
180   virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type,
181                                                     llvm::Value* operand_value);
182   virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs(
183       PrimitiveType prim_type, llvm::Value* operand_value);
184 
185   virtual StatusOr<llvm::Value*> EmitComplexAdd(const HloInstruction* op,
186                                                 llvm::Value* lhs_value,
187                                                 llvm::Value* rhs_value);
188 
189   virtual StatusOr<llvm::Value*> EmitComplexSubtract(const HloInstruction* op,
190                                                      llvm::Value* lhs_value,
191                                                      llvm::Value* rhs_value);
192 
193   virtual StatusOr<llvm::Value*> EmitComplexMultiply(const HloInstruction* op,
194                                                      llvm::Value* lhs_value,
195                                                      llvm::Value* rhs_value);
196 
197   virtual StatusOr<llvm::Value*> EmitComplexDivide(const HloInstruction* op,
198                                                    llvm::Value* lhs_value,
199                                                    llvm::Value* rhs_value);
200 
201   virtual StatusOr<llvm::Value*> EmitComplexLog(const HloInstruction* op,
202                                                 llvm::Value* operand_value);
203 
204   virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op,
205                                                  PrimitiveType prim_type,
206                                                  llvm::Value* operand_value);
207 
208   virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op,
209                                                  PrimitiveType prim_type,
210                                                  llvm::Value* operand_value);
211 
212   virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
213                                                   PrimitiveType prim_type,
214                                                   llvm::Value* operand_value);
215 
216   StatusOr<llvm::Value*> EmitAccumResult(
217       absl::Span<llvm::Value* const> accumulator_addrs,
218       llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic);
219 
220   // Composes a complex struct. imag may be nullptr for simple cast operations.
221   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
222                                   llvm::Value* imag);
223 
224   // Emit `accumulator + lhs * rhs` for the given primitive type.
225   llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
226                           llvm::Value* accumulator,
227                           xla::PrimitiveType primitive_type);
228 
229   // Identifier of the thread unique among all threads on the device
EmitThreadId()230   virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
231 
232   StatusOr<llvm::Value*> EmitElementalSelect(
233       const HloInstruction* hlo,
234       const HloToElementGeneratorMap& operand_to_generator,
235       const llvm_ir::IrArray::Index& index);
236 
237   StatusOr<llvm::Value*> EmitElementalClamp(
238       const HloInstruction* hlo,
239       const HloToElementGeneratorMap& operand_to_generator,
240       const llvm_ir::IrArray::Index& index);
241 
242   StatusOr<llvm::Value*> EmitElementalConcatenate(
243       const HloInstruction* hlo,
244       const HloToElementGeneratorMap& operand_to_generator,
245       const llvm_ir::IrArray::Index& target_index);
246 
247   StatusOr<llvm::Value*> EmitElementalDynamicSlice(
248       const HloInstruction* hlo,
249       const HloToElementGeneratorMap& operand_to_generator,
250       const llvm_ir::IrArray::Index& index);
251 
252   StatusOr<llvm::Value*> EmitElementalGather(
253       const HloInstruction* hlo,
254       const HloToElementGeneratorMap& operand_to_generator,
255       const llvm_ir::IrArray::Index& index);
256 
257   StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
258       const HloInstruction* hlo,
259       const HloToElementGeneratorMap& operand_to_generator,
260       const llvm_ir::IrArray::Index& index);
261 
262   StatusOr<llvm::Value*> EmitElementalPad(
263       const HloInstruction* hlo,
264       const HloToElementGeneratorMap& operand_to_generator,
265       const llvm_ir::IrArray::Index& padded_index);
266 
267   StatusOr<llvm::Value*> EmitElementalDot(
268       const HloInstruction* hlo,
269       const HloToElementGeneratorMap& operand_to_generator,
270       const llvm_ir::IrArray::Index& dot_result_index);
271 
272   virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
273       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
274       absl::string_view name, bool is_reducer) = 0;
275 
276   StatusOr<llvm::Value*> EmitElementalMap(
277       const HloMapInstruction* map_instr,
278       absl::Span<llvm::Value* const> elemental_operands);
279 
280   StatusOr<llvm::Value*> EmitElementalReduceWindow(
281       const HloReduceWindowInstruction* reduce_window,
282       std::vector<llvm_ir::ElementGenerator> input_generators,
283       std::vector<llvm_ir::ElementGenerator> initial_value_generators,
284       const llvm_ir::IrArray::Index& index);
285 
286   StatusOr<llvm::Value*> EmitElementalReduce(
287       const HloReduceInstruction* reduce,
288       std::vector<llvm_ir::ElementGenerator> input_generators,
289       std::vector<llvm_ir::ElementGenerator> initial_value_generators,
290       const llvm_ir::IrArray::Index& index);
291 
292   virtual StatusOr<llvm::Value*> EmitConvolution(
293       const HloInstruction* hlo,
294       const HloToElementGeneratorMap& operand_to_generator,
295       const llvm_ir::IrArray::Index& index);
296 
297   // Computes the complex power function, returns (a + i*b)^(c + i*d).
298   StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
299                                           llvm::Value* a, llvm::Value* b,
300                                           llvm::Value* c, llvm::Value* d);
301 
302   // Evaluates a polynomial using Horner's method.
303   StatusOr<llvm::Value*> EvaluatePolynomial(
304       llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
305 
306   virtual bool fast_min_max() = 0;
307 
308   llvm::IRBuilder<>* const b_;
309 
310   llvm::Module* module_;
311 };
312 
313 }  // namespace xla
314 
315 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
316