1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdexcept>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"  // from @llvm-project
20 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"  // from @llvm-project
21 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"  // from @llvm-project
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"  // from @llvm-project
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"  // from @llvm-project
24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"  // from @llvm-project
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"  // from @llvm-project
26 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"  // from @llvm-project
27 #include "mlir/Conversion/MathToLibm/MathToLibm.h"  // from @llvm-project
28 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"  // from @llvm-project
29 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"  // from @llvm-project
30 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
32 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"  // from @llvm-project
33 #include "mlir/Dialect/Complex/IR/Complex.h"  // from @llvm-project
34 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
35 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
37 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
38 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
43 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
45 
46 namespace mlir {
47 namespace kernel_gen {
48 namespace transforms {
49 namespace {
50 
51 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
52     "_mlir_ciface_tf_launch_kernel";
53 
54 #define GEN_PASS_CLASSES
55 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
56 
57 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
58 /// for the TensorFlow runtime.
59 class ConvertLaunchFuncOpToTfRuntimeCallPattern
60     : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
61  public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)62   ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
63                                             StringRef gpu_binary_annotation)
64       : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
65         gpu_binary_annotation_(gpu_binary_annotation) {}
66 
67  private:
68   Value generateParamsArray(gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
69                             OpBuilder &builder) const;
70   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
71                                    Location loc, OpBuilder &builder) const;
72 
73   LogicalResult matchAndRewrite(
74       gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
75       ConversionPatternRewriter &rewriter) const override;
76 
77   MLIRContext *context_ = &this->getTypeConverter()->getContext();
78 
79   Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
80   Type llvm_pointer_type_ =
81       LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
82   Type llvm_pointer_pointer_type_ =
83       LLVM::LLVMPointerType::get(llvm_pointer_type_);
84   Type llvm_int8_type_ = IntegerType::get(context_, 8);
85   Type llvm_int32_type_ = IntegerType::get(context_, 32);
86   Type llvm_int64_type_ = IntegerType::get(context_, 64);
87   Type llvm_intptr_type_ = IntegerType::get(
88       context_, this->getTypeConverter()->getPointerBitwidth(0));
89 
90   llvm::SmallString<32> gpu_binary_annotation_;
91 };
92 
93 // Creates a struct containing all kernel parameters on the stack and returns
94 // an array of type-erased pointers to the fields of the struct. The array can
95 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
96 // The generated code is essentially as follows:
97 //
98 // %struct = alloca(sizeof(struct { Parameters... }))
99 // %array = alloca(NumParameters * sizeof(void *))
100 // for (i : [0, NumParameters))
101 //   %fieldPtr = llvm.getelementptr %struct[0, i]
102 //   llvm.store parameters[i], %fieldPtr
103 //   %elementPtr = llvm.getelementptr %array[i]
104 //   llvm.store %fieldPtr, %elementPtr
105 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,OpAdaptor adaptor,OpBuilder & builder) const106 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
107     gpu::LaunchFuncOp launch_op, OpAdaptor adaptor, OpBuilder &builder) const {
108   auto loc = launch_op.getLoc();
109   auto num_kernel_operands = launch_op.getNumKernelOperands();
110   auto arguments = getTypeConverter()->promoteOperands(
111       loc, launch_op.getOperands().take_back(num_kernel_operands),
112       adaptor.operands().take_back(num_kernel_operands), builder);
113   auto num_arguments = arguments.size();
114   SmallVector<Type, 4> argument_types;
115   argument_types.reserve(num_arguments);
116   for (auto argument : arguments) argument_types.push_back(argument.getType());
117   auto struct_type = LLVM::LLVMStructType::getNewIdentified(
118       context_, StringRef(), argument_types);
119   auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
120                                               builder.getI32IntegerAttr(1));
121   auto struct_ptr = builder.create<LLVM::AllocaOp>(
122       loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
123   auto array_size = builder.create<LLVM::ConstantOp>(
124       loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
125   auto array_ptr = builder.create<LLVM::AllocaOp>(
126       loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
127   auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
128                                                builder.getI32IntegerAttr(0));
129   for (auto en : llvm::enumerate(arguments)) {
130     auto index = builder.create<LLVM::ConstantOp>(
131         loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
132     auto field_ptr = builder.create<LLVM::GEPOp>(
133         loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
134         ArrayRef<Value>{zero, index.getResult()});
135     builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
136     auto element_ptr = builder.create<LLVM::GEPOp>(
137         loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
138     auto casted =
139         builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
140     builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
141   }
142   return array_ptr;
143 }
144 
145 // Emits LLVM IR to launch a kernel function. Expects the module that contains
146 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
147 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
148 //
149 // %0 = call %binarygetter
150 // %1 = <pointer to kernel function name>
151 // %2 = <see generateParamsArray>
152 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const153 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
154     gpu::LaunchFuncOp launch_op, OpAdaptor adaptor,
155     ConversionPatternRewriter &rewriter) const {
156   if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
157     return rewriter.notifyMatchFailure(
158         launch_op, "Cannot convert with async dependency or result.");
159   }
160 
161   Location loc = launch_op.getLoc();
162 
163   // Create an LLVM global with CUBIN extracted from the kernel annotation and
164   // obtain a pointer to the first byte in it.
165   auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
166       launch_op, launch_op.getKernelModuleName());
167   assert(kernel_module && "expected a kernel module");
168 
169   auto binary_attr =
170       kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
171   if (!binary_attr) {
172     kernel_module.emitOpError()
173         << "missing " << gpu_binary_annotation_ << " attribute";
174     return failure();
175   }
176 
177   // Create a global for the module blob.
178   SmallString<128> name_buffer(kernel_module.getName());
179   name_buffer.append("_blob");
180   Value module_blob =
181       LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
182                                binary_attr.getValue(), LLVM::Linkage::Internal);
183 
184   // Make sure the trailing zero is included in the constant.
185   auto kernel_name = launch_op.getKernelName().getValue();
186   SmallString<128> kernel_name_buffer(kernel_name);
187   kernel_name_buffer.push_back('\0');
188 
189   // Create a global for the kernel name.
190   SmallString<128> kernel_name_global_name_buffer;
191   auto kernel_name_global_name =
192       (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
193           .toStringRef(kernel_name_global_name_buffer);
194   auto kernel_name_global =
195       LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
196                                kernel_name_buffer, LLVM::Linkage::Internal);
197 
198   // The TensorFlow OpKernelContext is the first argument of the surrounding
199   // LLVMFunc.
200   Value context_arg =
201       launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
202   auto kernel_params = generateParamsArray(launch_op, adaptor, rewriter);
203 
204   auto libraryLaunchNameAttr =
205       mlir::StringAttr::get(loc.getContext(), kTfWrapperLibaryLaunchHelperName);
206   auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
207       launch_op, libraryLaunchNameAttr);
208   if (!function) {
209     PatternRewriter::InsertionGuard guard(rewriter);
210     auto function_type = LLVM::LLVMFunctionType::get(
211         llvm_void_type_,
212         {
213             llvm_pointer_type_,         /* void* context */
214             llvm_pointer_type_,         /* void* module_blob */
215             llvm_pointer_type_,         /* void* function_name */
216             llvm_intptr_type_,          /* intptr_t grid_x_dim */
217             llvm_intptr_type_,          /* intptr_t grid_y_dim */
218             llvm_intptr_type_,          /* intptr_t grid_z_dim */
219             llvm_intptr_type_,          /* intptr_t block_x_dim */
220             llvm_intptr_type_,          /* intptr_t block_y_dim */
221             llvm_intptr_type_,          /* intptr_t block_z_dim */
222             llvm_pointer_pointer_type_, /* void **kernel_params */
223         });
224     rewriter.setInsertionPointToStart(
225         launch_op->getParentOfType<ModuleOp>().getBody());
226     function = rewriter.create<LLVM::LLVMFuncOp>(
227         loc, kTfWrapperLibaryLaunchHelperName, function_type);
228   }
229   rewriter.create<LLVM::CallOp>(
230       loc, TypeRange(), mlir::SymbolRefAttr::get(function),
231 
232       ArrayRef<Value>{
233           context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
234           adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
235           adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
236 
237   rewriter.eraseOp(launch_op);
238   return success();
239 }
240 
241 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const242   void getDependentDialects(DialectRegistry &registry) const override {
243     registry.insert<LLVM::LLVMDialect>();
244   }
245 
246  public:
TFKernelToLLVMPass(StringRef blob_annotation)247   explicit TFKernelToLLVMPass(StringRef blob_annotation) {
248     if (!blob_annotation.empty()) {
249       blob_annotation_ = blob_annotation.str();
250     }
251   }
252 
runOnOperation()253   void runOnOperation() override {
254     ModuleOp m = getOperation();
255 
256     // Populate type conversions.
257     MLIRContext *ctx = m.getContext();
258     LLVMTypeConverter type_converter(ctx);
259     type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
260       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
261     });
262     type_converter.addConversion([&](tf_framework::JITCallableType type) {
263       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
264     });
265 
266     // Populate patterns.
267     RewritePatternSet patterns(&getContext());
268     arith::populateArithmeticExpandOpsPatterns(patterns);
269     memref::populateExpandOpsPatterns(patterns);
270     arith::populateArithmeticToLLVMConversionPatterns(type_converter, patterns);
271     populateMemRefToLLVMConversionPatterns(type_converter, patterns);
272     populateMathToLLVMConversionPatterns(type_converter, patterns);
273     populateFuncToLLVMConversionPatterns(type_converter, patterns);
274     cf::populateControlFlowToLLVMConversionPatterns(type_converter, patterns);
275     populateComplexToLLVMConversionPatterns(type_converter, patterns);
276     populateVectorToLLVMConversionPatterns(type_converter, patterns);
277     populateMathToLibmConversionPatterns(patterns, 0);
278     tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
279                                                               &patterns);
280     patterns.add<ConvertLaunchFuncOpToTfRuntimeCallPattern>(type_converter,
281                                                             blob_annotation_);
282     //  Set target.
283     ConversionTarget target(*ctx);
284     target.addLegalDialect<LLVM::LLVMDialect>();
285     target.addIllegalDialect<
286         arith::ArithmeticDialect, func::FuncDialect, complex::ComplexDialect,
287         gpu::GPUDialect, tf_framework::TFFrameworkDialect, math::MathDialect>();
288     // Mark modules as legal.
289     target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
290     // Do not look into gpu modules, only consider host-side.
291     target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
292     // Unrealized conversion casts are cleaned up by a separate pass.
293     target.addLegalOp<UnrealizedConversionCastOp>();
294 
295     if (failed(applyFullConversion(m, target, std::move(patterns)))) {
296       signalPassFailure();
297     }
298 
299     // Finally, strip the GPU modules, as they are no longer needed.
300     for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
301       op.erase();
302     }
303   }
304 };
305 
306 }  // namespace
307 
CreateTFKernelToLLVMPass(StringRef blob_annotation)308 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
309     StringRef blob_annotation) {
310   return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
311 }
312 
313 }  // namespace transforms
314 }  // namespace kernel_gen
315 }  // namespace mlir
316