1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdexcept>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" // from @llvm-project
20 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
21 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" // from @llvm-project
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project
24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
25 #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project
26 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
27 #include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project
28 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project
29 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
30 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
32 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" // from @llvm-project
33 #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
34 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
35 #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
37 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
38 #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
43 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
45
46 namespace mlir {
47 namespace kernel_gen {
48 namespace transforms {
49 namespace {
50
51 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
52 "_mlir_ciface_tf_launch_kernel";
53
54 #define GEN_PASS_CLASSES
55 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
56
57 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
58 /// for the TensorFlow runtime.
59 class ConvertLaunchFuncOpToTfRuntimeCallPattern
60 : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
61 public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)62 ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
63 StringRef gpu_binary_annotation)
64 : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
65 gpu_binary_annotation_(gpu_binary_annotation) {}
66
67 private:
68 Value generateParamsArray(gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
69 OpBuilder &builder) const;
70 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
71 Location loc, OpBuilder &builder) const;
72
73 LogicalResult matchAndRewrite(
74 gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const override;
76
77 MLIRContext *context_ = &this->getTypeConverter()->getContext();
78
79 Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
80 Type llvm_pointer_type_ =
81 LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
82 Type llvm_pointer_pointer_type_ =
83 LLVM::LLVMPointerType::get(llvm_pointer_type_);
84 Type llvm_int8_type_ = IntegerType::get(context_, 8);
85 Type llvm_int32_type_ = IntegerType::get(context_, 32);
86 Type llvm_int64_type_ = IntegerType::get(context_, 64);
87 Type llvm_intptr_type_ = IntegerType::get(
88 context_, this->getTypeConverter()->getPointerBitwidth(0));
89
90 llvm::SmallString<32> gpu_binary_annotation_;
91 };
92
93 // Creates a struct containing all kernel parameters on the stack and returns
94 // an array of type-erased pointers to the fields of the struct. The array can
95 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
96 // The generated code is essentially as follows:
97 //
98 // %struct = alloca(sizeof(struct { Parameters... }))
99 // %array = alloca(NumParameters * sizeof(void *))
100 // for (i : [0, NumParameters))
101 // %fieldPtr = llvm.getelementptr %struct[0, i]
102 // llvm.store parameters[i], %fieldPtr
103 // %elementPtr = llvm.getelementptr %array[i]
104 // llvm.store %fieldPtr, %elementPtr
105 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,OpAdaptor adaptor,OpBuilder & builder) const106 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
107 gpu::LaunchFuncOp launch_op, OpAdaptor adaptor, OpBuilder &builder) const {
108 auto loc = launch_op.getLoc();
109 auto num_kernel_operands = launch_op.getNumKernelOperands();
110 auto arguments = getTypeConverter()->promoteOperands(
111 loc, launch_op.getOperands().take_back(num_kernel_operands),
112 adaptor.operands().take_back(num_kernel_operands), builder);
113 auto num_arguments = arguments.size();
114 SmallVector<Type, 4> argument_types;
115 argument_types.reserve(num_arguments);
116 for (auto argument : arguments) argument_types.push_back(argument.getType());
117 auto struct_type = LLVM::LLVMStructType::getNewIdentified(
118 context_, StringRef(), argument_types);
119 auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
120 builder.getI32IntegerAttr(1));
121 auto struct_ptr = builder.create<LLVM::AllocaOp>(
122 loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
123 auto array_size = builder.create<LLVM::ConstantOp>(
124 loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
125 auto array_ptr = builder.create<LLVM::AllocaOp>(
126 loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
127 auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
128 builder.getI32IntegerAttr(0));
129 for (auto en : llvm::enumerate(arguments)) {
130 auto index = builder.create<LLVM::ConstantOp>(
131 loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
132 auto field_ptr = builder.create<LLVM::GEPOp>(
133 loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
134 ArrayRef<Value>{zero, index.getResult()});
135 builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
136 auto element_ptr = builder.create<LLVM::GEPOp>(
137 loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
138 auto casted =
139 builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
140 builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
141 }
142 return array_ptr;
143 }
144
145 // Emits LLVM IR to launch a kernel function. Expects the module that contains
146 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
147 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
148 //
149 // %0 = call %binarygetter
150 // %1 = <pointer to kernel function name>
151 // %2 = <see generateParamsArray>
152 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const153 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
154 gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const {
156 if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
157 return rewriter.notifyMatchFailure(
158 launch_op, "Cannot convert with async dependency or result.");
159 }
160
161 Location loc = launch_op.getLoc();
162
163 // Create an LLVM global with CUBIN extracted from the kernel annotation and
164 // obtain a pointer to the first byte in it.
165 auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
166 launch_op, launch_op.getKernelModuleName());
167 assert(kernel_module && "expected a kernel module");
168
169 auto binary_attr =
170 kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
171 if (!binary_attr) {
172 kernel_module.emitOpError()
173 << "missing " << gpu_binary_annotation_ << " attribute";
174 return failure();
175 }
176
177 // Create a global for the module blob.
178 SmallString<128> name_buffer(kernel_module.getName());
179 name_buffer.append("_blob");
180 Value module_blob =
181 LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
182 binary_attr.getValue(), LLVM::Linkage::Internal);
183
184 // Make sure the trailing zero is included in the constant.
185 auto kernel_name = launch_op.getKernelName().getValue();
186 SmallString<128> kernel_name_buffer(kernel_name);
187 kernel_name_buffer.push_back('\0');
188
189 // Create a global for the kernel name.
190 SmallString<128> kernel_name_global_name_buffer;
191 auto kernel_name_global_name =
192 (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
193 .toStringRef(kernel_name_global_name_buffer);
194 auto kernel_name_global =
195 LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
196 kernel_name_buffer, LLVM::Linkage::Internal);
197
198 // The TensorFlow OpKernelContext is the first argument of the surrounding
199 // LLVMFunc.
200 Value context_arg =
201 launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
202 auto kernel_params = generateParamsArray(launch_op, adaptor, rewriter);
203
204 auto libraryLaunchNameAttr =
205 mlir::StringAttr::get(loc.getContext(), kTfWrapperLibaryLaunchHelperName);
206 auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
207 launch_op, libraryLaunchNameAttr);
208 if (!function) {
209 PatternRewriter::InsertionGuard guard(rewriter);
210 auto function_type = LLVM::LLVMFunctionType::get(
211 llvm_void_type_,
212 {
213 llvm_pointer_type_, /* void* context */
214 llvm_pointer_type_, /* void* module_blob */
215 llvm_pointer_type_, /* void* function_name */
216 llvm_intptr_type_, /* intptr_t grid_x_dim */
217 llvm_intptr_type_, /* intptr_t grid_y_dim */
218 llvm_intptr_type_, /* intptr_t grid_z_dim */
219 llvm_intptr_type_, /* intptr_t block_x_dim */
220 llvm_intptr_type_, /* intptr_t block_y_dim */
221 llvm_intptr_type_, /* intptr_t block_z_dim */
222 llvm_pointer_pointer_type_, /* void **kernel_params */
223 });
224 rewriter.setInsertionPointToStart(
225 launch_op->getParentOfType<ModuleOp>().getBody());
226 function = rewriter.create<LLVM::LLVMFuncOp>(
227 loc, kTfWrapperLibaryLaunchHelperName, function_type);
228 }
229 rewriter.create<LLVM::CallOp>(
230 loc, TypeRange(), mlir::SymbolRefAttr::get(function),
231
232 ArrayRef<Value>{
233 context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
234 adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
235 adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
236
237 rewriter.eraseOp(launch_op);
238 return success();
239 }
240
241 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const242 void getDependentDialects(DialectRegistry ®istry) const override {
243 registry.insert<LLVM::LLVMDialect>();
244 }
245
246 public:
TFKernelToLLVMPass(StringRef blob_annotation)247 explicit TFKernelToLLVMPass(StringRef blob_annotation) {
248 if (!blob_annotation.empty()) {
249 blob_annotation_ = blob_annotation.str();
250 }
251 }
252
runOnOperation()253 void runOnOperation() override {
254 ModuleOp m = getOperation();
255
256 // Populate type conversions.
257 MLIRContext *ctx = m.getContext();
258 LLVMTypeConverter type_converter(ctx);
259 type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
260 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
261 });
262 type_converter.addConversion([&](tf_framework::JITCallableType type) {
263 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
264 });
265
266 // Populate patterns.
267 RewritePatternSet patterns(&getContext());
268 arith::populateArithmeticExpandOpsPatterns(patterns);
269 memref::populateExpandOpsPatterns(patterns);
270 arith::populateArithmeticToLLVMConversionPatterns(type_converter, patterns);
271 populateMemRefToLLVMConversionPatterns(type_converter, patterns);
272 populateMathToLLVMConversionPatterns(type_converter, patterns);
273 populateFuncToLLVMConversionPatterns(type_converter, patterns);
274 cf::populateControlFlowToLLVMConversionPatterns(type_converter, patterns);
275 populateComplexToLLVMConversionPatterns(type_converter, patterns);
276 populateVectorToLLVMConversionPatterns(type_converter, patterns);
277 populateMathToLibmConversionPatterns(patterns, 0);
278 tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
279 &patterns);
280 patterns.add<ConvertLaunchFuncOpToTfRuntimeCallPattern>(type_converter,
281 blob_annotation_);
282 // Set target.
283 ConversionTarget target(*ctx);
284 target.addLegalDialect<LLVM::LLVMDialect>();
285 target.addIllegalDialect<
286 arith::ArithmeticDialect, func::FuncDialect, complex::ComplexDialect,
287 gpu::GPUDialect, tf_framework::TFFrameworkDialect, math::MathDialect>();
288 // Mark modules as legal.
289 target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
290 // Do not look into gpu modules, only consider host-side.
291 target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
292 // Unrealized conversion casts are cleaned up by a separate pass.
293 target.addLegalOp<UnrealizedConversionCastOp>();
294
295 if (failed(applyFullConversion(m, target, std::move(patterns)))) {
296 signalPassFailure();
297 }
298
299 // Finally, strip the GPU modules, as they are no longer needed.
300 for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
301 op.erase();
302 }
303 }
304 };
305
306 } // namespace
307
CreateTFKernelToLLVMPass(StringRef blob_annotation)308 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
309 StringRef blob_annotation) {
310 return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
311 }
312
313 } // namespace transforms
314 } // namespace kernel_gen
315 } // namespace mlir
316