1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"  // from @llvm-project
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h"
29 
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33 namespace {
34 
35 using transforms::CreateOrFindGlobalStringConstant;
36 using transforms::GetGlobalName;
37 using transforms::GetOrInsertLLVMFunction;
38 
39 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
40 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
41 static constexpr StringRef kCInterfaceReportError =
42     "_mlir_ciface_tf_report_error";
43 static constexpr StringRef kCInterfaceJITCompile =
44     "_mlir_ciface_tf_jit_compile";
45 static constexpr StringRef kCInterfaceJITExecute =
46     "_mlir_ciface_tf_jit_execute";
47 static constexpr StringRef kJITCodeGlobalBaseName = "jit_module_code";
48 static constexpr StringRef kErrorMessageGlobalBaseName = "error_message";
49 
50 /// Base class for patterns converting TF Framework ops to function calls.
51 template <typename OpTy>
52 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
53  public:
54   using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
55 
56  protected:
57   virtual StringRef GetFuncName() const = 0;
58   virtual Type GetFuncType() const = 0;
59 
ConvertArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter,std::function<Value (Attribute)> create_element) const60   std::pair<Value, Value> ConvertArrayAttrToStackAllocatedArray(
61       Location loc, Type size_ty, Type element_ty,
62       llvm::Optional<ArrayAttr> attr, ConversionPatternRewriter *rewriter,
63       std::function<Value(Attribute)> create_element) const {
64     Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty);
65 
66     // If the attribute is missing or empty, set the element count to 0 and
67     // return NULL.
68     if (!attr.has_value() || attr.getValue().empty()) {
69       Value zero = rewriter->create<LLVM::ConstantOp>(
70           loc, size_ty, rewriter->getIntegerAttr(size_ty, 0));
71       Value null_ptr = rewriter->create<LLVM::NullOp>(loc, element_ptr_ty);
72       return std::make_pair(zero, null_ptr);
73     }
74 
75     // Allocate array to store the elements.
76     auto &array_attr = attr.getValue();
77     Value array_size = rewriter->create<LLVM::ConstantOp>(
78         loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size()));
79     Value array_ptr = rewriter->create<LLVM::AllocaOp>(
80         loc, element_ptr_ty, array_size, /*alignment=*/0);
81     for (auto &e : llvm::enumerate(array_attr)) {
82       Value index = rewriter->create<LLVM::ConstantOp>(
83           loc, size_ty, rewriter->getIntegerAttr(size_ty, e.index()));
84       Value element_ptr =
85           rewriter->create<LLVM::GEPOp>(loc, element_ptr_ty, array_ptr, index);
86       Value element = create_element(e.value());
87       rewriter->create<LLVM::StoreOp>(loc, element, element_ptr);
88     }
89     return std::make_pair(array_size, array_ptr);
90   }
91 
ConvertIntegerArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const92   std::pair<Value, Value> ConvertIntegerArrayAttrToStackAllocatedArray(
93       Location loc, Type size_ty, Type element_ty,
94       llvm::Optional<ArrayAttr> attr,
95       ConversionPatternRewriter *rewriter) const {
96     assert(size_ty.isa<IntegerType>() && "expect integer size type");
97     assert(element_ty.isa<IntegerType>() && "expect integer element type");
98     return ConvertArrayAttrToStackAllocatedArray(
99         loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
100           return rewriter->create<LLVM::ConstantOp>(
101               loc, element_ty,
102               rewriter->getIntegerAttr(element_ty,
103                                        attr.cast<IntegerAttr>().getInt()));
104         });
105   }
106 };
107 
108 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
109  public:
110   using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
111 
matchAndRewrite(TFAllocOp tf_alloc_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const112   LogicalResult matchAndRewrite(
113       TFAllocOp tf_alloc_op, OpAdaptor adaptor,
114       ConversionPatternRewriter &rewriter) const override {
115     mlir::Operation *op = tf_alloc_op.getOperation();
116     Location loc = op->getLoc();
117     MemRefType memref_type = tf_alloc_op.getType();
118 
119     // Get memref descriptor sizes.
120     SmallVector<Value, 4> sizes;
121     SmallVector<Value, 4> strides;
122     Value sizeBytes;
123     getMemRefDescriptorSizes(loc, memref_type,
124                              llvm::to_vector<4>(adaptor.dyn_sizes()), rewriter,
125                              sizes, strides, sizeBytes);
126     // Get number of elements.
127     Value num_elements = getNumElements(loc, sizes, rewriter);
128     // Get element size.
129     Value element_size =
130         getSizeInBytes(loc, memref_type.getElementType(), rewriter);
131 
132     // Convert `output_index` or set it to -1 if the attribute is missing.
133     Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
134     Value output_index = rewriter.create<LLVM::ConstantOp>(
135         loc, llvmInt32Type,
136         rewriter.getI32IntegerAttr(tf_alloc_op.output_index().has_value()
137                                        ? tf_alloc_op.output_index().getValue()
138                                        : -1));
139 
140     // Convert `candidate_input_indices`.
141     auto candidates_count_and_ptr =
142         ConvertIntegerArrayAttrToStackAllocatedArray(
143             loc, rewriter.getI32Type(), rewriter.getI32Type(),
144             tf_alloc_op.input_indices(), &rewriter);
145 
146     // Insert function call.
147     FlatSymbolRefAttr tf_func_ref =
148         GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
149     Value allocated_byte_ptr =
150         rewriter
151             .create<LLVM::CallOp>(
152                 loc, getVoidPtrType(), tf_func_ref,
153                 llvm::makeArrayRef({adaptor.ctx(), num_elements, element_size,
154                                     output_index,
155                                     candidates_count_and_ptr.first,
156                                     candidates_count_and_ptr.second}))
157             .getResult();
158 
159     MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
160         loc, rewriter, memref_type, allocated_byte_ptr, sizes);
161 
162     // Return the final value of the descriptor.
163     rewriter.replaceOp(op, {memRefDescriptor});
164     return success();
165   }
166 
167  protected:
GetFuncName() const168   StringRef GetFuncName() const override { return kCInterfaceAlloc; }
169 
GetFuncType() const170   Type GetFuncType() const override {
171     Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
172     Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
173     Type llvm_void_ptr_type = getVoidPtrType();
174     return LLVM::LLVMFunctionType::get(
175         llvm_void_ptr_type,
176         llvm::makeArrayRef(
177             {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
178              /*size_t num_elements*/ getIndexType(),
179              /*size_t element_size*/ getIndexType(),
180              /*int32_t output_index*/ llvm_i32_type,
181              /*int32_t num_candidates*/ llvm_i32_type,
182              /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
183   }
184 
185  private:
186   // TODO(pifon): Remove strides computation.
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const187   MemRefDescriptor CreateMemRefDescriptor(Location loc,
188                                           ConversionPatternRewriter &rewriter,
189                                           MemRefType memref_type,
190                                           Value allocated_byte_ptr,
191                                           ArrayRef<Value> sizes) const {
192     auto memref_desc = MemRefDescriptor::undef(
193         rewriter, loc, typeConverter->convertType(memref_type));
194 
195     // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
196     Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
197         loc, getElementPtrType(memref_type), allocated_byte_ptr);
198     memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
199     memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
200     memref_desc.setConstantOffset(rewriter, loc, 0);
201 
202     if (memref_type.getRank() == 0) {
203       return memref_desc;
204     }
205 
206     // Compute strides and populate descriptor `size` and `stride` fields.
207     Value stride_carried = createIndexConstant(rewriter, loc, 1);
208     for (int pos = sizes.size() - 1; pos >= 0; --pos) {
209       Value size = sizes[pos];
210       memref_desc.setSize(rewriter, loc, pos, size);
211       memref_desc.setStride(rewriter, loc, pos, stride_carried);
212       // Update stride
213       if (pos > 0) {
214         stride_carried =
215             rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
216       }
217     }
218     return memref_desc;
219   }
220 };
221 
222 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
223  public:
224   using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
225 
matchAndRewrite(TFDeallocOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const226   LogicalResult matchAndRewrite(
227       TFDeallocOp op, OpAdaptor adaptor,
228       ConversionPatternRewriter &rewriter) const override {
229     // TODO(herhut) Support unranked memrefs.
230     if (!op.memref().getType().isa<MemRefType>()) return failure();
231     MemRefDescriptor memref(adaptor.memref());
232 
233     Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
234         op.getLoc(), getVoidPtrType(),
235         memref.allocatedPtr(rewriter, op.getLoc()));
236 
237     // Insert function call.
238     FlatSymbolRefAttr tf_func_ref =
239         GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
240     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
241         op, llvm::None, tf_func_ref,
242         llvm::makeArrayRef({adaptor.ctx(), allocated_bytes_ptr}));
243     return success();
244   }
245 
246  protected:
GetFuncName() const247   StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const248   Type GetFuncType() const override {
249     return LLVM::LLVMFunctionType::get(getVoidType(),
250                                        {getVoidPtrType(), getVoidPtrType()});
251   }
252 };
253 
254 class JITCompileFromStrOpConverter
255     : public ConvertToLLVMCallOpPattern<JITCompileFromStrOp> {
256   using ConvertToLLVMCallOpPattern<
257       JITCompileFromStrOp>::ConvertToLLVMCallOpPattern;
258 
matchAndRewrite(JITCompileFromStrOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const259   LogicalResult matchAndRewrite(
260       JITCompileFromStrOp op, OpAdaptor adaptor,
261       ConversionPatternRewriter &rewriter) const override {
262     if (adaptor.ctx() == nullptr) return failure();
263     auto loc = op.getLoc();
264     std::string zero_terminated_code = op.code().str() + '\00';
265     Value jit_module_code = CreateOrFindGlobalStringConstant(
266         loc, GetGlobalName(kJITCodeGlobalBaseName, zero_terminated_code),
267         zero_terminated_code, &rewriter);
268     std::pair<Value, Value> tile_sizes =
269         ConvertIntegerArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
270                                                      rewriter.getI64Type(),
271                                                      op.tileSizes(), &rewriter);
272     std::pair<Value, Value> unroll_factors =
273         ConvertIntegerArrayAttrToStackAllocatedArray(
274             loc, rewriter.getI64Type(), rewriter.getI64Type(),
275             op.unrollFactors(), &rewriter);
276     Value max_supported_rank = rewriter.create<LLVM::ConstantOp>(
277         loc, rewriter.getI64Type(), op.maxSupportedRankAttr());
278     Value enable_ftz = rewriter.create<LLVM::ConstantOp>(
279         loc, rewriter.getI1Type(), op.enableFtzAttr());
280     Value index_64bit = rewriter.create<LLVM::ConstantOp>(
281         loc, rewriter.getI1Type(), op.index64BitAttr());
282     Value cpu_codegen = rewriter.create<LLVM::ConstantOp>(
283         loc, rewriter.getI1Type(), op.cpuCodegenAttr());
284     FlatSymbolRefAttr tf_func_ref =
285         GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
286     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
287         op, getVoidPtrType(), tf_func_ref,
288         llvm::makeArrayRef({adaptor.ctx(), jit_module_code, tile_sizes.first,
289                             tile_sizes.second, unroll_factors.first,
290                             unroll_factors.second, max_supported_rank,
291                             enable_ftz, index_64bit, cpu_codegen}));
292     return success();
293   }
294 
295  protected:
GetFuncName() const296   StringRef GetFuncName() const override { return kCInterfaceJITCompile; }
297 
GetFuncType() const298   Type GetFuncType() const override {
299     auto i8_ptr_ty =
300         LLVM::LLVMPointerType::get(IntegerType::get(getContext(), 8));
301     auto i64_ty = IntegerType::get(getContext(), 64);
302     Type i64_ptr_ty = LLVM::LLVMPointerType::get(i64_ty);
303     auto i1_ty = IntegerType::get(getContext(), 1);
304     return LLVM::LLVMFunctionType::get(
305         getVoidPtrType(), {/*void* op_kernel_ctx*/ getVoidPtrType(),
306                            /*char* code*/ i8_ptr_ty,
307                            /*int64_t num_tile_sizes*/ i64_ty,
308                            /*int64_t* tile_sizes_ptr*/ i64_ptr_ty,
309                            /*int64_t num_unroll_factors*/ i64_ty,
310                            /*int64_t* unroll_factors_ptr*/ i64_ptr_ty,
311                            /*int64_t max_supported_rank*/ i64_ty,
312                            /*bool enable_ftz*/ i1_ty,
313                            /*bool index_64bit*/ i1_ty,
314                            /*bool cpu_codegen*/ i1_ty});
315   }
316 };
317 
318 class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern<JITExecuteOp> {
319  public:
320   using ConvertToLLVMCallOpPattern<JITExecuteOp>::ConvertToLLVMCallOpPattern;
321 
matchAndRewrite(JITExecuteOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const322   LogicalResult matchAndRewrite(
323       JITExecuteOp op, OpAdaptor adaptor,
324       ConversionPatternRewriter &rewriter) const override {
325     // The TF context must be known for a successful lowering.
326     if (adaptor.ctx() == nullptr || op.operands().empty()) {
327       return failure();
328     }
329 
330     // Allocate result on stack.
331     auto loc = op.getLoc();
332     Type result_ty =
333         getTypeConverter()->convertType(op->getResultTypes().front());
334     Type result_ptr_ty = LLVM::LLVMPointerType::get(result_ty);
335     Type i64_ty = rewriter.getI64Type();
336     Value one = rewriter.create<LLVM::ConstantOp>(
337         loc, i64_ty, rewriter.getI64IntegerAttr(1));
338     auto result_ptr =
339         rewriter.create<LLVM::AllocaOp>(loc, result_ptr_ty, one, llvm::None);
340     Type void_ptr_ty = getVoidPtrType();
341     auto result_void_ptr =
342         rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, result_ptr);
343 
344     // Pass the buffer arguments as a stack-allocated array.
345     Type arg_ptr_ty =
346         LLVM::LLVMPointerType::get(adaptor.operands().front().getType());
347     Value num_args = rewriter.create<LLVM::ConstantOp>(
348         loc, i64_ty,
349         rewriter.getI64IntegerAttr(
350             static_cast<int64_t>(adaptor.operands().size())));
351     Value args_ptr = rewriter.create<LLVM::AllocaOp>(loc, arg_ptr_ty, num_args,
352                                                      /*alignment=*/0);
353     for (const auto &it : llvm::enumerate(adaptor.operands())) {
354       Value index = rewriter.create<LLVM::ConstantOp>(
355           loc, i64_ty, rewriter.getI64IntegerAttr(it.index()));
356       Value element_ptr =
357           rewriter.create<LLVM::GEPOp>(loc, arg_ptr_ty, args_ptr, index);
358       rewriter.create<LLVM::StoreOp>(loc, it.value(), element_ptr);
359     }
360     auto args_void_ptr =
361         rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, args_ptr);
362 
363     // Materialize runtime call.
364     FlatSymbolRefAttr tf_func_ref =
365         GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
366     rewriter.create<LLVM::CallOp>(
367         loc, llvm::None, tf_func_ref,
368         ValueRange{adaptor.ctx(), adaptor.callable(), result_void_ptr, num_args,
369                    args_void_ptr});
370 
371     // Copy result (including the descriptor) to a stack-allocated buffer and
372     // free the old descriptor.
373     llvm::SmallVector<Value, 1> final_result = {
374         rewriter.create<LLVM::LoadOp>(loc, result_ptr)};
375     if (failed(copyUnrankedDescriptors(rewriter, loc, op->getResultTypes(),
376                                        final_result,
377                                        /*toDynamic=*/false))) {
378       return failure();
379     }
380 
381     rewriter.replaceOp(op, final_result.front());
382     return success();
383   }
384 
385  protected:
GetFuncName() const386   StringRef GetFuncName() const override { return kCInterfaceJITExecute; }
387 
GetFuncType() const388   Type GetFuncType() const override {
389     auto i64_ty = IntegerType::get(getContext(), 64);
390     auto void_ptr_ty = getVoidPtrType();
391     return LLVM::LLVMFunctionType::get(getVoidType(),
392                                        {/*void* op_kernel_ctx*/ void_ptr_ty,
393                                         /*void* callable*/ void_ptr_ty,
394                                         /*void* result*/ void_ptr_ty,
395                                         /*int64_t num_args*/ i64_ty,
396                                         /*void* args_ptr*/ void_ptr_ty});
397   }
398 };
399 
400 class ReportErrorOpConverter
401     : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
402  public:
403   using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
404 
matchAndRewrite(ReportErrorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const405   LogicalResult matchAndRewrite(
406       ReportErrorOp op, OpAdaptor adaptor,
407       ConversionPatternRewriter &rewriter) const override {
408     Location loc = op.getLoc();
409     auto module = op->getParentOfType<ModuleOp>();
410     Value message_constant =
411         GenerateErrorMessageConstant(loc, module, adaptor.msg(), rewriter);
412 
413     // Insert function call.
414     FlatSymbolRefAttr tf_func_ref =
415         GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
416     Value error_code = rewriter.create<LLVM::ConstantOp>(
417         loc, typeConverter->convertType(rewriter.getI32Type()),
418         adaptor.error_codeAttr());
419     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
420         op, llvm::None, tf_func_ref,
421         llvm::makeArrayRef({adaptor.ctx(), error_code, message_constant}));
422     return success();
423   }
424 
425  protected:
GetFuncName() const426   StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const427   Type GetFuncType() const override {
428     MLIRContext *ctx = &getTypeConverter()->getContext();
429     auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
430     auto i32_type = IntegerType::get(ctx, 32);
431     return LLVM::LLVMFunctionType::get(
432         getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
433   }
434 
435  private:
436   // Generates an LLVM IR dialect global that contains the name of the given
437   // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const438   Value GenerateErrorMessageConstant(Location loc, Operation *module,
439                                      StringRef message,
440                                      OpBuilder &builder) const {
441     std::string err_str;
442     llvm::raw_string_ostream err_stream(err_str);
443     err_stream << message;
444     if (!loc.isa<UnknownLoc>()) {
445       err_stream << " at ";
446       loc.print(err_stream);
447     }
448     err_stream << '\00';
449     StringRef generated_error(err_stream.str());
450     return CreateOrFindGlobalStringConstant(
451         loc, GetGlobalName(kErrorMessageGlobalBaseName, generated_error),
452         generated_error, &builder);
453   }
454 };
455 
456 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
457  public:
458   using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
459 
matchAndRewrite(NullContextOp op,OpAdaptor,ConversionPatternRewriter & rewriter) const460   LogicalResult matchAndRewrite(
461       NullContextOp op, OpAdaptor /*adaptor*/,
462       ConversionPatternRewriter &rewriter) const override {
463     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
464     return success();
465   }
466 };
467 
468 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
469  public:
470   using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
471 
matchAndRewrite(NullMemRefOp null_memref_op,OpAdaptor,ConversionPatternRewriter & rewriter) const472   LogicalResult matchAndRewrite(
473       NullMemRefOp null_memref_op, OpAdaptor /*adaptor*/,
474       ConversionPatternRewriter &rewriter) const override {
475     Location loc = null_memref_op->getLoc();
476     LLVMTypeConverter type_converter = *getTypeConverter();
477     mlir::Operation *op = null_memref_op.getOperation();
478 
479     auto shaped_result_type = null_memref_op.getType().cast<BaseMemRefType>();
480     unsigned address_space = shaped_result_type.getMemorySpaceAsInt();
481 
482     Type elem_type = shaped_result_type.getElementType();
483     Type llvm_elem_type = type_converter.convertType(elem_type);
484 
485     Value zero = createIndexConstant(rewriter, loc, 0);
486     if (auto result_type = null_memref_op.getType().dyn_cast<MemRefType>()) {
487       // Set all dynamic sizes to 1 and compute fake strides.
488       SmallVector<Value, 4> dyn_sizes(result_type.getNumDynamicDims(),
489                                       createIndexConstant(rewriter, loc, 1));
490       SmallVector<Value, 4> sizes, strides;
491       Value sizeBytes;
492       getMemRefDescriptorSizes(loc, result_type, dyn_sizes, rewriter, sizes,
493                                strides, sizeBytes);
494 
495       // Prepare packed args [allocatedPtr, alignedPtr, offset, sizes, strides]
496       // to create a memref descriptor.
497       Value null = rewriter.create<LLVM::NullOp>(
498           loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
499       SmallVector<Value, 12> packed_values{null, null, zero};
500       packed_values.append(sizes);
501       packed_values.append(strides);
502 
503       rewriter.replaceOp(
504           op, MemRefDescriptor::pack(rewriter, loc, type_converter, result_type,
505                                      packed_values));
506       return success();
507     }
508 
509     auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
510     Type llvm_result_type = type_converter.convertType(result_type);
511 
512     auto desc =
513         UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
514     desc.setRank(rewriter, loc, zero);
515 
516     // Due to the current way of handling unranked memref results escaping, we
517     // have to actually construct a ranked underlying descriptor instead of just
518     // setting its pointer to NULL.
519     SmallVector<Value, 4> sizes;
520     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
521                                            desc, sizes);
522     Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
523         loc, getVoidPtrType(), sizes.front(), llvm::None);
524 
525     // Populate underlying ranked descriptor.
526     Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
527         LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
528 
529     Value null = rewriter.create<LLVM::NullOp>(
530         loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
531     UnrankedMemRefDescriptor::setAllocatedPtr(
532         rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, null);
533     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
534                                             underlying_desc_ptr,
535                                             elem_ptr_ptr_type, null);
536     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
537                                         underlying_desc_ptr, elem_ptr_ptr_type,
538                                         zero);
539 
540     desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
541     rewriter.replaceOp(op, {desc});
542     return success();
543   }
544 };
545 
546 class IsValidMemRefOpConverter
547     : public ConvertOpToLLVMPattern<IsValidMemRefOp> {
548  public:
549   using ConvertOpToLLVMPattern<IsValidMemRefOp>::ConvertOpToLLVMPattern;
550 
matchAndRewrite(IsValidMemRefOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const551   LogicalResult matchAndRewrite(
552       IsValidMemRefOp op, OpAdaptor adaptor,
553       ConversionPatternRewriter &rewriter) const override {
554     Location loc = op.getLoc();
555     MemRefDescriptor desc(adaptor.arg());
556 
557     // Compare every size in the descriptor to 0 to check num_elements == 0.
558     int64_t rank = op.arg().getType().cast<MemRefType>().getRank();
559     Value is_empty_shape = rewriter.create<LLVM::ConstantOp>(
560         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
561     Value zero = createIndexConstant(rewriter, loc, 0);
562     for (int i = 0; i < rank; ++i) {
563       Value size = desc.size(rewriter, loc, i);
564       Value is_zero_size = rewriter.create<LLVM::ICmpOp>(
565           loc, rewriter.getI1Type(), LLVM::ICmpPredicate::eq, size, zero);
566       is_empty_shape =
567           rewriter.create<LLVM::OrOp>(loc, is_empty_shape, is_zero_size);
568     }
569 
570     Value ptr = rewriter.create<LLVM::BitcastOp>(
571         loc, getVoidPtrType(), desc.allocatedPtr(rewriter, loc));
572     Value null = rewriter.create<LLVM::NullOp>(loc, getVoidPtrType());
573     Value is_not_nullptr = rewriter.create<LLVM::ICmpOp>(
574         loc, rewriter.getI1Type(), LLVM::ICmpPredicate::ne, ptr, null);
575 
576     // Valid memref = ptr != NULL || num_elements == 0;
577     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, is_not_nullptr, is_empty_shape);
578     return success();
579   }
580 };
581 
582 }  // namespace
583 
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,RewritePatternSet * patterns)584 void PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter *converter,
585                                                  RewritePatternSet *patterns) {
586   // clang-format off
587   patterns->add<
588       IsValidMemRefOpConverter,
589       JITCompileFromStrOpConverter,
590       JITExecuteOpConverter,
591       NullContextOpConverter,
592       NullMemRefOpConverter,
593       ReportErrorOpConverter,
594       TFAllocOpConverter,
595       TFDeallocOpConverter>(*converter);
596   // clang-format on
597 }
598 
599 }  // namespace tf_framework
600 }  // namespace kernel_gen
601 }  // namespace mlir
602