1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" // from @llvm-project
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Support/LogicalResult.h" // from @llvm-project
31 #include "tensorflow/compiler/xla/runtime/custom_call.h"
32 #include "tensorflow/compiler/xla/runtime/type_id.h"
33
34 namespace Eigen {
35 struct half;
36 } // namespace Eigen
37
38 namespace xla {
39 namespace runtime {
40
41 using namespace mlir; // NOLINT
42 using arith::ConstantOp;
43 using func::FuncOp;
44
45 using llvm::ArrayRef;
46 using llvm::StringRef;
47
48 using tfrt::DType;
49
50 //===----------------------------------------------------------------------===//
51 // Custom call arguments encoding.
52 //===----------------------------------------------------------------------===//
53
54 using EncodedArg = CustomCallArgEncodingSet::Encoded;
55
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const56 FailureOr<EncodedArg> CustomCallArgEncodingSet::Encode(Globals &g,
57 ImplicitLocOpBuilder &b,
58 Value value,
59 Value converted) const {
60 for (auto &encoding : encodings_)
61 if (succeeded(encoding->Match(value, converted)))
62 return encoding->Encode(g, b, value, converted);
63 return failure();
64 }
65
66 //===----------------------------------------------------------------------===//
67 // Custom call attributes encoding.
68 //===----------------------------------------------------------------------===//
69
70 using EncodedAttr = CustomCallAttrEncodingSet::Encoded;
71
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const72 FailureOr<EncodedAttr> CustomCallAttrEncodingSet::Encode(
73 Globals &g, ImplicitLocOpBuilder &b, StringRef name, Attribute attr) const {
74 for (auto &encoding : encodings_)
75 if (succeeded(encoding->Match(name, attr)))
76 return encoding->Encode(g, b, name, attr);
77 return failure();
78 }
79
80 //===----------------------------------------------------------------------===//
81 // A set of helper functions for packing primitive attributes.
82 //===----------------------------------------------------------------------===//
83
PackTypeId(Globals & g,ImplicitLocOpBuilder & b,TypeID type_id)84 Value PackTypeId(Globals &g, ImplicitLocOpBuilder &b, TypeID type_id) {
85 auto global = g.GetOrCreate(b, type_id);
86 return Globals::AddrOf(b, global);
87 }
88
PackString(Globals & g,ImplicitLocOpBuilder & b,StringRef strref,StringRef symbol_base)89 Value PackString(Globals &g, ImplicitLocOpBuilder &b, StringRef strref,
90 StringRef symbol_base) {
91 MLIRContext *ctx = b.getContext();
92 int64_t size = strref.size();
93
94 // Encoded string type: !llvm.struct<(i64, !llvm.ptr<array<i8 x len>>)>.
95 Type arr = LLVM::LLVMArrayType::get(b.getI8Type(), 1 + size);
96 Type ptr = LLVM::LLVMPointerType::get(arr);
97 Type type = LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), ptr});
98
99 // Global constant initializer for the encoded string structure
100 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
101 // String size and pointer to a null-terminated string.
102 Value num_elements = ib.create<ConstantOp>(ib.getI64IntegerAttr(size));
103 Value str = Globals::AddrOf(ib, g.GetOrCreate(b, strref, "__rt_str"));
104
105 // Store size and pointer into the struct.
106 Value encoded = ib.create<LLVM::UndefOp>(type);
107 encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
108 encoded = ib.create<LLVM::InsertValueOp>(encoded, str, 1);
109 ib.create<LLVM::ReturnOp>(encoded);
110 };
111
112 auto value = b.getStringAttr(strref);
113 auto global = g.GetOrCreate(b, value, type, symbol_base, init);
114 return Globals::AddrOf(b, global);
115 }
116
117 // Packs scalar attribute as a global constant. Returns `!llvm.ptr<AttrType>`.
PackScalarAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)118 Value PackScalarAttribute(Globals &g, ImplicitLocOpBuilder &b, Attribute value,
119 StringRef symbol_base) {
120 auto global = g.GetOrCreate(b, value, symbol_base);
121 return Globals::AddrOf(b, global);
122 }
123
124 // Reshape dense elements as a one-dimensional array.
Flatten(DenseIntOrFPElementsAttr dense)125 static mlir::DenseElementsAttr Flatten(DenseIntOrFPElementsAttr dense) {
126 ShapedType shaped_type = dense.getType();
127 ShapedType new_shaped_type = shaped_type.cloneWith(
128 {shaped_type.getNumElements()}, dense.getElementType());
129 return dense.reshape(new_shaped_type);
130 }
131
132 //===----------------------------------------------------------------------===//
133 // A set of helper functions for packing dense and array-like attributes.
134 //===----------------------------------------------------------------------===//
135
136 // Packs dense elements attribute as a global constant. Returns
137 // `!llvm.ptr<EncodedDenseElements>`.
PackDenseElementsAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)138 static Value PackDenseElementsAttribute(Globals &g, ImplicitLocOpBuilder &b,
139 Attribute value,
140 StringRef symbol_base) {
141 MLIRContext *ctx = b.getContext();
142 DenseIntOrFPElementsAttr dense = value.cast<DenseIntOrFPElementsAttr>();
143
144 // Payload type:
145 // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>)>>.
146 Type element_type = dense.getElementType();
147 Type data_arr_type =
148 LLVM::LLVMArrayType::get(element_type, dense.getNumElements());
149 Type data_arr_ptr_type = LLVM::LLVMPointerType::get(data_arr_type);
150 Type payload_type = LLVM::LLVMStructType::getLiteral(
151 ctx, {b.getI64Type(), data_arr_ptr_type});
152
153 int64_t rank = dense.getType().getRank();
154 ArrayRef<int64_t> shape = dense.getType().getShape();
155 Type shape_arr_type = LLVM::LLVMArrayType::get(b.getI64Type(), rank);
156
157 // Encoded dense elements type:
158 // !llvm.struct<encoded_array_type, i64, array<i64, rank>
159 Type type = LLVM::LLVMStructType::getLiteral(
160 ctx, {payload_type, b.getI64Type(), shape_arr_type});
161
162 // Global constant initializer for the encoded array structure.
163 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
164 Value num_elements =
165 ib.create<ConstantOp>(b.getI64IntegerAttr(dense.getNumElements()));
166 Value data_ptr = Globals::AddrOf(
167 ib, g.GetOrCreate(b, Flatten(dense), data_arr_type, symbol_base));
168
169 // Create the payload struct.
170 Value payload = ib.create<LLVM::UndefOp>(payload_type);
171 payload = ib.create<LLVM::InsertValueOp>(payload, num_elements, 0);
172 payload = ib.create<LLVM::InsertValueOp>(payload, data_ptr, 1);
173
174 // Get rank and shape.
175 Value rank_value = ib.create<ConstantOp>(b.getI64IntegerAttr(rank));
176 Value shape_value = ib.create<LLVM::UndefOp>(shape_arr_type);
177
178 // Store each dimension size into shape_value.
179 for (int i = 0; i < rank; i++) {
180 Value dim = ib.create<ConstantOp>(ib.getI64IntegerAttr(shape[i]));
181 shape_value = ib.create<LLVM::InsertValueOp>(shape_value, dim, i);
182 }
183
184 // Store the payload, rank, and shape into the struct.
185 Value encoded = ib.create<LLVM::UndefOp>(type);
186 encoded = ib.create<LLVM::InsertValueOp>(encoded, payload, 0);
187 encoded = ib.create<LLVM::InsertValueOp>(encoded, rank_value, 1);
188 encoded = ib.create<LLVM::InsertValueOp>(encoded, shape_value, 2);
189 ib.create<LLVM::ReturnOp>(encoded);
190 };
191
192 auto global = g.GetOrCreate(b, value, type, symbol_base, init);
193 return Globals::AddrOf(b, global);
194 }
195
196 // Create a global for the data array in an EncodedArray.
197 // Returns `!llvm.ptr<array<element_type x size>>
CreateGlobalFromArray(Globals & g,ImplicitLocOpBuilder & b,ArrayAttr array,Type element_type,StringRef symbol_base)198 static Value CreateGlobalFromArray(Globals &g, ImplicitLocOpBuilder &b,
199 ArrayAttr array, Type element_type,
200 StringRef symbol_base) {
201 Type arr_type = LLVM::LLVMArrayType::get(element_type, array.size());
202
203 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
204 Value data = ib.create<LLVM::UndefOp>(arr_type);
205 for (int i = 0; i < array.size(); i++) {
206 Value value = ib.create<ConstantOp>(array[i]);
207 data = ib.create<LLVM::InsertValueOp>(data, value, i);
208 }
209 ib.create<LLVM::ReturnOp>(data);
210 };
211
212 auto global = g.GetOrCreate(b, array, arr_type, symbol_base, init);
213 return Globals::AddrOf(b, global);
214 }
215
216 // Packs array attribute as a global constant. Returns `!llvm.ptr<EncodedArr>`.
PackArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,ArrayAttr array,Type element_type,StringRef symbol_base)217 static Value PackArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
218 ArrayAttr array, Type element_type,
219 StringRef symbol_base) {
220 MLIRContext *ctx = b.getContext();
221
222 int64_t size = array.size();
223
224 // Encoded array type:
225 // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>)>>.
226 Type arr_type = LLVM::LLVMArrayType::get(element_type, size);
227 Type arr_ptr_type = LLVM::LLVMPointerType::get(arr_type);
228 Type type =
229 LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), arr_ptr_type});
230
231 // Global constant initializer for the encoded array structure
232 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
233 // Array size and the pointer to data.
234 Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(size));
235 Value data = CreateGlobalFromArray(g, b, array, element_type, symbol_base);
236
237 // Store size and values into the struct.
238 Value encoded = ib.create<LLVM::UndefOp>(type);
239 encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
240 encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
241
242 ib.create<LLVM::ReturnOp>(encoded);
243 };
244
245 auto global = g.GetOrCreate(b, array, type, symbol_base, init);
246 return Globals::AddrOf(b, global);
247 }
248
249 template <typename T, typename AttrType, typename ArrayType>
FillDataFromDenseArrayAttr(ImplicitLocOpBuilder & b,AttrType (ImplicitLocOpBuilder::* get_attr)(T),ArrayType array,Value data)250 static Value FillDataFromDenseArrayAttr(
251 ImplicitLocOpBuilder &b, AttrType (ImplicitLocOpBuilder::*get_attr)(T),
252 ArrayType array, Value data) {
253 ArrayRef<T> array_ref = array.asArrayRef();
254 for (int i = 0; i < array_ref.size(); i++) {
255 Value value = b.create<ConstantOp>((b.*get_attr)(array_ref[i]));
256 data = b.create<LLVM::InsertValueOp>(data, value, i);
257 }
258 return data;
259 }
260
CreateGlobalFromDenseArray(Globals & g,ImplicitLocOpBuilder & b,DenseArrayBaseAttr base_array,Type arr_type,StringRef symbol_base)261 static Value CreateGlobalFromDenseArray(Globals &g, ImplicitLocOpBuilder &b,
262 DenseArrayBaseAttr base_array,
263 Type arr_type, StringRef symbol_base) {
264 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
265 Value data = ib.create<LLVM::UndefOp>(arr_type);
266 switch (base_array.getElementType()) {
267 case DenseArrayBaseAttr::EltType::I8:
268 data = FillDataFromDenseArrayAttr<int8_t, IntegerAttr>(
269 b, &ImplicitLocOpBuilder::getI8IntegerAttr,
270 base_array.cast<mlir::DenseI8ArrayAttr>(), data);
271 break;
272 case DenseArrayBaseAttr::EltType::I16:
273 data = FillDataFromDenseArrayAttr<int16_t, IntegerAttr>(
274 b, &ImplicitLocOpBuilder::getI16IntegerAttr,
275 base_array.cast<mlir::DenseI16ArrayAttr>(), data);
276 break;
277 case DenseArrayBaseAttr::EltType::I32:
278 data = FillDataFromDenseArrayAttr<int32_t, IntegerAttr>(
279 b, &ImplicitLocOpBuilder::getI32IntegerAttr,
280 base_array.cast<mlir::DenseI32ArrayAttr>(), data);
281 break;
282 case DenseArrayBaseAttr::EltType::I64:
283 data = FillDataFromDenseArrayAttr<int64_t, IntegerAttr>(
284 b, &ImplicitLocOpBuilder::getI64IntegerAttr,
285 base_array.cast<mlir::DenseI64ArrayAttr>(), data);
286 break;
287 case DenseArrayBaseAttr::EltType::F32:
288 data = FillDataFromDenseArrayAttr<float, FloatAttr>(
289 b, &ImplicitLocOpBuilder::getF32FloatAttr,
290 base_array.cast<mlir::DenseF32ArrayAttr>(), data);
291 break;
292 case DenseArrayBaseAttr::EltType::F64:
293 data = FillDataFromDenseArrayAttr<double, FloatAttr>(
294 b, &ImplicitLocOpBuilder::getF64FloatAttr,
295 base_array.cast<mlir::DenseF64ArrayAttr>(), data);
296 break;
297 default:
298 assert(false && "unsupported DenseArrayAttr element type");
299 }
300 ib.create<LLVM::ReturnOp>(data);
301 };
302
303 auto global = g.GetOrCreate(b, base_array, arr_type, symbol_base, init);
304 return Globals::AddrOf(b, global);
305 }
306
PackDenseArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)307 static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
308 Attribute value, StringRef symbol_base) {
309 MLIRContext *ctx = b.getContext();
310
311 DenseArrayBaseAttr base_array = value.cast<DenseArrayBaseAttr>();
312 int64_t size = base_array.size();
313
314 // Encoded array type:
315 // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>>)>.
316 Type element_type = base_array.getType().getElementType();
317 Type arr_type = LLVM::LLVMArrayType::get(element_type, size);
318 Type arr_ptr_type = LLVM::LLVMPointerType::get(arr_type);
319 Type type =
320 LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), arr_ptr_type});
321
322 // Global constant initializer for the encoded array structure
323 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
324 // Array size and values.
325 Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(size));
326 Value data =
327 CreateGlobalFromDenseArray(g, ib, base_array, arr_type, symbol_base);
328
329 // Store size and values into the struct.
330 Value encoded = ib.create<LLVM::UndefOp>(type);
331 encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
332 encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
333
334 ib.create<LLVM::ReturnOp>(encoded);
335 };
336
337 auto global = g.GetOrCreate(b, value, type, symbol_base, init);
338 return Globals::AddrOf(b, global);
339 }
340
PackEmptyArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)341 static Value PackEmptyArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
342 Attribute value, StringRef symbol_base) {
343 MLIRContext *ctx = b.getContext();
344
345 // Encoded array type: !llvm.struct<(i64, !llvm.ptr<i8>)>.
346 // The pointer is always null. We use i8 as a placeholder type.
347 Type data_type = LLVM::LLVMPointerType::get(b.getI8Type());
348 Type type =
349 LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), data_type});
350
351 // Global constant initializer for the encoded array structure
352 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
353 // Array size and the pointer to data.
354 Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(0));
355 Value data = ib.create<LLVM::NullOp>(data_type);
356
357 // Store size and values into the struct.
358 Value encoded = ib.create<LLVM::UndefOp>(type);
359 encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
360 encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
361
362 ib.create<LLVM::ReturnOp>(encoded);
363 };
364
365 auto global = g.GetOrCreate(b, value, type, symbol_base, init);
366 return Globals::AddrOf(b, global);
367 }
368
369 //===----------------------------------------------------------------------===//
370 // Packing primitive values on the stack.
371 //===----------------------------------------------------------------------===//
372
373 // Returns the parent function operation for the given value.
GetParentFunc(Value value)374 static FuncOp GetParentFunc(Value value) {
375 Block *parent_block = value.getParentBlock();
376 Operation *parent_op = parent_block->getParentOp();
377
378 return isa<FuncOp>(parent_op) ? cast<FuncOp>(parent_op)
379 : parent_op->getParentOfType<FuncOp>();
380 }
381
382 // Packs value on the stack. Returns `!llvm.ptr<ValueType>`.
PackValue(ImplicitLocOpBuilder & b,Value value)383 static Value PackValue(ImplicitLocOpBuilder &b, Value value) {
384 Type ptr = LLVM::LLVMPointerType::get(value.getType());
385
386 // Always create an `alloca` in the parent function entry block.
387 // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas
388 Value mem = [&]() -> Value {
389 Block &block = GetParentFunc(value).getBody().front();
390 OpBuilder::InsertionGuard guard(b);
391 b.setInsertionPointToStart(&block);
392 Value one = b.create<ConstantOp>(b.getI32IntegerAttr(1));
393 return b.create<LLVM::AllocaOp>(ptr, one, 0);
394 }();
395
396 b.create<LLVM::StoreOp>(value, mem);
397
398 return mem;
399 }
400
401 //===----------------------------------------------------------------------===//
402 // A helper class to create global constants in the module.
403 //===----------------------------------------------------------------------===//
404
Find(Key key)405 LLVM::GlobalOp Globals::Find(Key key) {
406 auto it = globals_.find(key);
407 if (it != globals_.end()) return it->second;
408 return nullptr;
409 }
410
GetOrCreate(ImplicitLocOpBuilder & b,StringRef strref,StringRef symbol_base)411 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, StringRef strref,
412 StringRef symbol_base) {
413 // Create an std::string to get a null terminated sequence of characters.
414 std::string str = strref.str();
415
416 // Create a string reference that captures the null terminator.
417 StringRef ref(str.data(), str.size() + 1);
418 StringAttr attr = b.getStringAttr(ref);
419 Type arr = LLVM::LLVMArrayType::get(b.getI8Type(), ref.size());
420 return GetOrCreate(b, attr, arr, symbol_base);
421 }
422
GetOrCreate(ImplicitLocOpBuilder & b,TypedAttr attr,StringRef symbol_base)423 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, TypedAttr attr,
424 StringRef symbol_base) {
425 return GetOrCreate(b, attr, attr.getType(), symbol_base);
426 }
427
GetOrCreate(ImplicitLocOpBuilder & b,mlir::TypeID type_id)428 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b,
429 mlir::TypeID type_id) {
430 llvm::StringRef name = type_id_names_.FindTypeIDSymbolName(type_id);
431 assert(!name.empty() && "cannot find the symbol name of type_id");
432 return GetOrCreate(b, IntegerAttr(), b.getI64Type(), name, /*initialize=*/{},
433 LLVM::Linkage::External);
434 }
435
GetOrCreate(ImplicitLocOpBuilder & b,Attribute attr,Type type,StringRef symbol_base,GlobalInitializer initialize,LLVM::Linkage linkage)436 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, Attribute attr,
437 Type type, StringRef symbol_base,
438 GlobalInitializer initialize,
439 LLVM::Linkage linkage) {
440 if (!initialize) {
441 return *TryGetOrCreate(b, attr, type, symbol_base, /*initialize=*/{},
442 linkage);
443 }
444
445 return *TryGetOrCreate(
446 b, attr, type, symbol_base,
447 [&](ImplicitLocOpBuilder &b, Attribute) {
448 return (initialize(b, attr), success());
449 },
450 linkage);
451 }
452
TryGetOrCreate(mlir::ImplicitLocOpBuilder & b,mlir::Attribute attr,mlir::Type type,llvm::StringRef symbol_base,FailureOrGlobalInitializer initialize,mlir::LLVM::Linkage linkage)453 mlir::FailureOr<mlir::LLVM::GlobalOp> Globals::TryGetOrCreate(
454 mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type,
455 llvm::StringRef symbol_base, FailureOrGlobalInitializer initialize,
456 mlir::LLVM::Linkage linkage) {
457 // We assume that this triple uniquely identifies the global value and the
458 // global initializer always produces the same value for given inputs.
459 Key key(attr, type, b.getStringAttr(symbol_base));
460
461 // Check if global value already exists ...
462 if (auto global = Find(key)) return global;
463
464 // ... otherwise create a new one.
465 OpBuilder::InsertionGuard guard(b);
466 b.setInsertionPointToStart(module_.getBody());
467
468 // If the initialize function is not provided, create constant directly.
469 if (!initialize) {
470 auto global = b.create<LLVM::GlobalOp>(type, /*isConstant=*/true, linkage,
471 symbol_base, attr);
472 return (sym_table_.insert(global), globals_[key] = global);
473 }
474
475 // Create an uninitialized global.
476 auto global = b.create<LLVM::GlobalOp>(type, /*isConstant=*/true, linkage,
477 symbol_base, nullptr);
478
479 // Call user-provided global initializer.
480 mlir::Region ®ion = global.getInitializerRegion();
481 mlir::Block *block = b.createBlock(®ion);
482
483 b.setInsertionPointToStart(block);
484 if (failed(initialize(b, attr))) return failure();
485
486 return (sym_table_.insert(global), globals_[key] = global);
487 }
488
AddrOf(ImplicitLocOpBuilder & b,LLVM::GlobalOp global)489 /*static*/ Value Globals::AddrOf(ImplicitLocOpBuilder &b,
490 LLVM::GlobalOp global) {
491 return b.create<LLVM::AddressOfOp>(
492 LLVM::LLVMPointerType::get(global.getType()), global.getSymName());
493 }
494
OpaqueAddrOf(ImplicitLocOpBuilder & b,LLVM::GlobalOp global)495 /*static*/ Value Globals::OpaqueAddrOf(ImplicitLocOpBuilder &b,
496 LLVM::GlobalOp global) {
497 return b.create<LLVM::BitcastOp>(LLVM::LLVMPointerType::get(b.getI8Type()),
498 AddrOf(b, global));
499 }
500
501 //===----------------------------------------------------------------------===//
502 // Helper functions for encoding attributes and values for custom calls.
503 //===----------------------------------------------------------------------===//
504
IsSupportedScalarType(Type type)505 static bool IsSupportedScalarType(Type type) {
506 auto is_supported_width = [](unsigned width, ArrayRef<unsigned> supported) {
507 return llvm::any_of(supported, [&](unsigned w) { return w == width; });
508 };
509
510 if (auto i = type.dyn_cast<mlir::IntegerType>())
511 return i.isUnsigned() ? is_supported_width(i.getWidth(), {8, 32, 64})
512 : is_supported_width(i.getWidth(), {1, 32, 64});
513
514 if (auto fp = type.dyn_cast<mlir::FloatType>())
515 return is_supported_width(fp.getWidth(), {16, 32, 64});
516
517 return false;
518 }
519
IsSupportedScalarAttribute(Attribute attr)520 static bool IsSupportedScalarAttribute(Attribute attr) {
521 if (auto typed = attr.dyn_cast<TypedAttr>())
522 return IsSupportedScalarType(typed.getType());
523 return false;
524 }
525
ScalarRuntimeTypeId(Type type)526 static TypeID ScalarRuntimeTypeId(Type type) {
527 if (type.isUnsignedInteger(8)) return TypeID::get<Tagged<uint8_t>>();
528 if (type.isUnsignedInteger(32)) return TypeID::get<Tagged<uint32_t>>();
529 if (type.isUnsignedInteger(64)) return TypeID::get<Tagged<uint64_t>>();
530
531 if (type.isInteger(1)) return TypeID::get<Tagged<bool>>();
532 if (type.isInteger(32)) return TypeID::get<Tagged<int32_t>>();
533 if (type.isInteger(64)) return TypeID::get<Tagged<int64_t>>();
534
535 if (type.isF16()) return TypeID::get<Tagged<Eigen::half>>();
536 if (type.isF32()) return TypeID::get<Tagged<float>>();
537 if (type.isF64()) return TypeID::get<Tagged<double>>();
538
539 assert(false && "unsupported type id");
540 return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
541 }
542
ScalarDType(Type type)543 static DType ScalarDType(Type type) {
544 // Unsigned integer types.
545 if (type.isUnsignedInteger(8)) return DType::UI8;
546 if (type.isUnsignedInteger(16)) return DType::UI16;
547 if (type.isUnsignedInteger(32)) return DType::UI32;
548 if (type.isUnsignedInteger(64)) return DType::UI64;
549
550 // Signed integer types.
551 if (type.isInteger(1)) return DType::I1;
552 if (type.isInteger(8)) return DType::I8;
553 if (type.isInteger(16)) return DType::I16;
554 if (type.isInteger(32)) return DType::I32;
555 if (type.isInteger(64)) return DType::I64;
556
557 // Floating point types.
558 if (type.isF16()) return DType::F16;
559 if (type.isF32()) return DType::F32;
560 if (type.isF64()) return DType::F64;
561 if (type.isBF16()) return DType::BF16;
562
563 // Complex types.
564 if (auto complex = type.dyn_cast<ComplexType>()) {
565 if (complex.getElementType().isF32()) return DType::Complex64;
566 if (complex.getElementType().isF64()) return DType::Complex128;
567 }
568
569 assert(false && "unsupported type id");
570 return DType::Invalid;
571 }
572
ArrayRuntimeTypeId(Type elem_type)573 static TypeID ArrayRuntimeTypeId(Type elem_type) {
574 if (elem_type.isInteger(8)) return TypeID::get<Tagged<ArrayRef<int8_t>>>();
575 if (elem_type.isInteger(16)) return TypeID::get<Tagged<ArrayRef<int16_t>>>();
576 if (elem_type.isInteger(32)) return TypeID::get<Tagged<ArrayRef<int32_t>>>();
577 if (elem_type.isInteger(64)) return TypeID::get<Tagged<ArrayRef<int64_t>>>();
578 if (elem_type.isF32()) return TypeID::get<Tagged<ArrayRef<float>>>();
579 if (elem_type.isF64()) return TypeID::get<Tagged<ArrayRef<double>>>();
580
581 assert(false && "unsupported type id");
582 return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
583 }
584
DenseElementsRuntimeTypeId(Type elem_type)585 static TypeID DenseElementsRuntimeTypeId(Type elem_type) {
586 if (elem_type.isInteger(32))
587 return TypeID::get<Tagged<CustomCall::TensorRef<int32_t>>>();
588 if (elem_type.isInteger(64))
589 return TypeID::get<Tagged<CustomCall::TensorRef<int64_t>>>();
590 if (elem_type.isF32())
591 return TypeID::get<Tagged<CustomCall::TensorRef<float>>>();
592 if (elem_type.isF64())
593 return TypeID::get<Tagged<CustomCall::TensorRef<double>>>();
594
595 assert(false && "unsupported type id");
596 return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
597 }
598
599 //===----------------------------------------------------------------------===//
600 // Custom call attributes encoding.
601 //===----------------------------------------------------------------------===//
602
Match(StringRef name,Attribute attr) const603 LogicalResult StringAttrEncoding::Match(StringRef name, Attribute attr) const {
604 return success(attr.isa<StringAttr>());
605 }
606
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const607 FailureOr<EncodedAttr> StringAttrEncoding::Encode(Globals &g,
608 ImplicitLocOpBuilder &b,
609 StringRef name,
610 Attribute attr) const {
611 auto str = attr.cast<StringAttr>();
612
613 Encoded encoded;
614 encoded.name = PackString(g, b, name, kAttrName);
615 encoded.type_id = PackTypeId(g, b, TypeID::get<Tagged<llvm::StringRef>>());
616 encoded.value = PackString(g, b, str, kAttrValue);
617 return encoded;
618 }
619
620 //===----------------------------------------------------------------------===//
621
Match(StringRef name,Attribute attr) const622 LogicalResult ScalarAttrEncoding::Match(StringRef name, Attribute attr) const {
623 return success(IsSupportedScalarAttribute(attr));
624 }
625
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const626 FailureOr<EncodedAttr> ScalarAttrEncoding::Encode(Globals &g,
627 ImplicitLocOpBuilder &b,
628 StringRef name,
629 Attribute attr) const {
630 Type type = attr.cast<TypedAttr>().getType();
631
632 Encoded encoded;
633 encoded.name = PackString(g, b, name, kAttrName);
634 encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type));
635 encoded.value = PackScalarAttribute(g, b, attr, kAttrValue);
636
637 return encoded;
638 }
639
640 //===----------------------------------------------------------------------===//
641
Match(StringRef name,Attribute attr) const642 LogicalResult DenseElementsAttrEncoding::Match(StringRef name,
643 Attribute attr) const {
644 if (auto dense = attr.dyn_cast<DenseIntOrFPElementsAttr>())
645 return success(IsSupportedScalarType(dense.getElementType()));
646 return failure();
647 }
648
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const649 FailureOr<EncodedAttr> DenseElementsAttrEncoding::Encode(
650 Globals &g, ImplicitLocOpBuilder &b, StringRef name, Attribute attr) const {
651 auto dense = attr.cast<DenseIntOrFPElementsAttr>();
652 Type elem_type = dense.getType().getElementType();
653
654 Encoded encoded;
655 encoded.name = PackString(g, b, name, kAttrName);
656 encoded.type_id = PackTypeId(g, b, DenseElementsRuntimeTypeId(elem_type));
657 encoded.value = PackDenseElementsAttribute(g, b, attr, kAttrValue);
658
659 return encoded;
660 }
661
662 //===----------------------------------------------------------------------===//
663
Match(StringRef name,Attribute attr) const664 LogicalResult ArrayAttrEncoding::Match(StringRef name, Attribute attr) const {
665 if (auto array = attr.dyn_cast<ArrayAttr>();
666 array && !array.empty() && array[0].isa<TypedAttr>()) {
667 return success(IsSupportedScalarAttribute(array[0]));
668 }
669 return failure();
670 }
671
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const672 FailureOr<EncodedAttr> ArrayAttrEncoding::Encode(Globals &g,
673 ImplicitLocOpBuilder &b,
674 StringRef name,
675 Attribute attr) const {
676 ArrayAttr array = attr.dyn_cast<ArrayAttr>();
677 Type elem_type = array[0].cast<TypedAttr>().getType();
678
679 // We only support array attributes with elements of same type.
680 bool all_of_same_type = llvm::all_of(array, [&](Attribute attr) {
681 auto typed = attr.dyn_cast<TypedAttr>();
682 return typed && typed.getType() == elem_type;
683 });
684 if (!all_of_same_type) return failure();
685
686 Encoded encoded;
687 encoded.name = PackString(g, b, name, kAttrName);
688 encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type));
689 encoded.value = PackArrayAttribute(g, b, array, elem_type, kAttrValue);
690
691 return encoded;
692 }
693
694 //===----------------------------------------------------------------------===//
695
Match(StringRef name,Attribute attr) const696 LogicalResult DenseArrayAttrEncoding::Match(StringRef name,
697 Attribute attr) const {
698 if (auto array = attr.dyn_cast<DenseArrayBaseAttr>()) {
699 return success();
700 }
701 return failure();
702 }
703
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const704 FailureOr<EncodedAttr> DenseArrayAttrEncoding::Encode(Globals &g,
705 ImplicitLocOpBuilder &b,
706 StringRef name,
707 Attribute attr) const {
708 Type elem_type = attr.cast<DenseArrayBaseAttr>().getType().getElementType();
709
710 Encoded encoded;
711 encoded.name = PackString(g, b, name, kAttrName);
712 encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type));
713 encoded.value = PackDenseArrayAttribute(g, b, attr, kAttrValue);
714
715 return encoded;
716 }
717
718 //===----------------------------------------------------------------------===//
719
Match(StringRef name,Attribute attr) const720 LogicalResult EmptyArrayAttrEncoding::Match(StringRef name,
721 Attribute attr) const {
722 if (auto array = attr.dyn_cast<ArrayAttr>(); array && array.empty()) {
723 return success();
724 }
725 return failure();
726 }
727
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const728 FailureOr<EncodedAttr> EmptyArrayAttrEncoding::Encode(Globals &g,
729 ImplicitLocOpBuilder &b,
730 StringRef name,
731 Attribute attr) const {
732 Encoded encoded;
733 encoded.name = PackString(g, b, name, kAttrName);
734 encoded.type_id = PackTypeId(g, b, TypeID::get<Tagged<EmptyArrayRef>>());
735 encoded.value = PackEmptyArrayAttribute(g, b, attr, kAttrValue);
736
737 return encoded;
738 }
739
740 //===----------------------------------------------------------------------===//
741 // Encoding for collection of attributes.
742 //===----------------------------------------------------------------------===//
743
EncodeAttributes(Globals & g,ImplicitLocOpBuilder & b,const CustomCallAttrEncodingSet & encoding,StringRef symbol_base,ArrayRef<NamedAttribute> attrs)744 FailureOr<Value> EncodeAttributes(Globals &g, ImplicitLocOpBuilder &b,
745 const CustomCallAttrEncodingSet &encoding,
746 StringRef symbol_base,
747 ArrayRef<NamedAttribute> attrs) {
748 using EncodedAttr = std::pair<StringRef, CustomCallAttrEncoding::Encoded>;
749
750 // In addition to encoded attributes we encode the number of attributes.
751 int64_t n_attrs = attrs.size();
752
753 // We store encoded attribute as `!llvm.array<ptr<i8> x len>`.
754 Type ptr = LLVM::LLVMPointerType::get(b.getI8Type());
755 Type type = LLVM::LLVMArrayType::get(ptr, 1 + n_attrs * 3);
756
757 // Global initializer that encodes attributes as pointers.
758 auto init = [&](ImplicitLocOpBuilder &ib, Attribute) -> LogicalResult {
759 // Try to encode each individual attribute.
760 llvm::SmallVector<EncodedAttr> encoded_attrs;
761 for (auto &attr : attrs) {
762 auto encoded = encoding.Encode(g, b, attr.getName(), attr.getValue());
763 if (failed(encoded)) return failure();
764 encoded_attrs.emplace_back(attr.getName(), *encoded);
765 }
766
767 // Prepare an array for encoding attributes.
768 Value arr = b.create<LLVM::UndefOp>(type);
769 auto insert_value = [&](Value value, int64_t offset) {
770 Value bcasted = b.createOrFold<LLVM::BitcastOp>(ptr, value);
771 arr = b.create<LLVM::InsertValueOp>(arr, bcasted, offset);
772 };
773
774 // Insert the number of encoded attributes.
775 Attribute num_attrs = b.getI64IntegerAttr(n_attrs);
776 Value size = PackScalarAttribute(g, b, num_attrs, "__rt_num_attrs");
777 insert_value(size, 0);
778
779 // Insert encoded attributes into the allocated storage.
780 for (auto &pair : llvm::enumerate(encoded_attrs)) {
781 CustomCallAttrEncoding::Encoded encoded = pair.value().second;
782 int64_t offset = 1 + pair.index() * 3;
783
784 insert_value(encoded.name, offset + 0);
785 insert_value(encoded.type_id, offset + 1);
786 insert_value(encoded.value, offset + 2);
787 }
788
789 // Return attributes array from the global initializer block.
790 b.create<LLVM::ReturnOp>(arr);
791
792 return success();
793 };
794
795 // Put all attributes in a dictionary attribute, so we can rely use it as a
796 // part of the `Globals` cache key.
797 auto attrs_map = DictionaryAttr::get(b.getContext(), attrs);
798 auto global = g.TryGetOrCreate(b, attrs_map, type, symbol_base, init);
799 if (failed(global)) return failure();
800
801 // Get a pointer to the first element of the array: !llvm.ptr<ptr<i8>>.
802 Type ptr_ptr = mlir::LLVM::LLVMPointerType::get(ptr);
803 Value c0 = b.create<ConstantOp>(b.getI64IntegerAttr(0));
804 Value addr = Globals::AddrOf(b, *global);
805 Value gep = b.create<LLVM::GEPOp>(ptr_ptr, addr, ValueRange({c0, c0}));
806
807 // Return a pointer to the encoded attributes: `!llvm.ptr<ptr<i8>>` (void**).
808 return gep;
809 }
810
811 //===----------------------------------------------------------------------===//
812 // Custom call arguments encodings.
813 //===----------------------------------------------------------------------===//
814
Match(Value value,Value converted) const815 LogicalResult ScalarArgEncoding::Match(Value value, Value converted) const {
816 return success(IsSupportedScalarType(value.getType()));
817 }
818
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const819 FailureOr<EncodedArg> ScalarArgEncoding::Encode(Globals &g,
820 ImplicitLocOpBuilder &b,
821 Value value,
822 Value converted) const {
823 Type type = converted.getType();
824
825 Encoded encoded;
826 encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type));
827 encoded.value = PackValue(b, converted);
828
829 return encoded;
830 }
831
832 //===----------------------------------------------------------------------===//
833
Match(Value value,Value converted) const834 LogicalResult MemrefArgEncoding::Match(Value value, Value converted) const {
835 return success(value.getType().isa<MemRefType>());
836 }
837
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const838 FailureOr<EncodedArg> MemrefArgEncoding::Encode(Globals &g,
839 ImplicitLocOpBuilder &b,
840 Value value,
841 Value converted) const {
842 auto memref_type = value.getType().cast<MemRefType>();
843
844 // If memref has non-identity layout we use `StridedMemrefView` to
845 // distinguish it from the default row-major memref.
846 auto type_id = memref_type.getLayout().isIdentity()
847 ? TypeID::get<Tagged<MemrefView>>()
848 : TypeID::get<Tagged<StridedMemrefView>>();
849
850 Encoded encoded;
851 encoded.type_id = PackTypeId(g, b, type_id);
852 encoded.value = PackValue(b, EncodeMemRef(b, memref_type, converted));
853
854 return encoded;
855 }
856
EncodeMemRef(ImplicitLocOpBuilder & b,MemRefType memref_ty,Value descriptor) const857 Value MemrefArgEncoding::EncodeMemRef(ImplicitLocOpBuilder &b,
858 MemRefType memref_ty,
859 Value descriptor) const {
860 MLIRContext *ctx = b.getContext();
861 Location loc = b.getLoc();
862
863 // Encode sizes together with strides as a single array.
864 int64_t sizes_and_strides_size = 2 * memref_ty.getRank();
865
866 // Encoded memref type: !llvm.struct<(i8, i8, ptr<i8>, array<... x i64>)>.
867 Type i8 = b.getI8Type();
868 Type ptr = LLVM::LLVMPointerType::get(b.getI8Type());
869 Type arr = LLVM::LLVMArrayType::get(b.getI64Type(), sizes_and_strides_size);
870 Type type = LLVM::LLVMStructType::getLiteral(ctx, {i8, i8, ptr, arr});
871
872 // Helper to unpack MLIR strided memref descriptor value.
873 MemRefDescriptor desc(descriptor);
874
875 DType element_dtype = ScalarDType(memref_ty.getElementType());
876
877 // Create values for filling encoded memref struct.
878 Value dtype = b.create<ConstantOp>(
879 b.getI8IntegerAttr(static_cast<uint8_t>(element_dtype)));
880 Value rank = b.create<ConstantOp>(b.getI8IntegerAttr(memref_ty.getRank()));
881 Value data = b.create<LLVM::BitcastOp>(ptr, desc.alignedPtr(b, loc));
882
883 auto i64 = [&](int64_t i) { return b.getI64IntegerAttr(i); };
884
885 // Get the statically known strides and offset from the memref type.
886 llvm::SmallVector<int64_t> strides;
887 int64_t memref_offset;
888 if (failed(getStridesAndOffset(memref_ty, strides, memref_offset)))
889 strides.resize(memref_ty.getRank(), ShapedType::kDynamicStrideOrOffset);
890
891 // Build encoded memref sizes + strides: !llvm.array<... x i64>
892 Value payload = b.create<LLVM::UndefOp>(arr);
893 for (unsigned i = 0; i < memref_ty.getRank(); ++i) {
894 int64_t dim_size = memref_ty.getDimSize(i);
895 int64_t stride_size = strides[i];
896
897 Value dim = ShapedType::isDynamic(dim_size)
898 ? desc.size(b, loc, i)
899 : b.create<ConstantOp>(i64(dim_size));
900
901 Value stride = ShapedType::isDynamic(stride_size)
902 ? desc.stride(b, loc, i)
903 : b.create<ConstantOp>(i64(stride_size));
904
905 auto stride_pos = memref_ty.getRank() + i;
906
907 payload = b.create<LLVM::InsertValueOp>(payload, dim, i);
908 payload = b.create<LLVM::InsertValueOp>(payload, stride, stride_pos);
909 }
910
911 // Construct encoded memref value.
912 Value memref = b.create<LLVM::UndefOp>(type);
913 memref = b.create<LLVM::InsertValueOp>(memref, dtype, 0);
914 memref = b.create<LLVM::InsertValueOp>(memref, rank, 1);
915 memref = b.create<LLVM::InsertValueOp>(memref, payload, 3);
916
917 // Previous values almost always are known at compile time, and inserting
918 // dynamic values into the struct after all statically know values leads to a
919 // better canonicalization and cleaner final LLVM IR.
920 memref = b.create<LLVM::InsertValueOp>(memref, data, 2);
921
922 return memref;
923 }
924
925 //===----------------------------------------------------------------------===//
926 // Default encodings for arguments and attributes.
927 //===----------------------------------------------------------------------===//
928
DefaultAttrEncodings()929 CustomCallAttrEncodingSet DefaultAttrEncodings() {
930 CustomCallAttrEncodingSet encodings;
931 encodings
932 .Add<StringAttrEncoding, ScalarAttrEncoding, DenseElementsAttrEncoding,
933 ArrayAttrEncoding, DenseArrayAttrEncoding, EmptyArrayAttrEncoding>();
934 return encodings;
935 }
936
DefaultArgEncodings()937 CustomCallArgEncodingSet DefaultArgEncodings() {
938 CustomCallArgEncodingSet encodings;
939 encodings.Add<ScalarArgEncoding, MemrefArgEncoding>();
940 return encodings;
941 }
942
943 } // namespace runtime
944 } // namespace xla
945