xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"  // from @llvm-project
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "tensorflow/compiler/xla/runtime/custom_call.h"
32 #include "tensorflow/compiler/xla/runtime/type_id.h"
33 
34 namespace Eigen {
35 struct half;
36 }  // namespace Eigen
37 
38 namespace xla {
39 namespace runtime {
40 
41 using namespace mlir;  // NOLINT
42 using arith::ConstantOp;
43 using func::FuncOp;
44 
45 using llvm::ArrayRef;
46 using llvm::StringRef;
47 
48 using tfrt::DType;
49 
50 //===----------------------------------------------------------------------===//
51 // Custom call arguments encoding.
52 //===----------------------------------------------------------------------===//
53 
54 using EncodedArg = CustomCallArgEncodingSet::Encoded;
55 
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const56 FailureOr<EncodedArg> CustomCallArgEncodingSet::Encode(Globals &g,
57                                                        ImplicitLocOpBuilder &b,
58                                                        Value value,
59                                                        Value converted) const {
60   for (auto &encoding : encodings_)
61     if (succeeded(encoding->Match(value, converted)))
62       return encoding->Encode(g, b, value, converted);
63   return failure();
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Custom call attributes encoding.
68 //===----------------------------------------------------------------------===//
69 
70 using EncodedAttr = CustomCallAttrEncodingSet::Encoded;
71 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const72 FailureOr<EncodedAttr> CustomCallAttrEncodingSet::Encode(
73     Globals &g, ImplicitLocOpBuilder &b, StringRef name, Attribute attr) const {
74   for (auto &encoding : encodings_)
75     if (succeeded(encoding->Match(name, attr)))
76       return encoding->Encode(g, b, name, attr);
77   return failure();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // A set of helper functions for packing primitive attributes.
82 //===----------------------------------------------------------------------===//
83 
PackTypeId(Globals & g,ImplicitLocOpBuilder & b,TypeID type_id)84 Value PackTypeId(Globals &g, ImplicitLocOpBuilder &b, TypeID type_id) {
85   auto global = g.GetOrCreate(b, type_id);
86   return Globals::AddrOf(b, global);
87 }
88 
PackString(Globals & g,ImplicitLocOpBuilder & b,StringRef strref,StringRef symbol_base)89 Value PackString(Globals &g, ImplicitLocOpBuilder &b, StringRef strref,
90                  StringRef symbol_base) {
91   MLIRContext *ctx = b.getContext();
92   int64_t size = strref.size();
93 
94   // Encoded string type: !llvm.struct<(i64, !llvm.ptr<array<i8 x len>>)>.
95   Type arr = LLVM::LLVMArrayType::get(b.getI8Type(), 1 + size);
96   Type ptr = LLVM::LLVMPointerType::get(arr);
97   Type type = LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), ptr});
98 
99   // Global constant initializer for the encoded string structure
100   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
101     // String size and pointer to a null-terminated string.
102     Value num_elements = ib.create<ConstantOp>(ib.getI64IntegerAttr(size));
103     Value str = Globals::AddrOf(ib, g.GetOrCreate(b, strref, "__rt_str"));
104 
105     // Store size and pointer into the struct.
106     Value encoded = ib.create<LLVM::UndefOp>(type);
107     encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
108     encoded = ib.create<LLVM::InsertValueOp>(encoded, str, 1);
109     ib.create<LLVM::ReturnOp>(encoded);
110   };
111 
112   auto value = b.getStringAttr(strref);
113   auto global = g.GetOrCreate(b, value, type, symbol_base, init);
114   return Globals::AddrOf(b, global);
115 }
116 
117 // Packs scalar attribute as a global constant. Returns `!llvm.ptr<AttrType>`.
PackScalarAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)118 Value PackScalarAttribute(Globals &g, ImplicitLocOpBuilder &b, Attribute value,
119                           StringRef symbol_base) {
120   auto global = g.GetOrCreate(b, value, symbol_base);
121   return Globals::AddrOf(b, global);
122 }
123 
124 // Reshape dense elements as a one-dimensional array.
Flatten(DenseIntOrFPElementsAttr dense)125 static mlir::DenseElementsAttr Flatten(DenseIntOrFPElementsAttr dense) {
126   ShapedType shaped_type = dense.getType();
127   ShapedType new_shaped_type = shaped_type.cloneWith(
128       {shaped_type.getNumElements()}, dense.getElementType());
129   return dense.reshape(new_shaped_type);
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // A set of helper functions for packing dense and array-like attributes.
134 //===----------------------------------------------------------------------===//
135 
136 // Packs dense elements attribute as a global constant. Returns
137 // `!llvm.ptr<EncodedDenseElements>`.
PackDenseElementsAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)138 static Value PackDenseElementsAttribute(Globals &g, ImplicitLocOpBuilder &b,
139                                         Attribute value,
140                                         StringRef symbol_base) {
141   MLIRContext *ctx = b.getContext();
142   DenseIntOrFPElementsAttr dense = value.cast<DenseIntOrFPElementsAttr>();
143 
144   // Payload type:
145   // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>)>>.
146   Type element_type = dense.getElementType();
147   Type data_arr_type =
148       LLVM::LLVMArrayType::get(element_type, dense.getNumElements());
149   Type data_arr_ptr_type = LLVM::LLVMPointerType::get(data_arr_type);
150   Type payload_type = LLVM::LLVMStructType::getLiteral(
151       ctx, {b.getI64Type(), data_arr_ptr_type});
152 
153   int64_t rank = dense.getType().getRank();
154   ArrayRef<int64_t> shape = dense.getType().getShape();
155   Type shape_arr_type = LLVM::LLVMArrayType::get(b.getI64Type(), rank);
156 
157   // Encoded dense elements type:
158   // !llvm.struct<encoded_array_type, i64, array<i64, rank>
159   Type type = LLVM::LLVMStructType::getLiteral(
160       ctx, {payload_type, b.getI64Type(), shape_arr_type});
161 
162   // Global constant initializer for the encoded array structure.
163   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
164     Value num_elements =
165         ib.create<ConstantOp>(b.getI64IntegerAttr(dense.getNumElements()));
166     Value data_ptr = Globals::AddrOf(
167         ib, g.GetOrCreate(b, Flatten(dense), data_arr_type, symbol_base));
168 
169     // Create the payload struct.
170     Value payload = ib.create<LLVM::UndefOp>(payload_type);
171     payload = ib.create<LLVM::InsertValueOp>(payload, num_elements, 0);
172     payload = ib.create<LLVM::InsertValueOp>(payload, data_ptr, 1);
173 
174     // Get rank and shape.
175     Value rank_value = ib.create<ConstantOp>(b.getI64IntegerAttr(rank));
176     Value shape_value = ib.create<LLVM::UndefOp>(shape_arr_type);
177 
178     // Store each dimension size into shape_value.
179     for (int i = 0; i < rank; i++) {
180       Value dim = ib.create<ConstantOp>(ib.getI64IntegerAttr(shape[i]));
181       shape_value = ib.create<LLVM::InsertValueOp>(shape_value, dim, i);
182     }
183 
184     // Store the payload, rank, and shape into the struct.
185     Value encoded = ib.create<LLVM::UndefOp>(type);
186     encoded = ib.create<LLVM::InsertValueOp>(encoded, payload, 0);
187     encoded = ib.create<LLVM::InsertValueOp>(encoded, rank_value, 1);
188     encoded = ib.create<LLVM::InsertValueOp>(encoded, shape_value, 2);
189     ib.create<LLVM::ReturnOp>(encoded);
190   };
191 
192   auto global = g.GetOrCreate(b, value, type, symbol_base, init);
193   return Globals::AddrOf(b, global);
194 }
195 
196 // Create a global for the data array in an EncodedArray.
197 // Returns `!llvm.ptr<array<element_type x size>>
CreateGlobalFromArray(Globals & g,ImplicitLocOpBuilder & b,ArrayAttr array,Type element_type,StringRef symbol_base)198 static Value CreateGlobalFromArray(Globals &g, ImplicitLocOpBuilder &b,
199                                    ArrayAttr array, Type element_type,
200                                    StringRef symbol_base) {
201   Type arr_type = LLVM::LLVMArrayType::get(element_type, array.size());
202 
203   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
204     Value data = ib.create<LLVM::UndefOp>(arr_type);
205     for (int i = 0; i < array.size(); i++) {
206       Value value = ib.create<ConstantOp>(array[i]);
207       data = ib.create<LLVM::InsertValueOp>(data, value, i);
208     }
209     ib.create<LLVM::ReturnOp>(data);
210   };
211 
212   auto global = g.GetOrCreate(b, array, arr_type, symbol_base, init);
213   return Globals::AddrOf(b, global);
214 }
215 
216 // Packs array attribute as a global constant. Returns `!llvm.ptr<EncodedArr>`.
PackArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,ArrayAttr array,Type element_type,StringRef symbol_base)217 static Value PackArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
218                                 ArrayAttr array, Type element_type,
219                                 StringRef symbol_base) {
220   MLIRContext *ctx = b.getContext();
221 
222   int64_t size = array.size();
223 
224   // Encoded array type:
225   // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>)>>.
226   Type arr_type = LLVM::LLVMArrayType::get(element_type, size);
227   Type arr_ptr_type = LLVM::LLVMPointerType::get(arr_type);
228   Type type =
229       LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), arr_ptr_type});
230 
231   // Global constant initializer for the encoded array structure
232   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
233     // Array size and the pointer to data.
234     Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(size));
235     Value data = CreateGlobalFromArray(g, b, array, element_type, symbol_base);
236 
237     // Store size and values into the struct.
238     Value encoded = ib.create<LLVM::UndefOp>(type);
239     encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
240     encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
241 
242     ib.create<LLVM::ReturnOp>(encoded);
243   };
244 
245   auto global = g.GetOrCreate(b, array, type, symbol_base, init);
246   return Globals::AddrOf(b, global);
247 }
248 
249 template <typename T, typename AttrType, typename ArrayType>
FillDataFromDenseArrayAttr(ImplicitLocOpBuilder & b,AttrType (ImplicitLocOpBuilder::* get_attr)(T),ArrayType array,Value data)250 static Value FillDataFromDenseArrayAttr(
251     ImplicitLocOpBuilder &b, AttrType (ImplicitLocOpBuilder::*get_attr)(T),
252     ArrayType array, Value data) {
253   ArrayRef<T> array_ref = array.asArrayRef();
254   for (int i = 0; i < array_ref.size(); i++) {
255     Value value = b.create<ConstantOp>((b.*get_attr)(array_ref[i]));
256     data = b.create<LLVM::InsertValueOp>(data, value, i);
257   }
258   return data;
259 }
260 
CreateGlobalFromDenseArray(Globals & g,ImplicitLocOpBuilder & b,DenseArrayBaseAttr base_array,Type arr_type,StringRef symbol_base)261 static Value CreateGlobalFromDenseArray(Globals &g, ImplicitLocOpBuilder &b,
262                                         DenseArrayBaseAttr base_array,
263                                         Type arr_type, StringRef symbol_base) {
264   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
265     Value data = ib.create<LLVM::UndefOp>(arr_type);
266     switch (base_array.getElementType()) {
267       case DenseArrayBaseAttr::EltType::I8:
268         data = FillDataFromDenseArrayAttr<int8_t, IntegerAttr>(
269             b, &ImplicitLocOpBuilder::getI8IntegerAttr,
270             base_array.cast<mlir::DenseI8ArrayAttr>(), data);
271         break;
272       case DenseArrayBaseAttr::EltType::I16:
273         data = FillDataFromDenseArrayAttr<int16_t, IntegerAttr>(
274             b, &ImplicitLocOpBuilder::getI16IntegerAttr,
275             base_array.cast<mlir::DenseI16ArrayAttr>(), data);
276         break;
277       case DenseArrayBaseAttr::EltType::I32:
278         data = FillDataFromDenseArrayAttr<int32_t, IntegerAttr>(
279             b, &ImplicitLocOpBuilder::getI32IntegerAttr,
280             base_array.cast<mlir::DenseI32ArrayAttr>(), data);
281         break;
282       case DenseArrayBaseAttr::EltType::I64:
283         data = FillDataFromDenseArrayAttr<int64_t, IntegerAttr>(
284             b, &ImplicitLocOpBuilder::getI64IntegerAttr,
285             base_array.cast<mlir::DenseI64ArrayAttr>(), data);
286         break;
287       case DenseArrayBaseAttr::EltType::F32:
288         data = FillDataFromDenseArrayAttr<float, FloatAttr>(
289             b, &ImplicitLocOpBuilder::getF32FloatAttr,
290             base_array.cast<mlir::DenseF32ArrayAttr>(), data);
291         break;
292       case DenseArrayBaseAttr::EltType::F64:
293         data = FillDataFromDenseArrayAttr<double, FloatAttr>(
294             b, &ImplicitLocOpBuilder::getF64FloatAttr,
295             base_array.cast<mlir::DenseF64ArrayAttr>(), data);
296         break;
297       default:
298         assert(false && "unsupported DenseArrayAttr element type");
299     }
300     ib.create<LLVM::ReturnOp>(data);
301   };
302 
303   auto global = g.GetOrCreate(b, base_array, arr_type, symbol_base, init);
304   return Globals::AddrOf(b, global);
305 }
306 
PackDenseArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)307 static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
308                                      Attribute value, StringRef symbol_base) {
309   MLIRContext *ctx = b.getContext();
310 
311   DenseArrayBaseAttr base_array = value.cast<DenseArrayBaseAttr>();
312   int64_t size = base_array.size();
313 
314   // Encoded array type:
315   // !llvm.struct<(i64, !llvm.ptr<array<element_type x size>>)>.
316   Type element_type = base_array.getType().getElementType();
317   Type arr_type = LLVM::LLVMArrayType::get(element_type, size);
318   Type arr_ptr_type = LLVM::LLVMPointerType::get(arr_type);
319   Type type =
320       LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), arr_ptr_type});
321 
322   // Global constant initializer for the encoded array structure
323   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
324     // Array size and values.
325     Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(size));
326     Value data =
327         CreateGlobalFromDenseArray(g, ib, base_array, arr_type, symbol_base);
328 
329     // Store size and values into the struct.
330     Value encoded = ib.create<LLVM::UndefOp>(type);
331     encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
332     encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
333 
334     ib.create<LLVM::ReturnOp>(encoded);
335   };
336 
337   auto global = g.GetOrCreate(b, value, type, symbol_base, init);
338   return Globals::AddrOf(b, global);
339 }
340 
PackEmptyArrayAttribute(Globals & g,ImplicitLocOpBuilder & b,Attribute value,StringRef symbol_base)341 static Value PackEmptyArrayAttribute(Globals &g, ImplicitLocOpBuilder &b,
342                                      Attribute value, StringRef symbol_base) {
343   MLIRContext *ctx = b.getContext();
344 
345   // Encoded array type: !llvm.struct<(i64, !llvm.ptr<i8>)>.
346   // The pointer is always null. We use i8 as a placeholder type.
347   Type data_type = LLVM::LLVMPointerType::get(b.getI8Type());
348   Type type =
349       LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), data_type});
350 
351   // Global constant initializer for the encoded array structure
352   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) {
353     // Array size and the pointer to data.
354     Value num_elements = ib.create<ConstantOp>(b.getI64IntegerAttr(0));
355     Value data = ib.create<LLVM::NullOp>(data_type);
356 
357     // Store size and values into the struct.
358     Value encoded = ib.create<LLVM::UndefOp>(type);
359     encoded = ib.create<LLVM::InsertValueOp>(encoded, num_elements, 0);
360     encoded = ib.create<LLVM::InsertValueOp>(encoded, data, 1);
361 
362     ib.create<LLVM::ReturnOp>(encoded);
363   };
364 
365   auto global = g.GetOrCreate(b, value, type, symbol_base, init);
366   return Globals::AddrOf(b, global);
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Packing primitive values on the stack.
371 //===----------------------------------------------------------------------===//
372 
373 // Returns the parent function operation for the given value.
GetParentFunc(Value value)374 static FuncOp GetParentFunc(Value value) {
375   Block *parent_block = value.getParentBlock();
376   Operation *parent_op = parent_block->getParentOp();
377 
378   return isa<FuncOp>(parent_op) ? cast<FuncOp>(parent_op)
379                                 : parent_op->getParentOfType<FuncOp>();
380 }
381 
382 // Packs value on the stack. Returns `!llvm.ptr<ValueType>`.
PackValue(ImplicitLocOpBuilder & b,Value value)383 static Value PackValue(ImplicitLocOpBuilder &b, Value value) {
384   Type ptr = LLVM::LLVMPointerType::get(value.getType());
385 
386   // Always create an `alloca` in the parent function entry block.
387   // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas
388   Value mem = [&]() -> Value {
389     Block &block = GetParentFunc(value).getBody().front();
390     OpBuilder::InsertionGuard guard(b);
391     b.setInsertionPointToStart(&block);
392     Value one = b.create<ConstantOp>(b.getI32IntegerAttr(1));
393     return b.create<LLVM::AllocaOp>(ptr, one, 0);
394   }();
395 
396   b.create<LLVM::StoreOp>(value, mem);
397 
398   return mem;
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // A helper class to create global constants in the module.
403 //===----------------------------------------------------------------------===//
404 
Find(Key key)405 LLVM::GlobalOp Globals::Find(Key key) {
406   auto it = globals_.find(key);
407   if (it != globals_.end()) return it->second;
408   return nullptr;
409 }
410 
GetOrCreate(ImplicitLocOpBuilder & b,StringRef strref,StringRef symbol_base)411 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, StringRef strref,
412                                     StringRef symbol_base) {
413   // Create an std::string to get a null terminated sequence of characters.
414   std::string str = strref.str();
415 
416   // Create a string reference that captures the null terminator.
417   StringRef ref(str.data(), str.size() + 1);
418   StringAttr attr = b.getStringAttr(ref);
419   Type arr = LLVM::LLVMArrayType::get(b.getI8Type(), ref.size());
420   return GetOrCreate(b, attr, arr, symbol_base);
421 }
422 
GetOrCreate(ImplicitLocOpBuilder & b,TypedAttr attr,StringRef symbol_base)423 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, TypedAttr attr,
424                                     StringRef symbol_base) {
425   return GetOrCreate(b, attr, attr.getType(), symbol_base);
426 }
427 
GetOrCreate(ImplicitLocOpBuilder & b,mlir::TypeID type_id)428 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b,
429                                     mlir::TypeID type_id) {
430   llvm::StringRef name = type_id_names_.FindTypeIDSymbolName(type_id);
431   assert(!name.empty() && "cannot find the symbol name of type_id");
432   return GetOrCreate(b, IntegerAttr(), b.getI64Type(), name, /*initialize=*/{},
433                      LLVM::Linkage::External);
434 }
435 
GetOrCreate(ImplicitLocOpBuilder & b,Attribute attr,Type type,StringRef symbol_base,GlobalInitializer initialize,LLVM::Linkage linkage)436 LLVM::GlobalOp Globals::GetOrCreate(ImplicitLocOpBuilder &b, Attribute attr,
437                                     Type type, StringRef symbol_base,
438                                     GlobalInitializer initialize,
439                                     LLVM::Linkage linkage) {
440   if (!initialize) {
441     return *TryGetOrCreate(b, attr, type, symbol_base, /*initialize=*/{},
442                            linkage);
443   }
444 
445   return *TryGetOrCreate(
446       b, attr, type, symbol_base,
447       [&](ImplicitLocOpBuilder &b, Attribute) {
448         return (initialize(b, attr), success());
449       },
450       linkage);
451 }
452 
TryGetOrCreate(mlir::ImplicitLocOpBuilder & b,mlir::Attribute attr,mlir::Type type,llvm::StringRef symbol_base,FailureOrGlobalInitializer initialize,mlir::LLVM::Linkage linkage)453 mlir::FailureOr<mlir::LLVM::GlobalOp> Globals::TryGetOrCreate(
454     mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type,
455     llvm::StringRef symbol_base, FailureOrGlobalInitializer initialize,
456     mlir::LLVM::Linkage linkage) {
457   // We assume that this triple uniquely identifies the global value and the
458   // global initializer always produces the same value for given inputs.
459   Key key(attr, type, b.getStringAttr(symbol_base));
460 
461   // Check if global value already exists ...
462   if (auto global = Find(key)) return global;
463 
464   // ... otherwise create a new one.
465   OpBuilder::InsertionGuard guard(b);
466   b.setInsertionPointToStart(module_.getBody());
467 
468   // If the initialize function is not provided, create constant directly.
469   if (!initialize) {
470     auto global = b.create<LLVM::GlobalOp>(type, /*isConstant=*/true, linkage,
471                                            symbol_base, attr);
472     return (sym_table_.insert(global), globals_[key] = global);
473   }
474 
475   // Create an uninitialized global.
476   auto global = b.create<LLVM::GlobalOp>(type, /*isConstant=*/true, linkage,
477                                          symbol_base, nullptr);
478 
479   // Call user-provided global initializer.
480   mlir::Region &region = global.getInitializerRegion();
481   mlir::Block *block = b.createBlock(&region);
482 
483   b.setInsertionPointToStart(block);
484   if (failed(initialize(b, attr))) return failure();
485 
486   return (sym_table_.insert(global), globals_[key] = global);
487 }
488 
AddrOf(ImplicitLocOpBuilder & b,LLVM::GlobalOp global)489 /*static*/ Value Globals::AddrOf(ImplicitLocOpBuilder &b,
490                                  LLVM::GlobalOp global) {
491   return b.create<LLVM::AddressOfOp>(
492       LLVM::LLVMPointerType::get(global.getType()), global.getSymName());
493 }
494 
OpaqueAddrOf(ImplicitLocOpBuilder & b,LLVM::GlobalOp global)495 /*static*/ Value Globals::OpaqueAddrOf(ImplicitLocOpBuilder &b,
496                                        LLVM::GlobalOp global) {
497   return b.create<LLVM::BitcastOp>(LLVM::LLVMPointerType::get(b.getI8Type()),
498                                    AddrOf(b, global));
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // Helper functions for encoding attributes and values for custom calls.
503 //===----------------------------------------------------------------------===//
504 
IsSupportedScalarType(Type type)505 static bool IsSupportedScalarType(Type type) {
506   auto is_supported_width = [](unsigned width, ArrayRef<unsigned> supported) {
507     return llvm::any_of(supported, [&](unsigned w) { return w == width; });
508   };
509 
510   if (auto i = type.dyn_cast<mlir::IntegerType>())
511     return i.isUnsigned() ? is_supported_width(i.getWidth(), {8, 32, 64})
512                           : is_supported_width(i.getWidth(), {1, 32, 64});
513 
514   if (auto fp = type.dyn_cast<mlir::FloatType>())
515     return is_supported_width(fp.getWidth(), {16, 32, 64});
516 
517   return false;
518 }
519 
IsSupportedScalarAttribute(Attribute attr)520 static bool IsSupportedScalarAttribute(Attribute attr) {
521   if (auto typed = attr.dyn_cast<TypedAttr>())
522     return IsSupportedScalarType(typed.getType());
523   return false;
524 }
525 
ScalarRuntimeTypeId(Type type)526 static TypeID ScalarRuntimeTypeId(Type type) {
527   if (type.isUnsignedInteger(8)) return TypeID::get<Tagged<uint8_t>>();
528   if (type.isUnsignedInteger(32)) return TypeID::get<Tagged<uint32_t>>();
529   if (type.isUnsignedInteger(64)) return TypeID::get<Tagged<uint64_t>>();
530 
531   if (type.isInteger(1)) return TypeID::get<Tagged<bool>>();
532   if (type.isInteger(32)) return TypeID::get<Tagged<int32_t>>();
533   if (type.isInteger(64)) return TypeID::get<Tagged<int64_t>>();
534 
535   if (type.isF16()) return TypeID::get<Tagged<Eigen::half>>();
536   if (type.isF32()) return TypeID::get<Tagged<float>>();
537   if (type.isF64()) return TypeID::get<Tagged<double>>();
538 
539   assert(false && "unsupported type id");
540   return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
541 }
542 
ScalarDType(Type type)543 static DType ScalarDType(Type type) {
544   // Unsigned integer types.
545   if (type.isUnsignedInteger(8)) return DType::UI8;
546   if (type.isUnsignedInteger(16)) return DType::UI16;
547   if (type.isUnsignedInteger(32)) return DType::UI32;
548   if (type.isUnsignedInteger(64)) return DType::UI64;
549 
550   // Signed integer types.
551   if (type.isInteger(1)) return DType::I1;
552   if (type.isInteger(8)) return DType::I8;
553   if (type.isInteger(16)) return DType::I16;
554   if (type.isInteger(32)) return DType::I32;
555   if (type.isInteger(64)) return DType::I64;
556 
557   // Floating point types.
558   if (type.isF16()) return DType::F16;
559   if (type.isF32()) return DType::F32;
560   if (type.isF64()) return DType::F64;
561   if (type.isBF16()) return DType::BF16;
562 
563   // Complex types.
564   if (auto complex = type.dyn_cast<ComplexType>()) {
565     if (complex.getElementType().isF32()) return DType::Complex64;
566     if (complex.getElementType().isF64()) return DType::Complex128;
567   }
568 
569   assert(false && "unsupported type id");
570   return DType::Invalid;
571 }
572 
ArrayRuntimeTypeId(Type elem_type)573 static TypeID ArrayRuntimeTypeId(Type elem_type) {
574   if (elem_type.isInteger(8)) return TypeID::get<Tagged<ArrayRef<int8_t>>>();
575   if (elem_type.isInteger(16)) return TypeID::get<Tagged<ArrayRef<int16_t>>>();
576   if (elem_type.isInteger(32)) return TypeID::get<Tagged<ArrayRef<int32_t>>>();
577   if (elem_type.isInteger(64)) return TypeID::get<Tagged<ArrayRef<int64_t>>>();
578   if (elem_type.isF32()) return TypeID::get<Tagged<ArrayRef<float>>>();
579   if (elem_type.isF64()) return TypeID::get<Tagged<ArrayRef<double>>>();
580 
581   assert(false && "unsupported type id");
582   return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
583 }
584 
DenseElementsRuntimeTypeId(Type elem_type)585 static TypeID DenseElementsRuntimeTypeId(Type elem_type) {
586   if (elem_type.isInteger(32))
587     return TypeID::get<Tagged<CustomCall::TensorRef<int32_t>>>();
588   if (elem_type.isInteger(64))
589     return TypeID::get<Tagged<CustomCall::TensorRef<int64_t>>>();
590   if (elem_type.isF32())
591     return TypeID::get<Tagged<CustomCall::TensorRef<float>>>();
592   if (elem_type.isF64())
593     return TypeID::get<Tagged<CustomCall::TensorRef<double>>>();
594 
595   assert(false && "unsupported type id");
596   return TypeID::getFromOpaquePointer(reinterpret_cast<void *>(0xDEADBEEF));
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // Custom call attributes encoding.
601 //===----------------------------------------------------------------------===//
602 
Match(StringRef name,Attribute attr) const603 LogicalResult StringAttrEncoding::Match(StringRef name, Attribute attr) const {
604   return success(attr.isa<StringAttr>());
605 }
606 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const607 FailureOr<EncodedAttr> StringAttrEncoding::Encode(Globals &g,
608                                                   ImplicitLocOpBuilder &b,
609                                                   StringRef name,
610                                                   Attribute attr) const {
611   auto str = attr.cast<StringAttr>();
612 
613   Encoded encoded;
614   encoded.name = PackString(g, b, name, kAttrName);
615   encoded.type_id = PackTypeId(g, b, TypeID::get<Tagged<llvm::StringRef>>());
616   encoded.value = PackString(g, b, str, kAttrValue);
617   return encoded;
618 }
619 
620 //===----------------------------------------------------------------------===//
621 
Match(StringRef name,Attribute attr) const622 LogicalResult ScalarAttrEncoding::Match(StringRef name, Attribute attr) const {
623   return success(IsSupportedScalarAttribute(attr));
624 }
625 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const626 FailureOr<EncodedAttr> ScalarAttrEncoding::Encode(Globals &g,
627                                                   ImplicitLocOpBuilder &b,
628                                                   StringRef name,
629                                                   Attribute attr) const {
630   Type type = attr.cast<TypedAttr>().getType();
631 
632   Encoded encoded;
633   encoded.name = PackString(g, b, name, kAttrName);
634   encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type));
635   encoded.value = PackScalarAttribute(g, b, attr, kAttrValue);
636 
637   return encoded;
638 }
639 
640 //===----------------------------------------------------------------------===//
641 
Match(StringRef name,Attribute attr) const642 LogicalResult DenseElementsAttrEncoding::Match(StringRef name,
643                                                Attribute attr) const {
644   if (auto dense = attr.dyn_cast<DenseIntOrFPElementsAttr>())
645     return success(IsSupportedScalarType(dense.getElementType()));
646   return failure();
647 }
648 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const649 FailureOr<EncodedAttr> DenseElementsAttrEncoding::Encode(
650     Globals &g, ImplicitLocOpBuilder &b, StringRef name, Attribute attr) const {
651   auto dense = attr.cast<DenseIntOrFPElementsAttr>();
652   Type elem_type = dense.getType().getElementType();
653 
654   Encoded encoded;
655   encoded.name = PackString(g, b, name, kAttrName);
656   encoded.type_id = PackTypeId(g, b, DenseElementsRuntimeTypeId(elem_type));
657   encoded.value = PackDenseElementsAttribute(g, b, attr, kAttrValue);
658 
659   return encoded;
660 }
661 
662 //===----------------------------------------------------------------------===//
663 
Match(StringRef name,Attribute attr) const664 LogicalResult ArrayAttrEncoding::Match(StringRef name, Attribute attr) const {
665   if (auto array = attr.dyn_cast<ArrayAttr>();
666       array && !array.empty() && array[0].isa<TypedAttr>()) {
667     return success(IsSupportedScalarAttribute(array[0]));
668   }
669   return failure();
670 }
671 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const672 FailureOr<EncodedAttr> ArrayAttrEncoding::Encode(Globals &g,
673                                                  ImplicitLocOpBuilder &b,
674                                                  StringRef name,
675                                                  Attribute attr) const {
676   ArrayAttr array = attr.dyn_cast<ArrayAttr>();
677   Type elem_type = array[0].cast<TypedAttr>().getType();
678 
679   // We only support array attributes with elements of same type.
680   bool all_of_same_type = llvm::all_of(array, [&](Attribute attr) {
681     auto typed = attr.dyn_cast<TypedAttr>();
682     return typed && typed.getType() == elem_type;
683   });
684   if (!all_of_same_type) return failure();
685 
686   Encoded encoded;
687   encoded.name = PackString(g, b, name, kAttrName);
688   encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type));
689   encoded.value = PackArrayAttribute(g, b, array, elem_type, kAttrValue);
690 
691   return encoded;
692 }
693 
694 //===----------------------------------------------------------------------===//
695 
Match(StringRef name,Attribute attr) const696 LogicalResult DenseArrayAttrEncoding::Match(StringRef name,
697                                             Attribute attr) const {
698   if (auto array = attr.dyn_cast<DenseArrayBaseAttr>()) {
699     return success();
700   }
701   return failure();
702 }
703 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const704 FailureOr<EncodedAttr> DenseArrayAttrEncoding::Encode(Globals &g,
705                                                       ImplicitLocOpBuilder &b,
706                                                       StringRef name,
707                                                       Attribute attr) const {
708   Type elem_type = attr.cast<DenseArrayBaseAttr>().getType().getElementType();
709 
710   Encoded encoded;
711   encoded.name = PackString(g, b, name, kAttrName);
712   encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type));
713   encoded.value = PackDenseArrayAttribute(g, b, attr, kAttrValue);
714 
715   return encoded;
716 }
717 
718 //===----------------------------------------------------------------------===//
719 
Match(StringRef name,Attribute attr) const720 LogicalResult EmptyArrayAttrEncoding::Match(StringRef name,
721                                             Attribute attr) const {
722   if (auto array = attr.dyn_cast<ArrayAttr>(); array && array.empty()) {
723     return success();
724   }
725   return failure();
726 }
727 
Encode(Globals & g,ImplicitLocOpBuilder & b,StringRef name,Attribute attr) const728 FailureOr<EncodedAttr> EmptyArrayAttrEncoding::Encode(Globals &g,
729                                                       ImplicitLocOpBuilder &b,
730                                                       StringRef name,
731                                                       Attribute attr) const {
732   Encoded encoded;
733   encoded.name = PackString(g, b, name, kAttrName);
734   encoded.type_id = PackTypeId(g, b, TypeID::get<Tagged<EmptyArrayRef>>());
735   encoded.value = PackEmptyArrayAttribute(g, b, attr, kAttrValue);
736 
737   return encoded;
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // Encoding for collection of attributes.
742 //===----------------------------------------------------------------------===//
743 
EncodeAttributes(Globals & g,ImplicitLocOpBuilder & b,const CustomCallAttrEncodingSet & encoding,StringRef symbol_base,ArrayRef<NamedAttribute> attrs)744 FailureOr<Value> EncodeAttributes(Globals &g, ImplicitLocOpBuilder &b,
745                                   const CustomCallAttrEncodingSet &encoding,
746                                   StringRef symbol_base,
747                                   ArrayRef<NamedAttribute> attrs) {
748   using EncodedAttr = std::pair<StringRef, CustomCallAttrEncoding::Encoded>;
749 
750   // In addition to encoded attributes we encode the number of attributes.
751   int64_t n_attrs = attrs.size();
752 
753   // We store encoded attribute as `!llvm.array<ptr<i8> x len>`.
754   Type ptr = LLVM::LLVMPointerType::get(b.getI8Type());
755   Type type = LLVM::LLVMArrayType::get(ptr, 1 + n_attrs * 3);
756 
757   // Global initializer that encodes attributes as pointers.
758   auto init = [&](ImplicitLocOpBuilder &ib, Attribute) -> LogicalResult {
759     // Try to encode each individual attribute.
760     llvm::SmallVector<EncodedAttr> encoded_attrs;
761     for (auto &attr : attrs) {
762       auto encoded = encoding.Encode(g, b, attr.getName(), attr.getValue());
763       if (failed(encoded)) return failure();
764       encoded_attrs.emplace_back(attr.getName(), *encoded);
765     }
766 
767     // Prepare an array for encoding attributes.
768     Value arr = b.create<LLVM::UndefOp>(type);
769     auto insert_value = [&](Value value, int64_t offset) {
770       Value bcasted = b.createOrFold<LLVM::BitcastOp>(ptr, value);
771       arr = b.create<LLVM::InsertValueOp>(arr, bcasted, offset);
772     };
773 
774     // Insert the number of encoded attributes.
775     Attribute num_attrs = b.getI64IntegerAttr(n_attrs);
776     Value size = PackScalarAttribute(g, b, num_attrs, "__rt_num_attrs");
777     insert_value(size, 0);
778 
779     // Insert encoded attributes into the allocated storage.
780     for (auto &pair : llvm::enumerate(encoded_attrs)) {
781       CustomCallAttrEncoding::Encoded encoded = pair.value().second;
782       int64_t offset = 1 + pair.index() * 3;
783 
784       insert_value(encoded.name, offset + 0);
785       insert_value(encoded.type_id, offset + 1);
786       insert_value(encoded.value, offset + 2);
787     }
788 
789     // Return attributes array from the global initializer block.
790     b.create<LLVM::ReturnOp>(arr);
791 
792     return success();
793   };
794 
795   // Put all attributes in a dictionary attribute, so we can rely use it as a
796   // part of the `Globals` cache key.
797   auto attrs_map = DictionaryAttr::get(b.getContext(), attrs);
798   auto global = g.TryGetOrCreate(b, attrs_map, type, symbol_base, init);
799   if (failed(global)) return failure();
800 
801   // Get a pointer to the first element of the array: !llvm.ptr<ptr<i8>>.
802   Type ptr_ptr = mlir::LLVM::LLVMPointerType::get(ptr);
803   Value c0 = b.create<ConstantOp>(b.getI64IntegerAttr(0));
804   Value addr = Globals::AddrOf(b, *global);
805   Value gep = b.create<LLVM::GEPOp>(ptr_ptr, addr, ValueRange({c0, c0}));
806 
807   // Return a pointer to the encoded attributes: `!llvm.ptr<ptr<i8>>` (void**).
808   return gep;
809 }
810 
811 //===----------------------------------------------------------------------===//
812 // Custom call arguments encodings.
813 //===----------------------------------------------------------------------===//
814 
Match(Value value,Value converted) const815 LogicalResult ScalarArgEncoding::Match(Value value, Value converted) const {
816   return success(IsSupportedScalarType(value.getType()));
817 }
818 
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const819 FailureOr<EncodedArg> ScalarArgEncoding::Encode(Globals &g,
820                                                 ImplicitLocOpBuilder &b,
821                                                 Value value,
822                                                 Value converted) const {
823   Type type = converted.getType();
824 
825   Encoded encoded;
826   encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type));
827   encoded.value = PackValue(b, converted);
828 
829   return encoded;
830 }
831 
832 //===----------------------------------------------------------------------===//
833 
Match(Value value,Value converted) const834 LogicalResult MemrefArgEncoding::Match(Value value, Value converted) const {
835   return success(value.getType().isa<MemRefType>());
836 }
837 
Encode(Globals & g,ImplicitLocOpBuilder & b,Value value,Value converted) const838 FailureOr<EncodedArg> MemrefArgEncoding::Encode(Globals &g,
839                                                 ImplicitLocOpBuilder &b,
840                                                 Value value,
841                                                 Value converted) const {
842   auto memref_type = value.getType().cast<MemRefType>();
843 
844   // If memref has non-identity layout we use `StridedMemrefView` to
845   // distinguish it from the default row-major memref.
846   auto type_id = memref_type.getLayout().isIdentity()
847                      ? TypeID::get<Tagged<MemrefView>>()
848                      : TypeID::get<Tagged<StridedMemrefView>>();
849 
850   Encoded encoded;
851   encoded.type_id = PackTypeId(g, b, type_id);
852   encoded.value = PackValue(b, EncodeMemRef(b, memref_type, converted));
853 
854   return encoded;
855 }
856 
EncodeMemRef(ImplicitLocOpBuilder & b,MemRefType memref_ty,Value descriptor) const857 Value MemrefArgEncoding::EncodeMemRef(ImplicitLocOpBuilder &b,
858                                       MemRefType memref_ty,
859                                       Value descriptor) const {
860   MLIRContext *ctx = b.getContext();
861   Location loc = b.getLoc();
862 
863   // Encode sizes together with strides as a single array.
864   int64_t sizes_and_strides_size = 2 * memref_ty.getRank();
865 
866   // Encoded memref type: !llvm.struct<(i8, i8, ptr<i8>, array<... x i64>)>.
867   Type i8 = b.getI8Type();
868   Type ptr = LLVM::LLVMPointerType::get(b.getI8Type());
869   Type arr = LLVM::LLVMArrayType::get(b.getI64Type(), sizes_and_strides_size);
870   Type type = LLVM::LLVMStructType::getLiteral(ctx, {i8, i8, ptr, arr});
871 
872   // Helper to unpack MLIR strided memref descriptor value.
873   MemRefDescriptor desc(descriptor);
874 
875   DType element_dtype = ScalarDType(memref_ty.getElementType());
876 
877   // Create values for filling encoded memref struct.
878   Value dtype = b.create<ConstantOp>(
879       b.getI8IntegerAttr(static_cast<uint8_t>(element_dtype)));
880   Value rank = b.create<ConstantOp>(b.getI8IntegerAttr(memref_ty.getRank()));
881   Value data = b.create<LLVM::BitcastOp>(ptr, desc.alignedPtr(b, loc));
882 
883   auto i64 = [&](int64_t i) { return b.getI64IntegerAttr(i); };
884 
885   // Get the statically known strides and offset from the memref type.
886   llvm::SmallVector<int64_t> strides;
887   int64_t memref_offset;
888   if (failed(getStridesAndOffset(memref_ty, strides, memref_offset)))
889     strides.resize(memref_ty.getRank(), ShapedType::kDynamicStrideOrOffset);
890 
891   // Build encoded memref sizes + strides: !llvm.array<... x i64>
892   Value payload = b.create<LLVM::UndefOp>(arr);
893   for (unsigned i = 0; i < memref_ty.getRank(); ++i) {
894     int64_t dim_size = memref_ty.getDimSize(i);
895     int64_t stride_size = strides[i];
896 
897     Value dim = ShapedType::isDynamic(dim_size)
898                     ? desc.size(b, loc, i)
899                     : b.create<ConstantOp>(i64(dim_size));
900 
901     Value stride = ShapedType::isDynamic(stride_size)
902                        ? desc.stride(b, loc, i)
903                        : b.create<ConstantOp>(i64(stride_size));
904 
905     auto stride_pos = memref_ty.getRank() + i;
906 
907     payload = b.create<LLVM::InsertValueOp>(payload, dim, i);
908     payload = b.create<LLVM::InsertValueOp>(payload, stride, stride_pos);
909   }
910 
911   // Construct encoded memref value.
912   Value memref = b.create<LLVM::UndefOp>(type);
913   memref = b.create<LLVM::InsertValueOp>(memref, dtype, 0);
914   memref = b.create<LLVM::InsertValueOp>(memref, rank, 1);
915   memref = b.create<LLVM::InsertValueOp>(memref, payload, 3);
916 
917   // Previous values almost always are known at compile time, and inserting
918   // dynamic values into the struct after all statically know values leads to a
919   // better canonicalization and cleaner final LLVM IR.
920   memref = b.create<LLVM::InsertValueOp>(memref, data, 2);
921 
922   return memref;
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Default encodings for arguments and attributes.
927 //===----------------------------------------------------------------------===//
928 
DefaultAttrEncodings()929 CustomCallAttrEncodingSet DefaultAttrEncodings() {
930   CustomCallAttrEncodingSet encodings;
931   encodings
932       .Add<StringAttrEncoding, ScalarAttrEncoding, DenseElementsAttrEncoding,
933            ArrayAttrEncoding, DenseArrayAttrEncoding, EmptyArrayAttrEncoding>();
934   return encodings;
935 }
936 
DefaultArgEncodings()937 CustomCallArgEncodingSet DefaultArgEncodings() {
938   CustomCallArgEncodingSet encodings;
939   encodings.Add<ScalarArgEncoding, MemrefArgEncoding>();
940   return encodings;
941 }
942 
943 }  // namespace runtime
944 }  // namespace xla
945