1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <iterator>
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/STLExtras.h"
23 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
27 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/PatternMatch.h" // from @llvm-project
35 #include "mlir/IR/SymbolTable.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
39 #include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
40 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h"
41 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h"
42 #include "tensorflow/compiler/xla/runtime/custom_call.h"
43 #include "tensorflow/compiler/xla/runtime/type_id.h"
44
45 namespace xla {
46 namespace runtime {
47 namespace {
48
49 using namespace mlir; // NOLINT
50 using mlir::arith::ConstantOp;
51 using mlir::func::CallOp;
52 using mlir::func::FuncOp;
53
54 using llvm::DenseMap;
55
56 #define GEN_PASS_CLASSES
57 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h.inc"
58
59 //===----------------------------------------------------------------------===//
60 // Runtime C API declaration (see runtime.h header file).
61 //===----------------------------------------------------------------------===//
62
63 static constexpr const char *kGetResultStorage = "runtimeGetResultStorage";
64 static constexpr const char *kSetError = "runtimeSetError";
65 static constexpr const char *kCustomCall = "runtimeCustomCall";
66
67 struct RuntimeAPI {
OpaquePointerTypexla::runtime::__anonb1e801730111::RuntimeAPI68 static LLVM::LLVMPointerType OpaquePointerType(MLIRContext *ctx) {
69 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
70 }
71
CustomCallArgumentsTypexla::runtime::__anonb1e801730111::RuntimeAPI72 static LLVM::LLVMPointerType CustomCallArgumentsType(MLIRContext *ctx) {
73 return LLVM::LLVMPointerType::get(RuntimeAPI::OpaquePointerType(ctx));
74 }
75
CustomCallAttributesTypexla::runtime::__anonb1e801730111::RuntimeAPI76 static LLVM::LLVMPointerType CustomCallAttributesType(MLIRContext *ctx) {
77 return LLVM::LLVMPointerType::get(RuntimeAPI::OpaquePointerType(ctx));
78 }
79
GetResultStorageFunctionTypexla::runtime::__anonb1e801730111::RuntimeAPI80 static FunctionType GetResultStorageFunctionType(MLIRContext *ctx) {
81 auto kernel_context = OpaquePointerType(ctx);
82 auto i64 = IntegerType::get(ctx, 64);
83 auto storage = OpaquePointerType(ctx);
84 return FunctionType::get(ctx, {kernel_context, i64}, {storage});
85 }
86
SetErrorFunctionTypexla::runtime::__anonb1e801730111::RuntimeAPI87 static FunctionType SetErrorFunctionType(MLIRContext *ctx) {
88 auto kernel_context = OpaquePointerType(ctx);
89 auto error_msg = OpaquePointerType(ctx);
90 return FunctionType::get(ctx, {kernel_context, error_msg}, {});
91 }
92
CustomCallFunctionTypexla::runtime::__anonb1e801730111::RuntimeAPI93 static FunctionType CustomCallFunctionType(MLIRContext *ctx) {
94 auto kernel_context = OpaquePointerType(ctx);
95 auto callee = OpaquePointerType(ctx);
96 auto args = CustomCallArgumentsType(ctx);
97 auto attrs = CustomCallAttributesType(ctx);
98 auto i1 = IntegerType::get(ctx, 1);
99 return FunctionType::get(ctx, {kernel_context, callee, args, attrs}, {i1});
100 }
101
DirectCustomCallFunctionTypexla::runtime::__anonb1e801730111::RuntimeAPI102 static FunctionType DirectCustomCallFunctionType(MLIRContext *ctx) {
103 auto kernel_context = OpaquePointerType(ctx);
104 auto args = CustomCallArgumentsType(ctx);
105 auto attrs = CustomCallAttributesType(ctx);
106 auto i1 = IntegerType::get(ctx, 1);
107 return FunctionType::get(ctx, {kernel_context, args, attrs}, {i1});
108 }
109 };
110
111 // Adds function declaration if it doesn't already exist.
AddDeclaration(ModuleOp module,StringRef name,FunctionType type)112 static void AddDeclaration(ModuleOp module, StringRef name, FunctionType type) {
113 auto b = ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
114 if (module.lookupSymbol(name)) return;
115
116 MLIRContext *ctx = module.getContext();
117 FuncOp func = b.create<FuncOp>(name, type);
118 func.setPrivate();
119
120 // TODO(ezhulenev): Add per-argument nocapture attributes?
121 func->setAttr("passthrough",
122 ArrayAttr::get(ctx, {StringAttr::get(ctx, "nounwind")}));
123 }
124
125 // Adds Runtime C API declarations to the module.
AddRuntimeApiDeclarations(ModuleOp module)126 static void AddRuntimeApiDeclarations(ModuleOp module) {
127 auto add = [&](StringRef name, FunctionType type) {
128 AddDeclaration(module, name, type);
129 };
130
131 MLIRContext *ctx = module.getContext();
132 add(kGetResultStorage, RuntimeAPI::GetResultStorageFunctionType(ctx));
133 add(kSetError, RuntimeAPI::SetErrorFunctionType(ctx));
134 add(kCustomCall, RuntimeAPI::CustomCallFunctionType(ctx));
135 }
136
137 //===----------------------------------------------------------------------===//
138
139 class RuntimeTypeConverter : public TypeConverter {
140 public:
RuntimeTypeConverter()141 RuntimeTypeConverter() {
142 addConversion([](Type type) { return type; });
143 addConversion(ConvertKernelContextType);
144 addConversion(ConvertStatusType);
145 }
146
ConvertKernelContextType(KernelContextType type)147 static llvm::Optional<Type> ConvertKernelContextType(KernelContextType type) {
148 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
149 }
150
ConvertStatusType(StatusType type)151 static llvm::Optional<Type> ConvertStatusType(StatusType type) {
152 return IntegerType::get(type.getContext(), 1);
153 }
154 };
155
156 //===----------------------------------------------------------------------===//
157 // Convert rt.set_output to the corresponding runtime API call.
158 //===----------------------------------------------------------------------===//
159
160 class SetOutputOpLowering : public OpConversionPattern<SetOutputOp> {
161 public:
162 using OpConversionPattern::OpConversionPattern;
163
matchAndRewrite(SetOutputOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const164 LogicalResult matchAndRewrite(
165 SetOutputOp op, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter) const override {
167 Location loc = op->getLoc();
168
169 auto kernel_context = adaptor.ctx();
170 auto index = rewriter.create<ConstantOp>(loc, adaptor.indexAttr());
171
172 // Get a pointer to the result value storage from the runtime.
173 auto result_ptr_ty = RuntimeAPI::OpaquePointerType(rewriter.getContext());
174 auto result_ptr = rewriter.create<CallOp>(
175 loc, kGetResultStorage, TypeRange(result_ptr_ty),
176 ValueRange({kernel_context, index}));
177
178 // Cast from i8* to the LLVM pointer type to store the result.
179 auto stored_type = getTypeConverter()->convertType(op.value().getType());
180 if (!stored_type)
181 return rewriter.notifyMatchFailure(
182 op, "failed to convert output type to LLVM type");
183
184 auto casted_result_ptr = rewriter.create<LLVM::BitcastOp>(
185 loc, LLVM::LLVMPointerType::get(stored_type), result_ptr.getResult(0));
186
187 // Store the output value into the result value storage.
188 auto value = adaptor.value();
189 rewriter.create<LLVM::StoreOp>(loc, value, casted_result_ptr.getResult());
190
191 // Erase the original runtime operation.
192 rewriter.eraseOp(op);
193
194 return success();
195 }
196 };
197
198 //===----------------------------------------------------------------------===//
199 // Convert rt.is_ok to the corresponding runtime API call.
200 //===----------------------------------------------------------------------===//
201
202 class IsOkOpLowering : public OpConversionPattern<IsOkOp> {
203 public:
204 using OpConversionPattern::OpConversionPattern;
205
matchAndRewrite(IsOkOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const206 LogicalResult matchAndRewrite(
207 IsOkOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 // Just pass through the converted operand.
210 rewriter.replaceOp(op, adaptor.status());
211 return success();
212 }
213 };
214
215 //===----------------------------------------------------------------------===//
216 // Convert rt.custom_call to the corresponding runtime API call.
217 //===----------------------------------------------------------------------===//
218
EncodeArguments(CustomCallOp op,CustomCallArgEncodingSet & encodings,Globals & g,DenseMap<Value,CustomCallArgEncoding::Encoded> & encoded_args,ImplicitLocOpBuilder & b,ValueRange operands,ValueRange converted)219 static FailureOr<Value> EncodeArguments(
220 CustomCallOp op, CustomCallArgEncodingSet &encodings, Globals &g,
221 DenseMap<Value, CustomCallArgEncoding::Encoded> &encoded_args,
222 ImplicitLocOpBuilder &b, ValueRange operands, ValueRange converted) {
223 llvm::SmallVector<CustomCallArgEncoding::Encoded> encoded;
224
225 // Encode all arguments as a set of pointers (skip the kernel context).
226 for (auto tuple : llvm::drop_begin(llvm::zip(operands, converted))) {
227 // Check if the value was already encoded.
228 auto it = encoded_args.find(std::get<0>(tuple));
229 if (it != encoded_args.end()) {
230 encoded.push_back(it->second);
231 continue;
232 }
233
234 // Otherwise encode it right after the converted value definition.
235 OpBuilder::InsertionGuard guard(b);
236 if (auto *defining_op = std::get<1>(tuple).getDefiningOp()) {
237 b.setInsertionPointAfter(defining_op);
238 } else {
239 b.setInsertionPointToStart(std::get<1>(tuple).getParentBlock());
240 }
241
242 auto encoded_arg =
243 encodings.Encode(g, b, std::get<0>(tuple), std::get<1>(tuple));
244 if (failed(encoded_arg)) return failure();
245 encoded.push_back(*encoded_arg);
246 encoded_args.try_emplace(std::get<0>(tuple), *encoded_arg);
247 }
248
249 // We store encoded arguments as `!llvm.array<ptr<i8> x len>`.
250 Type ptr = LLVM::LLVMPointerType::get(b.getI8Type());
251 Type type = LLVM::LLVMArrayType::get(ptr, 1 + encoded.size() * 2);
252
253 // Prepare an array for encoding arguments.
254 Value arr = b.create<LLVM::UndefOp>(type);
255 auto insert_value = [&](Value value, int64_t offset) {
256 Value bcasted = b.createOrFold<LLVM::BitcastOp>(ptr, value);
257 arr = b.create<LLVM::InsertValueOp>(arr, bcasted, offset);
258 };
259
260 // Insert the number of encoded arguments.
261 Attribute num_args = b.getI64IntegerAttr(encoded.size());
262 insert_value(PackScalarAttribute(g, b, num_args, "__rt_num_args"), 0);
263
264 // Store encoded arguments into the allocated storage.
265 for (auto &pair : llvm::enumerate(encoded)) {
266 CustomCallArgEncoding::Encoded encoded = pair.value();
267 int64_t offset = 1 + pair.index() * 2;
268
269 insert_value(encoded.type_id, offset + 0);
270 insert_value(encoded.value, offset + 1);
271 }
272
273 // Always create an `alloca` in the parent function entry block.
274 // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas
275 Value mem = [&]() -> Value {
276 Block &block = op->getParentOfType<FuncOp>().getBody().front();
277 OpBuilder::InsertionGuard guard(b);
278 b.setInsertionPointToStart(&block);
279 Value c1 = b.create<ConstantOp>(b.getI32IntegerAttr(1));
280 return b.create<LLVM::AllocaOp>(LLVM::LLVMPointerType::get(type), c1, 0);
281 }();
282
283 // Store constructed arguments array on the stack and return a pointer to it.
284 b.create<LLVM::StoreOp>(arr, mem);
285
286 // Return a pointer to the first element of the arguments array.
287 Type ptr_ptr = mlir::LLVM::LLVMPointerType::get(ptr);
288 Value c0 = b.create<ConstantOp>(b.getI64IntegerAttr(0));
289 Value gep = b.create<LLVM::GEPOp>(ptr_ptr, mem, ValueRange({c0, c0}));
290 return gep;
291 }
292
293 // Encodes attributes into the global constant (array of pointers to the
294 // attributes data, which are also stored as global constants).
EncodeAttributes(CustomCallAttrEncodingSet & encodings,Globals & g,ImplicitLocOpBuilder & b,ArrayRef<NamedAttribute> attrs)295 static FailureOr<Value> EncodeAttributes(CustomCallAttrEncodingSet &encodings,
296 Globals &g, ImplicitLocOpBuilder &b,
297 ArrayRef<NamedAttribute> attrs) {
298 // Skip attributes passed explicitly as a custom call argument.
299 auto skip = [](NamedAttribute attr) {
300 return attr.getName() == "callee" || attr.getName() == "direct";
301 };
302
303 llvm::SmallVector<NamedAttribute> custom_call_attrs =
304 llvm::to_vector(llvm::make_filter_range(attrs, std::not_fn(skip)));
305
306 // Sort encoded attributes in lexicographical order so that when decoding we
307 // can efficiently find attributes by name.
308 llvm::sort(custom_call_attrs, [](NamedAttribute &a, NamedAttribute &b) {
309 return a.getName().strref() < b.getName().strref();
310 });
311
312 return EncodeAttributes(g, b, encodings, "__rt_custom_call_attrs",
313 custom_call_attrs);
314 }
315
316 class CustomCallOpLowering : public OpConversionPattern<CustomCallOp> {
317 public:
318 using OpConversionPattern::OpConversionPattern;
319
CustomCallOpLowering(TypeConverter & converter,MLIRContext * ctx,Globals & globals,CustomCallArgEncodingSet & arg_encoding,CustomCallAttrEncodingSet & attr_encoding,DenseMap<Value,CustomCallArgEncoding::Encoded> & encoded_args)320 CustomCallOpLowering(
321 TypeConverter &converter, MLIRContext *ctx, Globals &globals,
322 CustomCallArgEncodingSet &arg_encoding,
323 CustomCallAttrEncodingSet &attr_encoding,
324 DenseMap<Value, CustomCallArgEncoding::Encoded> &encoded_args)
325 : OpConversionPattern(converter, ctx),
326 globals_(globals),
327 arg_encoding_(arg_encoding),
328 attr_encoding_(attr_encoding),
329 encoded_args_(encoded_args) {}
330
matchAndRewrite(CustomCallOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const331 LogicalResult matchAndRewrite(
332 CustomCallOp op, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override {
334 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
335
336 // Encode operation arguments as a runtime API arguments.
337 auto args = EncodeArguments(op, arg_encoding_, globals_, encoded_args_, b,
338 op->getOperands(), adaptor.getOperands());
339 if (failed(args)) return op.emitOpError() << "failed to encode arguments";
340
341 // Encode operation attributes as a runtime API argument.
342 auto attrs = EncodeAttributes(attr_encoding_, globals_, b, op->getAttrs());
343 if (failed(attrs)) return op.emitOpError() << "failed to encode attributes";
344
345 if (op.direct()) {
346 // Call custom call target directly.
347 auto type = RuntimeAPI::DirectCustomCallFunctionType(op.getContext());
348 AddDeclaration(op->getParentOfType<ModuleOp>(), op.callee(), type);
349
350 rewriter.replaceOpWithNewOp<CallOp>(
351 op, op.callee(), TypeRange(rewriter.getI1Type()),
352 ValueRange({adaptor.ctx(), *args, *attrs}));
353
354 } else {
355 // Otherwise pass the custom call callee to the generic custom call API.
356 auto callee = Globals::OpaqueAddrOf(
357 b, globals_.GetOrCreate(b, op.callee(), "__rt_custom_call_callee"));
358
359 // Call runtime API to call the custom call target.
360 rewriter.replaceOpWithNewOp<CallOp>(
361 op, kCustomCall, TypeRange(rewriter.getI1Type()),
362 ValueRange({adaptor.ctx(), callee, *args, *attrs}));
363 }
364
365 return success();
366 }
367
368 private:
369 Globals &globals_;
370 CustomCallArgEncodingSet &arg_encoding_;
371 CustomCallAttrEncodingSet &attr_encoding_;
372 DenseMap<Value, CustomCallArgEncoding::Encoded> &encoded_args_;
373 };
374
375 //===----------------------------------------------------------------------===//
376 // Convert rt.set_error to the corresponding runtime API call.
377 //===----------------------------------------------------------------------===//
378
379 class SetErrorOpLowering : public OpConversionPattern<SetErrorOp> {
380 public:
SetErrorOpLowering(TypeConverter & converter,MLIRContext * ctx,Globals & globals)381 SetErrorOpLowering(TypeConverter &converter, MLIRContext *ctx,
382 Globals &globals)
383 : OpConversionPattern(converter, ctx), globals_(globals) {}
384
matchAndRewrite(SetErrorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const385 LogicalResult matchAndRewrite(
386 SetErrorOp op, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
389
390 // Get the error message (pointer to a null terminated string).
391 auto err = Globals::OpaqueAddrOf(
392 b, globals_.GetOrCreate(b, op.error(), "__assert_failed"));
393
394 // Call runtime API to report the error.
395 auto kernel_context = adaptor.ctx();
396 rewriter.replaceOpWithNewOp<CallOp>(op, kSetError, TypeRange(),
397 ValueRange({kernel_context, err}));
398
399 return success();
400 }
401
402 private:
403 Globals &globals_;
404 };
405
406 //===----------------------------------------------------------------------===//
407
408 class ConvertRuntimeToLLVMPass
409 : public ConvertRuntimeToLLVMPassBase<ConvertRuntimeToLLVMPass> {
410 public:
ConvertRuntimeToLLVMPass(ConvertRuntimeToLLvmOpts opts)411 explicit ConvertRuntimeToLLVMPass(ConvertRuntimeToLLvmOpts opts)
412 : opts_(std::move(opts)) {}
413
414 void runOnOperation() override;
415
416 private:
417 ConvertRuntimeToLLvmOpts opts_;
418 };
419
runOnOperation()420 void ConvertRuntimeToLLVMPass::runOnOperation() {
421 ModuleOp module = getOperation();
422 MLIRContext *ctx = module.getContext();
423
424 // Add declarations for the runtime API functions.
425 AddRuntimeApiDeclarations(module);
426
427 RuntimeTypeConverter converter;
428 RewritePatternSet patterns(ctx);
429
430 // We use conversion to LLVM type to lower all runtime operands to LLVM types.
431 LLVMTypeConverter llvm_converter(ctx);
432 llvm_converter.addConversion(RuntimeTypeConverter::ConvertKernelContextType);
433 llvm_converter.addConversion(RuntimeTypeConverter::ConvertStatusType);
434
435 // Add type conversions for user-defined types so that we can properly convert
436 // all function signatures in the module and prepare values for custom calls.
437 if (opts_.populate_type_conversions) {
438 opts_.populate_type_conversions(converter);
439 opts_.populate_type_conversions(llvm_converter);
440 }
441
442 // Register mappings from the TypeID to type names.
443 TypeIDNameRegistry type_id_names;
444 PopulateCustomCallTypeIdNames(type_id_names);
445 if (opts_.populate_type_id_names) opts_.populate_type_id_names(type_id_names);
446
447 // A helper class to create unique global constants.
448 Globals globals(module, type_id_names);
449
450 // Keep a cache of encoded values to encode each unique value just once.
451 DenseMap<Value, CustomCallArgEncoding::Encoded> encoded_args;
452
453 // Lower from the runtime operations to the runtime API function calls.
454 patterns.add<SetOutputOpLowering, IsOkOpLowering>(llvm_converter, ctx);
455 patterns.add<SetErrorOpLowering>(llvm_converter, ctx, globals);
456
457 // Use default custom call encoding for canonical types.
458 CustomCallArgEncodingSet args = DefaultArgEncodings();
459 CustomCallAttrEncodingSet attrs = DefaultAttrEncodings();
460
461 // Add user-defined arg and attr encodings.
462 if (opts_.populate_arg_encodings) opts_.populate_arg_encodings(args);
463 if (opts_.populate_attr_encodings) opts_.populate_attr_encodings(attrs);
464
465 patterns.add<CustomCallOpLowering>(llvm_converter, ctx, globals, args, attrs,
466 encoded_args);
467
468 // Convert function signatures and call sites.
469 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
470 converter);
471 populateCallOpTypeConversionPattern(patterns, converter);
472
473 // Set up conversion target to rewrite all runtime operations.
474 ConversionTarget target(*ctx);
475 target.addIllegalDialect<RuntimeDialect>();
476 target.addLegalDialect<LLVM::LLVMDialect>();
477 target.addLegalOp<ConstantOp, UnrealizedConversionCastOp, CallOp>();
478
479 // Add dynamic legality constraints to apply conversions defined above.
480 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
481 return converter.isSignatureLegal(op.getFunctionType());
482 });
483
484 if (failed(applyPartialConversion(module, target, std::move(patterns))))
485 signalPassFailure();
486 }
487
488 } // namespace
489
CreateConvertRuntimeToLLVMPass(ConvertRuntimeToLLvmOpts opts)490 std::unique_ptr<OperationPass<ModuleOp>> CreateConvertRuntimeToLLVMPass(
491 ConvertRuntimeToLLvmOpts opts) {
492 return std::make_unique<ConvertRuntimeToLLVMPass>(std::move(opts));
493 }
494
495 } // namespace runtime
496 } // namespace xla
497