xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/vector_support_library.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 namespace cpu {
30 
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33 
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32_t bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32_t bitwise_value) {
36   return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37                        llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39 
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43  public:
44   // This VectorSupportLibrary instance remembers `primitive_type` and
45   // `vector_size`, and these are implicitly used by the methods on this
46   // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47   // `primitive_type`>).
48   VectorSupportLibrary(PrimitiveType primitive_type, int64_t vector_size,
49                        llvm::IRBuilder<>* b, std::string name);
50 
51   llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64_t lhs,llvm::Value * rhs)52   llvm::Value* Mul(int64_t lhs, llvm::Value* rhs) {
53     return Mul(b()->getInt64(lhs), rhs);
54   }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55   llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57   }
58 
59   // If your call resolved to these then you probably wanted the versions taking
60   // APFloat.
61   llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62   llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63 
64   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64_t lhs,llvm::Value * rhs)65   llvm::Value* Add(int64_t lhs, llvm::Value* rhs) {
66     return Add(b()->getInt64(lhs), rhs);
67   }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68   llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70   }
71 
72   // If your call resolved to these then you probably wanted the versions taking
73   // APFloat.
74   llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75   llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76 
77   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78   llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80   }
81   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
82                    bool enable_fast_min_max);
Max(const llvm::APFloat & lhs,llvm::Value * rhs,bool enable_fast_min_max)83   llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
84                    bool enable_fast_min_max) {
85     return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
86   }
87   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
88 
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)89   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
90     return Add(c, Mul(a, b));
91   }
92 
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)93   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
94     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
95   }
96 
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)97   llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
98                       const llvm::APFloat& c) {
99     return Add(GetConstantFloat(a->getType(), c),
100                Mul(a, GetConstantFloat(a->getType(), b)));
101   }
102 
103   llvm::Value* Floor(llvm::Value* a);
104 
105   // Precondition: Neither `low` nor `high` is nan.
106   llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
107                      const llvm::APFloat& high);
108 
SplatFloat(const llvm::APFloat & d)109   llvm::Value* SplatFloat(const llvm::APFloat& d) {
110     return GetConstantFloat(vector_type(), d);
111   }
112 
113   // These compare instructions return a floating point typed mask instead of an
114   // i1.  For instance, on a vector typed input, lanes where the predicate is
115   // true get a float with all ones and other lanes get a float with all zeros.
116   // This is slightly odd from the perspective of LLVM's type system, but it
117   // makes kernel IR generation code written using VectorSupportLibrary (its
118   // raison d'etre) less cluttered.
119 
120   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)121   llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
122     return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
123   }
124   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
125   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)126   llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
127     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
128   }
129 
130   // These boolean operations operate on the bitwise values of the floating
131   // point inputs.  They return a (vector of) float(s) but like in the mask
132   // generating predicates above this type system oddity makes the kernel IR
133   // generation code less cluttered.
134   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)135   llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
136     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
137   }
138   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)139   llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
140     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
141   }
142   llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)143   llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
144     return FloatAnd(FloatNot(lhs), rhs);
145   }
146 
147   llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)148   llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
149     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
150   }
151 
152   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
153                                     llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64_t scale)154   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
155                                     llvm::Value* offset_elements,
156                                     int64_t scale) {
157     return ComputeOffsetPointer(
158         base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
159   }
ComputeOffsetPointer(llvm::Value * base_pointer,int64_t offset_elements)160   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
161                                     int64_t offset_elements) {
162     return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
163   }
164 
165   llvm::Value* LoadVector(llvm::Value* pointer);
166 
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)167   llvm::Value* LoadVector(llvm::Value* base_pointer,
168                           llvm::Value* offset_elements) {
169     return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
170   }
171 
LoadVector(llvm::Value * base_pointer,int64_t offset_elements)172   llvm::Value* LoadVector(llvm::Value* base_pointer, int64_t offset_elements) {
173     return LoadVector(base_pointer, b()->getInt64(offset_elements));
174   }
175 
176   llvm::Value* LoadScalar(llvm::Value* pointer);
177 
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)178   llvm::Value* LoadScalar(llvm::Value* base_pointer,
179                           llvm::Value* offset_elements) {
180     return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
181   }
182 
LoadScalar(llvm::Value * base_pointer,int64_t offset_elements)183   llvm::Value* LoadScalar(llvm::Value* base_pointer, int64_t offset_elements) {
184     return LoadScalar(base_pointer, b()->getInt64(offset_elements));
185   }
186 
187   void StoreVector(llvm::Value* value, llvm::Value* pointer);
188 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)189   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
190                    llvm::Value* offset_elements) {
191     StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
192   }
193 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64_t offset_elements)194   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
195                    int64_t offset_elements) {
196     StoreVector(value, base_pointer, b()->getInt64(offset_elements));
197   }
198 
199   void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)200   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
201                    llvm::Value* offset_elements) {
202     StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
203   }
204 
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64_t offset_elements)205   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
206                    int64_t offset_elements) {
207     StoreScalar(base_pointer, b()->getInt64(offset_elements));
208   }
209 
210   llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)211   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
212                              llvm::Value* offset_elements) {
213     return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
214   }
LoadBroadcast(llvm::Value * base_pointer,int64_t offset_elements)215   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
216                              int64_t offset_elements) {
217     return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
218   }
219 
220   // Compute the horizontal sum of each vector in `vectors`.  The i'th element
221   // in the result vector is the (scalar) horizontal sum of the i'th vector in
222   // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
223   // in `init_values` is added to the i'th horizontal sum.
224   std::vector<llvm::Value*> ComputeHorizontalSums(
225       std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
226 
227   llvm::Value* GetZeroVector();
228   llvm::Value* GetZeroScalar();
229 
b()230   llvm::IRBuilder<>* b() const { return b_; }
vector_size()231   int64_t vector_size() const { return vector_size_; }
vector_type()232   llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()233   llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()234   llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()235   llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()236   int64_t scalar_byte_size() const {
237     return primitive_util::BitWidth(primitive_type_) / 8;
238   }
239 
name()240   const std::string& name() const { return name_; }
241 
242  private:
243   llvm::Value* ExtractLowHalf(llvm::Value*);
244   llvm::Value* ExtractHighHalf(llvm::Value*);
245 
246   llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
247   llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
248 
249   llvm::Value* AddReduce(llvm::Value* vector);
250 
251   // Checks that each value in `values` is either of type scalar_type() or
252   // vector_type().  This LOG(FATAL)'s so it should only be called in cases
253   // where a mismatching type is a programmer bug.
254   void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
255 
256   // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
257   // resulting IR for an 8-float wide vector is expected to lower to a single
258   // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
259   // other cases.
260   //
261   // For a vector width of 8, the result vector is computed as:
262   //   Result[0] = Lhs[0] + Lhs[1]
263   //   Result[1] = Lhs[2] + Lhs[3]
264   //   Result[2] = Rhs[0] + Rhs[1]
265   //   Result[3] = Rhs[2] + Rhs[3]
266   //   Result[4] = Lhs[4] + Lhs[5]
267   //   Result[5] = Lhs[6] + Lhs[7]
268   //   Result[6] = Rhs[4] + Rhs[5]
269   //   Result[7] = Rhs[6] + Rhs[7]
270   llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
271 
272   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
273       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
274 
275   llvm::Type* IntegerTypeForFloatSize(bool vector);
276   llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)277   llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
278     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
279     if (llvm::isa<llvm::VectorType>(type)) {
280       return llvm::ConstantVector::getSplat(
281           llvm::ElementCount::getFixed(vector_size()), scalar_value);
282     }
283     return scalar_value;
284   }
285 
286   int64_t vector_size_;
287   PrimitiveType primitive_type_;
288   llvm::IRBuilder<>* b_;
289   llvm::Type* vector_type_;
290   llvm::Type* vector_pointer_type_;
291   llvm::Type* scalar_type_;
292   llvm::Type* scalar_pointer_type_;
293   std::string name_;
294 };
295 
296 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
297 // can later convert to a SSA value.
298 class LlvmVariable {
299  public:
300   LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
301 
302   llvm::Value* Get() const;
303   void Set(llvm::Value* new_value);
304 
305  private:
306   llvm::AllocaInst* alloca_;
307   llvm::IRBuilder<>* b_;
308 };
309 
310 class VectorVariable : public LlvmVariable {
311  public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)312   VectorVariable(VectorSupportLibrary* vector_support,
313                  llvm::Value* initial_value)
314       : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
315     Set(initial_value);
316   }
317 };
318 
319 class ScalarVariable : public LlvmVariable {
320  public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)321   ScalarVariable(VectorSupportLibrary* vector_support,
322                  llvm::Value* initial_value)
323       : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
324     Set(initial_value);
325   }
326 };
327 
328 // This wraps a set of alloca-backed stack variables that can, as a whole, store
329 // a tile.  A "tile" is a sequence of vectors that is typically used as a 2D
330 // grid of scalar values (e.g. for tiled GEMMs).
331 class TileVariable {
332  public:
333   TileVariable(VectorSupportLibrary* vector_support,
334                std::vector<llvm::Value*> initial_value);
335 
336   std::vector<llvm::Value*> Get() const;
337   void Set(absl::Span<llvm::Value* const> value);
338 
339  private:
340   std::vector<VectorVariable> storage_;
341 };
342 }  // namespace cpu
343 }  // namespace xla
344 
345 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
346