xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
18 
19 #include "llvm/IR/IRBuilder.h"
20 
21 namespace xla {
22 
23 // Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
24 // into a class.  Intended to be used as a CRTP base class, like:
25 //
26 //  class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
27 //    llvm::IRBuilder<>* builder() { return builder_; }
28 //
29 //    void EmitFoo(HloInstruction* foo) {
30 //      Add(Mul(...), FPToUI(...));
31 //    }
32 //  };
33 
34 template <typename Derived>
35 class IrBuilderMixin {
36  protected:
37   template <class... Args>
Add(Args &&...args)38   llvm::Value* Add(Args&&... args) {
39     return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
40   }
41 
42   template <class... Args>
AlignedLoad(Args &&...args)43   llvm::LoadInst* AlignedLoad(Args&&... args) {
44     return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
45   }
46 
47   template <class... Args>
AlignedStore(Args &&...args)48   llvm::StoreInst* AlignedStore(Args&&... args) {
49     return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
50   }
51 
52   template <class... Args>
Alloca(Args &&...args)53   llvm::AllocaInst* Alloca(Args&&... args) {
54     return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
55   }
56 
57   template <class... Args>
And(Args &&...args)58   llvm::Value* And(Args&&... args) {
59     return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
60   }
61 
62   template <class... Args>
AtomicCmpXchg(Args &&...args)63   llvm::Value* AtomicCmpXchg(Args&&... args) {
64     return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
65   }
66 
67   template <class... Args>
AtomicRMW(Args &&...args)68   llvm::Value* AtomicRMW(Args&&... args) {
69     return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
70   }
71 
72   template <class... Args>
BitCast(Args &&...args)73   llvm::Value* BitCast(Args&&... args) {
74     return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
75   }
76 
77   template <class... Args>
Br(Args &&...args)78   llvm::Value* Br(Args&&... args) {
79     return mixin_builder()->CreateBr(std::forward<Args>(args)...);
80   }
81 
82   llvm::CallInst* Call(llvm::FunctionCallee func_callee,
83                        llvm::ArrayRef<llvm::Value*> args = llvm::None,
84                        const llvm::Twine& name = "",
85                        llvm::MDNode* fp_math_tag = nullptr) {
86     return mixin_builder()->CreateCall(func_callee, args, name, fp_math_tag);
87   }
88 
89   llvm::CallInst* Call(llvm::FunctionType* func_type, llvm::Value* callee,
90                        llvm::ArrayRef<llvm::Value*> args = llvm::None,
91                        const llvm::Twine& name = "",
92                        llvm::MDNode* fp_math_tag = nullptr) {
93     return mixin_builder()->CreateCall(func_type, callee, args, name,
94                                        fp_math_tag);
95   }
96 
97   template <class... Args>
CondBr(Args &&...args)98   llvm::BranchInst* CondBr(Args&&... args) {
99     return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
100   }
101 
102   template <class... Args>
ConstInBoundsGEP1_32(Args &&...args)103   llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
104     return mixin_builder()->CreateConstInBoundsGEP1_32(
105         std::forward<Args>(args)...);
106   }
107 
108   template <class... Args>
FAdd(Args &&...args)109   llvm::Value* FAdd(Args&&... args) {
110     return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
111   }
112 
113   template <class... Args>
FMul(Args &&...args)114   llvm::Value* FMul(Args&&... args) {
115     return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
116   }
117 
118   llvm::Value* GEP(llvm::Type* type, llvm::Value* ptr,
119                    llvm::ArrayRef<llvm::Value*> idx_list,
120                    const llvm::Twine& name = "") {
121     return mixin_builder()->CreateGEP(type, ptr, idx_list, name);
122   }
123 
124   template <class... Args>
ICmpEQ(Args &&...args)125   llvm::Value* ICmpEQ(Args&&... args) {
126     return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
127   }
128 
129   template <class... Args>
ICmpNE(Args &&...args)130   llvm::Value* ICmpNE(Args&&... args) {
131     return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
132   }
133 
134   template <class... Args>
ICmpULE(Args &&...args)135   llvm::Value* ICmpULE(Args&&... args) {
136     return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
137   }
138 
139   template <class... Args>
ICmpULT(Args &&...args)140   llvm::Value* ICmpULT(Args&&... args) {
141     return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
142   }
143 
144   llvm::Value* InBoundsGEP(llvm::Type* type, llvm::Value* ptr,
145                            llvm::ArrayRef<llvm::Value*> idx_list,
146                            const llvm::Twine& name = "") {
147     return mixin_builder()->CreateInBoundsGEP(type, ptr, idx_list, name);
148   }
149 
150   llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
151                             const llvm::Twine& name = "") {
152     return mixin_builder()->CreateExtractValue(agg, idxs, name);
153   }
154 
155   llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
156                            llvm::ArrayRef<unsigned> idxs,
157                            const llvm::Twine& name = "") {
158     return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
159   }
160 
161   template <class... Args>
IntToPtr(Args &&...args)162   llvm::Value* IntToPtr(Args&&... args) {
163     return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
164   }
165 
166   template <class... Args>
Load(Args &&...args)167   llvm::LoadInst* Load(Args&&... args) {
168     return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
169   }
170 
171   template <class... Args>
MemCpy(Args &&...args)172   llvm::CallInst* MemCpy(Args&&... args) {
173     return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
174   }
175 
176   template <class... Args>
Mul(Args &&...args)177   llvm::Value* Mul(Args&&... args) {
178     return mixin_builder()->CreateMul(std::forward<Args>(args)...);
179   }
180 
181   template <class... Args>
NSWAdd(Args &&...args)182   llvm::Value* NSWAdd(Args&&... args) {
183     return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
184   }
185 
186   template <class... Args>
NSWMul(Args &&...args)187   llvm::Value* NSWMul(Args&&... args) {
188     return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
189   }
190 
191   template <class... Args>
NSWSub(Args &&...args)192   llvm::Value* NSWSub(Args&&... args) {
193     return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
194   }
195 
196   template <class... Args>
Or(Args &&...args)197   llvm::Value* Or(Args&&... args) {
198     return mixin_builder()->CreateOr(std::forward<Args>(args)...);
199   }
200 
201   template <class... Args>
PointerCast(Args &&...args)202   llvm::Value* PointerCast(Args&&... args) {
203     return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
204   }
205 
206   template <class... Args>
PtrToInt(Args &&...args)207   llvm::Value* PtrToInt(Args&&... args) {
208     return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
209   }
210 
211   template <class... Args>
SDiv(Args &&...args)212   llvm::Value* SDiv(Args&&... args) {
213     return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
214   }
215 
216   template <class... Args>
Select(Args &&...args)217   llvm::Value* Select(Args&&... args) {
218     return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
219   }
220 
221   template <class... Args>
SRem(Args &&...args)222   llvm::Value* SRem(Args&&... args) {
223     return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
224   }
225 
226   template <class... Args>
Store(Args &&...args)227   llvm::StoreInst* Store(Args&&... args) {
228     return mixin_builder()->CreateStore(std::forward<Args>(args)...);
229   }
230 
231   template <class... Args>
UDiv(Args &&...args)232   llvm::Value* UDiv(Args&&... args) {
233     return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
234   }
235 
236   template <class... Args>
URem(Args &&...args)237   llvm::Value* URem(Args&&... args) {
238     return mixin_builder()->CreateURem(std::forward<Args>(args)...);
239   }
240 
241   template <class... Args>
VectorSplat(Args &&...args)242   llvm::Value* VectorSplat(Args&&... args) {
243     return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
244   }
245 
246   template <class... Args>
ZExtOrTrunc(Args &&...args)247   llvm::Value* ZExtOrTrunc(Args&&... args) {
248     return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
249   }
250 
251   template <class... Args>
AShr(Args &&...args)252   llvm::Value* AShr(Args&&... args) {
253     return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
254   }
255 
256   template <class... Args>
FCmpOEQ(Args &&...args)257   llvm::Value* FCmpOEQ(Args&&... args) {
258     return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
259   }
260 
261   template <class... Args>
FCmpOGT(Args &&...args)262   llvm::Value* FCmpOGT(Args&&... args) {
263     return mixin_builder()->CreateFCmpOGT(std::forward<Args>(args)...);
264   }
265 
266   template <class... Args>
FCmpOGE(Args &&...args)267   llvm::Value* FCmpOGE(Args&&... args) {
268     return mixin_builder()->CreateFCmpOGE(std::forward<Args>(args)...);
269   }
270 
271   template <class... Args>
FCmpOLT(Args &&...args)272   llvm::Value* FCmpOLT(Args&&... args) {
273     return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
274   }
275 
276   template <class... Args>
FCmpULT(Args &&...args)277   llvm::Value* FCmpULT(Args&&... args) {
278     return mixin_builder()->CreateFCmpULT(std::forward<Args>(args)...);
279   }
280 
281   template <class... Args>
FCmpULE(Args &&...args)282   llvm::Value* FCmpULE(Args&&... args) {
283     return mixin_builder()->CreateFCmpULE(std::forward<Args>(args)...);
284   }
285 
286   template <class... Args>
FCmpOLE(Args &&...args)287   llvm::Value* FCmpOLE(Args&&... args) {
288     return mixin_builder()->CreateFCmpOLE(std::forward<Args>(args)...);
289   }
290 
291   template <class... Args>
FCmpONE(Args &&...args)292   llvm::Value* FCmpONE(Args&&... args) {
293     return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
294   }
295 
296   template <class... Args>
FCmpUNE(Args &&...args)297   llvm::Value* FCmpUNE(Args&&... args) {
298     return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
299   }
300 
301   template <class... Args>
FCmpUNO(Args &&...args)302   llvm::Value* FCmpUNO(Args&&... args) {
303     return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
304   }
305 
306   template <class... Args>
FDiv(Args &&...args)307   llvm::Value* FDiv(Args&&... args) {
308     return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
309   }
310 
311   template <class... Args>
FNeg(Args &&...args)312   llvm::Value* FNeg(Args&&... args) {
313     return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
314   }
315 
316   template <class... Args>
FPCast(Args &&...args)317   llvm::Value* FPCast(Args&&... args) {
318     return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
319   }
320 
321   template <class... Args>
FPToSI(Args &&...args)322   llvm::Value* FPToSI(Args&&... args) {
323     return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
324   }
325 
326   template <class... Args>
FPToUI(Args &&...args)327   llvm::Value* FPToUI(Args&&... args) {
328     return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
329   }
330 
331   template <class... Args>
FPTrunc(Args &&...args)332   llvm::Value* FPTrunc(Args&&... args) {
333     return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
334   }
335 
336   template <class... Args>
FRem(Args &&...args)337   llvm::Value* FRem(Args&&... args) {
338     return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
339   }
340 
341   template <class... Args>
FSub(Args &&...args)342   llvm::Value* FSub(Args&&... args) {
343     return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
344   }
345 
346   template <class... Args>
ICmpSGE(Args &&...args)347   llvm::Value* ICmpSGE(Args&&... args) {
348     return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
349   }
350 
351   template <class... Args>
ICmpSLT(Args &&...args)352   llvm::Value* ICmpSLT(Args&&... args) {
353     return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
354   }
355 
356   template <class... Args>
IntCast(Args &&...args)357   llvm::Value* IntCast(Args&&... args) {
358     return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
359   }
360 
361   template <class... Args>
LShr(Args &&...args)362   llvm::Value* LShr(Args&&... args) {
363     return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
364   }
365 
366   template <class... Args>
MemSet(Args &&...args)367   llvm::Value* MemSet(Args&&... args) {
368     return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
369   }
370 
371   template <class... Args>
Neg(Args &&...args)372   llvm::Value* Neg(Args&&... args) {
373     return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
374   }
375 
376   template <class... Args>
Not(Args &&...args)377   llvm::Value* Not(Args&&... args) {
378     return mixin_builder()->CreateNot(std::forward<Args>(args)...);
379   }
380 
381   template <class... Args>
PHI(Args &&...args)382   llvm::PHINode* PHI(Args&&... args) {
383     return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
384   }
385 
386   template <class... Args>
RetVoid(Args &&...args)387   llvm::Value* RetVoid(Args&&... args) {
388     return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
389   }
390 
391   template <class... Args>
SExtOrTrunc(Args &&...args)392   llvm::Value* SExtOrTrunc(Args&&... args) {
393     return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
394   }
395 
396   template <class... Args>
Shl(Args &&...args)397   llvm::Value* Shl(Args&&... args) {
398     return mixin_builder()->CreateShl(std::forward<Args>(args)...);
399   }
400 
401   template <class... Args>
SIToFP(Args &&...args)402   llvm::Value* SIToFP(Args&&... args) {
403     return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
404   }
405 
406   template <class... Args>
Sub(Args &&...args)407   llvm::Value* Sub(Args&&... args) {
408     return mixin_builder()->CreateSub(std::forward<Args>(args)...);
409   }
410 
411   template <class... Args>
Trunc(Args &&...args)412   llvm::Value* Trunc(Args&&... args) {
413     return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
414   }
415 
416   template <class... Args>
UIToFP(Args &&...args)417   llvm::Value* UIToFP(Args&&... args) {
418     return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
419   }
420 
421   template <class... Args>
Unreachable(Args &&...args)422   llvm::Value* Unreachable(Args&&... args) {
423     return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
424   }
425 
426   template <class... Args>
Xor(Args &&...args)427   llvm::Value* Xor(Args&&... args) {
428     return mixin_builder()->CreateXor(std::forward<Args>(args)...);
429   }
430 
431  private:
mixin_builder()432   llvm::IRBuilder<>* mixin_builder() {
433     return static_cast<Derived*>(this)->builder();
434   }
435 };
436 
437 }  // namespace xla
438 
439 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
440