1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18
19 #include <string>
20
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29 namespace cpu {
30
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32_t bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32_t bitwise_value) {
36 return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37 llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43 public:
44 // This VectorSupportLibrary instance remembers `primitive_type` and
45 // `vector_size`, and these are implicitly used by the methods on this
46 // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47 // `primitive_type`>).
48 VectorSupportLibrary(PrimitiveType primitive_type, int64_t vector_size,
49 llvm::IRBuilder<>* b, std::string name);
50
51 llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64_t lhs,llvm::Value * rhs)52 llvm::Value* Mul(int64_t lhs, llvm::Value* rhs) {
53 return Mul(b()->getInt64(lhs), rhs);
54 }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55 llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56 return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57 }
58
59 // If your call resolved to these then you probably wanted the versions taking
60 // APFloat.
61 llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62 llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63
64 llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64_t lhs,llvm::Value * rhs)65 llvm::Value* Add(int64_t lhs, llvm::Value* rhs) {
66 return Add(b()->getInt64(lhs), rhs);
67 }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68 llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69 return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70 }
71
72 // If your call resolved to these then you probably wanted the versions taking
73 // APFloat.
74 llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75 llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76
77 llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78 llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79 return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80 }
81 llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
82 bool enable_fast_min_max);
Max(const llvm::APFloat & lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
84 bool enable_fast_min_max) {
85 return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
86 }
87 llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
88
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)89 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
90 return Add(c, Mul(a, b));
91 }
92
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)93 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
94 return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
95 }
96
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)97 llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
98 const llvm::APFloat& c) {
99 return Add(GetConstantFloat(a->getType(), c),
100 Mul(a, GetConstantFloat(a->getType(), b)));
101 }
102
103 llvm::Value* Floor(llvm::Value* a);
104
105 // Precondition: Neither `low` nor `high` is nan.
106 llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
107 const llvm::APFloat& high);
108
SplatFloat(const llvm::APFloat & d)109 llvm::Value* SplatFloat(const llvm::APFloat& d) {
110 return GetConstantFloat(vector_type(), d);
111 }
112
113 // These compare instructions return a floating point typed mask instead of an
114 // i1. For instance, on a vector typed input, lanes where the predicate is
115 // true get a float with all ones and other lanes get a float with all zeros.
116 // This is slightly odd from the perspective of LLVM's type system, but it
117 // makes kernel IR generation code written using VectorSupportLibrary (its
118 // raison d'etre) less cluttered.
119
120 llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)121 llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
122 return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
123 }
124 llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
125 llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)126 llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
127 return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
128 }
129
130 // These boolean operations operate on the bitwise values of the floating
131 // point inputs. They return a (vector of) float(s) but like in the mask
132 // generating predicates above this type system oddity makes the kernel IR
133 // generation code less cluttered.
134 llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)135 llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
136 return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
137 }
138 llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)139 llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
140 return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
141 }
142 llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)143 llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
144 return FloatAnd(FloatNot(lhs), rhs);
145 }
146
147 llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)148 llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
149 return BroadcastScalar(GetConstantFloat(scalar_type(), d));
150 }
151
152 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
153 llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64_t scale)154 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
155 llvm::Value* offset_elements,
156 int64_t scale) {
157 return ComputeOffsetPointer(
158 base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
159 }
ComputeOffsetPointer(llvm::Value * base_pointer,int64_t offset_elements)160 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
161 int64_t offset_elements) {
162 return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
163 }
164
165 llvm::Value* LoadVector(llvm::Value* pointer);
166
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)167 llvm::Value* LoadVector(llvm::Value* base_pointer,
168 llvm::Value* offset_elements) {
169 return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
170 }
171
LoadVector(llvm::Value * base_pointer,int64_t offset_elements)172 llvm::Value* LoadVector(llvm::Value* base_pointer, int64_t offset_elements) {
173 return LoadVector(base_pointer, b()->getInt64(offset_elements));
174 }
175
176 llvm::Value* LoadScalar(llvm::Value* pointer);
177
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)178 llvm::Value* LoadScalar(llvm::Value* base_pointer,
179 llvm::Value* offset_elements) {
180 return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
181 }
182
LoadScalar(llvm::Value * base_pointer,int64_t offset_elements)183 llvm::Value* LoadScalar(llvm::Value* base_pointer, int64_t offset_elements) {
184 return LoadScalar(base_pointer, b()->getInt64(offset_elements));
185 }
186
187 void StoreVector(llvm::Value* value, llvm::Value* pointer);
188
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)189 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
190 llvm::Value* offset_elements) {
191 StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
192 }
193
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64_t offset_elements)194 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
195 int64_t offset_elements) {
196 StoreVector(value, base_pointer, b()->getInt64(offset_elements));
197 }
198
199 void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)200 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
201 llvm::Value* offset_elements) {
202 StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
203 }
204
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64_t offset_elements)205 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
206 int64_t offset_elements) {
207 StoreScalar(base_pointer, b()->getInt64(offset_elements));
208 }
209
210 llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)211 llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
212 llvm::Value* offset_elements) {
213 return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
214 }
LoadBroadcast(llvm::Value * base_pointer,int64_t offset_elements)215 llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
216 int64_t offset_elements) {
217 return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
218 }
219
220 // Compute the horizontal sum of each vector in `vectors`. The i'th element
221 // in the result vector is the (scalar) horizontal sum of the i'th vector in
222 // `vectors`. If `init_values` is not nullptr then the value in the i'th lane
223 // in `init_values` is added to the i'th horizontal sum.
224 std::vector<llvm::Value*> ComputeHorizontalSums(
225 std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
226
227 llvm::Value* GetZeroVector();
228 llvm::Value* GetZeroScalar();
229
b()230 llvm::IRBuilder<>* b() const { return b_; }
vector_size()231 int64_t vector_size() const { return vector_size_; }
vector_type()232 llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()233 llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()234 llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()235 llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()236 int64_t scalar_byte_size() const {
237 return primitive_util::BitWidth(primitive_type_) / 8;
238 }
239
name()240 const std::string& name() const { return name_; }
241
242 private:
243 llvm::Value* ExtractLowHalf(llvm::Value*);
244 llvm::Value* ExtractHighHalf(llvm::Value*);
245
246 llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
247 llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
248
249 llvm::Value* AddReduce(llvm::Value* vector);
250
251 // Checks that each value in `values` is either of type scalar_type() or
252 // vector_type(). This LOG(FATAL)'s so it should only be called in cases
253 // where a mismatching type is a programmer bug.
254 void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
255
256 // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The
257 // resulting IR for an 8-float wide vector is expected to lower to a single
258 // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
259 // other cases.
260 //
261 // For a vector width of 8, the result vector is computed as:
262 // Result[0] = Lhs[0] + Lhs[1]
263 // Result[1] = Lhs[2] + Lhs[3]
264 // Result[2] = Rhs[0] + Rhs[1]
265 // Result[3] = Rhs[2] + Rhs[3]
266 // Result[4] = Lhs[4] + Lhs[5]
267 // Result[5] = Lhs[6] + Lhs[7]
268 // Result[6] = Rhs[4] + Rhs[5]
269 // Result[7] = Rhs[6] + Rhs[7]
270 llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
271
272 std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
273 std::vector<llvm::Value*> vectors, llvm::Value* init_values);
274
275 llvm::Type* IntegerTypeForFloatSize(bool vector);
276 llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)277 llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
278 llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
279 if (llvm::isa<llvm::VectorType>(type)) {
280 return llvm::ConstantVector::getSplat(
281 llvm::ElementCount::getFixed(vector_size()), scalar_value);
282 }
283 return scalar_value;
284 }
285
286 int64_t vector_size_;
287 PrimitiveType primitive_type_;
288 llvm::IRBuilder<>* b_;
289 llvm::Type* vector_type_;
290 llvm::Type* vector_pointer_type_;
291 llvm::Type* scalar_type_;
292 llvm::Type* scalar_pointer_type_;
293 std::string name_;
294 };
295
296 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
297 // can later convert to a SSA value.
298 class LlvmVariable {
299 public:
300 LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
301
302 llvm::Value* Get() const;
303 void Set(llvm::Value* new_value);
304
305 private:
306 llvm::AllocaInst* alloca_;
307 llvm::IRBuilder<>* b_;
308 };
309
310 class VectorVariable : public LlvmVariable {
311 public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)312 VectorVariable(VectorSupportLibrary* vector_support,
313 llvm::Value* initial_value)
314 : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
315 Set(initial_value);
316 }
317 };
318
319 class ScalarVariable : public LlvmVariable {
320 public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)321 ScalarVariable(VectorSupportLibrary* vector_support,
322 llvm::Value* initial_value)
323 : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
324 Set(initial_value);
325 }
326 };
327
328 // This wraps a set of alloca-backed stack variables that can, as a whole, store
329 // a tile. A "tile" is a sequence of vectors that is typically used as a 2D
330 // grid of scalar values (e.g. for tiled GEMMs).
331 class TileVariable {
332 public:
333 TileVariable(VectorSupportLibrary* vector_support,
334 std::vector<llvm::Value*> initial_value);
335
336 std::vector<llvm::Value*> Get() const;
337 void Set(absl::Span<llvm::Value* const> value);
338
339 private:
340 std::vector<VectorVariable> storage_;
341 };
342 } // namespace cpu
343 } // namespace xla
344
345 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
346